rmpca 0.2.0

Enterprise-grade unified CLI for rmp.ca operations - Rust port
//! Extract OSM data command
//!
//! This command downloads and converts OSM data to GeoJSON
//! for a given bounding box using the Overpass API.

use crate::config::Config;
use anyhow::{Context, Result};
use clap::Args as ClapArgs;
use geojson::{Feature, FeatureCollection, GeoJson, Geometry, Value};
use std::collections::HashMap;
use std::path::PathBuf;

#[derive(Debug, ClapArgs)]
pub struct Args {
    /// Bounding box: MIN_LON,MIN_LAT,MAX_LON,MAX_LAT
    #[arg(long)]
    pub bbox: String,

    /// Output file (default: stdout)
    #[arg(short, long)]
    pub output: Option<PathBuf>,

    /// Highway class filter (comma-separated, e.g. "primary,secondary,trunk")
    #[arg(long)]
    pub highway: Option<String>,

    /// Overpass API endpoint (default: https://overpass-api.de/api/interpreter)
    #[arg(long, default_value = "https://overpass-api.de/api/interpreter")]
    pub overpass_url: String,
}

/// Road types that are generally drivable by motor vehicles.
const DRIVABLE_HIGHWAYS: &[&str] = &[
    "motorway", "trunk", "primary", "secondary", "tertiary",
    "unclassified", "residential", "motorway_link", "trunk_link",
    "primary_link", "secondary_link", "tertiary_link", "living_street", "service",
];

fn is_drivable(highway: &str) -> bool {
    DRIVABLE_HIGHWAYS.contains(&highway)
}

/// Parse bbox string into (min_lon, min_lat, max_lon, max_lat)
fn parse_bbox(bbox: &str) -> Result<(f64, f64, f64, f64)> {
    let parts: Vec<f64> = bbox
        .split(',')
        .map(|s| s.trim().parse::<f64>())
        .collect::<Result<Vec<f64>, _>>()
        .context("Invalid bbox format. Use MIN_LON,MIN_LAT,MAX_LON,MAX_LAT")?;

    if parts.len() != 4 {
        anyhow::bail!("bbox must have 4 values: MIN_LON,MIN_LAT,MAX_LON,MAX_LAT");
    }

    if parts[0] >= parts[2] || parts[1] >= parts[3] {
        anyhow::bail!("Invalid bbox: min must be less than max in each dimension");
    }

    Ok((parts[0], parts[1], parts[2], parts[3]))
}

/// Build Overpass QL query
fn build_overpass_query(bbox: &str, highway_filter: &Option<String>) -> String {
    let (min_lon, min_lat, max_lon, max_lat) = match parse_bbox(bbox) {
        Ok(v) => v,
        Err(_) => return String::new(),
    };

    let bbox_str = format!("{},{},{},{}", min_lat, min_lon, max_lat, max_lon);

    match highway_filter {
        Some(filter) => {
            let highways: Vec<&str> = filter.split(',').map(|s| s.trim()).collect();
            let way_filters: Vec<String> = highways
                .iter()
                .map(|h| format!("[\"highway\"=\"{}\"]", h))
                .collect();
            format!(
                "[out:xml][timeout:60];\nway{}({});\nout body;\n>;\nout skel qt;",
                way_filters.join(""),
                bbox_str
            )
        }
        None => {
            // All drivable roads
            format!(
                "[out:xml][timeout:60];\nway[\"highway\"]({});\nout body;\n>;\nout skel qt;",
                bbox_str
            )
        }
    }
}

