gdelt 0.1.0

CLI for GDELT Project - optimized for agentic usage with local data caching
//! Sentiment analysis for GDELT data.

#![allow(dead_code)]

use crate::db::AnalyticsDb;
use crate::error::Result;
use serde::{Deserialize, Serialize};

/// Dimension for sentiment aggregation
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SentimentDimension {
    Time,
    Region,
    Source,
    Entity,
}

/// A single sentiment data point
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentimentPoint {
    /// Dimension value (time bucket, region, source, etc.)
    pub dimension: String,
    /// Average tone (-10 to +10)
    pub avg_tone: f64,
    /// Positive tone score
    pub positive_score: Option<f64>,
    /// Negative tone score
    pub negative_score: Option<f64>,
    /// Number of records
    pub count: i64,
    /// Polarity (abs(positive - negative))
    pub polarity: Option<f64>,
}

/// Result of sentiment analysis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SentimentResult {
    /// Topic analyzed
    pub topic: String,
    /// Sentiment data points
    pub points: Vec<SentimentPoint>,
    /// Overall average tone
    pub overall_tone: f64,
    /// Tone trend (positive = improving, negative = declining)
    pub tone_trend: f64,
    /// Total records analyzed
    pub total_count: i64,
}

/// Configuration for sentiment analysis
#[derive(Debug, Clone)]
pub struct SentimentConfig {
    pub topic: String,
    pub dimension: SentimentDimension,
    pub start_date: Option<String>,
    pub end_date: Option<String>,
    pub granularity: super::trends::Granularity,
    pub compare_topic: Option<String>,
}

impl Default for SentimentConfig {
    fn default() -> Self {
        Self {
            topic: String::new(),
            dimension: SentimentDimension::Time,
            start_date: None,
            end_date: None,
            granularity: super::trends::Granularity::Day,
            compare_topic: None,
        }
    }
}

/// Analyze sentiment for a topic
pub fn analyze_sentiment(db: &AnalyticsDb, config: &SentimentConfig) -> Result<SentimentResult> {
    match config.dimension {
        SentimentDimension::Time => analyze_sentiment_over_time(db, config),
        SentimentDimension::Region => analyze_sentiment_by_region(db, config),
        SentimentDimension::Source => analyze_sentiment_by_source(db, config),
        SentimentDimension::Entity => analyze_sentiment_by_entity(db, config),
    }
}

fn build_topic_filter(topic: &str) -> String {
    format!(
        "(actor1_name ILIKE '%{}%' OR actor2_name ILIKE '%{}%' OR action_geo_fullname ILIKE '%{}%')",
        topic, topic, topic
    )
}

fn build_date_conditions(config: &SentimentConfig) -> Vec<String> {
    let mut conditions = Vec::new();

    if let Some(ref start) = config.start_date {
        conditions.push(format!("sql_date >= {}", start.replace('-', "")));
    }
    if let Some(ref end) = config.end_date {
        conditions.push(format!("sql_date <= {}", end.replace('-', "")));
    }

    conditions
}

fn analyze_sentiment_over_time(db: &AnalyticsDb, config: &SentimentConfig) -> Result<SentimentResult> {
    let mut conditions = build_date_conditions(config);
    conditions.push(build_topic_filter(&config.topic));

    let where_clause = format!("WHERE {}", conditions.join(" AND "));

    let time_bucket = match config.granularity {
        super::trends::Granularity::Hour => "CAST(sql_date AS VARCHAR) || LPAD(CAST((date_added / 10000) % 100 AS VARCHAR), 2, '0')",
        super::trends::Granularity::Day => "CAST(sql_date AS VARCHAR)",
        super::trends::Granularity::Week => "CAST(DATE_TRUNC('week', TO_DATE(CAST(sql_date AS VARCHAR), 'YYYYMMDD')) AS VARCHAR)",
        super::trends::Granularity::Month => "SUBSTR(CAST(sql_date AS VARCHAR), 1, 6)",
    };

    let sql = format!(
        r#"
        SELECT
            {} as time_bucket,
            AVG(avg_tone) as tone,
            COUNT(*) as cnt
        FROM events
        {}
        GROUP BY time_bucket
        ORDER BY time_bucket
        "#,
        time_bucket, where_clause
    );

    let query_result = db.query(&sql)?;

    let mut points = Vec::new();
    let mut total_tone: f64 = 0.0;
    let mut total_count: i64 = 0;

    for row in &query_result.rows {
        let dimension = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
        let avg_tone = row.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
        let count = row.get(2).and_then(|v| v.as_i64()).unwrap_or(0);

        total_tone += avg_tone * count as f64;
        total_count += count;

        points.push(SentimentPoint {
            dimension,
            avg_tone,
            positive_score: None,
            negative_score: None,
            count,
            polarity: None,
        });
    }

    let overall_tone = if total_count > 0 {
        total_tone / total_count as f64
    } else {
        0.0
    };

    // Calculate trend (linear regression slope)
    let tone_trend = calculate_tone_trend(&points);

    Ok(SentimentResult {
        topic: config.topic.clone(),
        points,
        overall_tone,
        tone_trend,
        total_count,
    })
}

