anti_scan/
lib.rs

1use anti_common::{resolve_hostnames, PingError, PingResult};
2use anti_ping::icmp::IcmpPinger;
3use anti_ping::PingConfig;
4use futures::future::BoxFuture;
5use futures::stream::{FuturesUnordered, StreamExt};
6use ipnet::IpNet;
7use std::collections::HashMap;
8use std::io::Write;
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::{TcpStream, UdpSocket};
13use tokio::sync::Semaphore;
14use tokio::time::timeout;
15
16const CONCURRENT_LIMIT: usize = 512;
17
18/// Parse a port range string like "20-80" into a list of ports.
19pub fn parse_port_range(range: &str) -> PingResult<Vec<u16>> {
20    if !range.contains('-') {
21        // Single port
22        let port: u16 = range.parse().map_err(|_| PingError::Configuration {
23            message: format!("Invalid port: {}", range),
24        })?;
25        if port == 0 {
26            return Err(PingError::Configuration {
27                message: "Port cannot be 0".into(),
28            });
29        }
30        return Ok(vec![port]);
31    }
32
33    let parts: Vec<&str> = range.split('-').collect();
34    if parts.len() != 2 {
35        return Err(PingError::Configuration {
36            message: format!("Invalid port range: {}", range),
37        });
38    }
39    let start: u16 = parts[0].parse().map_err(|_| PingError::Configuration {
40        message: format!("Invalid port: {}", parts[0]),
41    })?;
42    let end: u16 = parts[1].parse().map_err(|_| PingError::Configuration {
43        message: format!("Invalid port: {}", parts[1]),
44    })?;
45    if start == 0 || start > end {
46        return Err(PingError::Configuration {
47            message: format!("Invalid port range: {}", range),
48        });
49    }
50    Ok((start..=end).collect())
51}
52
53/// Parse a single IP address or CIDR range into a list of IP addresses.
54pub fn parse_ip_range(target: &str) -> PingResult<Vec<IpAddr>> {
55    if target.contains('/') {
56        let net: IpNet = target.parse().map_err(|_| PingError::Configuration {
57            message: format!("Invalid CIDR notation: {}", target),
58        })?;
59        Ok(net.hosts().collect())
60    } else {
61        resolve_hostnames(target)
62    }
63}
64
65fn service_map() -> HashMap<u16, &'static str> {
66    HashMap::from([
67        (20, "FTP"),
68        (21, "FTP"),
69        (22, "SSH"),
70        (23, "Telnet"),
71        (25, "SMTP"),
72        (53, "DNS"),
73        (80, "HTTP"),
74        (110, "POP3"),
75        (143, "IMAP"),
76        (443, "HTTPS"),
77        (3306, "MySQL"),
78        (5432, "Postgres"),
79        (6379, "Redis"),
80    ])
81}
82
83fn common_udp_ports() -> Vec<u16> {
84    vec![
85        53, // DNS
86        67, 68,  // DHCP
87        69,  // TFTP
88        123, // NTP
89        161, // SNMP
90        500, // IKE
91        514, // Syslog
92    ]
93}
94
95fn udp_payload(port: u16) -> Vec<u8> {
96    match port {
97        53 => {
98            // Minimal DNS query for '.' A record
99            vec![
100                0x00, 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
101            ]
102        }
103        123 => {
104            let mut pkt = vec![0; 48];
105            pkt[0] = 0x1b; // LI=0, VN=3, Mode=3 (client)
106            pkt
107        }
108        161 => vec![
109            0x30, 0x26, 0x02, 0x01, 0x01, 0x04, 0x06, b'p', b'u', b'b', b'l', b'i', b'c', 0xa0,
110            0x19, 0x02, 0x04, 0x70, 0x4b, 0x3a, 0x7f, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30,
111            0x0b, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x06, 0x01, 0x02, 0x01, 0x05, 0x00,
112        ],
113        _ => vec![0],
114    }
115}
116
117async fn detect_os(ip: IpAddr, timeout: Duration) -> Option<String> {
118    if let IpAddr::V4(v4) = ip {
119        let config = PingConfig {
120            target: v4,
121            count: 1,
122            timeout,
123            ..Default::default()
124        };
125        if let Ok(pinger) = IcmpPinger::new(config) {
126            if let Ok(reply) = pinger.ping(1) {
127                if let Some(ttl) = reply.ttl {
128                    return Some(match ttl {
129                        0..=64 => "unix".to_string(),
130                        65..=128 => "windows".to_string(),
131                        129..=255 => "network".to_string(),
132                    });
133                }
134            }
135        }
136    }
137    None
138}
139
140/// Scan the given IP address for an open TCP port.
141async fn scan_tcp_port(ip: IpAddr, port: u16, timeout_dur: Duration) -> Option<u16> {
142    let addr = SocketAddr::new(ip, port);
143    match timeout(timeout_dur, TcpStream::connect(addr)).await {
144        Ok(Ok(_stream)) => Some(port),
145        _ => None,
146    }
147}
148
149/// Scan the given IPv4 address for an open UDP port.
150///
151/// This implementation is intentionally conservative and only reports a port as
152/// open when a datagram is successfully received back from the target. This
153/// avoids the large number of false positives the previous naive check
154/// produced.
155async fn scan_udp_port(ip: IpAddr, port: u16, timeout_dur: Duration) -> Option<u16> {
156    let addr = SocketAddr::new(ip, port);
157    let bind_addr = match ip {
158        IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
159        IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
160    };
161    if let Ok(sock) = UdpSocket::bind(bind_addr).await {
162        if sock.connect(addr).await.is_ok() && sock.send(&udp_payload(port)).await.is_ok() {
163            let mut buf = [0u8; 512];
164            match timeout(timeout_dur, sock.recv(&mut buf)).await {
165                Ok(Ok(n)) if n > 0 => return Some(port),
166                Ok(Err(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => return None,
167                _ => {}
168            }
169        }
170    }
171    None
172}
173
174/// Protocols that can be scanned
175#[derive(Debug, Clone, Copy, PartialEq, Eq)]
176pub enum ScanProtocol {
177    Tcp,
178    Udp,
179    Both,
180}
181
182pub fn common_ports_for(proto: ScanProtocol) -> Vec<u16> {
183    let mut ports = Vec::new();
184    if matches!(proto, ScanProtocol::Tcp | ScanProtocol::Both) {
185        ports.extend(anti_ping::tcp::get_common_tcp_ports());
186    }
187    if matches!(proto, ScanProtocol::Udp | ScanProtocol::Both) {
188        ports.extend(common_udp_ports());
189    }
190    ports
191}
192
193/// Scan a host or CIDR range for open ports using the specified protocol.
194pub async fn scan_targets(
195    target: &str,
196    ports: &[u16],
197    proto: ScanProtocol,
198    timeout: Duration,
199    progress: bool,
200    os_detect: bool,
201) -> PingResult<(Vec<(IpAddr, u16, ScanProtocol)>, HashMap<IpAddr, String>)> {
202    let ips = parse_ip_range(target)?;
203    let svc_map = service_map();
204    let semaphore = Arc::new(Semaphore::new(CONCURRENT_LIMIT));
205    let mut os_map = HashMap::new();
206
207    enum TaskResult {
208        Port(IpAddr, u16, ScanProtocol, bool),
209    }
210
211    let mut futures: FuturesUnordered<BoxFuture<'static, TaskResult>> = FuturesUnordered::new();
212
213    for &ip in &ips {
214        for &port in ports {
215            if matches!(proto, ScanProtocol::Tcp | ScanProtocol::Both) {
216                let sem = semaphore.clone();
217                futures.push(Box::pin(async move {
218                    let _permit = sem.acquire().await.unwrap();
219                    let open = scan_tcp_port(ip, port, timeout).await.is_some();
220                    TaskResult::Port(ip, port, ScanProtocol::Tcp, open)
221                }));
222            }
223            if matches!(proto, ScanProtocol::Udp | ScanProtocol::Both) {
224                let sem = semaphore.clone();
225                futures.push(Box::pin(async move {
226                    let _permit = sem.acquire().await.unwrap();
227                    let open = scan_udp_port(ip, port, timeout).await.is_some();
228                    TaskResult::Port(ip, port, ScanProtocol::Udp, open)
229                }));
230            }
231        }
232    }
233
234    let mut spinner_idx = 0usize;
235    let spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'];
236    let total = futures.len();
237    let bar_width = 20usize;
238    let mut completed = 0usize;
239    let mut results = Vec::new();
240
241    while let Some(item) = futures.next().await {
242        completed += 1;
243        match item {
244            TaskResult::Port(ip, port, proto, open) => {
245                if open {
246                    let svc = svc_map.get(&port).copied().unwrap_or("");
247                    let label = match proto {
248                        ScanProtocol::Tcp => "tcp",
249                        ScanProtocol::Udp => "udp",
250                        ScanProtocol::Both => "both",
251                    };
252                    print!("\r{:<80}\r", "");
253                    println!("{}:{}/{} {}", ip, port, label, svc);
254                    results.push((ip, port, proto));
255                }
256
257                if progress {
258                    let filled = bar_width * completed / total;
259                    let bar = format!("[{}{}]", "█".repeat(filled), "░".repeat(bar_width - filled));
260                    let text = format!(" {}:{} {}/{}", ip, port, completed, total);
261                    let line = format!("{} {}", spinner[spinner_idx % spinner.len()], bar);
262                    let out = format!("{}{}", line, text);
263                    print!("\r{:<60}", out);
264                    std::io::stdout().flush().ok();
265                    spinner_idx += 1;
266                }
267            }
268        }
269    }
270
271    if progress {
272        print!("\r{:<80}\r", "");
273    }
274
275    results.sort_by_key(|k| (k.0, k.1));
276
277    let mut unique_ips = std::collections::HashSet::new();
278    for (ip, _, _) in &results {
279        unique_ips.insert(*ip);
280    }
281
282    if os_detect && !unique_ips.is_empty() {
283        let mut os_futures: FuturesUnordered<BoxFuture<'static, (IpAddr, Option<String>)>> =
284            FuturesUnordered::new();
285        for ip in unique_ips {
286            let sem = semaphore.clone();
287            os_futures.push(Box::pin(async move {
288                let _permit = sem.acquire().await.unwrap();
289                let os = detect_os(ip, timeout).await;
290                (ip, os)
291            }));
292        }
293
294        let total_os = os_futures.len();
295        let mut completed_os = 0usize;
296        spinner_idx = 0;
297
298        while let Some((ip, os)) = os_futures.next().await {
299            completed_os += 1;
300            if let Some(name) = os {
301                os_map.insert(ip, name);
302            }
303            if progress {
304                let filled = bar_width * completed_os / total_os;
305                let bar = format!("[{}{}]", "█".repeat(filled), "░".repeat(bar_width - filled));
306                let text = format!(" {} OS {}/{}", ip, completed_os, total_os);
307                let line = format!("{} {}", bar, spinner[spinner_idx % spinner.len()]);
308                let out = format!("{}{}", line, text);
309                print!("\r{:<80}", out);
310                std::io::stdout().flush().ok();
311                spinner_idx += 1;
312            }
313        }
314
315        if progress {
316            print!("\r{:<80}\r", "");
317        }
318    }
319
320    if progress {
321        println!("\n\nResults:");
322        println!(
323            "{:<39} {:<6} {:<4} {:<8} {}",
324            "IP", "PORT", "PROTO", "SERVICE", "OS"
325        );
326        for (ip, port, proto) in &results {
327            let svc = svc_map.get(port).copied().unwrap_or("");
328            let label = match proto {
329                ScanProtocol::Tcp => "tcp",
330                ScanProtocol::Udp => "udp",
331                ScanProtocol::Both => "both",
332            };
333            let os = os_map.get(ip).cloned().unwrap_or_default();
334            println!("{:<39} {:<6} {:<4} {:<8} {}", ip, port, label, svc, os);
335        }
336    }
337
338    Ok((results, os_map))
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use std::net::{TcpListener, UdpSocket};
345    use tokio::runtime::Runtime;
346
347    #[test]
348    fn test_parse_port_range() {
349        let ports = parse_port_range("20-22").unwrap();
350        assert_eq!(ports, vec![20, 21, 22]);
351    }
352
353    #[test]
354    fn test_parse_port_single() {
355        let ports = parse_port_range("80").unwrap();
356        assert_eq!(ports, vec![80]);
357    }
358
359    #[test]
360    fn test_parse_ip_range_cidr() {
361        let ips = parse_ip_range("10.0.0.0/30").unwrap();
362        let expected: Vec<IpAddr> = vec!["10.0.0.1".parse().unwrap(), "10.0.0.2".parse().unwrap()];
363        assert_eq!(ips, expected);
364    }
365
366    #[test]
367    fn test_scan_tcp_finds_open_port() {
368        let rt = Runtime::new().unwrap();
369        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
370        let port = listener.local_addr().unwrap().port();
371        let found = rt.block_on(scan_tcp_port(
372            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
373            port,
374            Duration::from_secs(1),
375        ));
376        assert_eq!(found, Some(port));
377    }
378
379    #[test]
380    fn test_scan_targets_range_tcp() {
381        let rt = Runtime::new().unwrap();
382        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
383        let port = listener.local_addr().unwrap().port();
384        let (results, _os) = rt
385            .block_on(scan_targets(
386                "127.0.0.0/30",
387                &[port],
388                ScanProtocol::Tcp,
389                Duration::from_secs(1),
390                false,
391                false,
392            ))
393            .unwrap();
394        assert!(results.contains(&(
395            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
396            port,
397            ScanProtocol::Tcp
398        )));
399    }
400
401    #[test]
402    fn test_scan_udp_detects_port() {
403        let rt = Runtime::new().unwrap();
404        let sock = UdpSocket::bind("127.0.0.1:0").unwrap();
405        let port = sock.local_addr().unwrap().port();
406        let result = rt.block_on(scan_udp_port(
407            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
408            port,
409            Duration::from_millis(100),
410        ));
411        // Without a UDP service responding, the conservative scanner will
412        // return None.
413        assert_eq!(result, None);
414    }
415}