hippox-drivers 0.3.3

🦛All indivisible atomic driver units in Hippox.
//! Port scanning skill

use crate::DriverCallback;
use crate::DriverContext;
use crate::{
    DriverCategory,
    common::net::{get_service_name, parse_ports, resolve_host},
    types::{Driver, DriverParameter},
};
use anyhow::Result;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Semaphore;
use tokio::time::timeout;

#[derive(Debug)]
pub struct PortScanDriver;

#[async_trait::async_trait]
impl Driver for PortScanDriver {
    fn name(&self) -> &str {
        "port_scan"
    }

    fn description(&self) -> &str {
        "Scan ports on a target host to discover open ports and services"
    }

    fn usage_hint(&self) -> &str {
        "Use this skill to find open ports, check service availability, or perform network reconnaissance"
    }

    fn parameters(&self) -> Vec<DriverParameter> {
        vec![
            DriverParameter {
                name: "target".to_string(),
                param_type: "string".to_string(),
                description: "Target hostname or IP address".to_string(),
                required: true,
                default: None,
                example: Some(Value::String("scanme.nmap.org".to_string())),
                enum_values: None,
            },
            DriverParameter {
                name: "ports".to_string(),
                param_type: "string".to_string(),
                description: "Ports to scan (e.g., '80', '1-1024', '22,80,443')".to_string(),
                required: false,
                default: Some(Value::String("1-1024".to_string())),
                example: Some(Value::String("22,80,443".to_string())),
                enum_values: None,
            },
            DriverParameter {
                name: "timeout".to_string(),
                param_type: "integer".to_string(),
                description: "Connection timeout in seconds".to_string(),
                required: false,
                default: Some(Value::Number(2.into())),
                example: Some(Value::Number(3.into())),
                enum_values: None,
            },
            DriverParameter {
                name: "concurrency".to_string(),
                param_type: "integer".to_string(),
                description: "Number of concurrent connection attempts".to_string(),
                required: false,
                default: Some(Value::Number(100.into())),
                example: Some(Value::Number(50.into())),
                enum_values: None,
            },
        ]
    }

    fn example_call(&self) -> Value {
        json!({
            "action": "port_scan",
            "parameters": {
                "target": "localhost",
                "ports": "1-1000"
            }
        })
    }

    fn example_output(&self) -> String {
        "Scanning localhost (127.0.0.1)\nPort 22: Open - SSH\nPort 80: Open - HTTP\nPort 443: Open - HTTPS\nTotal open ports: 3\nScan completed in 2.5 seconds".to_string()
    }

    fn category(&self) -> DriverCategory {
        DriverCategory::Network
    }

    async fn execute(
        &self,
        parameters: &HashMap<String, Value>,
        callback: Option<&dyn DriverCallback>,
        context: Option<&DriverContext>,
    ) -> Result<String> {
        let target = get_param_string(parameters, "target")?;
        let ports_spec = parameters
            .get("ports")
            .and_then(|v| v.as_str())
            .unwrap_or("1-1024");
        let timeout_secs = get_param_u64(parameters, "timeout", 2);
        let concurrency = get_param_u64(parameters, "concurrency", 100) as usize;

        let ip = resolve_host(&target)?;
        let ports = parse_ports(ports_spec)?;
        let total_ports = ports.len();
        let start_time = std::time::Instant::now();

        let semaphore = Arc::new(Semaphore::new(concurrency));
        let mut tasks = vec![];

        for port in ports {
            let permit = semaphore.clone().acquire_owned().await?;
            let target_ip = ip;
            let timeout_dur = Duration::from_secs(timeout_secs);

            tasks.push(tokio::spawn(async move {
                let is_open = scan_port(target_ip, port, timeout_dur).await;
                drop(permit);
                (port, is_open)
            }));
        }

        let mut open_ports = Vec::new();
        for task in tasks {
            if let Ok((port, true)) = task.await {
                open_ports.push(port);
            }
        }

        open_ports.sort();
        let duration = start_time.elapsed();

        let mut result = format!("Scanning {} ({})\n", target, ip);
        result.push_str(&format!("Total ports scanned: {}\n", total_ports));

        if !open_ports.is_empty() {
            result.push_str(&format!("\nOpen ports: {}\n", open_ports.len()));
            for port in &open_ports {
                result.push_str(&format!(
                    "  Port {}: Open - {}\n",
                    port,
                    get_service_name(*port)
                ));
            }
        } else {
            result.push_str("\nNo open ports found\n");
        }

        result.push_str(&format!(
            "\nScan completed in {:.2} seconds",
            duration.as_secs_f64()
        ));
        Ok(result)
    }
}

async fn scan_port(ip: std::net::IpAddr, port: u16, timeout_dur: Duration) -> bool {
    let addr = std::net::SocketAddr::new(ip, port);
    match timeout(timeout_dur, TcpStream::connect(&addr)).await {
        Ok(Ok(_)) => true,
        _ => false,
    }
}

fn get_param_string(params: &HashMap<String, Value>, name: &str) -> Result<String> {
    params
        .get(name)
        .and_then(|v| v.as_str())
        .map(|s| s.to_string())
        .ok_or_else(|| anyhow::anyhow!("Missing parameter: {}", name))
}

fn get_param_u64(params: &HashMap<String, Value>, name: &str, default: u64) -> u64 {
    params.get(name).and_then(|v| v.as_u64()).unwrap_or(default)
}