use crate::engine::interfaces::{
GenericMiddleware, Middleware, MiddlewareOutput, ParamDef, ParamType, Plugin, ResolvedInputs,
};
use anyhow::{Result, anyhow};
use async_trait::async_trait;
use serde_json::Value;
use std::{any::Any, borrow::Cow};
fn detect(payload: &[u8], method: &str) -> bool {
if payload.is_empty() {
return false;
}
match method {
"http" => {
payload.starts_with(b"GET ")
|| payload.starts_with(b"POST ")
|| payload.starts_with(b"PUT ")
|| payload.starts_with(b"DELETE ")
|| payload.starts_with(b"HEAD ")
|| payload.starts_with(b"OPTIONS ")
|| payload.starts_with(b"PATCH ")
}
"tls" => payload.starts_with(&[0x16, 0x03]) && payload.len() > 3,
"dns" => {
if payload.len() < 12 {
return false;
}
let flag_byte_1 = payload[2];
if (flag_byte_1 & 0xF8) != 0 {
return false;
}
let qdcount = u16::from_be_bytes([payload[4], payload[5]]);
qdcount > 0
}
"quic" => {
if payload.len() < 20 {
return false;
}
if (payload[0] & 0xC0) != 0xC0 {
return false;
}
let version = u32::from_be_bytes([payload[1], payload[2], payload[3], payload[4]]);
version == 1 || version == 2
}
_ => false,
}
}
pub struct ProtocolDetectPlugin;
impl Plugin for ProtocolDetectPlugin {
fn name(&self) -> &'static str {
"internal.protocol.detect"
}
fn params(&self) -> Vec<ParamDef> {
vec![
ParamDef {
name: "method".into(),
required: true,
param_type: ParamType::String,
},
ParamDef {
name: "payload".into(),
required: true,
param_type: ParamType::Bytes,
},
]
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_middleware(&self) -> Option<&dyn Middleware> {
Some(self)
}
fn as_generic_middleware(&self) -> Option<&dyn GenericMiddleware> {
Some(self)
}
}
#[async_trait]
impl GenericMiddleware for ProtocolDetectPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
vec!["true".into(), "false".into()]
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
let method = inputs
.get("method")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Resolved input 'method' is missing or not a string"))?;
let payload_hex = inputs
.get("payload")
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("Resolved input 'payload' is missing or not a string"))?;
let payload = hex::decode(payload_hex)?;
let result = detect(&payload, method);
let branch = if result { "true" } else { "false" };
Ok(MiddlewareOutput {
branch: branch.into(),
store: None,
})
}
}
#[async_trait]
impl Middleware for ProtocolDetectPlugin {
fn output(&self) -> Vec<Cow<'static, str>> {
<Self as GenericMiddleware>::output(self)
}
async fn execute(&self, inputs: ResolvedInputs) -> Result<MiddlewareOutput> {
<Self as GenericMiddleware>::execute(self, inputs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dns_detection() {
let mut valid_dns = vec![
0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
valid_dns.extend_from_slice(&[
0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01,
0x00, 0x01,
]);
assert!(
detect(&valid_dns, "dns"),
"Valid DNS query should be detected"
);
let mut response = valid_dns.clone();
response[2] = 0x81;
assert!(!detect(&response, "dns"), "DNS response should be rejected");
let mut bad_opcode = valid_dns.clone();
bad_opcode[2] = 0x09;
assert!(
!detect(&bad_opcode, "dns"),
"Non-standard Opcode should be rejected"
);
let mut zero_questions = valid_dns.clone();
zero_questions[4] = 0x00;
zero_questions[5] = 0x00;
assert!(
!detect(&zero_questions, "dns"),
"QDCOUNT=0 should be rejected"
);
assert!(!detect(&valid_dns[..10], "dns"), "Truncated header");
}
#[test]
fn test_http_detection() {
assert!(detect(b"GET / HTTP/1.1\r\n", "http"));
assert!(detect(b"POST /api/v1/submit HTTP/1.1\r\n", "http"));
assert!(detect(b"HEAD /index.html HTTP/1.1\r\n", "http"));
assert!(!detect(b"HELLO WORLD", "http"));
assert!(!detect(b"SSH-2.0", "http"));
}
#[test]
fn test_tls_detection() {
let tls_handshake = [0x16, 0x03, 0x01, 0x00, 0x50];
assert!(detect(&tls_handshake, "tls"));
let sslv3 = [0x16, 0x03, 0x00, 0x00, 0x10];
assert!(detect(&sslv3, "tls"));
assert!(!detect(&[0x00, 0x01, 0x02], "tls"));
}
#[test]
fn test_quic_detection() {
let mut quic_initial = vec![0xC0, 0x00, 0x00, 0x00, 0x01];
quic_initial.extend_from_slice(&[0x00; 20]);
assert!(detect(&quic_initial, "quic"));
let mut quic_v2 = vec![0xC0, 0x00, 0x00, 0x00, 0x02];
quic_v2.extend_from_slice(&[0x00; 20]);
assert!(detect(&quic_v2, "quic"));
let mut short_header = vec![0x40, 0xAB, 0xCD, 0xEF]; short_header.extend_from_slice(&[0x00; 20]);
assert!(!detect(&short_header, "quic"), "Short header rejected");
let mut bad_version = vec![0xC0, 0x00, 0x00, 0x00, 0x00]; bad_version.extend_from_slice(&[0x00; 20]);
assert!(!detect(&bad_version, "quic"), "Version 0 rejected");
}
}