1use rusqlite::Connection;
6use serde::{Deserialize, Serialize};
7
8use crate::error::Result;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AggregationResult {
13 pub group: String,
15 pub metrics: AggregationMetrics,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct AggregationMetrics {
22 pub count: i64,
23 pub avg_importance: Option<f32>,
24 pub total_access_count: Option<i64>,
25 pub oldest: Option<String>,
26 pub newest: Option<String>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum GroupBy {
33 Type,
34 Tags,
35 Month,
36 Week,
37 Visibility,
38}
39
40impl GroupBy {
41 pub fn as_sql_expr(&self) -> &'static str {
42 match self {
43 GroupBy::Type => "memory_type",
44 GroupBy::Tags => "t.name",
45 GroupBy::Month => "strftime('%Y-%m', created_at)",
46 GroupBy::Week => "strftime('%Y-W%W', created_at)",
47 GroupBy::Visibility => "visibility",
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(rename_all = "snake_case")]
55pub enum Metric {
56 Count,
57 AvgImportance,
58 TotalAccessCount,
59 DateRange,
60}
61
62pub fn aggregate_memories(
64 conn: &Connection,
65 group_by: GroupBy,
66 metrics: &[Metric],
67) -> Result<Vec<AggregationResult>> {
68 let mut select_parts = vec![format!("{} as group_key", group_by.as_sql_expr())];
70
71 for metric in metrics {
72 match metric {
73 Metric::Count => select_parts.push("COUNT(*) as cnt".to_string()),
74 Metric::AvgImportance => select_parts.push("AVG(importance) as avg_imp".to_string()),
75 Metric::TotalAccessCount => {
76 select_parts.push("SUM(access_count) as total_access".to_string())
77 }
78 Metric::DateRange => {
79 select_parts.push("MIN(created_at) as oldest".to_string());
80 select_parts.push("MAX(created_at) as newest".to_string());
81 }
82 }
83 }
84
85 let from_clause = if group_by == GroupBy::Tags {
87 "memories m
88 JOIN memory_tags mt ON m.id = mt.memory_id
89 JOIN tags t ON mt.tag_id = t.id"
90 } else {
91 "memories m"
92 };
93
94 let sql = format!(
95 "SELECT {} FROM {} WHERE m.valid_to IS NULL GROUP BY group_key ORDER BY cnt DESC",
96 select_parts.join(", "),
97 from_clause
98 );
99
100 let mut stmt = conn.prepare(&sql)?;
101 let mut results = Vec::new();
102
103 let rows = stmt.query_map([], |row| {
104 let group: String = row.get("group_key")?;
105
106 let count: i64 = row.get("cnt").unwrap_or(0);
107 let avg_importance: Option<f64> = row.get("avg_imp").ok();
108 let total_access: Option<i64> = row.get("total_access").ok();
109 let oldest: Option<String> = row.get("oldest").ok();
110 let newest: Option<String> = row.get("newest").ok();
111
112 Ok(AggregationResult {
113 group,
114 metrics: AggregationMetrics {
115 count,
116 avg_importance: avg_importance.map(|f| f as f32),
117 total_access_count: total_access,
118 oldest,
119 newest,
120 },
121 })
122 })?;
123
124 for row in rows {
125 results.push(row?);
126 }
127
128 Ok(results)
129}
130
131pub fn get_tag_distribution(conn: &Connection, limit: i64) -> Result<Vec<(String, i64)>> {
133 let mut stmt = conn.prepare(
134 "SELECT t.name, COUNT(*) as cnt
135 FROM tags t
136 JOIN memory_tags mt ON t.id = mt.tag_id
137 JOIN memories m ON mt.memory_id = m.id
138 WHERE m.valid_to IS NULL
139 GROUP BY t.name
140 ORDER BY cnt DESC
141 LIMIT ?",
142 )?;
143
144 let results: Vec<(String, i64)> = stmt
145 .query_map([limit], |row| Ok((row.get(0)?, row.get(1)?)))?
146 .filter_map(|r| r.ok())
147 .collect();
148
149 Ok(results)
150}
151
152pub fn get_type_distribution(conn: &Connection) -> Result<Vec<(String, i64)>> {
154 let mut stmt = conn.prepare(
155 "SELECT memory_type, COUNT(*) as cnt
156 FROM memories
157 WHERE valid_to IS NULL
158 GROUP BY memory_type
159 ORDER BY cnt DESC",
160 )?;
161
162 let results: Vec<(String, i64)> = stmt
163 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
164 .filter_map(|r| r.ok())
165 .collect();
166
167 Ok(results)
168}
169
170pub fn get_creation_trend(
172 conn: &Connection,
173 interval: &str, ) -> Result<Vec<(String, i64)>> {
175 let format_str = match interval {
176 "day" => "%Y-%m-%d",
177 "week" => "%Y-W%W",
178 "month" => "%Y-%m",
179 "year" => "%Y",
180 _ => "%Y-%m-%d",
181 };
182
183 let sql = format!(
184 "SELECT strftime('{}', created_at) as period, COUNT(*) as cnt
185 FROM memories
186 WHERE valid_to IS NULL
187 GROUP BY period
188 ORDER BY period DESC
189 LIMIT 100",
190 format_str
191 );
192
193 let mut stmt = conn.prepare(&sql)?;
194 let results: Vec<(String, i64)> = stmt
195 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
196 .filter_map(|r| r.ok())
197 .collect();
198
199 Ok(results)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_group_by_sql_expr() {
208 assert_eq!(GroupBy::Type.as_sql_expr(), "memory_type");
209 assert_eq!(GroupBy::Tags.as_sql_expr(), "t.name");
210 assert_eq!(
211 GroupBy::Month.as_sql_expr(),
212 "strftime('%Y-%m', created_at)"
213 );
214 }
215}