use psi_detector::{
DetectorBuilder, ProtocolDetector, ProtocolType,
core::{ProbeStrategy, ProbeContext, ProtocolProbe, ProtocolInfo},
error::{DetectorError, Result},
};
use std::time::{Duration, Instant};
fn main() -> Result<()> {
println!("🔌 PSI-Detector 插件系统演示");
println!("\n📡 1. DNS 协议探测插件演示");
demonstrate_dns_plugin()?;
println!("\n🔧 2. 多插件集成演示");
demonstrate_multi_plugin_integration()?;
println!("\n⚡ 3. 插件优先级演示");
demonstrate_plugin_priority()?;
println!("\n🎉 插件系统演示完成!");
Ok(())
}
#[derive(Debug)]
struct DnsProbe {
name: &'static str,
priority: u8,
min_packet_size: usize,
}
impl DnsProbe {
pub fn new() -> Self {
Self {
name: "DNS-UDP-Probe",
priority: 60, min_packet_size: 12, }
}
fn validate_dns_header(&self, data: &[u8]) -> bool {
if data.len() < self.min_packet_size {
return false;
}
let flags = u16::from_be_bytes([data[2], data[3]]);
let questions = u16::from_be_bytes([data[4], data[5]]);
let qr = (flags >> 15) & 1; let opcode = (flags >> 11) & 0xF; let rcode = flags & 0xF;
if opcode > 5 { return false;
}
if qr == 0 && rcode != 0 { return false;
}
if questions == 0 && qr == 0 { return false;
}
if questions > 100 { return false;
}
true
}
fn calculate_confidence(&self, data: &[u8]) -> f32 {
let mut confidence: f32 = 0.0;
if data.len() < self.min_packet_size {
return 0.0;
}
if self.validate_dns_header(data) {
confidence += 0.6;
} else {
return 0.0;
}
let flags = u16::from_be_bytes([data[2], data[3]]);
let questions = u16::from_be_bytes([data[4], data[5]]);
let opcode = (flags >> 11) & 0xF;
if opcode == 0 { confidence += 0.2;
}
if questions >= 1 && questions <= 10 {
confidence += 0.1;
}
if data.len() > 12 {
if self.validate_domain_name(&data[12..]) {
confidence += 0.1;
}
}
confidence.min(1.0)
}
fn validate_domain_name(&self, data: &[u8]) -> bool {
if data.is_empty() {
return false;
}
let mut pos = 0;
let mut labels = 0;
while pos < data.len() && labels < 63 { let len = data[pos] as usize;
if len == 0 {
return labels > 0;
}
if len > 63 {
return false;
}
if pos + 1 + len >= data.len() {
return false;
}
for i in 1..=len {
let c = data[pos + i];
if !c.is_ascii_alphanumeric() && c != b'-' && c != b'_' {
return false;
}
}
pos += 1 + len;
labels += 1;
}
false
}
}
impl ProtocolProbe for DnsProbe {
fn name(&self) -> &'static str {
self.name
}
fn supported_protocols(&self) -> Vec<ProtocolType> {
vec![ProtocolType::Custom] }
fn probe(&self, data: &[u8], context: &mut ProbeContext) -> Result<Option<ProtocolInfo>> {
let start_time = Instant::now();
if data.len() < self.min_packet_size {
return Ok(None);
}
let confidence = self.calculate_confidence(data);
if confidence > 0.5 {
let mut protocol_info = ProtocolInfo::new(ProtocolType::Custom, confidence);
protocol_info.add_feature("DNS-UDP");
protocol_info.add_feature(format!("confidence-{:.1}%", confidence * 100.0));
protocol_info.add_metadata("transport", "UDP");
protocol_info.add_metadata("protocol_name", "DNS"); protocol_info.add_metadata("details", format!("DNS packet detected (UDP), confidence: {:.1}%", confidence * 100.0));
context.add_candidate(protocol_info.clone());
Ok(Some(protocol_info))
} else {
Ok(None)
}
}
fn priority(&self) -> u8 {
self.priority
}
fn needs_more_data(&self, data: &[u8]) -> bool {
data.len() < self.min_packet_size
}
}
#[derive(Debug)]
struct MqttProbe {
name: &'static str,
priority: u8,
}
impl MqttProbe {
pub fn new() -> Self {
Self {
name: "MQTT-TCP-Probe",
priority: 55,
}
}
fn is_mqtt_connect(&self, data: &[u8]) -> bool {
if data.len() < 10 {
return false;
}
if data[0] != 0x10 {
return false;
}
if data.len() > 8 {
let protocol_name_len = u16::from_be_bytes([data[2], data[3]]) as usize;
if protocol_name_len == 4 && data.len() > 6 + protocol_name_len {
let protocol_name = &data[4..8];
return protocol_name == b"MQTT";
}
}
false
}
}
impl ProtocolProbe for MqttProbe {
fn name(&self) -> &'static str {
self.name
}
fn supported_protocols(&self) -> Vec<ProtocolType> {
vec![ProtocolType::Custom] }
fn probe(&self, data: &[u8], context: &mut ProbeContext) -> Result<Option<ProtocolInfo>> {
if self.is_mqtt_connect(data) {
let mut protocol_info = ProtocolInfo::new(ProtocolType::Custom, 0.9);
protocol_info.add_feature("MQTT-CONNECT");
protocol_info.add_metadata("transport", "TCP");
protocol_info.add_metadata("protocol_name", "MQTT"); protocol_info.add_metadata("details", "MQTT CONNECT packet detected");
context.add_candidate(protocol_info.clone());
Ok(Some(protocol_info))
} else {
Ok(None)
}
}
fn priority(&self) -> u8 {
self.priority
}
fn needs_more_data(&self, data: &[u8]) -> bool {
data.len() < 10
}
}
fn demonstrate_dns_plugin() -> Result<()> {
println!(" 🔍 创建带有 DNS 插件的探测器");
let dns_probe = DnsProbe::new();
let detector = DetectorBuilder::new()
.enable_http() .enable_tls()
.enable_custom() .add_custom_probe(Box::new(dns_probe)) .with_strategy(ProbeStrategy::Passive)
.with_timeout(Duration::from_millis(100))
.build()?;
let dns_query = create_dns_query_packet();
println!(" 📦 测试 DNS 查询包 ({} bytes)", dns_query.len());
let result = detector.detect(&dns_query)?;
println!(" ✅ DNS 探测结果:");
println!(" 协议类型: {:?}", result.protocol_type());
println!(" 置信度: {:.1}%", result.confidence() * 100.0);
if let Some(details) = result.protocol_info.metadata.get("details") {
println!(" 详情: {}", details);
}
println!(" 探测时间: {:?}", result.detection_time);
let http_data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
println!("\n 📦 测试 HTTP 数据包 ({} bytes)", http_data.len());
let result = detector.detect(http_data)?;
println!(" ✅ HTTP 探测结果:");
println!(" 协议类型: {:?}", result.protocol_type());
println!(" 置信度: {:.1}%", result.confidence() * 100.0);
Ok(())
}
fn demonstrate_multi_plugin_integration() -> Result<()> {
println!(" 🔧 创建多插件集成探测器");
let detector = DetectorBuilder::new()
.enable_http()
.enable_tls()
.enable_ssh()
.enable_custom()
.add_custom_probe(Box::new(DnsProbe::new()))
.add_custom_probe(Box::new(MqttProbe::new()))
.with_strategy(ProbeStrategy::Passive)
.build()?;
let test_cases = vec![
("DNS 查询", create_dns_query_packet()),
("MQTT 连接", create_mqtt_connect_packet()),
("HTTP 请求", b"GET /api HTTP/1.1\r\nHost: test.com\r\n\r\n".to_vec()),
("SSH 握手", b"SSH-2.0-OpenSSH_8.0\r\n".to_vec()),
];
for (name, data) in test_cases {
println!("\n 📦 测试 {} ({} bytes)", name, data.len());
match detector.detect(&data) {
Ok(result) => {
println!(" ✅ 探测成功:");
println!(" 协议: {:?}", result.protocol_type());
println!(" 置信度: {:.1}%", result.confidence() * 100.0);
println!(" 探测器: {}", result.detector_name);
if let Some(details) = result.protocol_info.metadata.get("details") {
println!(" 详情: {}", details);
}
}
Err(e) => {
println!(" ❌ 探测失败: {}", e);
}
}
}
Ok(())
}
fn demonstrate_plugin_priority() -> Result<()> {
println!(" ⚡ 测试插件优先级机制");
let high_priority_dns = DnsProbe {
name: "High-Priority-DNS",
priority: 90,
min_packet_size: 12,
};
let low_priority_dns = DnsProbe {
name: "Low-Priority-DNS",
priority: 30,
min_packet_size: 12,
};
let detector = DetectorBuilder::new()
.enable_http() .enable_custom() .add_custom_probe(Box::new(low_priority_dns))
.add_custom_probe(Box::new(high_priority_dns))
.build()?;
let dns_data = create_dns_query_packet();
let result = detector.detect(&dns_data)?;
println!(" 📊 优先级测试结果:");
println!(" 使用的探测器: {}", result.detector_name);
println!(" 协议类型: {:?}", result.protocol_type());
println!(" 置信度: {:.1}%", result.confidence() * 100.0);
Ok(())
}
fn create_dns_query_packet() -> Vec<u8> {
let mut packet = Vec::new();
packet.extend_from_slice(&[0x12, 0x34]); packet.extend_from_slice(&[0x01, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]);
packet.push(7); packet.extend_from_slice(b"example");
packet.push(3); packet.extend_from_slice(b"com");
packet.push(0);
packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x01]);
packet
}
fn create_mqtt_connect_packet() -> Vec<u8> {
let mut packet = Vec::new();
packet.push(0x10); packet.push(0x10);
packet.extend_from_slice(&[0x00, 0x04]); packet.extend_from_slice(b"MQTT"); packet.push(0x04); packet.push(0x02); packet.extend_from_slice(&[0x00, 0x3C]);
packet.extend_from_slice(&[0x00, 0x04]); packet.extend_from_slice(b"test");
packet
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dns_probe_creation() {
let probe = DnsProbe::new();
assert_eq!(probe.name(), "DNS-UDP-Probe");
assert_eq!(probe.priority(), 60);
}
#[test]
fn test_dns_packet_validation() {
let probe = DnsProbe::new();
let dns_packet = create_dns_query_packet();
assert!(probe.validate_dns_header(&dns_packet));
assert!(probe.calculate_confidence(&dns_packet) > 0.5);
}
#[test]
fn test_mqtt_probe_creation() {
let probe = MqttProbe::new();
assert_eq!(probe.name(), "MQTT-TCP-Probe");
assert_eq!(probe.priority(), 55);
}
#[test]
fn test_mqtt_packet_detection() {
let probe = MqttProbe::new();
let mqtt_packet = create_mqtt_connect_packet();
assert!(probe.is_mqtt_connect(&mqtt_packet));
}
}