gdelt 0.1.0

CLI for GDELT Project - optimized for agentic usage with local data caching
//! Topic comparison analytics.

#![allow(dead_code)]

use crate::db::AnalyticsDb;
use crate::error::Result;
use serde::{Deserialize, Serialize};

/// Result of comparing two topics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompareResult {
    /// First topic
    pub topic_a: TopicStats,
    /// Second topic
    pub topic_b: TopicStats,
    /// Comparison metrics
    pub comparison: ComparisonMetrics,
}

/// Statistics for a single topic
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopicStats {
    /// Topic name/query
    pub name: String,
    /// Total event count
    pub event_count: i64,
    /// Average tone
    pub avg_tone: f64,
    /// Average Goldstein scale
    pub avg_goldstein: f64,
    /// Top countries by event count
    pub top_countries: Vec<CountryCount>,
    /// Event type distribution
    pub event_types: Vec<EventTypeCount>,
    /// Date range
    pub date_range: Option<(String, String)>,
}

/// Country event count
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountryCount {
    pub country: String,
    pub count: i64,
    pub percentage: f64,
}

/// Event type count
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventTypeCount {
    pub event_code: String,
    pub description: String,
    pub count: i64,
    pub percentage: f64,
}

/// Comparison metrics between two topics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComparisonMetrics {
    /// Ratio of event counts (A/B)
    pub count_ratio: f64,
    /// Difference in average tone
    pub tone_diff: f64,
    /// Difference in Goldstein scale
    pub goldstein_diff: f64,
    /// Jaccard similarity of countries
    pub country_similarity: f64,
    /// Jaccard similarity of event types
    pub event_type_similarity: f64,
}

/// Configuration for topic comparison
#[derive(Debug, Clone)]
pub struct CompareConfig {
    pub topic_a: String,
    pub topic_b: String,
    pub start_date: Option<String>,
    pub end_date: Option<String>,
}

/// Compare two topics
pub fn compare_topics(db: &AnalyticsDb, config: &CompareConfig) -> Result<CompareResult> {
    let topic_a = analyze_topic(db, &config.topic_a, &config.start_date, &config.end_date)?;
    let topic_b = analyze_topic(db, &config.topic_b, &config.start_date, &config.end_date)?;

    let comparison = calculate_comparison(&topic_a, &topic_b);

    Ok(CompareResult {
        topic_a,
        topic_b,
        comparison,
    })
}

fn analyze_topic(
    db: &AnalyticsDb,
    topic: &str,
    start_date: &Option<String>,
    end_date: &Option<String>,
) -> Result<TopicStats> {
    let mut conditions = Vec::new();

    if let Some(ref start) = start_date {
        conditions.push(format!("sql_date >= {}", start.replace('-', "")));
    }
    if let Some(ref end) = end_date {
        conditions.push(format!("sql_date <= {}", end.replace('-', "")));
    }

    // Topic filter
    let topic_filter = format!(
        "(actor1_name ILIKE '%{}%' OR actor2_name ILIKE '%{}%' OR action_geo_fullname ILIKE '%{}%')",
        topic, topic, topic
    );
    conditions.push(topic_filter);

    let where_clause = format!("WHERE {}", conditions.join(" AND "));

    // Get overall stats
    let stats_sql = format!(
        r#"
        SELECT
            COUNT(*) as cnt,
            AVG(avg_tone) as tone,
            AVG(goldstein_scale) as goldstein,
            MIN(sql_date) as min_date,
            MAX(sql_date) as max_date
        FROM events
        {}
        "#,
        where_clause
    );

    let stats_result = db.query(&stats_sql)?;
    let stats_row = stats_result.rows.first();

    let event_count = stats_row
        .and_then(|r| r.get(0))
        .and_then(|v| v.as_i64())
        .unwrap_or(0);
    let avg_tone = stats_row
        .and_then(|r| r.get(1))
        .and_then(|v| v.as_f64())
        .unwrap_or(0.0);
    let avg_goldstein = stats_row
        .and_then(|r| r.get(2))
        .and_then(|v| v.as_f64())
        .unwrap_or(0.0);
    let min_date = stats_row
        .and_then(|r| r.get(3))
        .and_then(|v| v.as_i64())
        .map(|d| d.to_string());
    let max_date = stats_row
        .and_then(|r| r.get(4))
        .and_then(|v| v.as_i64())
        .map(|d| d.to_string());

    let date_range = match (min_date, max_date) {
        (Some(min), Some(max)) => Some((min, max)),
        _ => None,
    };

    // Get top countries
    let countries_sql = format!(
        r#"
        SELECT
            action_geo_country_code,
            COUNT(*) as cnt
        FROM events
        {}
        AND action_geo_country_code IS NOT NULL
        GROUP BY action_geo_country_code
        ORDER BY cnt DESC
        LIMIT 10
        "#,
        where_clause
    );

    let countries_result = db.query(&countries_sql)?;
    let top_countries: Vec<CountryCount> = countries_result
        .rows
        .iter()
        .map(|row| {
            let country = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
            let count = row.get(1).and_then(|v| v.as_i64()).unwrap_or(0);
            let percentage = if event_count > 0 {
                count as f64 / event_count as f64 * 100.0
            } else {
                0.0
            };
            CountryCount { country, count, percentage }
        })
        .collect();

    // Get event type distribution
    let events_sql = format!(
        r#"
        SELECT
            event_root_code,
            COUNT(*) as cnt
        FROM events
        {}
        AND event_root_code IS NOT NULL
        GROUP BY event_root_code
        ORDER BY cnt DESC
        LIMIT 10
        "#,
        where_clause
    );

    let events_result = db.query(&events_sql)?;
    let event_types: Vec<EventTypeCount> = events_result
        .rows
        .iter()
        .map(|row| {
            let event_code = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
            let count = row.get(1).and_then(|v| v.as_i64()).unwrap_or(0);
            let percentage = if event_count > 0 {
                count as f64 / event_count as f64 * 100.0
            } else {
                0.0
            };
            EventTypeCount {
                event_code: event_code.clone(),
                description: get_event_description(&event_code),
                count,
                percentage,
            }
        })
        .collect();

    Ok(TopicStats {
        name: topic.to_string(),
        event_count,
        avg_tone,
        avg_goldstein,
        top_countries,
        event_types,
        date_range,
    })
}

