Skip to main content

agentic_connect/engine/
protocol_detect.rs

1//! Protocol detection engine — probes hosts to identify protocols.
2
3use std::time::Duration;
4use tokio::net::TcpStream;
5use tokio::time::timeout;
6
7use crate::types::{ConnectResult, Protocol};
8
9/// Result of a protocol probe.
10#[derive(Debug, Clone, serde::Serialize)]
11pub struct ProbeResult {
12    pub host: String,
13    pub port: u16,
14    pub protocol: Option<Protocol>,
15    pub reachable: bool,
16    pub latency_ms: u64,
17    pub banner: Option<String>,
18    pub tls: bool,
19}
20
21/// Probe a host:port to detect the protocol by reading the server banner.
22pub async fn probe_host(host: &str, port: u16, timeout_ms: u64) -> ProbeResult {
23    let addr = format!("{}:{}", host, port);
24    let start = std::time::Instant::now();
25    let dur = Duration::from_millis(timeout_ms);
26
27    // Try TCP connect
28    let stream = match timeout(dur, TcpStream::connect(&addr)).await {
29        Ok(Ok(s)) => s,
30        _ => {
31            return ProbeResult {
32                host: host.into(), port, protocol: None,
33                reachable: false, latency_ms: start.elapsed().as_millis() as u64,
34                banner: None, tls: false,
35            };
36        }
37    };
38
39    let latency_ms = start.elapsed().as_millis() as u64;
40
41    // Try reading a banner (some protocols send one on connect)
42    let banner = read_banner(&stream, Duration::from_millis(2000)).await;
43
44    // Detect protocol from port + banner
45    let protocol = detect_from_banner_and_port(port, banner.as_deref());
46    let tls = matches!(protocol, Some(Protocol::Https | Protocol::Wss | Protocol::Sftp | Protocol::Imap));
47
48    ProbeResult {
49        host: host.into(), port, protocol, reachable: true,
50        latency_ms, banner, tls,
51    }
52}
53
54/// Probe multiple common ports on a host.
55pub async fn scan_host(host: &str, timeout_ms: u64) -> Vec<ProbeResult> {
56    let ports = [22, 80, 443, 3306, 5432, 6379, 8080, 8443, 9090, 27017];
57    let mut results = Vec::new();
58    for &port in &ports {
59        let r = probe_host(host, port, timeout_ms).await;
60        if r.reachable {
61            results.push(r);
62        }
63    }
64    results
65}
66
67async fn read_banner(stream: &TcpStream, timeout_dur: Duration) -> Option<String> {
68    use tokio::io::AsyncReadExt;
69    let mut buf = vec![0u8; 512];
70    match timeout(timeout_dur, stream.readable()).await {
71        Ok(Ok(())) => {
72            match stream.try_read(&mut buf) {
73                Ok(n) if n > 0 => {
74                    let s = String::from_utf8_lossy(&buf[..n]).to_string();
75                    Some(s.trim().to_string())
76                }
77                _ => None,
78            }
79        }
80        _ => None,
81    }
82}
83
84fn detect_from_banner_and_port(port: u16, banner: Option<&str>) -> Option<Protocol> {
85    // Banner-based detection first
86    if let Some(b) = banner {
87        let lower = b.to_lowercase();
88        if lower.starts_with("ssh-") { return Some(Protocol::Ssh); }
89        if lower.starts_with("220 ") && lower.contains("smtp") { return Some(Protocol::Smtp); }
90        if lower.starts_with("* ok") || lower.starts_with("* preauth") { return Some(Protocol::Imap); }
91        if lower.starts_with("+ok") { return Some(Protocol::Ftp); } // POP3 actually, but close
92        if lower.starts_with("220 ") && lower.contains("ftp") { return Some(Protocol::Ftp); }
93        if lower.starts_with("+pong") || lower.starts_with("-err") { return Some(Protocol::Redis); }
94        if lower.contains("mysql") { return Some(Protocol::Mysql); }
95        if lower.starts_with("http/") { return Some(Protocol::Http); }
96    }
97
98    // Port-based fallback
99    match port {
100        22 => Some(Protocol::Ssh),
101        80 | 8080 => Some(Protocol::Http),
102        443 | 8443 => Some(Protocol::Https),
103        21 => Some(Protocol::Ftp),
104        25 | 587 => Some(Protocol::Smtp),
105        993 => Some(Protocol::Imap),
106        53 => Some(Protocol::Dns),
107        1883 => Some(Protocol::Mqtt),
108        5672 => Some(Protocol::Amqp),
109        6379 => Some(Protocol::Redis),
110        5432 => Some(Protocol::Postgres),
111        3306 => Some(Protocol::Mysql),
112        _ => None,
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_detect_ssh_banner() {
122        let proto = detect_from_banner_and_port(22, Some("SSH-2.0-OpenSSH_8.9"));
123        assert_eq!(proto, Some(Protocol::Ssh));
124    }
125
126    #[test]
127    fn test_detect_redis_banner() {
128        let proto = detect_from_banner_and_port(6379, Some("+PONG"));
129        assert_eq!(proto, Some(Protocol::Redis));
130    }
131
132    #[test]
133    fn test_detect_by_port_no_banner() {
134        assert_eq!(detect_from_banner_and_port(443, None), Some(Protocol::Https));
135        assert_eq!(detect_from_banner_and_port(5432, None), Some(Protocol::Postgres));
136        assert_eq!(detect_from_banner_and_port(9999, None), None);
137    }
138}