use psi_detector::{
DetectorBuilder, ProtocolDetector, ProtocolType,
core::ProbeStrategy
};
use std::time::{Duration, Instant};
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("⚡ PSI-Detector SIMD 高性能探测示例");
let simd_detector = DetectorBuilder::new()
.enable_http()
.enable_http2()
.enable_http3() .enable_tls()
.enable_ssh()
.with_strategy(ProbeStrategy::Passive)
.with_timeout(Duration::from_millis(50))
.build()?;
let standard_detector = DetectorBuilder::new()
.enable_http()
.enable_http2()
.enable_http3() .enable_tls()
.enable_ssh()
.with_strategy(ProbeStrategy::Passive)
.with_timeout(Duration::from_millis(50))
.build()?;
let test_data = generate_test_data();
println!("\n🧪 性能测试开始...");
println!("测试数据量: {} 个样本", test_data.len());
println!("\n⚡ SIMD 加速探测:");
let simd_results = benchmark_detection(&simd_detector, &test_data, "SIMD");
println!("\n🔧 标准探测:");
let standard_results = benchmark_detection(&standard_detector, &test_data, "Standard");
println!("\n📊 性能对比报告:");
let speedup = standard_results.total_time.as_nanos() as f64 / simd_results.total_time.as_nanos() as f64;
println!(" 🚀 SIMD 加速倍数: {:.2}x", speedup);
println!(" ⏱️ 平均延迟减少: {:.2}μs",
(standard_results.avg_latency.as_nanos() as f64 - simd_results.avg_latency.as_nanos() as f64) / 1000.0);
println!(" 🎯 准确率对比: SIMD {:.1}% vs Standard {:.1}%",
simd_results.accuracy * 100.0, standard_results.accuracy * 100.0);
println!("\n📈 协议探测分布:");
for (protocol, count) in &simd_results.protocol_stats {
println!(" {:?}: {} 次", protocol, count);
}
println!("\n🎉 SIMD 性能测试完成!");
Ok(())
}
#[derive(Debug)]
struct BenchmarkResult {
total_time: Duration,
avg_latency: Duration,
accuracy: f64,
protocol_stats: std::collections::HashMap<ProtocolType, usize>,
}
fn benchmark_detection(
detector: &dyn ProtocolDetector,
test_data: &[(String, Vec<u8>)],
name: &str,
) -> BenchmarkResult {
let mut total_time = Duration::new(0, 0);
let mut successful_detections = 0;
let mut protocol_stats = std::collections::HashMap::new();
let start_time = Instant::now();
println!("\n🔍 {} 详细检测结果:", name);
for (expected_protocol, data) in test_data {
let detection_start = Instant::now();
match detector.detect(data) {
Ok(result) => {
let detection_time = detection_start.elapsed();
total_time += detection_time;
*protocol_stats.entry(result.protocol_type()).or_insert(0) += 1;
let is_correct = is_correct_detection(expected_protocol, result.protocol_type());
if is_correct {
successful_detections += 1;
}
let status = if is_correct { "✅" } else { "❌" };
println!(" {} {} -> {:?} ({:.1}%) [预期: {}]",
status,
expected_protocol,
result.protocol_type(),
result.confidence() * 100.0,
expected_protocol);
}
Err(e) => {
let is_correct = expected_protocol == "Unknown";
let status = if is_correct { "✅" } else { "❌" };
println!(" {} {} -> 检测失败: {:?}", status, expected_protocol, e);
if is_correct {
successful_detections += 1;
}
*protocol_stats.entry(ProtocolType::Unknown).or_insert(0) += 1;
}
}
}
let total_benchmark_time = start_time.elapsed();
let avg_latency = total_time / test_data.len() as u32;
let accuracy = successful_detections as f64 / test_data.len() as f64;
println!(" ⏱️ 总时间: {:?}", total_benchmark_time);
println!(" 📊 平均延迟: {:?}", avg_latency);
println!(" 🎯 准确率: {:.1}%", accuracy * 100.0);
println!(" 📈 吞吐量: {:.0} 检测/秒",
test_data.len() as f64 / total_benchmark_time.as_secs_f64());
BenchmarkResult {
total_time: total_benchmark_time,
avg_latency,
accuracy,
protocol_stats,
}
}
fn generate_test_data() -> Vec<(String, Vec<u8>)> {
vec![
("HTTP1_1".to_string(), b"GET /api/data HTTP/1.1\r\nHost: api.example.com\r\nAccept: application/json\r\nUser-Agent: TestClient/1.0\r\n\r\n".to_vec()),
("HTTP1_1".to_string(), b"POST /submit HTTP/1.1\r\nContent-Type: application/json\r\nContent-Length: 25\r\nAuthorization: Bearer token123\r\n\r\n{\"key\": \"value\"}".to_vec()),
("HTTP2".to_string(), {
let mut data = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(); data.extend_from_slice(&[0x00, 0x00, 0x12, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x00, 0x03, 0x00, 0x00, 0x00, 0x64]); data
}),
("HTTP3".to_string(), {
let mut data = vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x74, 0x01, 0x00, 0x00, 0xed, 0x03, ];
data.extend_from_slice(&[0x03, 0x68, 0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data
}),
("TLS".to_string(), {
let mut data = vec![
0x16, 0x03, 0x01, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x2b, 0x03, 0x03, ];
data.extend_from_slice(&[0x00; 21]); data
}),
("SSH".to_string(), b"SSH-2.0-OpenSSH_8.9 Ubuntu-3ubuntu0.1\r\n".to_vec()),
("SSH".to_string(), b"SSH-1.99-Cisco-1.25 (protocol 2.0)\r\n".to_vec()),
("WebSocket".to_string(), b"GET /chat HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec()),
("QUIC".to_string(), {
let mut data = vec![0xc0, 0x00, 0x00, 0x00, 0x01]; data.extend_from_slice(&[0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data
}),
("Unknown".to_string(), {
let mut data = vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07];
data.extend_from_slice(&[0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11]); data
}),
("Unknown".to_string(), {
let mut data = vec![0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA];
data.extend_from_slice(&[0xF9, 0xF8, 0xF7, 0xF6, 0xF5, 0xF4, 0xF3, 0xF2, 0xF1, 0xF0, 0xEF, 0xEE]); data
}),
]
}
fn is_correct_detection(expected: &str, detected: ProtocolType) -> bool {
match (expected, detected) {
("HTTP1_1", ProtocolType::HTTP1_1) => true,
("HTTP2", ProtocolType::HTTP2) => true,
("HTTP3", ProtocolType::HTTP3) => true,
("HTTP3", ProtocolType::QUIC) => true, ("TLS", ProtocolType::TLS) => true,
("SSH", ProtocolType::SSH) => true,
("WebSocket", ProtocolType::WebSocket) => true,
("WebSocket", ProtocolType::HTTP1_1) => true, ("QUIC", ProtocolType::QUIC) => true,
("QUIC", ProtocolType::HTTP3) => true, ("Unknown", ProtocolType::Unknown) => true,
("Unknown", ProtocolType::Unknown) => true,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_vs_standard_performance() {
let simd_detector = DetectorBuilder::new()
.enable_http()
.enable_http2()
.enable_http3() .build()
.expect("Failed to build SIMD detector");
let standard_detector = DetectorBuilder::new()
.enable_http()
.enable_http2()
.enable_http3() .build()
.expect("Failed to build standard detector");
let test_data = vec![
("HTTP1_1".to_string(), b"GET / HTTP/1.1\r\n\r\n".to_vec()),
("HTTP3".to_string(), vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x03, 0x68, 0x33, ]),
];
let simd_result = benchmark_detection(&simd_detector, &test_data, "SIMD");
let standard_result = benchmark_detection(&standard_detector, &test_data, "Standard");
let performance_ratio = simd_result.total_time.as_nanos() as f64 / standard_result.total_time.as_nanos() as f64;
assert!(performance_ratio <= 3.0, "SIMD performance ratio: {:.2}x", performance_ratio); }
#[test]
fn test_http3_detection_accuracy() {
let detector = DetectorBuilder::new()
.enable_http3()
.build()
.expect("Failed to build HTTP/3 detector");
let http3_data = vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x03, 0x68, 0x33, ];
match detector.detect(&http3_data) {
Ok(result) => {
assert!(matches!(result.protocol_type(), ProtocolType::HTTP3 | ProtocolType::QUIC));
assert!(result.confidence() > 0.5);
}
Err(_) => {
println!("HTTP/3 detection failed, which may be expected for simplified test data");
}
}
}
}