use crate::core::protocol::ProtocolType;
use crate::error::{DetectorError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProtocolFingerprint {
pub protocol: ProtocolType,
pub name: String,
pub description: String,
pub rules: Vec<FingerprintRule>,
pub weight: f32,
pub enabled: bool,
}
impl ProtocolFingerprint {
pub fn new<S: Into<String>>(
protocol: ProtocolType,
name: S,
description: S,
) -> Self {
Self {
protocol,
name: name.into(),
description: description.into(),
rules: Vec::new(),
weight: 1.0,
enabled: true,
}
}
pub fn add_rule(mut self, rule: FingerprintRule) -> Self {
self.rules.push(rule);
self
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight.max(0.0);
self
}
pub fn enable(mut self) -> Self {
self.enabled = true;
self
}
pub fn disable(mut self) -> Self {
self.enabled = false;
self
}
pub fn matches(&self, data: &[u8]) -> Result<FingerprintMatch> {
if !self.enabled {
return Ok(FingerprintMatch::no_match());
}
let mut total_score = 0.0;
let mut matched_rules = 0;
let mut rule_matches = Vec::new();
for rule in &self.rules {
let rule_match = rule.matches(data)?;
if rule_match.matched {
total_score += rule_match.score * rule.weight;
matched_rules += 1;
rule_matches.push(rule_match);
} else if rule.required {
return Ok(FingerprintMatch::no_match());
}
}
if matched_rules == 0 {
return Ok(FingerprintMatch::no_match());
}
let final_score = (total_score / self.rules.len() as f32) * self.weight;
Ok(FingerprintMatch {
matched: true,
score: final_score.min(1.0),
fingerprint_name: self.name.clone(),
protocol: self.protocol,
rule_matches,
})
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct FingerprintRule {
pub name: String,
pub rule_type: RuleType,
pub weight: f32,
pub required: bool,
}
impl FingerprintRule {
pub fn new<S: Into<String>>(name: S, rule_type: RuleType) -> Self {
Self {
name: name.into(),
rule_type,
weight: 1.0,
required: false,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight.max(0.0);
self
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn matches(&self, data: &[u8]) -> Result<RuleMatch> {
let matched = self.rule_type.matches(data)?;
let score = if matched { 1.0 } else { 0.0 };
Ok(RuleMatch {
matched,
score,
rule_name: self.name.clone(),
rule_type: self.rule_type.clone(),
})
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RuleType {
ByteSequence {
pattern: Vec<u8>,
offset: usize,
},
Regex {
pattern: String,
},
String {
pattern: String,
case_sensitive: bool,
},
Length {
min: Option<usize>,
max: Option<usize>,
},
Port {
port: u16,
},
MagicBytes {
magic: Vec<u8>,
offset: usize,
},
Custom {
name: String,
},
}
impl RuleType {
pub fn matches(&self, data: &[u8]) -> Result<bool> {
match self {
Self::ByteSequence { pattern, offset } => {
if data.len() < offset + pattern.len() {
return Ok(false);
}
Ok(&data[*offset..*offset + pattern.len()] == pattern.as_slice())
}
Self::Regex { pattern } => {
let text = String::from_utf8_lossy(data);
Ok(text.contains(pattern))
}
Self::String { pattern, case_sensitive } => {
let text = String::from_utf8_lossy(data);
if *case_sensitive {
Ok(text.contains(pattern))
} else {
Ok(text.to_lowercase().contains(&pattern.to_lowercase()))
}
}
Self::Length { min, max } => {
let len = data.len();
let min_ok = min.map_or(true, |m| len >= m);
let max_ok = max.map_or(true, |m| len <= m);
Ok(min_ok && max_ok)
}
Self::Port { port: _ } => {
Ok(false)
}
Self::MagicBytes { magic, offset } => {
if data.len() < offset + magic.len() {
return Ok(false);
}
Ok(&data[*offset..*offset + magic.len()] == magic.as_slice())
}
Self::Custom { name: _ } => {
Ok(false)
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct FingerprintMatch {
pub matched: bool,
pub score: f32,
pub fingerprint_name: String,
pub protocol: ProtocolType,
pub rule_matches: Vec<RuleMatch>,
}
impl FingerprintMatch {
pub fn no_match() -> Self {
Self {
matched: false,
score: 0.0,
fingerprint_name: String::new(),
protocol: ProtocolType::Unknown,
rule_matches: Vec::new(),
}
}
pub fn is_high_score(&self) -> bool {
self.matched && self.score >= 0.8
}
pub fn is_acceptable(&self, threshold: f32) -> bool {
self.matched && self.score >= threshold
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RuleMatch {
pub matched: bool,
pub score: f32,
pub rule_name: String,
pub rule_type: RuleType,
}
#[derive(Debug, Clone)]
pub struct FingerprintDatabase {
fingerprints: HashMap<ProtocolType, Vec<ProtocolFingerprint>>,
}
impl FingerprintDatabase {
pub fn new() -> Self {
Self {
fingerprints: HashMap::new(),
}
}
pub fn add_fingerprint(&mut self, fingerprint: ProtocolFingerprint) {
self.fingerprints
.entry(fingerprint.protocol)
.or_insert_with(Vec::new)
.push(fingerprint);
}
pub fn add_fingerprints(&mut self, fingerprints: Vec<ProtocolFingerprint>) {
for fingerprint in fingerprints {
self.add_fingerprint(fingerprint);
}
}
pub fn match_protocol(&self, data: &[u8]) -> Result<Vec<FingerprintMatch>> {
let mut matches = Vec::new();
for fingerprints in self.fingerprints.values() {
for fingerprint in fingerprints {
let fingerprint_match = fingerprint.matches(data)?;
if fingerprint_match.matched {
matches.push(fingerprint_match);
}
}
}
matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(matches)
}
pub fn match_specific_protocol(
&self,
protocol: ProtocolType,
data: &[u8],
) -> Result<Vec<FingerprintMatch>> {
let mut matches = Vec::new();
if let Some(fingerprints) = self.fingerprints.get(&protocol) {
for fingerprint in fingerprints {
let fingerprint_match = fingerprint.matches(data)?;
if fingerprint_match.matched {
matches.push(fingerprint_match);
}
}
}
matches.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(matches)
}
pub fn best_match(&self, data: &[u8]) -> Result<Option<FingerprintMatch>> {
let matches = self.match_protocol(data)?;
Ok(matches.into_iter().next())
}
pub fn supported_protocols(&self) -> Vec<ProtocolType> {
self.fingerprints.keys().copied().collect()
}
pub fn fingerprint_count(&self) -> usize {
self.fingerprints.values().map(|v| v.len()).sum()
}
pub fn clear(&mut self) {
self.fingerprints.clear();
}
pub fn load_default_fingerprints(&mut self) {
let http11_fingerprint = ProtocolFingerprint::new(
ProtocolType::HTTP1_1,
"HTTP/1.1 Request",
"HTTP/1.1 request detection",
)
.add_rule(
FingerprintRule::new(
"HTTP Method",
RuleType::String {
pattern: "GET ".to_string(),
case_sensitive: true,
},
)
.required(),
)
.add_rule(
FingerprintRule::new(
"HTTP Version",
RuleType::String {
pattern: "HTTP/1.1".to_string(),
case_sensitive: false,
},
)
.with_weight(0.8),
);
let http2_fingerprint = ProtocolFingerprint::new(
ProtocolType::HTTP2,
"HTTP/2 Connection Preface",
"HTTP/2 connection preface detection",
)
.add_rule(
FingerprintRule::new(
"HTTP/2 Preface",
RuleType::ByteSequence {
pattern: b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(),
offset: 0,
},
)
.required(),
);
let tls_fingerprint = ProtocolFingerprint::new(
ProtocolType::TLS,
"TLS Handshake",
"TLS handshake detection",
)
.add_rule(
FingerprintRule::new(
"TLS Record Type",
RuleType::ByteSequence {
pattern: vec![0x16], offset: 0,
},
)
.required(),
)
.add_rule(
FingerprintRule::new(
"TLS Version",
RuleType::ByteSequence {
pattern: vec![0x03, 0x01], offset: 1,
},
)
.with_weight(0.7),
);
let ssh_fingerprint = ProtocolFingerprint::new(
ProtocolType::SSH,
"SSH Protocol",
"SSH protocol detection",
)
.add_rule(
FingerprintRule::new(
"SSH Banner",
RuleType::String {
pattern: "SSH-".to_string(),
case_sensitive: false,
},
)
.required(),
);
self.add_fingerprint(http11_fingerprint);
self.add_fingerprint(http2_fingerprint);
self.add_fingerprint(tls_fingerprint);
self.add_fingerprint(ssh_fingerprint);
}
}
impl Default for FingerprintDatabase {
fn default() -> Self {
let mut db = Self::new();
db.load_default_fingerprints();
db
}
}