Skip to main content

oxirs_stream/processing/
aggregation.rs

1//! Aggregation functions and state management for event processing
2//!
3//! This module provides aggregation capabilities including:
4//! - Basic aggregations (count, sum, average, min, max)
5//! - Complex aggregations (distinct, custom expressions)
6//! - Aggregation state management
7
8use crate::StreamEvent;
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13/// Aggregation functions for window processing
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum AggregateFunction {
16    Count,
17    Sum { field: String },
18    Average { field: String },
19    Min { field: String },
20    Max { field: String },
21    First,
22    Last,
23    Distinct { field: String },
24    Custom { name: String, expression: String },
25}
26
27/// Aggregation state for maintaining running calculations
28#[derive(Debug, Clone)]
29pub enum AggregationState {
30    Count(u64),
31    Sum(f64),
32    Average { sum: f64, count: u64 },
33    Min(f64),
34    Max(f64),
35    First(StreamEvent),
36    Last(StreamEvent),
37    Distinct(HashSet<String>),
38}
39
40impl AggregationState {
41    /// Create new aggregation state for a function
42    pub fn new(function: &AggregateFunction) -> Self {
43        match function {
44            AggregateFunction::Count => AggregationState::Count(0),
45            AggregateFunction::Sum { .. } => AggregationState::Sum(0.0),
46            AggregateFunction::Average { .. } => AggregationState::Average { sum: 0.0, count: 0 },
47            AggregateFunction::Min { .. } => AggregationState::Min(f64::INFINITY),
48            AggregateFunction::Max { .. } => AggregationState::Max(f64::NEG_INFINITY),
49            AggregateFunction::First => AggregationState::First(StreamEvent::TripleAdded {
50                subject: String::new(),
51                predicate: String::new(),
52                object: String::new(),
53                graph: None,
54                metadata: crate::event::EventMetadata::default(),
55            }),
56            AggregateFunction::Last => AggregationState::Last(StreamEvent::TripleAdded {
57                subject: String::new(),
58                predicate: String::new(),
59                object: String::new(),
60                graph: None,
61                metadata: crate::event::EventMetadata::default(),
62            }),
63            AggregateFunction::Distinct { .. } => AggregationState::Distinct(HashSet::new()),
64            AggregateFunction::Custom { .. } => AggregationState::Count(0), // Default for custom
65        }
66    }
67
68    /// Update aggregation state with new event
69    pub fn update(&mut self, event: &StreamEvent, function: &AggregateFunction) -> Result<()> {
70        match (self, function) {
71            (AggregationState::Count(count), AggregateFunction::Count) => {
72                *count += 1;
73            }
74            (AggregationState::Sum(sum), AggregateFunction::Sum { field }) => {
75                if let Some(value) = extract_numeric_field(event, field)? {
76                    *sum += value;
77                }
78            }
79            (AggregationState::Average { sum, count }, AggregateFunction::Average { field }) => {
80                if let Some(value) = extract_numeric_field(event, field)? {
81                    *sum += value;
82                    *count += 1;
83                }
84            }
85            (AggregationState::Min(min), AggregateFunction::Min { field }) => {
86                if let Some(value) = extract_numeric_field(event, field)? {
87                    *min = min.min(value);
88                }
89            }
90            (AggregationState::Max(max), AggregateFunction::Max { field }) => {
91                if let Some(value) = extract_numeric_field(event, field)? {
92                    *max = max.max(value);
93                }
94            }
95            (AggregationState::First(first), AggregateFunction::First) => {
96                *first = event.clone();
97            }
98            (AggregationState::Last(last), AggregateFunction::Last) => {
99                *last = event.clone();
100            }
101            (AggregationState::Distinct(set), AggregateFunction::Distinct { field }) => {
102                if let Some(value) = extract_string_field(event, field)? {
103                    set.insert(value);
104                }
105            }
106            (AggregationState::Count(count), AggregateFunction::Custom { expression, .. }) => {
107                // Custom aggregation evaluation
108                if let Some(result) = evaluate_custom_expression(expression, event)? {
109                    *count += result as u64;
110                }
111            }
112            _ => return Err(anyhow!("Mismatched aggregation state and function")),
113        }
114        Ok(())
115    }
116
117    /// Get the current aggregation result
118    pub fn result(&self) -> Result<serde_json::Value> {
119        match self {
120            AggregationState::Count(count) => Ok(serde_json::Value::Number((*count).into())),
121            AggregationState::Sum(sum) => Ok(serde_json::Value::Number(
122                serde_json::Number::from_f64(*sum).unwrap_or(0.into()),
123            )),
124            AggregationState::Average { sum, count } => {
125                if *count > 0 {
126                    let avg = *sum / (*count as f64);
127                    Ok(serde_json::Value::Number(
128                        serde_json::Number::from_f64(avg).unwrap_or(0.into()),
129                    ))
130                } else {
131                    Ok(serde_json::Value::Number(0.into()))
132                }
133            }
134            AggregationState::Min(min) => {
135                if min.is_finite() {
136                    Ok(serde_json::Value::Number(
137                        serde_json::Number::from_f64(*min).unwrap_or(0.into()),
138                    ))
139                } else {
140                    Ok(serde_json::Value::Null)
141                }
142            }
143            AggregationState::Max(max) => {
144                if max.is_finite() {
145                    Ok(serde_json::Value::Number(
146                        serde_json::Number::from_f64(*max).unwrap_or(0.into()),
147                    ))
148                } else {
149                    Ok(serde_json::Value::Null)
150                }
151            }
152            AggregationState::First(event) => Ok(serde_json::to_value(event)?),
153            AggregationState::Last(event) => Ok(serde_json::to_value(event)?),
154            AggregationState::Distinct(set) => Ok(serde_json::Value::Number(set.len().into())),
155        }
156    }
157}
158
159/// Extract numeric field from event
160fn extract_numeric_field(event: &StreamEvent, _field: &str) -> Result<Option<f64>> {
161    // Implementation would depend on StreamEvent structure
162    // This is a simplified version for the actual StreamEvent variants
163    match event {
164        StreamEvent::SparqlUpdate { .. } => Ok(None),
165        StreamEvent::TripleAdded { .. } => Ok(None),
166        StreamEvent::TripleRemoved { .. } => Ok(None),
167        StreamEvent::QuadAdded { .. } => Ok(None),
168        StreamEvent::QuadRemoved { .. } => Ok(None),
169        StreamEvent::GraphCreated { .. } => Ok(None),
170        StreamEvent::GraphCleared { .. } => Ok(None),
171        StreamEvent::GraphDeleted { .. } => Ok(None),
172        StreamEvent::TransactionBegin { .. } => Ok(None),
173        StreamEvent::TransactionCommit { .. } => Ok(None),
174        StreamEvent::TransactionAbort { .. } => Ok(None),
175        _ => Ok(None),
176    }
177}
178
179/// Extract string field from event
180fn extract_string_field(event: &StreamEvent, field: &str) -> Result<Option<String>> {
181    // Implementation would depend on StreamEvent structure
182    // This is a simplified version for the actual StreamEvent variants
183    match event {
184        StreamEvent::TripleAdded {
185            subject,
186            predicate,
187            object,
188            ..
189        } => match field {
190            "subject" => Ok(Some(subject.clone())),
191            "predicate" => Ok(Some(predicate.clone())),
192            "object" => Ok(Some(object.clone())),
193            _ => Ok(None),
194        },
195        StreamEvent::TripleRemoved {
196            subject,
197            predicate,
198            object,
199            ..
200        } => match field {
201            "subject" => Ok(Some(subject.clone())),
202            "predicate" => Ok(Some(predicate.clone())),
203            "object" => Ok(Some(object.clone())),
204            _ => Ok(None),
205        },
206        StreamEvent::QuadAdded {
207            subject,
208            predicate,
209            object,
210            graph,
211            ..
212        } => match field {
213            "subject" => Ok(Some(subject.clone())),
214            "predicate" => Ok(Some(predicate.clone())),
215            "object" => Ok(Some(object.clone())),
216            "graph" => Ok(Some(graph.clone())),
217            _ => Ok(None),
218        },
219        StreamEvent::QuadRemoved {
220            subject,
221            predicate,
222            object,
223            graph,
224            ..
225        } => match field {
226            "subject" => Ok(Some(subject.clone())),
227            "predicate" => Ok(Some(predicate.clone())),
228            "object" => Ok(Some(object.clone())),
229            "graph" => Ok(Some(graph.clone())),
230            _ => Ok(None),
231        },
232        _ => Ok(None),
233    }
234}
235
236/// Evaluate custom expression
237fn evaluate_custom_expression(expression: &str, event: &StreamEvent) -> Result<Option<f64>> {
238    // Parse and evaluate custom expressions
239    // This is a simplified implementation
240    match expression {
241        expr if expr.starts_with("field:") => {
242            let field = expr
243                .strip_prefix("field:")
244                .expect("strip_prefix should succeed after starts_with check");
245            extract_numeric_field(event, field)
246        }
247        expr if expr.starts_with("const:") => {
248            let value = expr
249                .strip_prefix("const:")
250                .expect("strip_prefix should succeed after starts_with check");
251            match value.parse::<f64>() {
252                Ok(n) => Ok(Some(n)),
253                Err(_) => Ok(None),
254            }
255        }
256        expr if expr.contains('+') => {
257            let parts: Vec<&str> = expr.split('+').collect();
258            if parts.len() == 2 {
259                let left = evaluate_custom_expression(parts[0].trim(), event)?;
260                let right = evaluate_custom_expression(parts[1].trim(), event)?;
261                match (left, right) {
262                    (Some(l), Some(r)) => Ok(Some(l + r)),
263                    _ => Ok(None),
264                }
265            } else {
266                Ok(None)
267            }
268        }
269        expr if expr.contains('*') => {
270            let parts: Vec<&str> = expr.split('*').collect();
271            if parts.len() == 2 {
272                let left = evaluate_custom_expression(parts[0].trim(), event)?;
273                let right = evaluate_custom_expression(parts[1].trim(), event)?;
274                match (left, right) {
275                    (Some(l), Some(r)) => Ok(Some(l * r)),
276                    _ => Ok(None),
277                }
278            } else {
279                Ok(None)
280            }
281        }
282        _ => Ok(None),
283    }
284}
285
286/// Aggregation manager for handling multiple aggregations
287pub struct AggregationManager {
288    aggregations: HashMap<String, (AggregateFunction, AggregationState)>,
289}
290
291impl AggregationManager {
292    /// Create new aggregation manager
293    pub fn new() -> Self {
294        Self {
295            aggregations: HashMap::new(),
296        }
297    }
298
299    /// Add aggregation function
300    pub fn add_aggregation(&mut self, name: String, function: AggregateFunction) {
301        let state = AggregationState::new(&function);
302        self.aggregations.insert(name, (function, state));
303    }
304
305    /// Update all aggregations with new event
306    pub fn update(&mut self, event: &StreamEvent) -> Result<()> {
307        for (_, (function, state)) in self.aggregations.iter_mut() {
308            state.update(event, function)?;
309        }
310        Ok(())
311    }
312
313    /// Get all aggregation results
314    pub fn results(&self) -> Result<HashMap<String, serde_json::Value>> {
315        let mut results = HashMap::new();
316        for (name, (_, state)) in &self.aggregations {
317            results.insert(name.clone(), state.result()?);
318        }
319        Ok(results)
320    }
321}
322
323impl Default for AggregationManager {
324    fn default() -> Self {
325        Self::new()
326    }
327}