Skip to main content

drasi_core/evaluation/functions/aggregation/
collect.rs

1// Copyright 2024 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, sync::Arc};
16
17use async_trait::async_trait;
18use drasi_query_ast::ast;
19
20use crate::{
21    evaluation::{
22        variable_value::VariableValue, ExpressionEvaluationContext, FunctionError,
23        FunctionEvaluationError,
24    },
25    interface::ResultIndex,
26    models::ElementValue,
27};
28
29use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
30
31/// Collect aggregation function that collects all values into a list
32pub struct Collect {}
33
34#[async_trait]
35impl AggregatingFunction for Collect {
36    fn initialize_accumulator(
37        &self,
38        _context: &ExpressionEvaluationContext,
39        _expression: &ast::FunctionExpression,
40        _grouping_keys: &Vec<VariableValue>,
41        _index: Arc<dyn ResultIndex>,
42    ) -> Accumulator {
43        // Initialize with an empty list
44        Accumulator::Value(ValueAccumulator::Value(ElementValue::List(vec![])))
45    }
46
47    fn accumulator_is_lazy(&self) -> bool {
48        false
49    }
50
51    async fn apply(
52        &self,
53        _context: &ExpressionEvaluationContext,
54        args: Vec<VariableValue>,
55        accumulator: &mut Accumulator,
56    ) -> Result<VariableValue, FunctionError> {
57        if args.len() != 1 {
58            return Err(FunctionError {
59                function_name: "Collect".to_string(),
60                error: FunctionEvaluationError::InvalidArgumentCount,
61            });
62        }
63
64        let list = match accumulator {
65            Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
66            _ => {
67                return Err(FunctionError {
68                    function_name: "Collect".to_string(),
69                    error: FunctionEvaluationError::CorruptData,
70                })
71            }
72        };
73
74        // Convert VariableValue to ElementValue and add to list
75        // Skip null values (similar to how other aggregation functions handle nulls)
76        if !args[0].is_null() {
77            if let Ok(elem_value) = (&args[0]).try_into() {
78                list.push(elem_value);
79            }
80        }
81
82        // Return current list as VariableValue
83        Ok((&ElementValue::List(list.clone())).into())
84    }
85
86    async fn revert(
87        &self,
88        _context: &ExpressionEvaluationContext,
89        args: Vec<VariableValue>,
90        accumulator: &mut Accumulator,
91    ) -> Result<VariableValue, FunctionError> {
92        if args.len() != 1 {
93            return Err(FunctionError {
94                function_name: "Collect".to_string(),
95                error: FunctionEvaluationError::InvalidArgumentCount,
96            });
97        }
98
99        let list = match accumulator {
100            Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
101            _ => {
102                return Err(FunctionError {
103                    function_name: "Collect".to_string(),
104                    error: FunctionEvaluationError::CorruptData,
105                })
106            }
107        };
108
109        // For revert, we need to remove the value from the list
110        // This is tricky because we need to find and remove the exact value
111        // For now, we'll remove the first occurrence
112        if !args[0].is_null() {
113            if let Ok(elem_value) = (&args[0]).try_into() {
114                // Find and remove the first matching value
115                if let Some(pos) = list.iter().position(|x| x == &elem_value) {
116                    list.remove(pos);
117                }
118            }
119        }
120
121        // Return current list as VariableValue
122        Ok((&ElementValue::List(list.clone())).into())
123    }
124
125    async fn snapshot(
126        &self,
127        _context: &ExpressionEvaluationContext,
128        _args: Vec<VariableValue>,
129        accumulator: &Accumulator,
130    ) -> Result<VariableValue, FunctionError> {
131        let list = match accumulator {
132            Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
133            _ => {
134                return Err(FunctionError {
135                    function_name: "Collect".to_string(),
136                    error: FunctionEvaluationError::CorruptData,
137                })
138            }
139        };
140
141        Ok((&ElementValue::List(list.clone())).into())
142    }
143}
144
145impl Debug for Collect {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        write!(f, "Collect")
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::{
155        evaluation::{
156            context::QueryVariables, variable_value::VariableValue, ExpressionEvaluationContext,
157            InstantQueryClock,
158        },
159        in_memory_index::in_memory_result_index::InMemoryResultIndex,
160    };
161    use drasi_query_ast::ast;
162
163    #[tokio::test]
164    async fn test_collect_basic() {
165        let collect = Collect {};
166        let index = Arc::new(InMemoryResultIndex::new());
167        let variables = QueryVariables::new();
168        let context =
169            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
170        let expression = ast::FunctionExpression {
171            name: "collect".into(),
172            args: vec![],
173            position_in_query: 10,
174        };
175
176        // Initialize accumulator
177        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
178
179        // Apply some values
180        let val1 = VariableValue::String("hello".into());
181        let val2 = VariableValue::Integer(42.into());
182        let val3 = VariableValue::String("world".into());
183
184        let _ = collect
185            .apply(&context, vec![val1.clone()], &mut accumulator)
186            .await
187            .unwrap();
188        let _ = collect
189            .apply(&context, vec![val2.clone()], &mut accumulator)
190            .await
191            .unwrap();
192        let _ = collect
193            .apply(&context, vec![val3.clone()], &mut accumulator)
194            .await
195            .unwrap();
196
197        // Snapshot should return all values
198        let result = collect
199            .snapshot(&context, vec![], &accumulator)
200            .await
201            .unwrap();
202
203        if let VariableValue::List(list) = result {
204            assert_eq!(list.len(), 3);
205            assert_eq!(list[0], val1);
206            assert_eq!(list[1], val2);
207            assert_eq!(list[2], val3);
208        } else {
209            panic!("Expected list result");
210        }
211    }
212
213    #[tokio::test]
214    async fn test_collect_with_revert() {
215        let collect = Collect {};
216        let index = Arc::new(InMemoryResultIndex::new());
217        let variables = QueryVariables::new();
218        let context =
219            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
220        let expression = ast::FunctionExpression {
221            name: "collect".into(),
222            args: vec![],
223            position_in_query: 10,
224        };
225
226        // Initialize accumulator
227        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
228
229        // Apply some values
230        let val1 = VariableValue::String("hello".into());
231        let val2 = VariableValue::Integer(42.into());
232
233        let _ = collect
234            .apply(&context, vec![val1.clone()], &mut accumulator)
235            .await
236            .unwrap();
237        let _ = collect
238            .apply(&context, vec![val2.clone()], &mut accumulator)
239            .await
240            .unwrap();
241
242        // Revert one value
243        let _ = collect
244            .revert(&context, vec![val1.clone()], &mut accumulator)
245            .await
246            .unwrap();
247
248        // Snapshot should return only remaining value
249        let result = collect
250            .snapshot(&context, vec![], &accumulator)
251            .await
252            .unwrap();
253
254        if let VariableValue::List(list) = result {
255            assert_eq!(list.len(), 1);
256            assert_eq!(list[0], val2);
257        } else {
258            panic!("Expected list result");
259        }
260    }
261
262    #[tokio::test]
263    async fn test_collect_null_values() {
264        let collect = Collect {};
265        let index = Arc::new(InMemoryResultIndex::new());
266        let variables = QueryVariables::new();
267        let context =
268            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
269        let expression = ast::FunctionExpression {
270            name: "collect".into(),
271            args: vec![],
272            position_in_query: 10,
273        };
274
275        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
276
277        // Apply null values - they should be ignored
278        let _ = collect
279            .apply(&context, vec![VariableValue::Null], &mut accumulator)
280            .await
281            .unwrap();
282        let _ = collect
283            .apply(
284                &context,
285                vec![VariableValue::Integer(42.into())],
286                &mut accumulator,
287            )
288            .await
289            .unwrap();
290        let _ = collect
291            .apply(&context, vec![VariableValue::Null], &mut accumulator)
292            .await
293            .unwrap();
294        let _ = collect
295            .apply(
296                &context,
297                vec![VariableValue::String("test".into())],
298                &mut accumulator,
299            )
300            .await
301            .unwrap();
302
303        let result = collect
304            .snapshot(&context, vec![], &accumulator)
305            .await
306            .unwrap();
307
308        if let VariableValue::List(list) = result {
309            assert_eq!(list.len(), 2, "Null values should be ignored");
310            assert_eq!(list[0], VariableValue::Integer(42.into()));
311            assert_eq!(list[1], VariableValue::String("test".into()));
312        } else {
313            panic!("Expected list result");
314        }
315    }
316
317    #[tokio::test]
318    async fn test_collect_empty_list() {
319        let collect = Collect {};
320        let index = Arc::new(InMemoryResultIndex::new());
321        let variables = QueryVariables::new();
322        let context =
323            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
324        let expression = ast::FunctionExpression {
325            name: "collect".into(),
326            args: vec![],
327            position_in_query: 10,
328        };
329
330        let accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
331
332        // Snapshot of empty accumulator should return empty list
333        let result = collect
334            .snapshot(&context, vec![], &accumulator)
335            .await
336            .unwrap();
337
338        if let VariableValue::List(list) = result {
339            assert_eq!(list.len(), 0, "Empty accumulator should return empty list");
340        } else {
341            panic!("Expected list result");
342        }
343    }
344
345    #[tokio::test]
346    async fn test_collect_duplicate_values() {
347        let collect = Collect {};
348        let index = Arc::new(InMemoryResultIndex::new());
349        let variables = QueryVariables::new();
350        let context =
351            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
352        let expression = ast::FunctionExpression {
353            name: "collect".into(),
354            args: vec![],
355            position_in_query: 10,
356        };
357
358        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
359
360        // Apply duplicate values - they should all be collected
361        let val = VariableValue::Integer(42.into());
362        let _ = collect
363            .apply(&context, vec![val.clone()], &mut accumulator)
364            .await
365            .unwrap();
366        let _ = collect
367            .apply(&context, vec![val.clone()], &mut accumulator)
368            .await
369            .unwrap();
370        let _ = collect
371            .apply(&context, vec![val.clone()], &mut accumulator)
372            .await
373            .unwrap();
374
375        let result = collect
376            .snapshot(&context, vec![], &accumulator)
377            .await
378            .unwrap();
379
380        if let VariableValue::List(list) = result {
381            assert_eq!(list.len(), 3, "Duplicate values should all be collected");
382            assert_eq!(list[0], val);
383            assert_eq!(list[1], val);
384            assert_eq!(list[2], val);
385        } else {
386            panic!("Expected list result");
387        }
388    }
389
390    #[tokio::test]
391    async fn test_collect_different_types() {
392        let collect = Collect {};
393        let index = Arc::new(InMemoryResultIndex::new());
394        let variables = QueryVariables::new();
395        let context =
396            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
397        let expression = ast::FunctionExpression {
398            name: "collect".into(),
399            args: vec![],
400            position_in_query: 10,
401        };
402
403        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
404
405        // Apply values of different types
406        let _ = collect
407            .apply(
408                &context,
409                vec![VariableValue::Integer(42.into())],
410                &mut accumulator,
411            )
412            .await
413            .unwrap();
414        let _ = collect
415            .apply(
416                &context,
417                vec![VariableValue::Float(3.125.into())],
418                &mut accumulator,
419            )
420            .await
421            .unwrap();
422        let _ = collect
423            .apply(
424                &context,
425                vec![VariableValue::String("hello".into())],
426                &mut accumulator,
427            )
428            .await
429            .unwrap();
430        let _ = collect
431            .apply(&context, vec![VariableValue::Bool(true)], &mut accumulator)
432            .await
433            .unwrap();
434
435        let result = collect
436            .snapshot(&context, vec![], &accumulator)
437            .await
438            .unwrap();
439
440        if let VariableValue::List(list) = result {
441            assert_eq!(list.len(), 4, "Should collect values of different types");
442            assert_eq!(list[0], VariableValue::Integer(42.into()));
443            assert_eq!(list[1], VariableValue::Float(3.125.into()));
444            assert_eq!(list[2], VariableValue::String("hello".into()));
445            assert_eq!(list[3], VariableValue::Bool(true));
446        } else {
447            panic!("Expected list result");
448        }
449    }
450
451    #[tokio::test]
452    async fn test_collect_revert_multiple() {
453        let collect = Collect {};
454        let index = Arc::new(InMemoryResultIndex::new());
455        let variables = QueryVariables::new();
456        let context =
457            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
458        let expression = ast::FunctionExpression {
459            name: "collect".into(),
460            args: vec![],
461            position_in_query: 10,
462        };
463
464        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
465
466        // Apply values including duplicates
467        let val1 = VariableValue::Integer(1.into());
468        let val2 = VariableValue::Integer(2.into());
469
470        let _ = collect
471            .apply(&context, vec![val1.clone()], &mut accumulator)
472            .await
473            .unwrap();
474        let _ = collect
475            .apply(&context, vec![val2.clone()], &mut accumulator)
476            .await
477            .unwrap();
478        let _ = collect
479            .apply(&context, vec![val1.clone()], &mut accumulator)
480            .await
481            .unwrap();
482        let _ = collect
483            .apply(&context, vec![val2.clone()], &mut accumulator)
484            .await
485            .unwrap();
486
487        // Revert one instance of val1
488        let _ = collect
489            .revert(&context, vec![val1.clone()], &mut accumulator)
490            .await
491            .unwrap();
492
493        let result = collect
494            .snapshot(&context, vec![], &accumulator)
495            .await
496            .unwrap();
497
498        if let VariableValue::List(list) = result {
499            assert_eq!(list.len(), 3, "Should have removed only first occurrence");
500            assert_eq!(list[0], val2); // First val1 was removed
501            assert_eq!(list[1], val1); // Second val1 remains
502            assert_eq!(list[2], val2);
503        } else {
504            panic!("Expected list result");
505        }
506    }
507
508    #[tokio::test]
509    async fn test_collect_revert_nonexistent() {
510        let collect = Collect {};
511        let index = Arc::new(InMemoryResultIndex::new());
512        let variables = QueryVariables::new();
513        let context =
514            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
515        let expression = ast::FunctionExpression {
516            name: "collect".into(),
517            args: vec![],
518            position_in_query: 10,
519        };
520
521        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
522
523        let val1 = VariableValue::Integer(1.into());
524        let val2 = VariableValue::Integer(2.into());
525
526        let _ = collect
527            .apply(&context, vec![val1.clone()], &mut accumulator)
528            .await
529            .unwrap();
530
531        // Try to revert a value that doesn't exist
532        let _ = collect
533            .revert(&context, vec![val2.clone()], &mut accumulator)
534            .await
535            .unwrap();
536
537        let result = collect
538            .snapshot(&context, vec![], &accumulator)
539            .await
540            .unwrap();
541
542        if let VariableValue::List(list) = result {
543            assert_eq!(
544                list.len(),
545                1,
546                "Should not affect list if value doesn't exist"
547            );
548            assert_eq!(list[0], val1);
549        } else {
550            panic!("Expected list result");
551        }
552    }
553
554    #[tokio::test]
555    async fn test_collect_error_cases() {
556        let collect = Collect {};
557        let index = Arc::new(InMemoryResultIndex::new());
558        let variables = QueryVariables::new();
559        let context =
560            ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
561        let expression = ast::FunctionExpression {
562            name: "collect".into(),
563            args: vec![],
564            position_in_query: 10,
565        };
566
567        let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
568
569        // Test with wrong number of arguments
570        let result = collect.apply(&context, vec![], &mut accumulator).await;
571        assert!(result.is_err(), "Should error with no arguments");
572
573        let result = collect
574            .apply(
575                &context,
576                vec![
577                    VariableValue::Integer(1.into()),
578                    VariableValue::Integer(2.into()),
579                ],
580                &mut accumulator,
581            )
582            .await;
583        assert!(result.is_err(), "Should error with too many arguments");
584    }
585}