gdelt 0.1.0

CLI for GDELT Project - optimized for agentic usage with local data caching
//! Time-series trend analysis for GDELT data.

#![allow(dead_code)]

use crate::db::AnalyticsDb;
use crate::error::Result;
use serde::{Deserialize, Serialize};

/// Granularity for time-series analysis
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Granularity {
    Hour,
    Day,
    Week,
    Month,
}

impl Granularity {
    /// Get SQL date truncation expression
    pub fn sql_trunc(&self) -> &'static str {
        match self {
            Self::Hour => "date_trunc('hour', to_timestamp(sql_date::varchar, 'YYYYMMDD'))",
            Self::Day => "sql_date",
            Self::Week => "date_trunc('week', to_timestamp(sql_date::varchar, 'YYYYMMDD'))",
            Self::Month => "date_trunc('month', to_timestamp(sql_date::varchar, 'YYYYMMDD'))",
        }
    }
}

/// A single point in a trend time-series
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrendPoint {
    /// Time bucket (format depends on granularity)
    pub time: String,
    /// Count of events/articles
    pub count: i64,
    /// Average tone (if available)
    pub avg_tone: Option<f64>,
    /// Normalized value (0-100 if normalize=true)
    pub normalized: Option<f64>,
    /// Z-score for anomaly detection
    pub z_score: Option<f64>,
}

/// Result of trend analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrendsResult {
    /// Topic/query analyzed
    pub topic: String,
    /// Time series data points
    pub points: Vec<TrendPoint>,
    /// Total count across all points
    pub total_count: i64,
    /// Average count per period
    pub avg_count: f64,
    /// Standard deviation
    pub std_dev: f64,
    /// Detected anomalies (indices into points)
    pub anomalies: Vec<usize>,
}

/// Configuration for trend analysis
#[derive(Debug, Clone)]
pub struct TrendsConfig {
    pub topics: Vec<String>,
    pub granularity: Granularity,
    pub start_date: Option<String>,
    pub end_date: Option<String>,
    pub normalize: bool,
    pub detect_anomalies: bool,
    pub anomaly_threshold: f64,
}

impl Default for TrendsConfig {
    fn default() -> Self {
        Self {
            topics: Vec::new(),
            granularity: Granularity::Day,
            start_date: None,
            end_date: None,
            normalize: false,
            detect_anomalies: false,
            anomaly_threshold: 2.0, // Z-score threshold
        }
    }
}

/// Analyze trends for given topics
pub fn analyze_trends(db: &AnalyticsDb, config: &TrendsConfig) -> Result<Vec<TrendsResult>> {
    let mut results = Vec::new();

    for topic in &config.topics {
        let result = analyze_topic_trends(db, topic, config)?;
        results.push(result);
    }

    Ok(results)
}

fn analyze_topic_trends(db: &AnalyticsDb, topic: &str, config: &TrendsConfig) -> Result<TrendsResult> {
    // Build date filter
    let mut conditions = Vec::new();

    if let Some(ref start) = config.start_date {
        conditions.push(format!("sql_date >= {}", start.replace('-', "")));
    }
    if let Some(ref end) = config.end_date {
        conditions.push(format!("sql_date <= {}", end.replace('-', "")));
    }

    // Add topic filter (search in actor names and geo locations)
    let topic_filter = format!(
        "(actor1_name ILIKE '%{}%' OR actor2_name ILIKE '%{}%' OR action_geo_fullname ILIKE '%{}%')",
        topic, topic, topic
    );
    conditions.push(topic_filter);

    let where_clause = if conditions.is_empty() {
        String::new()
    } else {
        format!("WHERE {}", conditions.join(" AND "))
    };

    // Build SQL based on granularity
    let time_bucket = match config.granularity {
        Granularity::Hour => "CAST(sql_date AS VARCHAR) || LPAD(CAST((date_added / 10000) % 100 AS VARCHAR), 2, '0')",
        Granularity::Day => "CAST(sql_date AS VARCHAR)",
        Granularity::Week => "CAST(DATE_TRUNC('week', TO_DATE(CAST(sql_date AS VARCHAR), 'YYYYMMDD')) AS VARCHAR)",
        Granularity::Month => "SUBSTR(CAST(sql_date AS VARCHAR), 1, 6)",
    };

    let sql = format!(
        r#"
        SELECT
            {} as time_bucket,
            COUNT(*) as event_count,
            AVG(avg_tone) as avg_tone
        FROM events
        {}
        GROUP BY time_bucket
        ORDER BY time_bucket
        "#,
        time_bucket, where_clause
    );

    let query_result = db.query(&sql)?;

    // Convert rows to TrendPoints
    let mut points: Vec<TrendPoint> = Vec::new();
    let mut total_count: i64 = 0;

    for row in &query_result.rows {
        let time = row.get(0)
            .and_then(|v| v.as_str())
            .unwrap_or("")
            .to_string();
        let count = row.get(1)
            .and_then(|v| v.as_i64())
            .unwrap_or(0);
        let avg_tone = row.get(2)
            .and_then(|v| v.as_f64());

        total_count += count;
        points.push(TrendPoint {
            time,
            count,
            avg_tone,
            normalized: None,
            z_score: None,
        });
    }

    // Calculate statistics
    let n = points.len() as f64;
    let avg_count = if n > 0.0 { total_count as f64 / n } else { 0.0 };

    let variance = if n > 1.0 {
        points.iter()
            .map(|p| (p.count as f64 - avg_count).powi(2))
            .sum::<f64>() / (n - 1.0)
    } else {
        0.0
    };
    let std_dev = variance.sqrt();

    // Calculate normalized values and z-scores
    let max_count = points.iter().map(|p| p.count).max().unwrap_or(1);
    let mut anomalies = Vec::new();

    for (i, point) in points.iter_mut().enumerate() {
        if config.normalize {
            point.normalized = Some(point.count as f64 / max_count as f64 * 100.0);
        }

        if config.detect_anomalies && std_dev > 0.0 {
            let z = (point.count as f64 - avg_count) / std_dev;
            point.z_score = Some(z);
            if z.abs() >= config.anomaly_threshold {
                anomalies.push(i);
            }
        }
    }

    Ok(TrendsResult {
        topic: topic.to_string(),
        points,
        total_count,
        avg_count,
        std_dev,
        anomalies,
    })
}