1use crate::StreamEvent;
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum AggregateFunction {
16 Count,
17 Sum { field: String },
18 Average { field: String },
19 Min { field: String },
20 Max { field: String },
21 First,
22 Last,
23 Distinct { field: String },
24 Custom { name: String, expression: String },
25}
26
27#[derive(Debug, Clone)]
29pub enum AggregationState {
30 Count(u64),
31 Sum(f64),
32 Average { sum: f64, count: u64 },
33 Min(f64),
34 Max(f64),
35 First(StreamEvent),
36 Last(StreamEvent),
37 Distinct(HashSet<String>),
38}
39
40impl AggregationState {
41 pub fn new(function: &AggregateFunction) -> Self {
43 match function {
44 AggregateFunction::Count => AggregationState::Count(0),
45 AggregateFunction::Sum { .. } => AggregationState::Sum(0.0),
46 AggregateFunction::Average { .. } => AggregationState::Average { sum: 0.0, count: 0 },
47 AggregateFunction::Min { .. } => AggregationState::Min(f64::INFINITY),
48 AggregateFunction::Max { .. } => AggregationState::Max(f64::NEG_INFINITY),
49 AggregateFunction::First => AggregationState::First(StreamEvent::TripleAdded {
50 subject: String::new(),
51 predicate: String::new(),
52 object: String::new(),
53 graph: None,
54 metadata: crate::event::EventMetadata::default(),
55 }),
56 AggregateFunction::Last => AggregationState::Last(StreamEvent::TripleAdded {
57 subject: String::new(),
58 predicate: String::new(),
59 object: String::new(),
60 graph: None,
61 metadata: crate::event::EventMetadata::default(),
62 }),
63 AggregateFunction::Distinct { .. } => AggregationState::Distinct(HashSet::new()),
64 AggregateFunction::Custom { .. } => AggregationState::Count(0), }
66 }
67
68 pub fn update(&mut self, event: &StreamEvent, function: &AggregateFunction) -> Result<()> {
70 match (self, function) {
71 (AggregationState::Count(count), AggregateFunction::Count) => {
72 *count += 1;
73 }
74 (AggregationState::Sum(sum), AggregateFunction::Sum { field }) => {
75 if let Some(value) = extract_numeric_field(event, field)? {
76 *sum += value;
77 }
78 }
79 (AggregationState::Average { sum, count }, AggregateFunction::Average { field }) => {
80 if let Some(value) = extract_numeric_field(event, field)? {
81 *sum += value;
82 *count += 1;
83 }
84 }
85 (AggregationState::Min(min), AggregateFunction::Min { field }) => {
86 if let Some(value) = extract_numeric_field(event, field)? {
87 *min = min.min(value);
88 }
89 }
90 (AggregationState::Max(max), AggregateFunction::Max { field }) => {
91 if let Some(value) = extract_numeric_field(event, field)? {
92 *max = max.max(value);
93 }
94 }
95 (AggregationState::First(first), AggregateFunction::First) => {
96 *first = event.clone();
97 }
98 (AggregationState::Last(last), AggregateFunction::Last) => {
99 *last = event.clone();
100 }
101 (AggregationState::Distinct(set), AggregateFunction::Distinct { field }) => {
102 if let Some(value) = extract_string_field(event, field)? {
103 set.insert(value);
104 }
105 }
106 (AggregationState::Count(count), AggregateFunction::Custom { expression, .. }) => {
107 if let Some(result) = evaluate_custom_expression(expression, event)? {
109 *count += result as u64;
110 }
111 }
112 _ => return Err(anyhow!("Mismatched aggregation state and function")),
113 }
114 Ok(())
115 }
116
117 pub fn result(&self) -> Result<serde_json::Value> {
119 match self {
120 AggregationState::Count(count) => Ok(serde_json::Value::Number((*count).into())),
121 AggregationState::Sum(sum) => Ok(serde_json::Value::Number(
122 serde_json::Number::from_f64(*sum).unwrap_or(0.into()),
123 )),
124 AggregationState::Average { sum, count } => {
125 if *count > 0 {
126 let avg = *sum / (*count as f64);
127 Ok(serde_json::Value::Number(
128 serde_json::Number::from_f64(avg).unwrap_or(0.into()),
129 ))
130 } else {
131 Ok(serde_json::Value::Number(0.into()))
132 }
133 }
134 AggregationState::Min(min) => {
135 if min.is_finite() {
136 Ok(serde_json::Value::Number(
137 serde_json::Number::from_f64(*min).unwrap_or(0.into()),
138 ))
139 } else {
140 Ok(serde_json::Value::Null)
141 }
142 }
143 AggregationState::Max(max) => {
144 if max.is_finite() {
145 Ok(serde_json::Value::Number(
146 serde_json::Number::from_f64(*max).unwrap_or(0.into()),
147 ))
148 } else {
149 Ok(serde_json::Value::Null)
150 }
151 }
152 AggregationState::First(event) => Ok(serde_json::to_value(event)?),
153 AggregationState::Last(event) => Ok(serde_json::to_value(event)?),
154 AggregationState::Distinct(set) => Ok(serde_json::Value::Number(set.len().into())),
155 }
156 }
157}
158
159fn extract_numeric_field(event: &StreamEvent, _field: &str) -> Result<Option<f64>> {
161 match event {
164 StreamEvent::SparqlUpdate { .. } => Ok(None),
165 StreamEvent::TripleAdded { .. } => Ok(None),
166 StreamEvent::TripleRemoved { .. } => Ok(None),
167 StreamEvent::QuadAdded { .. } => Ok(None),
168 StreamEvent::QuadRemoved { .. } => Ok(None),
169 StreamEvent::GraphCreated { .. } => Ok(None),
170 StreamEvent::GraphCleared { .. } => Ok(None),
171 StreamEvent::GraphDeleted { .. } => Ok(None),
172 StreamEvent::TransactionBegin { .. } => Ok(None),
173 StreamEvent::TransactionCommit { .. } => Ok(None),
174 StreamEvent::TransactionAbort { .. } => Ok(None),
175 _ => Ok(None),
176 }
177}
178
179fn extract_string_field(event: &StreamEvent, field: &str) -> Result<Option<String>> {
181 match event {
184 StreamEvent::TripleAdded {
185 subject,
186 predicate,
187 object,
188 ..
189 } => match field {
190 "subject" => Ok(Some(subject.clone())),
191 "predicate" => Ok(Some(predicate.clone())),
192 "object" => Ok(Some(object.clone())),
193 _ => Ok(None),
194 },
195 StreamEvent::TripleRemoved {
196 subject,
197 predicate,
198 object,
199 ..
200 } => match field {
201 "subject" => Ok(Some(subject.clone())),
202 "predicate" => Ok(Some(predicate.clone())),
203 "object" => Ok(Some(object.clone())),
204 _ => Ok(None),
205 },
206 StreamEvent::QuadAdded {
207 subject,
208 predicate,
209 object,
210 graph,
211 ..
212 } => match field {
213 "subject" => Ok(Some(subject.clone())),
214 "predicate" => Ok(Some(predicate.clone())),
215 "object" => Ok(Some(object.clone())),
216 "graph" => Ok(Some(graph.clone())),
217 _ => Ok(None),
218 },
219 StreamEvent::QuadRemoved {
220 subject,
221 predicate,
222 object,
223 graph,
224 ..
225 } => match field {
226 "subject" => Ok(Some(subject.clone())),
227 "predicate" => Ok(Some(predicate.clone())),
228 "object" => Ok(Some(object.clone())),
229 "graph" => Ok(Some(graph.clone())),
230 _ => Ok(None),
231 },
232 _ => Ok(None),
233 }
234}
235
236fn evaluate_custom_expression(expression: &str, event: &StreamEvent) -> Result<Option<f64>> {
238 match expression {
241 expr if expr.starts_with("field:") => {
242 let field = expr
243 .strip_prefix("field:")
244 .expect("strip_prefix should succeed after starts_with check");
245 extract_numeric_field(event, field)
246 }
247 expr if expr.starts_with("const:") => {
248 let value = expr
249 .strip_prefix("const:")
250 .expect("strip_prefix should succeed after starts_with check");
251 match value.parse::<f64>() {
252 Ok(n) => Ok(Some(n)),
253 Err(_) => Ok(None),
254 }
255 }
256 expr if expr.contains('+') => {
257 let parts: Vec<&str> = expr.split('+').collect();
258 if parts.len() == 2 {
259 let left = evaluate_custom_expression(parts[0].trim(), event)?;
260 let right = evaluate_custom_expression(parts[1].trim(), event)?;
261 match (left, right) {
262 (Some(l), Some(r)) => Ok(Some(l + r)),
263 _ => Ok(None),
264 }
265 } else {
266 Ok(None)
267 }
268 }
269 expr if expr.contains('*') => {
270 let parts: Vec<&str> = expr.split('*').collect();
271 if parts.len() == 2 {
272 let left = evaluate_custom_expression(parts[0].trim(), event)?;
273 let right = evaluate_custom_expression(parts[1].trim(), event)?;
274 match (left, right) {
275 (Some(l), Some(r)) => Ok(Some(l * r)),
276 _ => Ok(None),
277 }
278 } else {
279 Ok(None)
280 }
281 }
282 _ => Ok(None),
283 }
284}
285
286pub struct AggregationManager {
288 aggregations: HashMap<String, (AggregateFunction, AggregationState)>,
289}
290
291impl AggregationManager {
292 pub fn new() -> Self {
294 Self {
295 aggregations: HashMap::new(),
296 }
297 }
298
299 pub fn add_aggregation(&mut self, name: String, function: AggregateFunction) {
301 let state = AggregationState::new(&function);
302 self.aggregations.insert(name, (function, state));
303 }
304
305 pub fn update(&mut self, event: &StreamEvent) -> Result<()> {
307 for (_, (function, state)) in self.aggregations.iter_mut() {
308 state.update(event, function)?;
309 }
310 Ok(())
311 }
312
313 pub fn results(&self) -> Result<HashMap<String, serde_json::Value>> {
315 let mut results = HashMap::new();
316 for (name, (_, state)) in &self.aggregations {
317 results.insert(name.clone(), state.result()?);
318 }
319 Ok(results)
320 }
321}
322
323impl Default for AggregationManager {
324 fn default() -> Self {
325 Self::new()
326 }
327}