#![allow(dead_code)]
use crate::db::AnalyticsDb;
use crate::error::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompareResult {
pub topic_a: TopicStats,
pub topic_b: TopicStats,
pub comparison: ComparisonMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopicStats {
pub name: String,
pub event_count: i64,
pub avg_tone: f64,
pub avg_goldstein: f64,
pub top_countries: Vec<CountryCount>,
pub event_types: Vec<EventTypeCount>,
pub date_range: Option<(String, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CountryCount {
pub country: String,
pub count: i64,
pub percentage: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EventTypeCount {
pub event_code: String,
pub description: String,
pub count: i64,
pub percentage: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComparisonMetrics {
pub count_ratio: f64,
pub tone_diff: f64,
pub goldstein_diff: f64,
pub country_similarity: f64,
pub event_type_similarity: f64,
}
#[derive(Debug, Clone)]
pub struct CompareConfig {
pub topic_a: String,
pub topic_b: String,
pub start_date: Option<String>,
pub end_date: Option<String>,
}
pub fn compare_topics(db: &AnalyticsDb, config: &CompareConfig) -> Result<CompareResult> {
let topic_a = analyze_topic(db, &config.topic_a, &config.start_date, &config.end_date)?;
let topic_b = analyze_topic(db, &config.topic_b, &config.start_date, &config.end_date)?;
let comparison = calculate_comparison(&topic_a, &topic_b);
Ok(CompareResult {
topic_a,
topic_b,
comparison,
})
}
fn analyze_topic(
db: &AnalyticsDb,
topic: &str,
start_date: &Option<String>,
end_date: &Option<String>,
) -> Result<TopicStats> {
let mut conditions = Vec::new();
if let Some(ref start) = start_date {
conditions.push(format!("sql_date >= {}", start.replace('-', "")));
}
if let Some(ref end) = end_date {
conditions.push(format!("sql_date <= {}", end.replace('-', "")));
}
let topic_filter = format!(
"(actor1_name ILIKE '%{}%' OR actor2_name ILIKE '%{}%' OR action_geo_fullname ILIKE '%{}%')",
topic, topic, topic
);
conditions.push(topic_filter);
let where_clause = format!("WHERE {}", conditions.join(" AND "));
let stats_sql = format!(
r#"
SELECT
COUNT(*) as cnt,
AVG(avg_tone) as tone,
AVG(goldstein_scale) as goldstein,
MIN(sql_date) as min_date,
MAX(sql_date) as max_date
FROM events
{}
"#,
where_clause
);
let stats_result = db.query(&stats_sql)?;
let stats_row = stats_result.rows.first();
let event_count = stats_row
.and_then(|r| r.get(0))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let avg_tone = stats_row
.and_then(|r| r.get(1))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let avg_goldstein = stats_row
.and_then(|r| r.get(2))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let min_date = stats_row
.and_then(|r| r.get(3))
.and_then(|v| v.as_i64())
.map(|d| d.to_string());
let max_date = stats_row
.and_then(|r| r.get(4))
.and_then(|v| v.as_i64())
.map(|d| d.to_string());
let date_range = match (min_date, max_date) {
(Some(min), Some(max)) => Some((min, max)),
_ => None,
};
let countries_sql = format!(
r#"
SELECT
action_geo_country_code,
COUNT(*) as cnt
FROM events
{}
AND action_geo_country_code IS NOT NULL
GROUP BY action_geo_country_code
ORDER BY cnt DESC
LIMIT 10
"#,
where_clause
);
let countries_result = db.query(&countries_sql)?;
let top_countries: Vec<CountryCount> = countries_result
.rows
.iter()
.map(|row| {
let country = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
let count = row.get(1).and_then(|v| v.as_i64()).unwrap_or(0);
let percentage = if event_count > 0 {
count as f64 / event_count as f64 * 100.0
} else {
0.0
};
CountryCount { country, count, percentage }
})
.collect();
let events_sql = format!(
r#"
SELECT
event_root_code,
COUNT(*) as cnt
FROM events
{}
AND event_root_code IS NOT NULL
GROUP BY event_root_code
ORDER BY cnt DESC
LIMIT 10
"#,
where_clause
);
let events_result = db.query(&events_sql)?;
let event_types: Vec<EventTypeCount> = events_result
.rows
.iter()
.map(|row| {
let event_code = row.get(0).and_then(|v| v.as_str()).unwrap_or("").to_string();
let count = row.get(1).and_then(|v| v.as_i64()).unwrap_or(0);
let percentage = if event_count > 0 {
count as f64 / event_count as f64 * 100.0
} else {
0.0
};
EventTypeCount {
event_code: event_code.clone(),
description: get_event_description(&event_code),
count,
percentage,
}
})
.collect();
Ok(TopicStats {
name: topic.to_string(),
event_count,
avg_tone,
avg_goldstein,
top_countries,
event_types,
date_range,
})
}
fn calculate_comparison(topic_a: &TopicStats, topic_b: &TopicStats) -> ComparisonMetrics {
let count_ratio = if topic_b.event_count > 0 {
topic_a.event_count as f64 / topic_b.event_count as f64
} else {
0.0
};
let tone_diff = topic_a.avg_tone - topic_b.avg_tone;
let goldstein_diff = topic_a.avg_goldstein - topic_b.avg_goldstein;
let countries_a: std::collections::HashSet<_> = topic_a.top_countries.iter().map(|c| &c.country).collect();
let countries_b: std::collections::HashSet<_> = topic_b.top_countries.iter().map(|c| &c.country).collect();
let country_similarity = jaccard_similarity(&countries_a, &countries_b);
let events_a: std::collections::HashSet<_> = topic_a.event_types.iter().map(|e| &e.event_code).collect();
let events_b: std::collections::HashSet<_> = topic_b.event_types.iter().map(|e| &e.event_code).collect();
let event_type_similarity = jaccard_similarity(&events_a, &events_b);
ComparisonMetrics {
count_ratio,
tone_diff,
goldstein_diff,
country_similarity,
event_type_similarity,
}
}
fn jaccard_similarity<T: std::hash::Hash + Eq>(set_a: &std::collections::HashSet<T>, set_b: &std::collections::HashSet<T>) -> f64 {
let intersection = set_a.intersection(set_b).count();
let union = set_a.union(set_b).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
fn get_event_description(code: &str) -> String {
match code {
"01" => "Make public statement".to_string(),
"02" => "Appeal".to_string(),
"03" => "Express intent to cooperate".to_string(),
"04" => "Consult".to_string(),
"05" => "Engage in diplomatic cooperation".to_string(),
"06" => "Engage in material cooperation".to_string(),
"07" => "Provide aid".to_string(),
"08" => "Yield".to_string(),
"09" => "Investigate".to_string(),
"10" => "Demand".to_string(),
"11" => "Disapprove".to_string(),
"12" => "Reject".to_string(),
"13" => "Threaten".to_string(),
"14" => "Protest".to_string(),
"15" => "Exhibit force posture".to_string(),
"16" => "Reduce relations".to_string(),
"17" => "Coerce".to_string(),
"18" => "Assault".to_string(),
"19" => "Fight".to_string(),
"20" => "Use unconventional mass violence".to_string(),
_ => format!("Event code {}", code),
}
}