1use crate::error::Result;
4use crate::graph::Node;
5use serde_json::Value;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum AggregateFunction {
11 Count,
13 Sum(
15 String,
17 ),
18 Avg(
20 String,
22 ),
23 Min(
25 String,
27 ),
28 Max(
30 String,
32 ),
33 Distinct(
35 String,
37 ),
38 GroupBy(
40 String,
42 ),
43}
44
45#[derive(Debug, Clone, PartialEq)]
47pub enum AggregateResult {
48 Number(f64),
50 String(String),
52 List(Vec<Value>),
54 Groups(HashMap<String, Vec<Node>>),
56}
57
58pub struct Aggregator;
60
61impl Aggregator {
62 pub fn aggregate(nodes: &[Node], function: &AggregateFunction) -> Result<AggregateResult> {
64 match function {
65 AggregateFunction::Count => Ok(AggregateResult::Number(nodes.len() as f64)),
66
67 AggregateFunction::Sum(property) => {
68 let sum = nodes
69 .iter()
70 .filter_map(|node| node.get_property(property))
71 .filter_map(|value| value.as_f64())
72 .sum();
73 Ok(AggregateResult::Number(sum))
74 }
75
76 AggregateFunction::Avg(property) => {
77 let values: Vec<f64> = nodes
78 .iter()
79 .filter_map(|node| node.get_property(property))
80 .filter_map(|value| value.as_f64())
81 .collect();
82
83 if values.is_empty() {
84 Ok(AggregateResult::Number(0.0))
85 } else {
86 let avg = values.iter().sum::<f64>() / values.len() as f64;
87 Ok(AggregateResult::Number(avg))
88 }
89 }
90
91 AggregateFunction::Min(property) => {
92 let min = nodes
93 .iter()
94 .filter_map(|node| node.get_property(property))
95 .filter_map(|value| value.as_f64())
96 .fold(f64::INFINITY, f64::min);
97
98 if min.is_infinite() {
99 Ok(AggregateResult::Number(0.0))
100 } else {
101 Ok(AggregateResult::Number(min))
102 }
103 }
104
105 AggregateFunction::Max(property) => {
106 let max = nodes
107 .iter()
108 .filter_map(|node| node.get_property(property))
109 .filter_map(|value| value.as_f64())
110 .fold(f64::NEG_INFINITY, f64::max);
111
112 if max.is_infinite() {
113 Ok(AggregateResult::Number(0.0))
114 } else {
115 Ok(AggregateResult::Number(max))
116 }
117 }
118
119 AggregateFunction::Distinct(property) => {
120 let mut unique_values = std::collections::HashSet::new();
121
122 for node in nodes {
123 if let Some(value) = node.get_property(property) {
124 unique_values.insert(value.clone());
125 }
126 }
127
128 let list: Vec<Value> = unique_values.into_iter().collect();
129 Ok(AggregateResult::List(list))
130 }
131
132 AggregateFunction::GroupBy(property) => {
133 let mut groups = HashMap::new();
134
135 for node in nodes {
136 let group_key = if let Some(value) = node.get_property(property) {
137 match value {
138 Value::String(s) => s.clone(),
139 Value::Number(n) => n.to_string(),
140 Value::Bool(b) => b.to_string(),
141 _ => "null".to_string(),
142 }
143 } else {
144 "null".to_string()
145 };
146
147 groups
148 .entry(group_key)
149 .or_insert_with(Vec::new)
150 .push(node.clone());
151 }
152
153 Ok(AggregateResult::Groups(groups))
154 }
155 }
156 }
157
158 pub fn multi_aggregate(
160 nodes: &[Node],
161 functions: &[AggregateFunction],
162 ) -> Result<Vec<AggregateResult>> {
163 functions
164 .iter()
165 .map(|func| Self::aggregate(nodes, func))
166 .collect()
167 }
168}
169
170pub struct Statistics;
172
173impl Statistics {
174 pub fn calculate(nodes: &[Node], property: &str) -> Result<StatisticsResult> {
176 let values: Vec<f64> = nodes
177 .iter()
178 .filter_map(|node| node.get_property(property))
179 .filter_map(|value| value.as_f64())
180 .collect();
181
182 if values.is_empty() {
183 return Ok(StatisticsResult::default());
184 }
185
186 let count = values.len() as f64;
187 let sum = values.iter().sum::<f64>();
188 let mean = sum / count;
189
190 let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
191 let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
192
193 let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count;
195 let std_dev = variance.sqrt();
196
197 let mut sorted_values = values.clone();
199 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
200 let median = if sorted_values.len().is_multiple_of(2) {
201 let mid = sorted_values.len() / 2;
202 (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
203 } else {
204 sorted_values[sorted_values.len() / 2]
205 };
206
207 Ok(StatisticsResult {
208 count,
209 sum,
210 mean,
211 median,
212 min,
213 max,
214 variance,
215 std_dev,
216 })
217 }
218}
219
220#[derive(Debug, Clone, PartialEq)]
222pub struct StatisticsResult {
223 pub count: f64,
225 pub sum: f64,
227 pub mean: f64,
229 pub median: f64,
231 pub min: f64,
233 pub max: f64,
235 pub variance: f64,
237 pub std_dev: f64,
239}
240
241impl Default for StatisticsResult {
242 fn default() -> Self {
243 StatisticsResult {
244 count: 0.0,
245 sum: 0.0,
246 mean: 0.0,
247 median: 0.0,
248 min: 0.0,
249 max: 0.0,
250 variance: 0.0,
251 std_dev: 0.0,
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use serde_json::json;
260
261 fn create_test_nodes() -> Vec<Node> {
262 vec![
263 Node::new(
264 1,
265 [
266 ("age".to_string(), json!(25)),
267 ("name".to_string(), json!("Alice")),
268 ("category".to_string(), json!("A")),
269 ]
270 .into(),
271 ),
272 Node::new(
273 2,
274 [
275 ("age".to_string(), json!(30)),
276 ("name".to_string(), json!("Bob")),
277 ("category".to_string(), json!("B")),
278 ]
279 .into(),
280 ),
281 Node::new(
282 3,
283 [
284 ("age".to_string(), json!(35)),
285 ("name".to_string(), json!("Charlie")),
286 ("category".to_string(), json!("A")),
287 ]
288 .into(),
289 ),
290 ]
291 }
292
293 #[test]
294 fn test_count_aggregation() {
295 let nodes = create_test_nodes();
296 let result = Aggregator::aggregate(&nodes, &AggregateFunction::Count).unwrap();
297 assert_eq!(result, AggregateResult::Number(3.0));
298 }
299
300 #[test]
301 fn test_sum_aggregation() {
302 let nodes = create_test_nodes();
303 let result =
304 Aggregator::aggregate(&nodes, &AggregateFunction::Sum("age".to_string())).unwrap();
305 assert_eq!(result, AggregateResult::Number(90.0));
306 }
307
308 #[test]
309 fn test_avg_aggregation() {
310 let nodes = create_test_nodes();
311 let result =
312 Aggregator::aggregate(&nodes, &AggregateFunction::Avg("age".to_string())).unwrap();
313 assert_eq!(result, AggregateResult::Number(30.0));
314 }
315
316 #[test]
317 fn test_group_by_aggregation() {
318 let nodes = create_test_nodes();
319 let result =
320 Aggregator::aggregate(&nodes, &AggregateFunction::GroupBy("category".to_string()))
321 .unwrap();
322
323 if let AggregateResult::Groups(groups) = result {
324 assert_eq!(groups.len(), 2);
325 assert_eq!(groups.get("A").unwrap().len(), 2);
326 assert_eq!(groups.get("B").unwrap().len(), 1);
327 } else {
328 panic!("Expected Groups result");
329 }
330 }
331
332 #[test]
333 fn test_statistics() {
334 let nodes = create_test_nodes();
335 let stats = Statistics::calculate(&nodes, "age").unwrap();
336
337 assert_eq!(stats.count, 3.0);
338 assert_eq!(stats.sum, 90.0);
339 assert_eq!(stats.mean, 30.0);
340 assert_eq!(stats.median, 30.0);
341 assert_eq!(stats.min, 25.0);
342 assert_eq!(stats.max, 35.0);
343 }
344}