1use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44 match sv {
45 ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46 ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47 ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48 ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49 ScalarResult::OptionalFloat64(None)
50 }
51 ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52 ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54 ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55 ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56 ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57 ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58 ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59 ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61 _ => ScalarResult::Null,
62 }
63}
64
65#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68 match sr {
69 ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70 ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71 ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72 ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73 ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74 ScalarResult::Null => ScalarValue::Null,
75 }
76}
77
78pub struct DataFusionAccumulatorAdapter {
87 inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89 column_indices: Vec<usize>,
91 input_types: Vec<DataType>,
93 function_name: String,
95}
96
97unsafe impl Send for DataFusionAccumulatorAdapter {}
100
101impl std::fmt::Debug for DataFusionAccumulatorAdapter {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("DataFusionAccumulatorAdapter")
104 .field("function_name", &self.function_name)
105 .field("column_indices", &self.column_indices)
106 .field("input_types", &self.input_types)
107 .finish_non_exhaustive()
108 }
109}
110
111impl DataFusionAccumulatorAdapter {
112 #[must_use]
114 pub fn new(
115 inner: Box<dyn datafusion_expr::Accumulator>,
116 column_indices: Vec<usize>,
117 input_types: Vec<DataType>,
118 function_name: String,
119 ) -> Self {
120 Self {
121 inner: RefCell::new(inner),
122 column_indices,
123 input_types,
124 function_name,
125 }
126 }
127
128 #[must_use]
130 pub fn function_name(&self) -> &str {
131 &self.function_name
132 }
133
134 fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
136 self.column_indices
137 .iter()
138 .enumerate()
139 .map(|(arg_idx, &col_idx)| {
140 if col_idx < batch.num_columns() {
141 Arc::clone(batch.column(col_idx))
142 } else {
143 let dt = self
144 .input_types
145 .get(arg_idx)
146 .cloned()
147 .unwrap_or(DataType::Int64);
148 arrow_array::new_null_array(&dt, batch.num_rows())
149 }
150 })
151 .collect()
152 }
153}
154
155impl DynAccumulator for DataFusionAccumulatorAdapter {
156 fn add_event(&mut self, event: &Event) {
157 let columns = self.extract_columns(&event.data);
158 let _ = self.inner.borrow_mut().update_batch(&columns);
159 }
160
161 fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
162 let other = other
163 .as_any()
164 .downcast_ref::<DataFusionAccumulatorAdapter>()
165 .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
166
167 if let Ok(state_values) = other.inner.borrow_mut().state() {
168 let state_arrays: Vec<ArrayRef> = state_values
169 .iter()
170 .filter_map(|sv| sv.to_array().ok())
171 .collect();
172 if !state_arrays.is_empty() {
173 let _ = self.inner.borrow_mut().merge_batch(&state_arrays);
174 }
175 }
176 }
177
178 fn result_scalar(&self) -> ScalarResult {
179 match self.inner.borrow_mut().evaluate() {
180 Ok(sv) => scalar_value_to_result(&sv),
181 Err(_) => ScalarResult::Null,
182 }
183 }
184
185 fn is_empty(&self) -> bool {
186 self.inner.borrow().size() <= std::mem::size_of::<Self>()
187 }
188
189 fn clone_box(&self) -> Box<dyn DynAccumulator> {
190 panic!(
191 "clone_box not supported for DataFusion adapter '{}'; \
192 use the factory to create new accumulators",
193 self.function_name
194 )
195 }
196
197 #[allow(clippy::cast_possible_truncation)] fn serialize(&self) -> Vec<u8> {
199 match self.inner.borrow_mut().state() {
200 Ok(state_values) => {
201 let mut buf = Vec::new();
202 let count = state_values.len() as u32;
203 buf.extend_from_slice(&count.to_le_bytes());
204 for sv in &state_values {
205 let bytes = sv.to_string();
206 let len = bytes.len() as u32;
207 buf.extend_from_slice(&len.to_le_bytes());
208 buf.extend_from_slice(bytes.as_bytes());
209 }
210 buf
211 }
212 Err(_) => Vec::new(),
213 }
214 }
215
216 fn result_field(&self) -> Field {
217 let result = self.result_scalar();
218 let dt = result.data_type();
219 let dt = if dt == DataType::Null {
220 DataType::Float64
221 } else {
222 dt
223 };
224 Field::new(&self.function_name, dt, true)
225 }
226
227 fn type_tag(&self) -> &'static str {
228 "datafusion_adapter"
229 }
230
231 fn as_any(&self) -> &dyn std::any::Any {
232 self
233 }
234}
235
236pub struct DataFusionAggregateFactory {
243 udf: Arc<AggregateUDF>,
245 column_indices: Vec<usize>,
247 input_types: Vec<DataType>,
249}
250
251impl std::fmt::Debug for DataFusionAggregateFactory {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("DataFusionAggregateFactory")
254 .field("name", &self.udf.name())
255 .field("column_indices", &self.column_indices)
256 .field("input_types", &self.input_types)
257 .finish()
258 }
259}
260
261impl DataFusionAggregateFactory {
262 #[must_use]
264 pub fn new(
265 udf: Arc<AggregateUDF>,
266 column_indices: Vec<usize>,
267 input_types: Vec<DataType>,
268 ) -> Self {
269 Self {
270 udf,
271 column_indices,
272 input_types,
273 }
274 }
275
276 #[must_use]
278 pub fn name(&self) -> &str {
279 self.udf.name()
280 }
281
282 const COL_NAMES: [&str; 8] = [
284 "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
285 ];
286
287 fn col_name(i: usize) -> &'static str {
289 Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
290 }
291
292 fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
294 let return_type = self
295 .udf
296 .return_type(&self.input_types)
297 .unwrap_or(DataType::Float64);
298 let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
299 let schema = Schema::new(
300 self.input_types
301 .iter()
302 .enumerate()
303 .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
304 .collect::<Vec<_>>(),
305 );
306 let expr_fields: Vec<FieldRef> = self
307 .input_types
308 .iter()
309 .enumerate()
310 .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
311 .collect();
312 let args = AccumulatorArgs {
313 return_field,
314 schema: &schema,
315 ignore_nulls: false,
316 order_bys: &[],
317 is_reversed: false,
318 name: self.udf.name(),
319 is_distinct: false,
320 exprs: &[],
321 expr_fields: &expr_fields,
322 };
323 self.udf
324 .accumulator(args)
325 .expect("Failed to create DataFusion accumulator")
326 }
327}
328
329impl DynAggregatorFactory for DataFusionAggregateFactory {
330 fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
331 let inner = self.create_df_accumulator();
332 Box::new(DataFusionAccumulatorAdapter::new(
333 inner,
334 self.column_indices.clone(),
335 self.input_types.clone(),
336 self.udf.name().to_string(),
337 ))
338 }
339
340 fn result_field(&self) -> Field {
341 let return_type = self
342 .udf
343 .return_type(&self.input_types)
344 .unwrap_or(DataType::Float64);
345 Field::new(self.udf.name(), return_type, true)
346 }
347
348 fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
349 Box::new(DataFusionAggregateFactory {
350 udf: Arc::clone(&self.udf),
351 column_indices: self.column_indices.clone(),
352 input_types: self.input_types.clone(),
353 })
354 }
355
356 fn type_tag(&self) -> &'static str {
357 "datafusion_factory"
358 }
359}
360
361#[must_use]
367pub fn lookup_aggregate_udf(
368 ctx: &datafusion::prelude::SessionContext,
369 name: &str,
370) -> Option<Arc<AggregateUDF>> {
371 let normalized = name.to_lowercase();
372 ctx.udaf(&normalized).ok()
373}
374
375#[must_use]
379pub fn create_aggregate_factory(
380 ctx: &datafusion::prelude::SessionContext,
381 name: &str,
382 column_indices: Vec<usize>,
383 input_types: Vec<DataType>,
384) -> Option<DataFusionAggregateFactory> {
385 lookup_aggregate_udf(ctx, name)
386 .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
387}
388
389#[cfg(test)]
392mod tests {
393 use super::*;
394 use arrow_array::{Float64Array, Int64Array, RecordBatch};
395 use datafusion::prelude::SessionContext;
396
397 fn float_event(ts: i64, values: Vec<f64>) -> Event {
398 let schema = Arc::new(Schema::new(vec![Field::new(
399 "value",
400 DataType::Float64,
401 false,
402 )]));
403 let batch =
404 RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
405 Event::new(ts, batch)
406 }
407
408 fn int_event(ts: i64, values: Vec<i64>) -> Event {
409 let schema = Arc::new(Schema::new(vec![Field::new(
410 "value",
411 DataType::Int64,
412 false,
413 )]));
414 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
415 Event::new(ts, batch)
416 }
417
418 fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
419 let schema = Arc::new(Schema::new(vec![
420 Field::new("x", DataType::Float64, false),
421 Field::new("y", DataType::Float64, false),
422 ]));
423 let batch = RecordBatch::try_new(
424 schema,
425 vec![
426 Arc::new(Float64Array::from(col0)),
427 Arc::new(Float64Array::from(col1)),
428 ],
429 )
430 .unwrap();
431 Event::new(ts, batch)
432 }
433
434 #[test]
437 fn test_scalar_value_to_result_int64() {
438 let sv = ScalarValue::Int64(Some(42));
439 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
440 }
441
442 #[test]
443 fn test_scalar_value_to_result_float64() {
444 let sv = ScalarValue::Float64(Some(3.125));
445 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
446 }
447
448 #[test]
449 fn test_scalar_value_to_result_uint64() {
450 let sv = ScalarValue::UInt64(Some(100));
451 assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
452 }
453
454 #[test]
455 fn test_scalar_value_to_result_null_int64() {
456 let sv = ScalarValue::Int64(None);
457 assert_eq!(
458 scalar_value_to_result(&sv),
459 ScalarResult::OptionalInt64(None)
460 );
461 }
462
463 #[test]
464 fn test_scalar_value_to_result_null_float64() {
465 let sv = ScalarValue::Float64(None);
466 assert_eq!(
467 scalar_value_to_result(&sv),
468 ScalarResult::OptionalFloat64(None)
469 );
470 }
471
472 #[test]
473 fn test_scalar_value_to_result_smaller_ints() {
474 assert_eq!(
475 scalar_value_to_result(&ScalarValue::Int8(Some(8))),
476 ScalarResult::Int64(8)
477 );
478 assert_eq!(
479 scalar_value_to_result(&ScalarValue::Int16(Some(16))),
480 ScalarResult::Int64(16)
481 );
482 assert_eq!(
483 scalar_value_to_result(&ScalarValue::Int32(Some(32))),
484 ScalarResult::Int64(32)
485 );
486 assert_eq!(
487 scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
488 ScalarResult::UInt64(8)
489 );
490 }
491
492 #[test]
493 fn test_scalar_value_to_result_float32() {
494 let sv = ScalarValue::Float32(Some(2.5));
495 assert_eq!(
496 scalar_value_to_result(&sv),
497 ScalarResult::Float64(f64::from(2.5f32))
498 );
499 }
500
501 #[test]
502 fn test_scalar_value_to_result_unsupported() {
503 let sv = ScalarValue::Utf8(Some("hello".to_string()));
504 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
505 }
506
507 #[test]
508 fn test_result_to_scalar_value_roundtrip() {
509 let exact_cases = vec![
511 ScalarResult::Int64(42),
512 ScalarResult::Float64(3.125),
513 ScalarResult::UInt64(100),
514 ];
515 for sr in &exact_cases {
516 let sv = result_to_scalar_value(sr);
517 let back = scalar_value_to_result(&sv);
518 assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
519 }
520
521 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
524 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
525
526 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
527 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
528
529 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
531 assert_eq!(
532 scalar_value_to_result(&sv),
533 ScalarResult::OptionalInt64(None)
534 );
535
536 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
537 assert_eq!(
538 scalar_value_to_result(&sv),
539 ScalarResult::OptionalFloat64(None)
540 );
541
542 let sv = result_to_scalar_value(&ScalarResult::Null);
544 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
545 }
546
547 #[test]
550 fn test_factory_count() {
551 let ctx = SessionContext::new();
552 let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
553 assert!(factory.is_some(), "count should be a recognized aggregate");
554 assert_eq!(factory.unwrap().name(), "count");
555 }
556
557 #[test]
558 fn test_factory_sum() {
559 let ctx = SessionContext::new();
560 let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
561 assert!(factory.is_some());
562 assert_eq!(factory.unwrap().name(), "sum");
563 }
564
565 #[test]
566 fn test_factory_avg() {
567 let ctx = SessionContext::new();
568 let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
569 assert!(factory.is_some());
570 }
571
572 #[test]
573 fn test_factory_stddev() {
574 let ctx = SessionContext::new();
575 let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
576 assert!(
577 factory.is_some(),
578 "stddev should be available in DataFusion"
579 );
580 }
581
582 #[test]
583 fn test_factory_unknown() {
584 let ctx = SessionContext::new();
585 let factory = create_aggregate_factory(
586 &ctx,
587 "nonexistent_aggregate_xyz",
588 vec![0],
589 vec![DataType::Int64],
590 );
591 assert!(factory.is_none());
592 }
593
594 #[test]
595 fn test_factory_result_field() {
596 let ctx = SessionContext::new();
597 let factory =
598 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
599 let field = factory.result_field();
600 assert_eq!(field.name(), "sum");
601 assert_eq!(field.data_type(), &DataType::Float64);
602 }
603
604 #[test]
605 fn test_factory_clone_box() {
606 let ctx = SessionContext::new();
607 let factory =
608 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
609 let cloned = factory.clone_box();
610 assert_eq!(cloned.type_tag(), "datafusion_factory");
611 }
612
613 #[test]
616 fn test_adapter_count_basic() {
617 let ctx = SessionContext::new();
618 let factory =
619 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
620 let mut acc = factory.create_accumulator();
621
622 let result = acc.result_scalar();
623 assert!(
624 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
625 "Expected 0, got {result:?}"
626 );
627
628 acc.add_event(&int_event(1000, vec![10, 20, 30]));
629 let result = acc.result_scalar();
630 assert!(
631 matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
632 "Expected 3, got {result:?}"
633 );
634
635 acc.add_event(&int_event(2000, vec![40, 50]));
636 let result = acc.result_scalar();
637 assert!(
638 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
639 "Expected 5, got {result:?}"
640 );
641 }
642
643 #[test]
644 fn test_adapter_sum_float64() {
645 let ctx = SessionContext::new();
646 let factory =
647 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
648 let mut acc = factory.create_accumulator();
649
650 acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
651 assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
652 }
653
654 #[test]
655 fn test_adapter_avg_float64() {
656 let ctx = SessionContext::new();
657 let factory =
658 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
659 let mut acc = factory.create_accumulator();
660
661 acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
662 assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
663 }
664
665 #[test]
666 fn test_adapter_min_float64() {
667 let ctx = SessionContext::new();
668 let factory =
669 create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
670 let mut acc = factory.create_accumulator();
671
672 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
673 assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
674 }
675
676 #[test]
677 fn test_adapter_max_float64() {
678 let ctx = SessionContext::new();
679 let factory =
680 create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
681 let mut acc = factory.create_accumulator();
682
683 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
684 assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
685 }
686
687 #[test]
688 fn test_adapter_sum_int64() {
689 let ctx = SessionContext::new();
690 let factory =
691 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
692 let mut acc = factory.create_accumulator();
693
694 acc.add_event(&int_event(1000, vec![10, 20, 30]));
695 assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
696 }
697
698 #[test]
699 fn test_adapter_type_tag() {
700 let ctx = SessionContext::new();
701 let factory =
702 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
703 let acc = factory.create_accumulator();
704 assert_eq!(acc.type_tag(), "datafusion_adapter");
705 }
706
707 #[test]
708 fn test_adapter_result_field() {
709 let ctx = SessionContext::new();
710 let factory =
711 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
712 let mut acc = factory.create_accumulator();
713 acc.add_event(&float_event(1000, vec![1.0]));
714 assert_eq!(acc.result_field().name(), "sum");
715 }
716
717 #[test]
720 fn test_adapter_merge_sum() {
721 let ctx = SessionContext::new();
722 let factory =
723 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
724
725 let mut acc1 = factory.create_accumulator();
726 acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
727
728 let mut acc2 = factory.create_accumulator();
729 acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
730
731 acc1.merge_dyn(acc2.as_ref());
732 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
733 }
734
735 #[test]
736 fn test_adapter_merge_count() {
737 let ctx = SessionContext::new();
738 let factory =
739 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
740
741 let mut acc1 = factory.create_accumulator();
742 acc1.add_event(&int_event(1000, vec![1, 2, 3]));
743
744 let mut acc2 = factory.create_accumulator();
745 acc2.add_event(&int_event(2000, vec![4, 5]));
746
747 acc1.merge_dyn(acc2.as_ref());
748 let result = acc1.result_scalar();
749 assert!(
750 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
751 "Expected 5 after merge, got {result:?}"
752 );
753 }
754
755 #[test]
756 fn test_adapter_merge_avg() {
757 let ctx = SessionContext::new();
758 let factory =
759 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
760
761 let mut acc1 = factory.create_accumulator();
762 acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
763
764 let mut acc2 = factory.create_accumulator();
765 acc2.add_event(&float_event(2000, vec![30.0]));
766
767 acc1.merge_dyn(acc2.as_ref());
768 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
769 }
770
771 #[test]
772 fn test_adapter_merge_empty() {
773 let ctx = SessionContext::new();
774 let factory =
775 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
776
777 let mut acc1 = factory.create_accumulator();
778 acc1.add_event(&float_event(1000, vec![5.0]));
779
780 let acc2 = factory.create_accumulator();
781 acc1.merge_dyn(acc2.as_ref());
782 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
783 }
784
785 #[test]
788 fn test_adapter_stddev() {
789 let ctx = SessionContext::new();
790 let factory =
791 create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
792 let mut acc = factory.create_accumulator();
793
794 acc.add_event(&float_event(
795 1000,
796 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
797 ));
798 let result = acc.result_scalar();
799 if let ScalarResult::Float64(v) = result {
800 assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
801 } else {
802 panic!("Expected Float64 result, got {result:?}");
803 }
804 }
805
806 #[test]
807 fn test_adapter_variance() {
808 let ctx = SessionContext::new();
809 if let Some(factory) =
810 create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
811 {
812 let mut acc = factory.create_accumulator();
813 acc.add_event(&float_event(
814 1000,
815 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
816 ));
817 if let ScalarResult::Float64(v) = acc.result_scalar() {
818 assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
819 }
820 }
821 }
822
823 #[test]
824 fn test_adapter_median() {
825 let ctx = SessionContext::new();
826 if let Some(factory) =
827 create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
828 {
829 let mut acc = factory.create_accumulator();
830 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
831 assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
832 }
833 }
834
835 #[test]
838 fn test_adapter_serialize() {
839 let ctx = SessionContext::new();
840 let factory =
841 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
842 let mut acc = factory.create_accumulator();
843 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
844 assert!(!acc.serialize().is_empty());
845 }
846
847 #[test]
848 fn test_adapter_serialize_empty() {
849 let ctx = SessionContext::new();
850 let factory =
851 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
852 let acc = factory.create_accumulator();
853 assert!(!acc.serialize().is_empty());
854 }
855
856 #[test]
859 fn test_lookup_common_aggregates() {
860 let ctx = SessionContext::new();
861 for name in &["count", "sum", "min", "max", "avg"] {
862 assert!(
863 lookup_aggregate_udf(&ctx, name).is_some(),
864 "Expected '{name}' to be a recognized aggregate"
865 );
866 }
867 }
868
869 #[test]
870 fn test_lookup_statistical_aggregates() {
871 let ctx = SessionContext::new();
872 for name in &["stddev", "stddev_pop", "median"] {
873 let _ = lookup_aggregate_udf(&ctx, name);
875 }
876 }
877
878 #[test]
879 fn test_lookup_case_insensitive() {
880 let ctx = SessionContext::new();
881 assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
882 assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
883 assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
884 }
885
886 #[test]
889 fn test_adapter_multi_column_covar() {
890 let ctx = SessionContext::new();
891 if let Some(factory) = create_aggregate_factory(
892 &ctx,
893 "covar_samp",
894 vec![0, 1],
895 vec![DataType::Float64, DataType::Float64],
896 ) {
897 let mut acc = factory.create_accumulator();
898 acc.add_event(&two_col_float_event(
899 1000,
900 vec![1.0, 2.0, 3.0, 4.0, 5.0],
901 vec![1.0, 2.0, 3.0, 4.0, 5.0],
902 ));
903 if let ScalarResult::Float64(v) = acc.result_scalar() {
904 assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
905 }
906 }
907 }
908
909 #[test]
912 fn test_create_aggregate_factory_api() {
913 let ctx = SessionContext::new();
914 let factory =
915 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
916 let acc = factory.create_accumulator();
917 assert_eq!(acc.type_tag(), "datafusion_adapter");
918 }
919
920 #[test]
921 fn test_factory_creates_independent_accumulators() {
922 let ctx = SessionContext::new();
923 let factory =
924 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
925
926 let mut acc1 = factory.create_accumulator();
927 let mut acc2 = factory.create_accumulator();
928
929 acc1.add_event(&float_event(1000, vec![10.0]));
930 acc2.add_event(&float_event(2000, vec![20.0]));
931
932 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
933 assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
934 }
935
936 #[test]
937 fn test_adapter_function_name() {
938 let ctx = SessionContext::new();
939 let factory =
940 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
941 let acc = factory.create_accumulator();
942 let adapter = acc
943 .as_any()
944 .downcast_ref::<DataFusionAccumulatorAdapter>()
945 .expect("should be adapter");
946 assert_eq!(adapter.function_name(), "sum");
947 }
948}