1use std::{fmt::Debug, sync::Arc};
16
17use async_trait::async_trait;
18use drasi_query_ast::ast;
19
20use crate::{
21 evaluation::{
22 variable_value::VariableValue, ExpressionEvaluationContext, FunctionError,
23 FunctionEvaluationError,
24 },
25 interface::ResultIndex,
26 models::ElementValue,
27};
28
29use super::{super::AggregatingFunction, Accumulator, ValueAccumulator};
30
31pub struct Collect {}
33
34#[async_trait]
35impl AggregatingFunction for Collect {
36 fn initialize_accumulator(
37 &self,
38 _context: &ExpressionEvaluationContext,
39 _expression: &ast::FunctionExpression,
40 _grouping_keys: &Vec<VariableValue>,
41 _index: Arc<dyn ResultIndex>,
42 ) -> Accumulator {
43 Accumulator::Value(ValueAccumulator::Value(ElementValue::List(vec![])))
45 }
46
47 fn accumulator_is_lazy(&self) -> bool {
48 false
49 }
50
51 async fn apply(
52 &self,
53 _context: &ExpressionEvaluationContext,
54 args: Vec<VariableValue>,
55 accumulator: &mut Accumulator,
56 ) -> Result<VariableValue, FunctionError> {
57 if args.len() != 1 {
58 return Err(FunctionError {
59 function_name: "Collect".to_string(),
60 error: FunctionEvaluationError::InvalidArgumentCount,
61 });
62 }
63
64 let list = match accumulator {
65 Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
66 _ => {
67 return Err(FunctionError {
68 function_name: "Collect".to_string(),
69 error: FunctionEvaluationError::CorruptData,
70 })
71 }
72 };
73
74 if !args[0].is_null() {
77 if let Ok(elem_value) = (&args[0]).try_into() {
78 list.push(elem_value);
79 }
80 }
81
82 Ok((&ElementValue::List(list.clone())).into())
84 }
85
86 async fn revert(
87 &self,
88 _context: &ExpressionEvaluationContext,
89 args: Vec<VariableValue>,
90 accumulator: &mut Accumulator,
91 ) -> Result<VariableValue, FunctionError> {
92 if args.len() != 1 {
93 return Err(FunctionError {
94 function_name: "Collect".to_string(),
95 error: FunctionEvaluationError::InvalidArgumentCount,
96 });
97 }
98
99 let list = match accumulator {
100 Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
101 _ => {
102 return Err(FunctionError {
103 function_name: "Collect".to_string(),
104 error: FunctionEvaluationError::CorruptData,
105 })
106 }
107 };
108
109 if !args[0].is_null() {
113 if let Ok(elem_value) = (&args[0]).try_into() {
114 if let Some(pos) = list.iter().position(|x| x == &elem_value) {
116 list.remove(pos);
117 }
118 }
119 }
120
121 Ok((&ElementValue::List(list.clone())).into())
123 }
124
125 async fn snapshot(
126 &self,
127 _context: &ExpressionEvaluationContext,
128 _args: Vec<VariableValue>,
129 accumulator: &Accumulator,
130 ) -> Result<VariableValue, FunctionError> {
131 let list = match accumulator {
132 Accumulator::Value(ValueAccumulator::Value(ElementValue::List(list))) => list,
133 _ => {
134 return Err(FunctionError {
135 function_name: "Collect".to_string(),
136 error: FunctionEvaluationError::CorruptData,
137 })
138 }
139 };
140
141 Ok((&ElementValue::List(list.clone())).into())
142 }
143}
144
145impl Debug for Collect {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 write!(f, "Collect")
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::{
155 evaluation::{
156 context::QueryVariables, variable_value::VariableValue, ExpressionEvaluationContext,
157 InstantQueryClock,
158 },
159 in_memory_index::in_memory_result_index::InMemoryResultIndex,
160 };
161 use drasi_query_ast::ast;
162
163 #[tokio::test]
164 async fn test_collect_basic() {
165 let collect = Collect {};
166 let index = Arc::new(InMemoryResultIndex::new());
167 let variables = QueryVariables::new();
168 let context =
169 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
170 let expression = ast::FunctionExpression {
171 name: "collect".into(),
172 args: vec![],
173 position_in_query: 10,
174 };
175
176 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
178
179 let val1 = VariableValue::String("hello".into());
181 let val2 = VariableValue::Integer(42.into());
182 let val3 = VariableValue::String("world".into());
183
184 let _ = collect
185 .apply(&context, vec![val1.clone()], &mut accumulator)
186 .await
187 .unwrap();
188 let _ = collect
189 .apply(&context, vec![val2.clone()], &mut accumulator)
190 .await
191 .unwrap();
192 let _ = collect
193 .apply(&context, vec![val3.clone()], &mut accumulator)
194 .await
195 .unwrap();
196
197 let result = collect
199 .snapshot(&context, vec![], &accumulator)
200 .await
201 .unwrap();
202
203 if let VariableValue::List(list) = result {
204 assert_eq!(list.len(), 3);
205 assert_eq!(list[0], val1);
206 assert_eq!(list[1], val2);
207 assert_eq!(list[2], val3);
208 } else {
209 panic!("Expected list result");
210 }
211 }
212
213 #[tokio::test]
214 async fn test_collect_with_revert() {
215 let collect = Collect {};
216 let index = Arc::new(InMemoryResultIndex::new());
217 let variables = QueryVariables::new();
218 let context =
219 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
220 let expression = ast::FunctionExpression {
221 name: "collect".into(),
222 args: vec![],
223 position_in_query: 10,
224 };
225
226 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
228
229 let val1 = VariableValue::String("hello".into());
231 let val2 = VariableValue::Integer(42.into());
232
233 let _ = collect
234 .apply(&context, vec![val1.clone()], &mut accumulator)
235 .await
236 .unwrap();
237 let _ = collect
238 .apply(&context, vec![val2.clone()], &mut accumulator)
239 .await
240 .unwrap();
241
242 let _ = collect
244 .revert(&context, vec![val1.clone()], &mut accumulator)
245 .await
246 .unwrap();
247
248 let result = collect
250 .snapshot(&context, vec![], &accumulator)
251 .await
252 .unwrap();
253
254 if let VariableValue::List(list) = result {
255 assert_eq!(list.len(), 1);
256 assert_eq!(list[0], val2);
257 } else {
258 panic!("Expected list result");
259 }
260 }
261
262 #[tokio::test]
263 async fn test_collect_null_values() {
264 let collect = Collect {};
265 let index = Arc::new(InMemoryResultIndex::new());
266 let variables = QueryVariables::new();
267 let context =
268 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
269 let expression = ast::FunctionExpression {
270 name: "collect".into(),
271 args: vec![],
272 position_in_query: 10,
273 };
274
275 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
276
277 let _ = collect
279 .apply(&context, vec![VariableValue::Null], &mut accumulator)
280 .await
281 .unwrap();
282 let _ = collect
283 .apply(
284 &context,
285 vec![VariableValue::Integer(42.into())],
286 &mut accumulator,
287 )
288 .await
289 .unwrap();
290 let _ = collect
291 .apply(&context, vec![VariableValue::Null], &mut accumulator)
292 .await
293 .unwrap();
294 let _ = collect
295 .apply(
296 &context,
297 vec![VariableValue::String("test".into())],
298 &mut accumulator,
299 )
300 .await
301 .unwrap();
302
303 let result = collect
304 .snapshot(&context, vec![], &accumulator)
305 .await
306 .unwrap();
307
308 if let VariableValue::List(list) = result {
309 assert_eq!(list.len(), 2, "Null values should be ignored");
310 assert_eq!(list[0], VariableValue::Integer(42.into()));
311 assert_eq!(list[1], VariableValue::String("test".into()));
312 } else {
313 panic!("Expected list result");
314 }
315 }
316
317 #[tokio::test]
318 async fn test_collect_empty_list() {
319 let collect = Collect {};
320 let index = Arc::new(InMemoryResultIndex::new());
321 let variables = QueryVariables::new();
322 let context =
323 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
324 let expression = ast::FunctionExpression {
325 name: "collect".into(),
326 args: vec![],
327 position_in_query: 10,
328 };
329
330 let accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
331
332 let result = collect
334 .snapshot(&context, vec![], &accumulator)
335 .await
336 .unwrap();
337
338 if let VariableValue::List(list) = result {
339 assert_eq!(list.len(), 0, "Empty accumulator should return empty list");
340 } else {
341 panic!("Expected list result");
342 }
343 }
344
345 #[tokio::test]
346 async fn test_collect_duplicate_values() {
347 let collect = Collect {};
348 let index = Arc::new(InMemoryResultIndex::new());
349 let variables = QueryVariables::new();
350 let context =
351 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
352 let expression = ast::FunctionExpression {
353 name: "collect".into(),
354 args: vec![],
355 position_in_query: 10,
356 };
357
358 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
359
360 let val = VariableValue::Integer(42.into());
362 let _ = collect
363 .apply(&context, vec![val.clone()], &mut accumulator)
364 .await
365 .unwrap();
366 let _ = collect
367 .apply(&context, vec![val.clone()], &mut accumulator)
368 .await
369 .unwrap();
370 let _ = collect
371 .apply(&context, vec![val.clone()], &mut accumulator)
372 .await
373 .unwrap();
374
375 let result = collect
376 .snapshot(&context, vec![], &accumulator)
377 .await
378 .unwrap();
379
380 if let VariableValue::List(list) = result {
381 assert_eq!(list.len(), 3, "Duplicate values should all be collected");
382 assert_eq!(list[0], val);
383 assert_eq!(list[1], val);
384 assert_eq!(list[2], val);
385 } else {
386 panic!("Expected list result");
387 }
388 }
389
390 #[tokio::test]
391 async fn test_collect_different_types() {
392 let collect = Collect {};
393 let index = Arc::new(InMemoryResultIndex::new());
394 let variables = QueryVariables::new();
395 let context =
396 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
397 let expression = ast::FunctionExpression {
398 name: "collect".into(),
399 args: vec![],
400 position_in_query: 10,
401 };
402
403 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
404
405 let _ = collect
407 .apply(
408 &context,
409 vec![VariableValue::Integer(42.into())],
410 &mut accumulator,
411 )
412 .await
413 .unwrap();
414 let _ = collect
415 .apply(
416 &context,
417 vec![VariableValue::Float(3.125.into())],
418 &mut accumulator,
419 )
420 .await
421 .unwrap();
422 let _ = collect
423 .apply(
424 &context,
425 vec![VariableValue::String("hello".into())],
426 &mut accumulator,
427 )
428 .await
429 .unwrap();
430 let _ = collect
431 .apply(&context, vec![VariableValue::Bool(true)], &mut accumulator)
432 .await
433 .unwrap();
434
435 let result = collect
436 .snapshot(&context, vec![], &accumulator)
437 .await
438 .unwrap();
439
440 if let VariableValue::List(list) = result {
441 assert_eq!(list.len(), 4, "Should collect values of different types");
442 assert_eq!(list[0], VariableValue::Integer(42.into()));
443 assert_eq!(list[1], VariableValue::Float(3.125.into()));
444 assert_eq!(list[2], VariableValue::String("hello".into()));
445 assert_eq!(list[3], VariableValue::Bool(true));
446 } else {
447 panic!("Expected list result");
448 }
449 }
450
451 #[tokio::test]
452 async fn test_collect_revert_multiple() {
453 let collect = Collect {};
454 let index = Arc::new(InMemoryResultIndex::new());
455 let variables = QueryVariables::new();
456 let context =
457 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
458 let expression = ast::FunctionExpression {
459 name: "collect".into(),
460 args: vec![],
461 position_in_query: 10,
462 };
463
464 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
465
466 let val1 = VariableValue::Integer(1.into());
468 let val2 = VariableValue::Integer(2.into());
469
470 let _ = collect
471 .apply(&context, vec![val1.clone()], &mut accumulator)
472 .await
473 .unwrap();
474 let _ = collect
475 .apply(&context, vec![val2.clone()], &mut accumulator)
476 .await
477 .unwrap();
478 let _ = collect
479 .apply(&context, vec![val1.clone()], &mut accumulator)
480 .await
481 .unwrap();
482 let _ = collect
483 .apply(&context, vec![val2.clone()], &mut accumulator)
484 .await
485 .unwrap();
486
487 let _ = collect
489 .revert(&context, vec![val1.clone()], &mut accumulator)
490 .await
491 .unwrap();
492
493 let result = collect
494 .snapshot(&context, vec![], &accumulator)
495 .await
496 .unwrap();
497
498 if let VariableValue::List(list) = result {
499 assert_eq!(list.len(), 3, "Should have removed only first occurrence");
500 assert_eq!(list[0], val2); assert_eq!(list[1], val1); assert_eq!(list[2], val2);
503 } else {
504 panic!("Expected list result");
505 }
506 }
507
508 #[tokio::test]
509 async fn test_collect_revert_nonexistent() {
510 let collect = Collect {};
511 let index = Arc::new(InMemoryResultIndex::new());
512 let variables = QueryVariables::new();
513 let context =
514 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
515 let expression = ast::FunctionExpression {
516 name: "collect".into(),
517 args: vec![],
518 position_in_query: 10,
519 };
520
521 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
522
523 let val1 = VariableValue::Integer(1.into());
524 let val2 = VariableValue::Integer(2.into());
525
526 let _ = collect
527 .apply(&context, vec![val1.clone()], &mut accumulator)
528 .await
529 .unwrap();
530
531 let _ = collect
533 .revert(&context, vec![val2.clone()], &mut accumulator)
534 .await
535 .unwrap();
536
537 let result = collect
538 .snapshot(&context, vec![], &accumulator)
539 .await
540 .unwrap();
541
542 if let VariableValue::List(list) = result {
543 assert_eq!(
544 list.len(),
545 1,
546 "Should not affect list if value doesn't exist"
547 );
548 assert_eq!(list[0], val1);
549 } else {
550 panic!("Expected list result");
551 }
552 }
553
554 #[tokio::test]
555 async fn test_collect_error_cases() {
556 let collect = Collect {};
557 let index = Arc::new(InMemoryResultIndex::new());
558 let variables = QueryVariables::new();
559 let context =
560 ExpressionEvaluationContext::new(&variables, Arc::new(InstantQueryClock::new(0, 0)));
561 let expression = ast::FunctionExpression {
562 name: "collect".into(),
563 args: vec![],
564 position_in_query: 10,
565 };
566
567 let mut accumulator = collect.initialize_accumulator(&context, &expression, &vec![], index);
568
569 let result = collect.apply(&context, vec![], &mut accumulator).await;
571 assert!(result.is_err(), "Should error with no arguments");
572
573 let result = collect
574 .apply(
575 &context,
576 vec![
577 VariableValue::Integer(1.into()),
578 VariableValue::Integer(2.into()),
579 ],
580 &mut accumulator,
581 )
582 .await;
583 assert!(result.is_err(), "Should error with too many arguments");
584 }
585}