Skip to main content

engram/search/
aggregation.rs

1//! Aggregation queries for memory statistics (RML-880)
2//!
3//! Supports grouping by tags, type, time periods with various metrics.
4
5use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7
8use crate::error::Result;
9
10/// Aggregation result
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AggregationResult {
13    /// Group key
14    pub group: String,
15    /// Metrics for this group
16    pub metrics: AggregationMetrics,
17}
18
19/// Metrics calculated per group
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AggregationMetrics {
22    pub count: i64,
23    pub avg_importance: Option<f32>,
24    pub total_access_count: Option<i64>,
25    pub oldest: Option<String>,
26    pub newest: Option<String>,
27}
28
29/// Group by options
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum GroupBy {
33    Type,
34    Tags,
35    Month,
36    Week,
37    Visibility,
38}
39
40impl GroupBy {
41    pub fn as_sql_expr(&self) -> &'static str {
42        match self {
43            GroupBy::Type => "memory_type",
44            GroupBy::Tags => "t.name",
45            GroupBy::Month => "strftime('%Y-%m', created_at)",
46            GroupBy::Week => "strftime('%Y-W%W', created_at)",
47            GroupBy::Visibility => "visibility",
48        }
49    }
50}
51
52/// Metrics to calculate
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum Metric {
56    Count,
57    AvgImportance,
58    TotalAccessCount,
59    DateRange,
60}
61
62/// Perform aggregation query
63pub fn aggregate_memories(
64    conn: &Connection,
65    group_by: GroupBy,
66    metrics: &[Metric],
67) -> Result<Vec<AggregationResult>> {
68    // Build SELECT clause
69    let mut select_parts = vec![format!("{} as group_key", group_by.as_sql_expr())];
70
71    for metric in metrics {
72        match metric {
73            Metric::Count => select_parts.push("COUNT(*) as cnt".to_string()),
74            Metric::AvgImportance => select_parts.push("AVG(importance) as avg_imp".to_string()),
75            Metric::TotalAccessCount => {
76                select_parts.push("SUM(access_count) as total_access".to_string())
77            }
78            Metric::DateRange => {
79                select_parts.push("MIN(created_at) as oldest".to_string());
80                select_parts.push("MAX(created_at) as newest".to_string());
81            }
82        }
83    }
84
85    // Build FROM clause (join tags table if grouping by tags)
86    let from_clause = if group_by == GroupBy::Tags {
87        "memories m
88         JOIN memory_tags mt ON m.id = mt.memory_id
89         JOIN tags t ON mt.tag_id = t.id"
90    } else {
91        "memories m"
92    };
93
94    let sql = format!(
95        "SELECT {} FROM {} WHERE m.valid_to IS NULL GROUP BY group_key ORDER BY cnt DESC",
96        select_parts.join(", "),
97        from_clause
98    );
99
100    let mut stmt = conn.prepare(&sql)?;
101    let mut results = Vec::new();
102
103    let rows = stmt.query_map([], |row| {
104        let group: String = row.get("group_key")?;
105
106        let count: i64 = row.get("cnt").unwrap_or(0);
107        let avg_importance: Option<f64> = row.get("avg_imp").ok();
108        let total_access: Option<i64> = row.get("total_access").ok();
109        let oldest: Option<String> = row.get("oldest").ok();
110        let newest: Option<String> = row.get("newest").ok();
111
112        Ok(AggregationResult {
113            group,
114            metrics: AggregationMetrics {
115                count,
116                avg_importance: avg_importance.map(|f| f as f32),
117                total_access_count: total_access,
118                oldest,
119                newest,
120            },
121        })
122    })?;
123
124    for row in rows {
125        results.push(row?);
126    }
127
128    Ok(results)
129}
130
131/// Get tag distribution
132pub fn get_tag_distribution(conn: &Connection, limit: i64) -> Result<Vec<(String, i64)>> {
133    let mut stmt = conn.prepare(
134        "SELECT t.name, COUNT(*) as cnt
135         FROM tags t
136         JOIN memory_tags mt ON t.id = mt.tag_id
137         JOIN memories m ON mt.memory_id = m.id
138         WHERE m.valid_to IS NULL
139         GROUP BY t.name
140         ORDER BY cnt DESC
141         LIMIT ?",
142    )?;
143
144    let results: Vec<(String, i64)> = stmt
145        .query_map([limit], |row| Ok((row.get(0)?, row.get(1)?)))?
146        .filter_map(|r| r.ok())
147        .collect();
148
149    Ok(results)
150}
151
152/// Get type distribution
153pub fn get_type_distribution(conn: &Connection) -> Result<Vec<(String, i64)>> {
154    let mut stmt = conn.prepare(
155        "SELECT memory_type, COUNT(*) as cnt
156         FROM memories
157         WHERE valid_to IS NULL
158         GROUP BY memory_type
159         ORDER BY cnt DESC",
160    )?;
161
162    let results: Vec<(String, i64)> = stmt
163        .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
164        .filter_map(|r| r.ok())
165        .collect();
166
167    Ok(results)
168}
169
170/// Get memories created over time (for trend analysis)
171pub fn get_creation_trend(
172    conn: &Connection,
173    interval: &str, // "day", "week", "month"
174) -> Result<Vec<(String, i64)>> {
175    let format_str = match interval {
176        "day" => "%Y-%m-%d",
177        "week" => "%Y-W%W",
178        "month" => "%Y-%m",
179        "year" => "%Y",
180        _ => "%Y-%m-%d",
181    };
182
183    let sql = format!(
184        "SELECT strftime('{}', created_at) as period, COUNT(*) as cnt
185         FROM memories
186         WHERE valid_to IS NULL
187         GROUP BY period
188         ORDER BY period DESC
189         LIMIT 100",
190        format_str
191    );
192
193    let mut stmt = conn.prepare(&sql)?;
194    let results: Vec<(String, i64)> = stmt
195        .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
196        .filter_map(|r| r.ok())
197        .collect();
198
199    Ok(results)
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_group_by_sql_expr() {
208        assert_eq!(GroupBy::Type.as_sql_expr(), "memory_type");
209        assert_eq!(GroupBy::Tags.as_sql_expr(), "t.name");
210        assert_eq!(
211            GroupBy::Month.as_sql_expr(),
212            "strftime('%Y-%m', created_at)"
213        );
214    }
215}