graph_d/query/
aggregation.rs

1//! Aggregation operations for graph queries.
2
3use crate::error::Result;
4use crate::graph::Node;
5use serde_json::Value;
6use std::collections::HashMap;
7
8/// Aggregation functions for query results.
9#[derive(Debug, Clone, PartialEq)]
10pub enum AggregateFunction {
11    /// Count the number of items
12    Count,
13    /// Sum numeric values for the specified property key
14    Sum(
15        /// Property key to sum
16        String,
17    ),
18    /// Average numeric values for the specified property key
19    Avg(
20        /// Property key to average
21        String,
22    ),
23    /// Minimum value for the specified property key
24    Min(
25        /// Property key to find minimum
26        String,
27    ),
28    /// Maximum value for the specified property key
29    Max(
30        /// Property key to find maximum
31        String,
32    ),
33    /// Collect unique values for the specified property key
34    Distinct(
35        /// Property key to collect distinct values
36        String,
37    ),
38    /// Group by the specified property value
39    GroupBy(
40        /// Property key to group by
41        String,
42    ),
43}
44
45/// Result of an aggregation operation.
46#[derive(Debug, Clone, PartialEq)]
47pub enum AggregateResult {
48    /// Single numeric value
49    Number(f64),
50    /// Single string value
51    String(String),
52    /// List of unique values
53    List(Vec<Value>),
54    /// Grouped results
55    Groups(HashMap<String, Vec<Node>>),
56}
57
58/// Aggregator for performing aggregation operations on node collections.
59pub struct Aggregator;
60
61impl Aggregator {
62    /// Apply an aggregation function to a collection of nodes.
63    pub fn aggregate(nodes: &[Node], function: &AggregateFunction) -> Result<AggregateResult> {
64        match function {
65            AggregateFunction::Count => Ok(AggregateResult::Number(nodes.len() as f64)),
66
67            AggregateFunction::Sum(property) => {
68                let sum = nodes
69                    .iter()
70                    .filter_map(|node| node.get_property(property))
71                    .filter_map(|value| value.as_f64())
72                    .sum();
73                Ok(AggregateResult::Number(sum))
74            }
75
76            AggregateFunction::Avg(property) => {
77                let values: Vec<f64> = nodes
78                    .iter()
79                    .filter_map(|node| node.get_property(property))
80                    .filter_map(|value| value.as_f64())
81                    .collect();
82
83                if values.is_empty() {
84                    Ok(AggregateResult::Number(0.0))
85                } else {
86                    let avg = values.iter().sum::<f64>() / values.len() as f64;
87                    Ok(AggregateResult::Number(avg))
88                }
89            }
90
91            AggregateFunction::Min(property) => {
92                let min = nodes
93                    .iter()
94                    .filter_map(|node| node.get_property(property))
95                    .filter_map(|value| value.as_f64())
96                    .fold(f64::INFINITY, f64::min);
97
98                if min.is_infinite() {
99                    Ok(AggregateResult::Number(0.0))
100                } else {
101                    Ok(AggregateResult::Number(min))
102                }
103            }
104
105            AggregateFunction::Max(property) => {
106                let max = nodes
107                    .iter()
108                    .filter_map(|node| node.get_property(property))
109                    .filter_map(|value| value.as_f64())
110                    .fold(f64::NEG_INFINITY, f64::max);
111
112                if max.is_infinite() {
113                    Ok(AggregateResult::Number(0.0))
114                } else {
115                    Ok(AggregateResult::Number(max))
116                }
117            }
118
119            AggregateFunction::Distinct(property) => {
120                let mut unique_values = std::collections::HashSet::new();
121
122                for node in nodes {
123                    if let Some(value) = node.get_property(property) {
124                        unique_values.insert(value.clone());
125                    }
126                }
127
128                let list: Vec<Value> = unique_values.into_iter().collect();
129                Ok(AggregateResult::List(list))
130            }
131
132            AggregateFunction::GroupBy(property) => {
133                let mut groups = HashMap::new();
134
135                for node in nodes {
136                    let group_key = if let Some(value) = node.get_property(property) {
137                        match value {
138                            Value::String(s) => s.clone(),
139                            Value::Number(n) => n.to_string(),
140                            Value::Bool(b) => b.to_string(),
141                            _ => "null".to_string(),
142                        }
143                    } else {
144                        "null".to_string()
145                    };
146
147                    groups
148                        .entry(group_key)
149                        .or_insert_with(Vec::new)
150                        .push(node.clone());
151                }
152
153                Ok(AggregateResult::Groups(groups))
154            }
155        }
156    }
157
158    /// Apply multiple aggregation functions to a collection of nodes.
159    pub fn multi_aggregate(
160        nodes: &[Node],
161        functions: &[AggregateFunction],
162    ) -> Result<Vec<AggregateResult>> {
163        functions
164            .iter()
165            .map(|func| Self::aggregate(nodes, func))
166            .collect()
167    }
168}
169
170/// Statistical aggregations for numeric properties.
171pub struct Statistics;
172
173impl Statistics {
174    /// Calculate basic statistics for a numeric property.
175    pub fn calculate(nodes: &[Node], property: &str) -> Result<StatisticsResult> {
176        let values: Vec<f64> = nodes
177            .iter()
178            .filter_map(|node| node.get_property(property))
179            .filter_map(|value| value.as_f64())
180            .collect();
181
182        if values.is_empty() {
183            return Ok(StatisticsResult::default());
184        }
185
186        let count = values.len() as f64;
187        let sum = values.iter().sum::<f64>();
188        let mean = sum / count;
189
190        let min = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
191        let max = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
192
193        // Calculate variance and standard deviation
194        let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / count;
195        let std_dev = variance.sqrt();
196
197        // Calculate median
198        let mut sorted_values = values.clone();
199        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
200        let median = if sorted_values.len().is_multiple_of(2) {
201            let mid = sorted_values.len() / 2;
202            (sorted_values[mid - 1] + sorted_values[mid]) / 2.0
203        } else {
204            sorted_values[sorted_values.len() / 2]
205        };
206
207        Ok(StatisticsResult {
208            count,
209            sum,
210            mean,
211            median,
212            min,
213            max,
214            variance,
215            std_dev,
216        })
217    }
218}
219
220/// Result of statistical calculations.
221#[derive(Debug, Clone, PartialEq)]
222pub struct StatisticsResult {
223    /// Number of values in the dataset
224    pub count: f64,
225    /// Sum of all values
226    pub sum: f64,
227    /// Arithmetic mean (average)
228    pub mean: f64,
229    /// Middle value when sorted
230    pub median: f64,
231    /// Smallest value in the dataset
232    pub min: f64,
233    /// Largest value in the dataset
234    pub max: f64,
235    /// Population variance
236    pub variance: f64,
237    /// Standard deviation (square root of variance)
238    pub std_dev: f64,
239}
240
241impl Default for StatisticsResult {
242    fn default() -> Self {
243        StatisticsResult {
244            count: 0.0,
245            sum: 0.0,
246            mean: 0.0,
247            median: 0.0,
248            min: 0.0,
249            max: 0.0,
250            variance: 0.0,
251            std_dev: 0.0,
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use serde_json::json;
260
261    fn create_test_nodes() -> Vec<Node> {
262        vec![
263            Node::new(
264                1,
265                [
266                    ("age".to_string(), json!(25)),
267                    ("name".to_string(), json!("Alice")),
268                    ("category".to_string(), json!("A")),
269                ]
270                .into(),
271            ),
272            Node::new(
273                2,
274                [
275                    ("age".to_string(), json!(30)),
276                    ("name".to_string(), json!("Bob")),
277                    ("category".to_string(), json!("B")),
278                ]
279                .into(),
280            ),
281            Node::new(
282                3,
283                [
284                    ("age".to_string(), json!(35)),
285                    ("name".to_string(), json!("Charlie")),
286                    ("category".to_string(), json!("A")),
287                ]
288                .into(),
289            ),
290        ]
291    }
292
293    #[test]
294    fn test_count_aggregation() {
295        let nodes = create_test_nodes();
296        let result = Aggregator::aggregate(&nodes, &AggregateFunction::Count).unwrap();
297        assert_eq!(result, AggregateResult::Number(3.0));
298    }
299
300    #[test]
301    fn test_sum_aggregation() {
302        let nodes = create_test_nodes();
303        let result =
304            Aggregator::aggregate(&nodes, &AggregateFunction::Sum("age".to_string())).unwrap();
305        assert_eq!(result, AggregateResult::Number(90.0));
306    }
307
308    #[test]
309    fn test_avg_aggregation() {
310        let nodes = create_test_nodes();
311        let result =
312            Aggregator::aggregate(&nodes, &AggregateFunction::Avg("age".to_string())).unwrap();
313        assert_eq!(result, AggregateResult::Number(30.0));
314    }
315
316    #[test]
317    fn test_group_by_aggregation() {
318        let nodes = create_test_nodes();
319        let result =
320            Aggregator::aggregate(&nodes, &AggregateFunction::GroupBy("category".to_string()))
321                .unwrap();
322
323        if let AggregateResult::Groups(groups) = result {
324            assert_eq!(groups.len(), 2);
325            assert_eq!(groups.get("A").unwrap().len(), 2);
326            assert_eq!(groups.get("B").unwrap().len(), 1);
327        } else {
328            panic!("Expected Groups result");
329        }
330    }
331
332    #[test]
333    fn test_statistics() {
334        let nodes = create_test_nodes();
335        let stats = Statistics::calculate(&nodes, "age").unwrap();
336
337        assert_eq!(stats.count, 3.0);
338        assert_eq!(stats.sum, 90.0);
339        assert_eq!(stats.mean, 30.0);
340        assert_eq!(stats.median, 30.0);
341        assert_eq!(stats.min, 25.0);
342        assert_eq!(stats.max, 35.0);
343    }
344}