/// Parse Overpass XML response into nodes and ways
fn parse_overpass_xml(xml: &str) -> Result<(HashMap<i64, (f64, f64)>, Vec<OverpassWay>)> {
    use quick_xml::events::Event;
    use quick_xml::Reader;

    let mut reader = Reader::from_str(xml);
    let mut nodes: HashMap<i64, (f64, f64)> = HashMap::new();
    let mut ways: Vec<OverpassWay> = Vec::new();

    let mut current_way: Option<OverpassWay> = None;
    let mut current_tags: HashMap<String, String> = HashMap::new();
    let mut buf = Vec::new();

    loop {
        match reader.read_event_into(&mut buf) {
            Ok(Event::Start(ref e)) | Ok(Event::Empty(ref e)) => {
                match e.local_name().as_ref() {
                    b"node" => {
                        let mut id: i64 = 0;
                        let mut lat: f64 = 0.0;
                        let mut lon: f64 = 0.0;
                        for attr in e.attributes().flatten() {
                            match attr.key.as_ref() {
                                b"id" => id = String::from_utf8_lossy(&attr.value).parse().unwrap_or(0),
                                b"lat" => lat = String::from_utf8_lossy(&attr.value).parse().unwrap_or(0.0),
                                b"lon" => lon = String::from_utf8_lossy(&attr.value).parse().unwrap_or(0.0),
                                _ => {}
                            }
                        }
                        nodes.insert(id, (lat, lon));
                    }
                    b"way" => {
                        // Save previous way if any
                        if let Some(mut w) = current_way.take() {
                            w.tags = current_tags.clone();
                            ways.push(w);
                        }
                        current_tags.clear();
                        let mut id: i64 = 0;
                        for attr in e.attributes().flatten() {
                            if attr.key.as_ref() == b"id" {
                                id = String::from_utf8_lossy(&attr.value).parse().unwrap_or(0);
                            }
                        }
                        current_way = Some(OverpassWay {
                            id,
                            node_ids: Vec::new(),
                            tags: HashMap::new(),
                        });
                    }
                    b"nd" => {
                        if let Some(ref mut way) = current_way {
                            for attr in e.attributes().flatten() {
                                if attr.key.as_ref() == b"ref" {
                                    if let Ok(ref_id) = String::from_utf8_lossy(&attr.value).parse::<i64>() {
                                        way.node_ids.push(ref_id);
                                    }
                                }
                            }
                        }
                    }
                    b"tag" => {
                        let mut k = String::new();
                        let mut v = String::new();
                        for attr in e.attributes().flatten() {
                            match attr.key.as_ref() {
                                b"k" => k = String::from_utf8_lossy(&attr.value).into_owned(),
                                b"v" => v = String::from_utf8_lossy(&attr.value).into_owned(),
                                _ => {}
                            }
                        }
                        if !k.is_empty() {
                            current_tags.insert(k, v);
                        }
                    }
                    _ => {}
                }
            }
            Ok(Event::End(ref e)) => {
                match e.local_name().as_ref() {
                    b"way" => {
                        if let Some(mut w) = current_way.take() {
                            w.tags = current_tags.clone();
                            ways.push(w);
                        }
                        current_tags.clear();
                    }
                    _ => {}
                }
            }
            Ok(Event::Eof) => break,
            Err(e) => {
                return Err(anyhow::anyhow!("XML parse error: {}", e));
            }
            _ => {}
        }
        buf.clear();
    }

    Ok((nodes, ways))
}

struct OverpassWay {
    id: i64,
    node_ids: Vec<i64>,
    tags: HashMap<String, String>,
}

