1use anti_common::{resolve_hostnames, PingError, PingResult};
2use anti_ping::icmp::IcmpPinger;
3use anti_ping::PingConfig;
4use futures::future::BoxFuture;
5use futures::stream::{FuturesUnordered, StreamExt};
6use ipnet::IpNet;
7use std::collections::HashMap;
8use std::io::Write;
9use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::{TcpStream, UdpSocket};
13use tokio::sync::Semaphore;
14use tokio::time::timeout;
15
16const CONCURRENT_LIMIT: usize = 512;
17
18pub fn parse_port_range(range: &str) -> PingResult<Vec<u16>> {
20 if !range.contains('-') {
21 let port: u16 = range.parse().map_err(|_| PingError::Configuration {
23 message: format!("Invalid port: {}", range),
24 })?;
25 if port == 0 {
26 return Err(PingError::Configuration {
27 message: "Port cannot be 0".into(),
28 });
29 }
30 return Ok(vec![port]);
31 }
32
33 let parts: Vec<&str> = range.split('-').collect();
34 if parts.len() != 2 {
35 return Err(PingError::Configuration {
36 message: format!("Invalid port range: {}", range),
37 });
38 }
39 let start: u16 = parts[0].parse().map_err(|_| PingError::Configuration {
40 message: format!("Invalid port: {}", parts[0]),
41 })?;
42 let end: u16 = parts[1].parse().map_err(|_| PingError::Configuration {
43 message: format!("Invalid port: {}", parts[1]),
44 })?;
45 if start == 0 || start > end {
46 return Err(PingError::Configuration {
47 message: format!("Invalid port range: {}", range),
48 });
49 }
50 Ok((start..=end).collect())
51}
52
53pub fn parse_ip_range(target: &str) -> PingResult<Vec<IpAddr>> {
55 if target.contains('/') {
56 let net: IpNet = target.parse().map_err(|_| PingError::Configuration {
57 message: format!("Invalid CIDR notation: {}", target),
58 })?;
59 Ok(net.hosts().collect())
60 } else {
61 resolve_hostnames(target)
62 }
63}
64
65fn service_map() -> HashMap<u16, &'static str> {
66 HashMap::from([
67 (20, "FTP"),
68 (21, "FTP"),
69 (22, "SSH"),
70 (23, "Telnet"),
71 (25, "SMTP"),
72 (53, "DNS"),
73 (80, "HTTP"),
74 (110, "POP3"),
75 (143, "IMAP"),
76 (443, "HTTPS"),
77 (3306, "MySQL"),
78 (5432, "Postgres"),
79 (6379, "Redis"),
80 ])
81}
82
83fn common_udp_ports() -> Vec<u16> {
84 vec![
85 53, 67, 68, 69, 123, 161, 500, 514, ]
93}
94
95fn udp_payload(port: u16) -> Vec<u8> {
96 match port {
97 53 => {
98 vec![
100 0x00, 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
101 ]
102 }
103 123 => {
104 let mut pkt = vec![0; 48];
105 pkt[0] = 0x1b; pkt
107 }
108 161 => vec![
109 0x30, 0x26, 0x02, 0x01, 0x01, 0x04, 0x06, b'p', b'u', b'b', b'l', b'i', b'c', 0xa0,
110 0x19, 0x02, 0x04, 0x70, 0x4b, 0x3a, 0x7f, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30,
111 0x0b, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x06, 0x01, 0x02, 0x01, 0x05, 0x00,
112 ],
113 _ => vec![0],
114 }
115}
116
117async fn detect_os(ip: IpAddr, timeout: Duration) -> Option<String> {
118 if let IpAddr::V4(v4) = ip {
119 let config = PingConfig {
120 target: v4,
121 count: 1,
122 timeout,
123 ..Default::default()
124 };
125 if let Ok(pinger) = IcmpPinger::new(config) {
126 if let Ok(reply) = pinger.ping(1) {
127 if let Some(ttl) = reply.ttl {
128 return Some(match ttl {
129 0..=64 => "unix".to_string(),
130 65..=128 => "windows".to_string(),
131 129..=255 => "network".to_string(),
132 });
133 }
134 }
135 }
136 }
137 None
138}
139
140async fn scan_tcp_port(ip: IpAddr, port: u16, timeout_dur: Duration) -> Option<u16> {
142 let addr = SocketAddr::new(ip, port);
143 match timeout(timeout_dur, TcpStream::connect(addr)).await {
144 Ok(Ok(_stream)) => Some(port),
145 _ => None,
146 }
147}
148
149async fn scan_udp_port(ip: IpAddr, port: u16, timeout_dur: Duration) -> Option<u16> {
156 let addr = SocketAddr::new(ip, port);
157 let bind_addr = match ip {
158 IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
159 IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
160 };
161 if let Ok(sock) = UdpSocket::bind(bind_addr).await {
162 if sock.connect(addr).await.is_ok() && sock.send(&udp_payload(port)).await.is_ok() {
163 let mut buf = [0u8; 512];
164 match timeout(timeout_dur, sock.recv(&mut buf)).await {
165 Ok(Ok(n)) if n > 0 => return Some(port),
166 Ok(Err(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => return None,
167 _ => {}
168 }
169 }
170 }
171 None
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
176pub enum ScanProtocol {
177 Tcp,
178 Udp,
179 Both,
180}
181
182pub fn common_ports_for(proto: ScanProtocol) -> Vec<u16> {
183 let mut ports = Vec::new();
184 if matches!(proto, ScanProtocol::Tcp | ScanProtocol::Both) {
185 ports.extend(anti_ping::tcp::get_common_tcp_ports());
186 }
187 if matches!(proto, ScanProtocol::Udp | ScanProtocol::Both) {
188 ports.extend(common_udp_ports());
189 }
190 ports
191}
192
193pub async fn scan_targets(
195 target: &str,
196 ports: &[u16],
197 proto: ScanProtocol,
198 timeout: Duration,
199 progress: bool,
200 os_detect: bool,
201) -> PingResult<(Vec<(IpAddr, u16, ScanProtocol)>, HashMap<IpAddr, String>)> {
202 let ips = parse_ip_range(target)?;
203 let svc_map = service_map();
204 let semaphore = Arc::new(Semaphore::new(CONCURRENT_LIMIT));
205 let mut os_map = HashMap::new();
206
207 enum TaskResult {
208 Port(IpAddr, u16, ScanProtocol, bool),
209 }
210
211 let mut futures: FuturesUnordered<BoxFuture<'static, TaskResult>> = FuturesUnordered::new();
212
213 for &ip in &ips {
214 for &port in ports {
215 if matches!(proto, ScanProtocol::Tcp | ScanProtocol::Both) {
216 let sem = semaphore.clone();
217 futures.push(Box::pin(async move {
218 let _permit = sem.acquire().await.unwrap();
219 let open = scan_tcp_port(ip, port, timeout).await.is_some();
220 TaskResult::Port(ip, port, ScanProtocol::Tcp, open)
221 }));
222 }
223 if matches!(proto, ScanProtocol::Udp | ScanProtocol::Both) {
224 let sem = semaphore.clone();
225 futures.push(Box::pin(async move {
226 let _permit = sem.acquire().await.unwrap();
227 let open = scan_udp_port(ip, port, timeout).await.is_some();
228 TaskResult::Port(ip, port, ScanProtocol::Udp, open)
229 }));
230 }
231 }
232 }
233
234 let mut spinner_idx = 0usize;
235 let spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'];
236 let total = futures.len();
237 let bar_width = 20usize;
238 let mut completed = 0usize;
239 let mut results = Vec::new();
240
241 while let Some(item) = futures.next().await {
242 completed += 1;
243 match item {
244 TaskResult::Port(ip, port, proto, open) => {
245 if open {
246 let svc = svc_map.get(&port).copied().unwrap_or("");
247 let label = match proto {
248 ScanProtocol::Tcp => "tcp",
249 ScanProtocol::Udp => "udp",
250 ScanProtocol::Both => "both",
251 };
252 print!("\r{:<80}\r", "");
253 println!("{}:{}/{} {}", ip, port, label, svc);
254 results.push((ip, port, proto));
255 }
256
257 if progress {
258 let filled = bar_width * completed / total;
259 let bar = format!("[{}{}]", "█".repeat(filled), "░".repeat(bar_width - filled));
260 let text = format!(" {}:{} {}/{}", ip, port, completed, total);
261 let line = format!("{} {}", spinner[spinner_idx % spinner.len()], bar);
262 let out = format!("{}{}", line, text);
263 print!("\r{:<60}", out);
264 std::io::stdout().flush().ok();
265 spinner_idx += 1;
266 }
267 }
268 }
269 }
270
271 if progress {
272 print!("\r{:<80}\r", "");
273 }
274
275 results.sort_by_key(|k| (k.0, k.1));
276
277 let mut unique_ips = std::collections::HashSet::new();
278 for (ip, _, _) in &results {
279 unique_ips.insert(*ip);
280 }
281
282 if os_detect && !unique_ips.is_empty() {
283 let mut os_futures: FuturesUnordered<BoxFuture<'static, (IpAddr, Option<String>)>> =
284 FuturesUnordered::new();
285 for ip in unique_ips {
286 let sem = semaphore.clone();
287 os_futures.push(Box::pin(async move {
288 let _permit = sem.acquire().await.unwrap();
289 let os = detect_os(ip, timeout).await;
290 (ip, os)
291 }));
292 }
293
294 let total_os = os_futures.len();
295 let mut completed_os = 0usize;
296 spinner_idx = 0;
297
298 while let Some((ip, os)) = os_futures.next().await {
299 completed_os += 1;
300 if let Some(name) = os {
301 os_map.insert(ip, name);
302 }
303 if progress {
304 let filled = bar_width * completed_os / total_os;
305 let bar = format!("[{}{}]", "█".repeat(filled), "░".repeat(bar_width - filled));
306 let text = format!(" {} OS {}/{}", ip, completed_os, total_os);
307 let line = format!("{} {}", bar, spinner[spinner_idx % spinner.len()]);
308 let out = format!("{}{}", line, text);
309 print!("\r{:<80}", out);
310 std::io::stdout().flush().ok();
311 spinner_idx += 1;
312 }
313 }
314
315 if progress {
316 print!("\r{:<80}\r", "");
317 }
318 }
319
320 if progress {
321 println!("\n\nResults:");
322 println!(
323 "{:<39} {:<6} {:<4} {:<8} {}",
324 "IP", "PORT", "PROTO", "SERVICE", "OS"
325 );
326 for (ip, port, proto) in &results {
327 let svc = svc_map.get(port).copied().unwrap_or("");
328 let label = match proto {
329 ScanProtocol::Tcp => "tcp",
330 ScanProtocol::Udp => "udp",
331 ScanProtocol::Both => "both",
332 };
333 let os = os_map.get(ip).cloned().unwrap_or_default();
334 println!("{:<39} {:<6} {:<4} {:<8} {}", ip, port, label, svc, os);
335 }
336 }
337
338 Ok((results, os_map))
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use std::net::{TcpListener, UdpSocket};
345 use tokio::runtime::Runtime;
346
347 #[test]
348 fn test_parse_port_range() {
349 let ports = parse_port_range("20-22").unwrap();
350 assert_eq!(ports, vec![20, 21, 22]);
351 }
352
353 #[test]
354 fn test_parse_port_single() {
355 let ports = parse_port_range("80").unwrap();
356 assert_eq!(ports, vec![80]);
357 }
358
359 #[test]
360 fn test_parse_ip_range_cidr() {
361 let ips = parse_ip_range("10.0.0.0/30").unwrap();
362 let expected: Vec<IpAddr> = vec!["10.0.0.1".parse().unwrap(), "10.0.0.2".parse().unwrap()];
363 assert_eq!(ips, expected);
364 }
365
366 #[test]
367 fn test_scan_tcp_finds_open_port() {
368 let rt = Runtime::new().unwrap();
369 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
370 let port = listener.local_addr().unwrap().port();
371 let found = rt.block_on(scan_tcp_port(
372 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
373 port,
374 Duration::from_secs(1),
375 ));
376 assert_eq!(found, Some(port));
377 }
378
379 #[test]
380 fn test_scan_targets_range_tcp() {
381 let rt = Runtime::new().unwrap();
382 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
383 let port = listener.local_addr().unwrap().port();
384 let (results, _os) = rt
385 .block_on(scan_targets(
386 "127.0.0.0/30",
387 &[port],
388 ScanProtocol::Tcp,
389 Duration::from_secs(1),
390 false,
391 false,
392 ))
393 .unwrap();
394 assert!(results.contains(&(
395 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
396 port,
397 ScanProtocol::Tcp
398 )));
399 }
400
401 #[test]
402 fn test_scan_udp_detects_port() {
403 let rt = Runtime::new().unwrap();
404 let sock = UdpSocket::bind("127.0.0.1:0").unwrap();
405 let port = sock.local_addr().unwrap().port();
406 let result = rt.block_on(scan_udp_port(
407 IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
408 port,
409 Duration::from_millis(100),
410 ));
411 assert_eq!(result, None);
414 }
415}