1use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44 match sv {
45 ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46 ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47 ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48 ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49 ScalarResult::OptionalFloat64(None)
50 }
51 ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52 ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54 ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55 ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56 ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57 ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58 ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59 ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61 _ => ScalarResult::Null,
62 }
63}
64
65#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68 match sr {
69 ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70 ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71 ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72 ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73 ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74 ScalarResult::Null => ScalarValue::Null,
75 }
76}
77
78pub struct DataFusionAccumulatorAdapter {
87 inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89 column_indices: Vec<usize>,
91 input_types: Vec<DataType>,
93 function_name: String,
95}
96
97unsafe impl Send for DataFusionAccumulatorAdapter {}
100
101impl std::fmt::Debug for DataFusionAccumulatorAdapter {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("DataFusionAccumulatorAdapter")
104 .field("function_name", &self.function_name)
105 .field("column_indices", &self.column_indices)
106 .field("input_types", &self.input_types)
107 .finish_non_exhaustive()
108 }
109}
110
111impl DataFusionAccumulatorAdapter {
112 #[must_use]
114 pub fn new(
115 inner: Box<dyn datafusion_expr::Accumulator>,
116 column_indices: Vec<usize>,
117 input_types: Vec<DataType>,
118 function_name: String,
119 ) -> Self {
120 Self {
121 inner: RefCell::new(inner),
122 column_indices,
123 input_types,
124 function_name,
125 }
126 }
127
128 #[must_use]
130 pub fn function_name(&self) -> &str {
131 &self.function_name
132 }
133
134 fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
136 self.column_indices
137 .iter()
138 .enumerate()
139 .map(|(arg_idx, &col_idx)| {
140 if col_idx < batch.num_columns() {
141 Arc::clone(batch.column(col_idx))
142 } else {
143 let dt = self
144 .input_types
145 .get(arg_idx)
146 .cloned()
147 .unwrap_or(DataType::Int64);
148 arrow_array::new_null_array(&dt, batch.num_rows())
149 }
150 })
151 .collect()
152 }
153}
154
155impl DynAccumulator for DataFusionAccumulatorAdapter {
156 fn add_event(&mut self, event: &Event) {
157 let columns = self.extract_columns(&event.data);
158 let _ = self.inner.borrow_mut().update_batch(&columns);
159 }
160
161 fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
162 let other = other
163 .as_any()
164 .downcast_ref::<DataFusionAccumulatorAdapter>()
165 .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
166
167 if let Ok(state_values) = other.inner.borrow_mut().state() {
168 let state_arrays: Vec<ArrayRef> = state_values
169 .iter()
170 .filter_map(|sv| sv.to_array().ok())
171 .collect();
172 if !state_arrays.is_empty() {
173 let _ = self.inner.borrow_mut().merge_batch(&state_arrays);
174 }
175 }
176 }
177
178 fn result_scalar(&self) -> ScalarResult {
179 match self.inner.borrow_mut().evaluate() {
180 Ok(sv) => scalar_value_to_result(&sv),
181 Err(_) => ScalarResult::Null,
182 }
183 }
184
185 fn is_empty(&self) -> bool {
186 self.inner.borrow().size() <= std::mem::size_of::<Self>()
187 }
188
189 fn clone_box(&self) -> Box<dyn DynAccumulator> {
190 panic!(
191 "clone_box not supported for DataFusion adapter '{}'; \
192 use the factory to create new accumulators",
193 self.function_name
194 )
195 }
196
197 #[allow(clippy::cast_possible_truncation)] fn serialize(&self) -> Vec<u8> {
199 match self.inner.borrow_mut().state() {
200 Ok(state_values) => {
201 let mut buf = Vec::new();
202 let count = state_values.len() as u32;
203 buf.extend_from_slice(&count.to_le_bytes());
204 for sv in &state_values {
205 let bytes = sv.to_string();
206 let len = bytes.len() as u32;
207 buf.extend_from_slice(&len.to_le_bytes());
208 buf.extend_from_slice(bytes.as_bytes());
209 }
210 buf
211 }
212 Err(_) => Vec::new(),
213 }
214 }
215
216 fn result_field(&self) -> Field {
217 let result = self.result_scalar();
218 let dt = result.data_type();
219 let dt = if dt == DataType::Null {
220 DataType::Float64
221 } else {
222 dt
223 };
224 Field::new(&self.function_name, dt, true)
225 }
226
227 fn type_tag(&self) -> &'static str {
228 "datafusion_adapter"
229 }
230
231 fn as_any(&self) -> &dyn std::any::Any {
232 self
233 }
234}
235
236pub struct DataFusionAggregateFactory {
243 udf: Arc<AggregateUDF>,
245 column_indices: Vec<usize>,
247 input_types: Vec<DataType>,
249}
250
251impl std::fmt::Debug for DataFusionAggregateFactory {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("DataFusionAggregateFactory")
254 .field("name", &self.udf.name())
255 .field("column_indices", &self.column_indices)
256 .field("input_types", &self.input_types)
257 .finish()
258 }
259}
260
261impl DataFusionAggregateFactory {
262 #[must_use]
264 pub fn new(
265 udf: Arc<AggregateUDF>,
266 column_indices: Vec<usize>,
267 input_types: Vec<DataType>,
268 ) -> Self {
269 Self {
270 udf,
271 column_indices,
272 input_types,
273 }
274 }
275
276 #[must_use]
278 pub fn name(&self) -> &str {
279 self.udf.name()
280 }
281
282 fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
284 let return_type = self
285 .udf
286 .return_type(&self.input_types)
287 .unwrap_or(DataType::Float64);
288 let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
289 let schema = Schema::new(
290 self.input_types
291 .iter()
292 .enumerate()
293 .map(|(i, dt)| Field::new(format!("col_{i}"), dt.clone(), true))
294 .collect::<Vec<_>>(),
295 );
296 let expr_fields: Vec<FieldRef> = self
297 .input_types
298 .iter()
299 .enumerate()
300 .map(|(i, dt)| Arc::new(Field::new(format!("col_{i}"), dt.clone(), true)) as FieldRef)
301 .collect();
302 let args = AccumulatorArgs {
303 return_field,
304 schema: &schema,
305 ignore_nulls: false,
306 order_bys: &[],
307 is_reversed: false,
308 name: self.udf.name(),
309 is_distinct: false,
310 exprs: &[],
311 expr_fields: &expr_fields,
312 };
313 self.udf
314 .accumulator(args)
315 .expect("Failed to create DataFusion accumulator")
316 }
317}
318
319impl DynAggregatorFactory for DataFusionAggregateFactory {
320 fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
321 let inner = self.create_df_accumulator();
322 Box::new(DataFusionAccumulatorAdapter::new(
323 inner,
324 self.column_indices.clone(),
325 self.input_types.clone(),
326 self.udf.name().to_string(),
327 ))
328 }
329
330 fn result_field(&self) -> Field {
331 let return_type = self
332 .udf
333 .return_type(&self.input_types)
334 .unwrap_or(DataType::Float64);
335 Field::new(self.udf.name(), return_type, true)
336 }
337
338 fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
339 Box::new(DataFusionAggregateFactory {
340 udf: Arc::clone(&self.udf),
341 column_indices: self.column_indices.clone(),
342 input_types: self.input_types.clone(),
343 })
344 }
345
346 fn type_tag(&self) -> &'static str {
347 "datafusion_factory"
348 }
349}
350
351#[must_use]
357pub fn lookup_aggregate_udf(
358 ctx: &datafusion::prelude::SessionContext,
359 name: &str,
360) -> Option<Arc<AggregateUDF>> {
361 let normalized = name.to_lowercase();
362 ctx.udaf(&normalized).ok()
363}
364
365#[must_use]
369pub fn create_aggregate_factory(
370 ctx: &datafusion::prelude::SessionContext,
371 name: &str,
372 column_indices: Vec<usize>,
373 input_types: Vec<DataType>,
374) -> Option<DataFusionAggregateFactory> {
375 lookup_aggregate_udf(ctx, name)
376 .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
377}
378
379#[cfg(test)]
382mod tests {
383 use super::*;
384 use arrow_array::{Float64Array, Int64Array, RecordBatch};
385 use datafusion::prelude::SessionContext;
386
387 fn float_event(ts: i64, values: Vec<f64>) -> Event {
388 let schema = Arc::new(Schema::new(vec![Field::new(
389 "value",
390 DataType::Float64,
391 false,
392 )]));
393 let batch =
394 RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
395 Event::new(ts, batch)
396 }
397
398 fn int_event(ts: i64, values: Vec<i64>) -> Event {
399 let schema = Arc::new(Schema::new(vec![Field::new(
400 "value",
401 DataType::Int64,
402 false,
403 )]));
404 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
405 Event::new(ts, batch)
406 }
407
408 fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
409 let schema = Arc::new(Schema::new(vec![
410 Field::new("x", DataType::Float64, false),
411 Field::new("y", DataType::Float64, false),
412 ]));
413 let batch = RecordBatch::try_new(
414 schema,
415 vec![
416 Arc::new(Float64Array::from(col0)),
417 Arc::new(Float64Array::from(col1)),
418 ],
419 )
420 .unwrap();
421 Event::new(ts, batch)
422 }
423
424 #[test]
427 fn test_scalar_value_to_result_int64() {
428 let sv = ScalarValue::Int64(Some(42));
429 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
430 }
431
432 #[test]
433 fn test_scalar_value_to_result_float64() {
434 let sv = ScalarValue::Float64(Some(3.125));
435 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
436 }
437
438 #[test]
439 fn test_scalar_value_to_result_uint64() {
440 let sv = ScalarValue::UInt64(Some(100));
441 assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
442 }
443
444 #[test]
445 fn test_scalar_value_to_result_null_int64() {
446 let sv = ScalarValue::Int64(None);
447 assert_eq!(
448 scalar_value_to_result(&sv),
449 ScalarResult::OptionalInt64(None)
450 );
451 }
452
453 #[test]
454 fn test_scalar_value_to_result_null_float64() {
455 let sv = ScalarValue::Float64(None);
456 assert_eq!(
457 scalar_value_to_result(&sv),
458 ScalarResult::OptionalFloat64(None)
459 );
460 }
461
462 #[test]
463 fn test_scalar_value_to_result_smaller_ints() {
464 assert_eq!(
465 scalar_value_to_result(&ScalarValue::Int8(Some(8))),
466 ScalarResult::Int64(8)
467 );
468 assert_eq!(
469 scalar_value_to_result(&ScalarValue::Int16(Some(16))),
470 ScalarResult::Int64(16)
471 );
472 assert_eq!(
473 scalar_value_to_result(&ScalarValue::Int32(Some(32))),
474 ScalarResult::Int64(32)
475 );
476 assert_eq!(
477 scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
478 ScalarResult::UInt64(8)
479 );
480 }
481
482 #[test]
483 fn test_scalar_value_to_result_float32() {
484 let sv = ScalarValue::Float32(Some(2.5));
485 assert_eq!(
486 scalar_value_to_result(&sv),
487 ScalarResult::Float64(f64::from(2.5f32))
488 );
489 }
490
491 #[test]
492 fn test_scalar_value_to_result_unsupported() {
493 let sv = ScalarValue::Utf8(Some("hello".to_string()));
494 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
495 }
496
497 #[test]
498 fn test_result_to_scalar_value_roundtrip() {
499 let exact_cases = vec![
501 ScalarResult::Int64(42),
502 ScalarResult::Float64(3.125),
503 ScalarResult::UInt64(100),
504 ];
505 for sr in &exact_cases {
506 let sv = result_to_scalar_value(sr);
507 let back = scalar_value_to_result(&sv);
508 assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
509 }
510
511 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
514 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
515
516 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
517 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
518
519 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
521 assert_eq!(
522 scalar_value_to_result(&sv),
523 ScalarResult::OptionalInt64(None)
524 );
525
526 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
527 assert_eq!(
528 scalar_value_to_result(&sv),
529 ScalarResult::OptionalFloat64(None)
530 );
531
532 let sv = result_to_scalar_value(&ScalarResult::Null);
534 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
535 }
536
537 #[test]
540 fn test_factory_count() {
541 let ctx = SessionContext::new();
542 let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
543 assert!(factory.is_some(), "count should be a recognized aggregate");
544 assert_eq!(factory.unwrap().name(), "count");
545 }
546
547 #[test]
548 fn test_factory_sum() {
549 let ctx = SessionContext::new();
550 let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
551 assert!(factory.is_some());
552 assert_eq!(factory.unwrap().name(), "sum");
553 }
554
555 #[test]
556 fn test_factory_avg() {
557 let ctx = SessionContext::new();
558 let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
559 assert!(factory.is_some());
560 }
561
562 #[test]
563 fn test_factory_stddev() {
564 let ctx = SessionContext::new();
565 let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
566 assert!(
567 factory.is_some(),
568 "stddev should be available in DataFusion"
569 );
570 }
571
572 #[test]
573 fn test_factory_unknown() {
574 let ctx = SessionContext::new();
575 let factory = create_aggregate_factory(
576 &ctx,
577 "nonexistent_aggregate_xyz",
578 vec![0],
579 vec![DataType::Int64],
580 );
581 assert!(factory.is_none());
582 }
583
584 #[test]
585 fn test_factory_result_field() {
586 let ctx = SessionContext::new();
587 let factory =
588 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
589 let field = factory.result_field();
590 assert_eq!(field.name(), "sum");
591 assert_eq!(field.data_type(), &DataType::Float64);
592 }
593
594 #[test]
595 fn test_factory_clone_box() {
596 let ctx = SessionContext::new();
597 let factory =
598 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
599 let cloned = factory.clone_box();
600 assert_eq!(cloned.type_tag(), "datafusion_factory");
601 }
602
603 #[test]
606 fn test_adapter_count_basic() {
607 let ctx = SessionContext::new();
608 let factory =
609 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
610 let mut acc = factory.create_accumulator();
611
612 let result = acc.result_scalar();
613 assert!(
614 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
615 "Expected 0, got {result:?}"
616 );
617
618 acc.add_event(&int_event(1000, vec![10, 20, 30]));
619 let result = acc.result_scalar();
620 assert!(
621 matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
622 "Expected 3, got {result:?}"
623 );
624
625 acc.add_event(&int_event(2000, vec![40, 50]));
626 let result = acc.result_scalar();
627 assert!(
628 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
629 "Expected 5, got {result:?}"
630 );
631 }
632
633 #[test]
634 fn test_adapter_sum_float64() {
635 let ctx = SessionContext::new();
636 let factory =
637 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
638 let mut acc = factory.create_accumulator();
639
640 acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
641 assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
642 }
643
644 #[test]
645 fn test_adapter_avg_float64() {
646 let ctx = SessionContext::new();
647 let factory =
648 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
649 let mut acc = factory.create_accumulator();
650
651 acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
652 assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
653 }
654
655 #[test]
656 fn test_adapter_min_float64() {
657 let ctx = SessionContext::new();
658 let factory =
659 create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
660 let mut acc = factory.create_accumulator();
661
662 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
663 assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
664 }
665
666 #[test]
667 fn test_adapter_max_float64() {
668 let ctx = SessionContext::new();
669 let factory =
670 create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
671 let mut acc = factory.create_accumulator();
672
673 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
674 assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
675 }
676
677 #[test]
678 fn test_adapter_sum_int64() {
679 let ctx = SessionContext::new();
680 let factory =
681 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
682 let mut acc = factory.create_accumulator();
683
684 acc.add_event(&int_event(1000, vec![10, 20, 30]));
685 assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
686 }
687
688 #[test]
689 fn test_adapter_type_tag() {
690 let ctx = SessionContext::new();
691 let factory =
692 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
693 let acc = factory.create_accumulator();
694 assert_eq!(acc.type_tag(), "datafusion_adapter");
695 }
696
697 #[test]
698 fn test_adapter_result_field() {
699 let ctx = SessionContext::new();
700 let factory =
701 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
702 let mut acc = factory.create_accumulator();
703 acc.add_event(&float_event(1000, vec![1.0]));
704 assert_eq!(acc.result_field().name(), "sum");
705 }
706
707 #[test]
710 fn test_adapter_merge_sum() {
711 let ctx = SessionContext::new();
712 let factory =
713 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
714
715 let mut acc1 = factory.create_accumulator();
716 acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
717
718 let mut acc2 = factory.create_accumulator();
719 acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
720
721 acc1.merge_dyn(acc2.as_ref());
722 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
723 }
724
725 #[test]
726 fn test_adapter_merge_count() {
727 let ctx = SessionContext::new();
728 let factory =
729 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
730
731 let mut acc1 = factory.create_accumulator();
732 acc1.add_event(&int_event(1000, vec![1, 2, 3]));
733
734 let mut acc2 = factory.create_accumulator();
735 acc2.add_event(&int_event(2000, vec![4, 5]));
736
737 acc1.merge_dyn(acc2.as_ref());
738 let result = acc1.result_scalar();
739 assert!(
740 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
741 "Expected 5 after merge, got {result:?}"
742 );
743 }
744
745 #[test]
746 fn test_adapter_merge_avg() {
747 let ctx = SessionContext::new();
748 let factory =
749 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
750
751 let mut acc1 = factory.create_accumulator();
752 acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
753
754 let mut acc2 = factory.create_accumulator();
755 acc2.add_event(&float_event(2000, vec![30.0]));
756
757 acc1.merge_dyn(acc2.as_ref());
758 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
759 }
760
761 #[test]
762 fn test_adapter_merge_empty() {
763 let ctx = SessionContext::new();
764 let factory =
765 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
766
767 let mut acc1 = factory.create_accumulator();
768 acc1.add_event(&float_event(1000, vec![5.0]));
769
770 let acc2 = factory.create_accumulator();
771 acc1.merge_dyn(acc2.as_ref());
772 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
773 }
774
775 #[test]
778 fn test_adapter_stddev() {
779 let ctx = SessionContext::new();
780 let factory =
781 create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
782 let mut acc = factory.create_accumulator();
783
784 acc.add_event(&float_event(
785 1000,
786 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
787 ));
788 let result = acc.result_scalar();
789 if let ScalarResult::Float64(v) = result {
790 assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
791 } else {
792 panic!("Expected Float64 result, got {result:?}");
793 }
794 }
795
796 #[test]
797 fn test_adapter_variance() {
798 let ctx = SessionContext::new();
799 if let Some(factory) =
800 create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
801 {
802 let mut acc = factory.create_accumulator();
803 acc.add_event(&float_event(
804 1000,
805 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
806 ));
807 if let ScalarResult::Float64(v) = acc.result_scalar() {
808 assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
809 }
810 }
811 }
812
813 #[test]
814 fn test_adapter_median() {
815 let ctx = SessionContext::new();
816 if let Some(factory) =
817 create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
818 {
819 let mut acc = factory.create_accumulator();
820 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
821 assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
822 }
823 }
824
825 #[test]
828 fn test_adapter_serialize() {
829 let ctx = SessionContext::new();
830 let factory =
831 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
832 let mut acc = factory.create_accumulator();
833 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
834 assert!(!acc.serialize().is_empty());
835 }
836
837 #[test]
838 fn test_adapter_serialize_empty() {
839 let ctx = SessionContext::new();
840 let factory =
841 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
842 let acc = factory.create_accumulator();
843 assert!(!acc.serialize().is_empty());
844 }
845
846 #[test]
849 fn test_lookup_common_aggregates() {
850 let ctx = SessionContext::new();
851 for name in &["count", "sum", "min", "max", "avg"] {
852 assert!(
853 lookup_aggregate_udf(&ctx, name).is_some(),
854 "Expected '{name}' to be a recognized aggregate"
855 );
856 }
857 }
858
859 #[test]
860 fn test_lookup_statistical_aggregates() {
861 let ctx = SessionContext::new();
862 for name in &["stddev", "stddev_pop", "median"] {
863 let _ = lookup_aggregate_udf(&ctx, name);
865 }
866 }
867
868 #[test]
869 fn test_lookup_case_insensitive() {
870 let ctx = SessionContext::new();
871 assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
872 assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
873 assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
874 }
875
876 #[test]
879 fn test_adapter_multi_column_covar() {
880 let ctx = SessionContext::new();
881 if let Some(factory) = create_aggregate_factory(
882 &ctx,
883 "covar_samp",
884 vec![0, 1],
885 vec![DataType::Float64, DataType::Float64],
886 ) {
887 let mut acc = factory.create_accumulator();
888 acc.add_event(&two_col_float_event(
889 1000,
890 vec![1.0, 2.0, 3.0, 4.0, 5.0],
891 vec![1.0, 2.0, 3.0, 4.0, 5.0],
892 ));
893 if let ScalarResult::Float64(v) = acc.result_scalar() {
894 assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
895 }
896 }
897 }
898
899 #[test]
902 fn test_create_aggregate_factory_api() {
903 let ctx = SessionContext::new();
904 let factory =
905 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
906 let acc = factory.create_accumulator();
907 assert_eq!(acc.type_tag(), "datafusion_adapter");
908 }
909
910 #[test]
911 fn test_factory_creates_independent_accumulators() {
912 let ctx = SessionContext::new();
913 let factory =
914 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
915
916 let mut acc1 = factory.create_accumulator();
917 let mut acc2 = factory.create_accumulator();
918
919 acc1.add_event(&float_event(1000, vec![10.0]));
920 acc2.add_event(&float_event(2000, vec![20.0]));
921
922 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
923 assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
924 }
925
926 #[test]
927 fn test_adapter_function_name() {
928 let ctx = SessionContext::new();
929 let factory =
930 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
931 let acc = factory.create_accumulator();
932 let adapter = acc
933 .as_any()
934 .downcast_ref::<DataFusionAccumulatorAdapter>()
935 .expect("should be adapter");
936 assert_eq!(adapter.function_name(), "sum");
937 }
938}