fn calculate_comparison(topic_a: &TopicStats, topic_b: &TopicStats) -> ComparisonMetrics {
    let count_ratio = if topic_b.event_count > 0 {
        topic_a.event_count as f64 / topic_b.event_count as f64
    } else {
        0.0
    };

    let tone_diff = topic_a.avg_tone - topic_b.avg_tone;
    let goldstein_diff = topic_a.avg_goldstein - topic_b.avg_goldstein;

    // Jaccard similarity for countries
    let countries_a: std::collections::HashSet<_> = topic_a.top_countries.iter().map(|c| &c.country).collect();
    let countries_b: std::collections::HashSet<_> = topic_b.top_countries.iter().map(|c| &c.country).collect();
    let country_similarity = jaccard_similarity(&countries_a, &countries_b);

    // Jaccard similarity for event types
    let events_a: std::collections::HashSet<_> = topic_a.event_types.iter().map(|e| &e.event_code).collect();
    let events_b: std::collections::HashSet<_> = topic_b.event_types.iter().map(|e| &e.event_code).collect();
    let event_type_similarity = jaccard_similarity(&events_a, &events_b);

    ComparisonMetrics {
        count_ratio,
        tone_diff,
        goldstein_diff,
        country_similarity,
        event_type_similarity,
    }
}

fn jaccard_similarity<T: std::hash::Hash + Eq>(set_a: &std::collections::HashSet<T>, set_b: &std::collections::HashSet<T>) -> f64 {
    let intersection = set_a.intersection(set_b).count();
    let union = set_a.union(set_b).count();

    if union == 0 {
        0.0
    } else {
        intersection as f64 / union as f64
    }
}

fn get_event_description(code: &str) -> String {
    // Basic CAMEO event code descriptions
    match code {
        "01" => "Make public statement".to_string(),
        "02" => "Appeal".to_string(),
        "03" => "Express intent to cooperate".to_string(),
        "04" => "Consult".to_string(),
        "05" => "Engage in diplomatic cooperation".to_string(),
        "06" => "Engage in material cooperation".to_string(),
        "07" => "Provide aid".to_string(),
        "08" => "Yield".to_string(),
        "09" => "Investigate".to_string(),
        "10" => "Demand".to_string(),
        "11" => "Disapprove".to_string(),
        "12" => "Reject".to_string(),
        "13" => "Threaten".to_string(),
        "14" => "Protest".to_string(),
        "15" => "Exhibit force posture".to_string(),
        "16" => "Reduce relations".to_string(),
        "17" => "Coerce".to_string(),
        "18" => "Assault".to_string(),
        "19" => "Fight".to_string(),
        "20" => "Use unconventional mass violence".to_string(),
        _ => format!("Event code {}", code),
    }
}