/// Extract OSM data to GeoJSON
pub async fn run(args: Args) -> Result<()> {
    let config = Config::load().unwrap_or_default();
    config.init_logging();

    // Validate bbox
    let _bbox = parse_bbox(&args.bbox)?;

    tracing::info!("Extracting OSM data for bbox: {}", args.bbox);

    // Build Overpass query
    let query = build_overpass_query(&args.bbox, &args.highway);
    if query.is_empty() {
        anyhow::bail!("Failed to build Overpass query");
    }

    tracing::debug!("Overpass query:\n{}", query);

    // Send request to Overpass API
    let client = reqwest::Client::builder()
        .timeout(std::time::Duration::from_secs(config.timeout_secs))
        .build()?;

    tracing::info!("Sending request to Overpass API...");

    let response = client
        .post(&args.overpass_url)
        .form(&[("data", &query)])
        .send()
        .await
        .context("Failed to connect to Overpass API")?;

    if !response.status().is_success() {
        let status = response.status();
        let body = response.text().await.unwrap_or_default();
        anyhow::bail!("Overpass API returned {}: {}", status, body);
    }

    let xml = response
        .text()
        .await
        .context("Failed to read Overpass API response")?;

    tracing::info!("Received {} bytes from Overpass API", xml.len());

    // Parse the response
    let (nodes, ways) = parse_overpass_xml(&xml)?;

    tracing::info!("Parsed {} nodes, {} ways", nodes.len(), ways.len());

    // Filter ways by highway type
    let highway_filter: Option<Vec<String>> = args
        .highway
        .as_ref()
        .map(|h| h.split(',').map(|s| s.trim().to_string()).collect());

    let filtered_ways: Vec<&OverpassWay> = ways
        .iter()
        .filter(|way| {
            let highway = way.tags.get("highway").map(|s| s.as_str()).unwrap_or("");
            if highway.is_empty() {
                return false;
            }
            if let Some(ref allowed) = highway_filter {
                allowed.iter().any(|h| h == highway)
            } else {
                is_drivable(highway)
            }
        })
        .collect();

    tracing::info!("After filtering: {} ways", filtered_ways.len());

    // Convert to GeoJSON
    let mut features = Vec::new();
    for way in &filtered_ways {
        let coords: Vec<Vec<f64>> = way
            .node_ids
            .iter()
            .filter_map(|nid| nodes.get(nid))
            .map(|(lat, lon)| vec![*lon, *lat])
            .collect();

        if coords.len() < 2 {
            continue;
        }

        let geometry = Geometry::new(Value::LineString(coords));

        let mut properties = serde_json::Map::new();
        properties.insert("id".to_string(), serde_json::Value::String(way.id.to_string()));
        for (k, v) in &way.tags {
            properties.insert(k.clone(), serde_json::Value::String(v.clone()));
        }

        features.push(Feature {
            geometry: Some(geometry),
            properties: Some(properties),
            ..Default::default()
        });
    }

    let fc = FeatureCollection {
        features,
        bbox: None,
        foreign_members: None,
    };

    let geojson = GeoJson::from(fc);
    let json = serde_json::to_string_pretty(&geojson)
        .context("Failed to serialize GeoJSON")?;

    match &args.output {
        Some(path) => {
            std::fs::write(path, &json)
                .with_context(|| format!("Failed to write to {}", path.display()))?;
            tracing::info!("GeoJSON written to {}", path.display());
        }
        None => println!("{}", json),
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_extract_osm_args() {
        let args = Args {
            bbox: "-73.59,45.49,-73.55,45.52".to_string(),
            output: None,
            highway: Some("primary,secondary".to_string()),
            overpass_url: "https://overpass-api.de/api/interpreter".to_string(),
        };
        assert_eq!(args.bbox, "-73.59,45.49,-73.55,45.52");
        assert_eq!(args.highway, Some("primary,secondary".to_string()));
    }

    #[test]
    fn test_parse_bbox() {
        let bbox = parse_bbox("-73.59,45.49,-73.55,45.52").unwrap();
        assert_eq!(bbox, (-73.59, 45.49, -73.55, 45.52));

        // Invalid: too few values
        assert!(parse_bbox("1,2,3").is_err());

        // Invalid: min >= max
        assert!(parse_bbox("1,2,0,3").is_err());
    }

    #[test]
    fn test_build_overpass_query() {
        let query = build_overpass_query("-73.59,45.49,-73.55,45.52", &None);
        assert!(query.contains("[out:xml]"));
        assert!(query.contains("way[\"highway\"]"));

        let query = build_overpass_query("-73.59,45.49,-73.55,45.52", &Some("primary".to_string()));
        assert!(query.contains("[\"highway\"=\"primary\"]"));
    }

    #[test]
    fn test_drivable_highways() {
        assert!(is_drivable("motorway"));
        assert!(is_drivable("residential"));
        assert!(!is_drivable("footway"));
        assert!(!is_drivable("cycleway"));
    }
}