use psi_detector::{
DetectorBuilder, ProtocolDetector, ProtocolType,
core::ProbeStrategy
};
use std::time::Duration;
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("🔄 PSI-Detector 协议升级示例");
let detector = DetectorBuilder::new()
.enable_http()
.enable_http2()
.enable_http3()
.enable_tls()
.enable_ssh()
.with_strategy(ProbeStrategy::Passive)
.with_timeout(Duration::from_millis(200))
.build()?;
let test_scenarios = vec![
TestScenario {
name: "HTTP/1.1 升级到 H2C 请求",
data: b"GET / HTTP/1.1\r\nHost: example.com\r\nUpgrade: h2c\r\nConnection: Upgrade\r\nHTTP2-Settings: AAMAAABkAARAAAAAAAIAAAAA\r\n\r\n".to_vec(),
expected_protocol: ProtocolType::HTTP1_1,
},
TestScenario {
name: "HTTP/2 连接前言",
data: b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec(),
expected_protocol: ProtocolType::HTTP2,
},
TestScenario {
name: "HTTP/3 over QUIC 连接",
data: vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x00, 0x44, 0x01, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x00, 0x06, 0x00, 0x3c, 0x01, 0x00, 0x00, 0x38, 0x03, 0x03, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x00, 0x00, 0x02, 0x13, 0x01, 0x01, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, 0x68, 0x33, ],
expected_protocol: ProtocolType::HTTP3,
},
TestScenario {
name: "WebSocket 升级请求",
data: b"GET /chat HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec(),
expected_protocol: ProtocolType::HTTP1_1,
},
TestScenario {
name: "TLS 握手",
data: vec![
0x16, 0x03, 0x01, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x2b, 0x03, 0x03, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
],
expected_protocol: ProtocolType::TLS,
},
TestScenario {
name: "SSH 连接",
data: b"SSH-2.0-OpenSSH_8.0\r\n".to_vec(),
expected_protocol: ProtocolType::SSH,
},
];
println!("\n🧪 开始协议探测测试...");
for scenario in &test_scenarios {
println!("\n📋 测试场景: {}", scenario.name);
match detector.detect(&scenario.data) {
Ok(result) => {
println!(" 探测结果: {:?} (置信度: {:.2})",
result.protocol_info.protocol_type, result.confidence());
if result.protocol_info.protocol_type == scenario.expected_protocol {
println!(" ✅ 协议识别正确!");
} else {
println!(" ⚠️ 协议识别不匹配,期望: {:?}", scenario.expected_protocol);
}
if scenario.name.contains("升级") {
let data_str = String::from_utf8_lossy(&scenario.data);
if data_str.contains("Upgrade:") {
println!(" 🔄 检测到升级请求头");
}
}
},
Err(e) => {
println!(" ❌ 探测失败: {}", e);
}
}
}
println!("\n🎉 协议探测示例完成!");
Ok(())
}
#[derive(Debug)]
struct TestScenario {
name: &'static str,
data: Vec<u8>,
expected_protocol: ProtocolType,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_to_websocket_upgrade() {
let mut upgrade_manager = UpgradeManager::default();
let websocket_request = b"GET /chat HTTP/1.1\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: test\r\n\r\n";
let result = upgrade_manager.upgrade_protocol(
ProtocolType::HTTP1_1,
ProtocolType::WebSocket,
websocket_request,
);
match result {
Ok(upgrade_result) => {
assert_eq!(upgrade_result.target_protocol, ProtocolType::WebSocket);
}
Err(_) => {
println!("WebSocket 升级器未实现,这是正常的");
}
}
}
}