agentic_connect/engine/
protocol_detect.rs1use std::time::Duration;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6
7use crate::types::{ConnectResult, Protocol};
8
9#[derive(Debug, Clone, serde::Serialize)]
11pub struct ProbeResult {
12 pub host: String,
13 pub port: u16,
14 pub protocol: Option<Protocol>,
15 pub reachable: bool,
16 pub latency_ms: u64,
17 pub banner: Option<String>,
18 pub tls: bool,
19}
20
21pub async fn probe_host(host: &str, port: u16, timeout_ms: u64) -> ProbeResult {
23 let addr = format!("{}:{}", host, port);
24 let start = std::time::Instant::now();
25 let dur = Duration::from_millis(timeout_ms);
26
27 let stream = match timeout(dur, TcpStream::connect(&addr)).await {
29 Ok(Ok(s)) => s,
30 _ => {
31 return ProbeResult {
32 host: host.into(), port, protocol: None,
33 reachable: false, latency_ms: start.elapsed().as_millis() as u64,
34 banner: None, tls: false,
35 };
36 }
37 };
38
39 let latency_ms = start.elapsed().as_millis() as u64;
40
41 let banner = read_banner(&stream, Duration::from_millis(2000)).await;
43
44 let protocol = detect_from_banner_and_port(port, banner.as_deref());
46 let tls = matches!(protocol, Some(Protocol::Https | Protocol::Wss | Protocol::Sftp | Protocol::Imap));
47
48 ProbeResult {
49 host: host.into(), port, protocol, reachable: true,
50 latency_ms, banner, tls,
51 }
52}
53
54pub async fn scan_host(host: &str, timeout_ms: u64) -> Vec<ProbeResult> {
56 let ports = [22, 80, 443, 3306, 5432, 6379, 8080, 8443, 9090, 27017];
57 let mut results = Vec::new();
58 for &port in &ports {
59 let r = probe_host(host, port, timeout_ms).await;
60 if r.reachable {
61 results.push(r);
62 }
63 }
64 results
65}
66
67async fn read_banner(stream: &TcpStream, timeout_dur: Duration) -> Option<String> {
68 use tokio::io::AsyncReadExt;
69 let mut buf = vec![0u8; 512];
70 match timeout(timeout_dur, stream.readable()).await {
71 Ok(Ok(())) => {
72 match stream.try_read(&mut buf) {
73 Ok(n) if n > 0 => {
74 let s = String::from_utf8_lossy(&buf[..n]).to_string();
75 Some(s.trim().to_string())
76 }
77 _ => None,
78 }
79 }
80 _ => None,
81 }
82}
83
84fn detect_from_banner_and_port(port: u16, banner: Option<&str>) -> Option<Protocol> {
85 if let Some(b) = banner {
87 let lower = b.to_lowercase();
88 if lower.starts_with("ssh-") { return Some(Protocol::Ssh); }
89 if lower.starts_with("220 ") && lower.contains("smtp") { return Some(Protocol::Smtp); }
90 if lower.starts_with("* ok") || lower.starts_with("* preauth") { return Some(Protocol::Imap); }
91 if lower.starts_with("+ok") { return Some(Protocol::Ftp); } if lower.starts_with("220 ") && lower.contains("ftp") { return Some(Protocol::Ftp); }
93 if lower.starts_with("+pong") || lower.starts_with("-err") { return Some(Protocol::Redis); }
94 if lower.contains("mysql") { return Some(Protocol::Mysql); }
95 if lower.starts_with("http/") { return Some(Protocol::Http); }
96 }
97
98 match port {
100 22 => Some(Protocol::Ssh),
101 80 | 8080 => Some(Protocol::Http),
102 443 | 8443 => Some(Protocol::Https),
103 21 => Some(Protocol::Ftp),
104 25 | 587 => Some(Protocol::Smtp),
105 993 => Some(Protocol::Imap),
106 53 => Some(Protocol::Dns),
107 1883 => Some(Protocol::Mqtt),
108 5672 => Some(Protocol::Amqp),
109 6379 => Some(Protocol::Redis),
110 5432 => Some(Protocol::Postgres),
111 3306 => Some(Protocol::Mysql),
112 _ => None,
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn test_detect_ssh_banner() {
122 let proto = detect_from_banner_and_port(22, Some("SSH-2.0-OpenSSH_8.9"));
123 assert_eq!(proto, Some(Protocol::Ssh));
124 }
125
126 #[test]
127 fn test_detect_redis_banner() {
128 let proto = detect_from_banner_and_port(6379, Some("+PONG"));
129 assert_eq!(proto, Some(Protocol::Redis));
130 }
131
132 #[test]
133 fn test_detect_by_port_no_banner() {
134 assert_eq!(detect_from_banner_and_port(443, None), Some(Protocol::Https));
135 assert_eq!(detect_from_banner_and_port(5432, None), Some(Protocol::Postgres));
136 assert_eq!(detect_from_banner_and_port(9999, None), None);
137 }
138}