fn analyze_sentiment_by_region(db: &AnalyticsDb, config: &SentimentConfig) -> Result<SentimentResult> {
    let mut conditions = build_date_conditions(config);
    conditions.push(build_topic_filter(&config.topic));
    conditions.push("action_geo_country_code IS NOT NULL".to_string());

    let where_clause = format!("WHERE {}", conditions.join(" AND "));

    let sql = format!(
        r#"
        SELECT
            action_geo_country_code as region,
            AVG(avg_tone) as tone,
            COUNT(*) as cnt
        FROM events
        {}
        GROUP BY region
        ORDER BY cnt DESC
        LIMIT 50
        "#,
        where_clause
    );

    let query_result = db.query(&sql)?;

    let mut points = Vec::new();
    let mut total_tone: f64 = 0.0;
    let mut total_count: i64 = 0;

    for row in &query_result.rows {
        let dimension = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
        let avg_tone = row.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
        let count = row.get(2).and_then(|v| v.as_i64()).unwrap_or(0);

        total_tone += avg_tone * count as f64;
        total_count += count;

        points.push(SentimentPoint {
            dimension,
            avg_tone,
            positive_score: None,
            negative_score: None,
            count,
            polarity: None,
        });
    }

    let overall_tone = if total_count > 0 {
        total_tone / total_count as f64
    } else {
        0.0
    };

    Ok(SentimentResult {
        topic: config.topic.clone(),
        points,
        overall_tone,
        tone_trend: 0.0, // Not applicable for region
        total_count,
    })
}

fn analyze_sentiment_by_source(db: &AnalyticsDb, config: &SentimentConfig) -> Result<SentimentResult> {
    // Use GKG for source analysis
    let sql = r#"
        SELECT
            source_common_name as source,
            AVG(tone) as tone,
            AVG(positive_score) as pos,
            AVG(negative_score) as neg,
            COUNT(*) as cnt
        FROM gkg
        WHERE source_common_name IS NOT NULL
        GROUP BY source
        ORDER BY cnt DESC
        LIMIT 50
    "#;

    let query_result = db.query(sql)?;

    let mut points = Vec::new();
    let mut total_tone: f64 = 0.0;
    let mut total_count: i64 = 0;

    for row in &query_result.rows {
        let dimension = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
        let avg_tone = row.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
        let positive_score = row.get(2).and_then(|v| v.as_f64());
        let negative_score = row.get(3).and_then(|v| v.as_f64());
        let count = row.get(4).and_then(|v| v.as_i64()).unwrap_or(0);

        let polarity = match (positive_score, negative_score) {
            (Some(p), Some(n)) => Some((p - n).abs()),
            _ => None,
        };

        total_tone += avg_tone * count as f64;
        total_count += count;

        points.push(SentimentPoint {
            dimension,
            avg_tone,
            positive_score,
            negative_score,
            count,
            polarity,
        });
    }

    let overall_tone = if total_count > 0 {
        total_tone / total_count as f64
    } else {
        0.0
    };

    Ok(SentimentResult {
        topic: config.topic.clone(),
        points,
        overall_tone,
        tone_trend: 0.0,
        total_count,
    })
}

fn analyze_sentiment_by_entity(db: &AnalyticsDb, config: &SentimentConfig) -> Result<SentimentResult> {
    let mut conditions = build_date_conditions(config);
    conditions.push(build_topic_filter(&config.topic));

    let where_clause = format!("WHERE {}", conditions.join(" AND "));

    // Analyze by actor
    let sql = format!(
        r#"
        SELECT
            COALESCE(actor1_name, actor1_code) as entity,
            AVG(avg_tone) as tone,
            COUNT(*) as cnt
        FROM events
        {}
        AND actor1_code IS NOT NULL
        GROUP BY entity
        ORDER BY cnt DESC
        LIMIT 50
        "#,
        where_clause
    );

    let query_result = db.query(&sql)?;

    let mut points = Vec::new();
    let mut total_tone: f64 = 0.0;
    let mut total_count: i64 = 0;

    for row in &query_result.rows {
        let dimension = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
        let avg_tone = row.get(1).and_then(|v| v.as_f64()).unwrap_or(0.0);
        let count = row.get(2).and_then(|v| v.as_i64()).unwrap_or(0);

        total_tone += avg_tone * count as f64;
        total_count += count;

        points.push(SentimentPoint {
            dimension,
            avg_tone,
            positive_score: None,
            negative_score: None,
            count,
            polarity: None,
        });
    }

    let overall_tone = if total_count > 0 {
        total_tone / total_count as f64
    } else {
        0.0
    };

    Ok(SentimentResult {
        topic: config.topic.clone(),
        points,
        overall_tone,
        tone_trend: 0.0,
        total_count,
    })
}

fn calculate_tone_trend(points: &[SentimentPoint]) -> f64 {
    if points.len() < 2 {
        return 0.0;
    }

    let n = points.len() as f64;

    // Simple linear regression
    let sum_x: f64 = (0..points.len()).map(|i| i as f64).sum();
    let sum_y: f64 = points.iter().map(|p| p.avg_tone).sum();
    let sum_xy: f64 = points.iter().enumerate().map(|(i, p)| i as f64 * p.avg_tone).sum();
    let sum_xx: f64 = (0..points.len()).map(|i| (i as f64).powi(2)).sum();

    let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x.powi(2));

    if slope.is_nan() {
        0.0
    } else {
        slope
    }
}