1use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44 match sv {
45 ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46 ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47 ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48 ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49 ScalarResult::OptionalFloat64(None)
50 }
51 ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52 ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54 ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55 ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56 ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57 ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58 ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59 ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61 _ => ScalarResult::Null,
62 }
63}
64
65#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68 match sr {
69 ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70 ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71 ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72 ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73 ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74 ScalarResult::Null => ScalarValue::Null,
75 }
76}
77
78pub struct DataFusionAccumulatorAdapter {
87 inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89 column_indices: Vec<usize>,
91 input_types: Vec<DataType>,
93 function_name: String,
95 factory: Arc<DataFusionAggregateFactory>,
97}
98
99unsafe impl Send for DataFusionAccumulatorAdapter {}
102
103impl std::fmt::Debug for DataFusionAccumulatorAdapter {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("DataFusionAccumulatorAdapter")
106 .field("function_name", &self.function_name)
107 .field("column_indices", &self.column_indices)
108 .field("input_types", &self.input_types)
109 .finish_non_exhaustive()
110 }
111}
112
113impl DataFusionAccumulatorAdapter {
114 #[must_use]
116 pub fn new(
117 inner: Box<dyn datafusion_expr::Accumulator>,
118 column_indices: Vec<usize>,
119 input_types: Vec<DataType>,
120 function_name: String,
121 factory: Arc<DataFusionAggregateFactory>,
122 ) -> Self {
123 Self {
124 inner: RefCell::new(inner),
125 column_indices,
126 input_types,
127 function_name,
128 factory,
129 }
130 }
131
132 #[must_use]
134 pub fn function_name(&self) -> &str {
135 &self.function_name
136 }
137
138 fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
140 self.column_indices
141 .iter()
142 .enumerate()
143 .map(|(arg_idx, &col_idx)| {
144 if col_idx < batch.num_columns() {
145 Arc::clone(batch.column(col_idx))
146 } else {
147 let dt = self
148 .input_types
149 .get(arg_idx)
150 .cloned()
151 .unwrap_or(DataType::Int64);
152 arrow_array::new_null_array(&dt, batch.num_rows())
153 }
154 })
155 .collect()
156 }
157}
158
159impl DynAccumulator for DataFusionAccumulatorAdapter {
160 fn add_event(&mut self, event: &Event) {
161 let columns = self.extract_columns(&event.data);
162 if let Err(e) = self.inner.borrow_mut().update_batch(&columns) {
163 tracing::warn!(
164 func = %self.function_name,
165 error = %e,
166 "Accumulator update_batch failed"
167 );
168 }
169 }
170
171 fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
172 let other = other
173 .as_any()
174 .downcast_ref::<DataFusionAccumulatorAdapter>()
175 .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
176
177 match other.inner.borrow_mut().state() {
178 Ok(state_values) => {
179 let mut failed_conversions = 0u32;
180 let state_arrays: Vec<ArrayRef> = state_values
181 .iter()
182 .filter_map(|sv| {
183 if let Ok(arr) = sv.to_array() {
184 Some(arr)
185 } else {
186 failed_conversions += 1;
187 None
188 }
189 })
190 .collect();
191 if failed_conversions > 0 {
192 tracing::warn!(
193 func = %self.function_name,
194 count = failed_conversions,
195 "ScalarValue to_array conversions failed during merge"
196 );
197 }
198 if !state_arrays.is_empty() {
199 if let Err(e) = self.inner.borrow_mut().merge_batch(&state_arrays) {
200 tracing::warn!(
201 func = %self.function_name,
202 error = %e,
203 "Accumulator merge_batch failed"
204 );
205 }
206 }
207 }
208 Err(e) => {
209 tracing::warn!(
210 func = %self.function_name,
211 error = %e,
212 "Failed to extract state for merge"
213 );
214 }
215 }
216 }
217
218 fn result_scalar(&self) -> ScalarResult {
219 match self.inner.borrow_mut().evaluate() {
220 Ok(sv) => scalar_value_to_result(&sv),
221 Err(_) => ScalarResult::Null,
222 }
223 }
224
225 fn is_empty(&self) -> bool {
226 self.inner.borrow().size() <= std::mem::size_of::<Self>()
227 }
228
229 fn clone_box(&self) -> Box<dyn DynAccumulator> {
230 let new_inner = self.factory.create_df_accumulator();
231 if let Ok(state_values) = self.inner.borrow_mut().state() {
233 let state_arrays: Vec<ArrayRef> = state_values
234 .iter()
235 .filter_map(|sv| sv.to_array().ok())
236 .collect();
237 if !state_arrays.is_empty() {
238 let mut new_acc = new_inner;
239 if new_acc.merge_batch(&state_arrays).is_ok() {
240 return Box::new(DataFusionAccumulatorAdapter {
241 inner: RefCell::new(new_acc),
242 column_indices: self.column_indices.clone(),
243 input_types: self.input_types.clone(),
244 function_name: self.function_name.clone(),
245 factory: Arc::clone(&self.factory),
246 });
247 }
248 }
249 }
250 Box::new(DataFusionAccumulatorAdapter {
252 inner: RefCell::new(self.factory.create_df_accumulator()),
253 column_indices: self.column_indices.clone(),
254 input_types: self.input_types.clone(),
255 function_name: self.function_name.clone(),
256 factory: Arc::clone(&self.factory),
257 })
258 }
259
260 #[allow(clippy::cast_possible_truncation)] fn serialize(&self) -> Vec<u8> {
262 match self.inner.borrow_mut().state() {
263 Ok(state_values) => {
264 let mut buf = Vec::new();
265 let count = state_values.len() as u32;
266 buf.extend_from_slice(&count.to_le_bytes());
267 for sv in &state_values {
268 let bytes = sv.to_string();
269 let len = bytes.len() as u32;
270 buf.extend_from_slice(&len.to_le_bytes());
271 buf.extend_from_slice(bytes.as_bytes());
272 }
273 buf
274 }
275 Err(_) => Vec::new(),
276 }
277 }
278
279 fn result_field(&self) -> Field {
280 let result = self.result_scalar();
281 let dt = result.data_type();
282 let dt = if dt == DataType::Null {
283 DataType::Float64
284 } else {
285 dt
286 };
287 Field::new(&self.function_name, dt, true)
288 }
289
290 fn type_tag(&self) -> &'static str {
291 "datafusion_adapter"
292 }
293
294 fn as_any(&self) -> &dyn std::any::Any {
295 self
296 }
297}
298
299pub struct DataFusionAggregateFactory {
306 udf: Arc<AggregateUDF>,
308 column_indices: Vec<usize>,
310 input_types: Vec<DataType>,
312 is_distinct: bool,
314}
315
316impl std::fmt::Debug for DataFusionAggregateFactory {
317 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318 f.debug_struct("DataFusionAggregateFactory")
319 .field("name", &self.udf.name())
320 .field("column_indices", &self.column_indices)
321 .field("input_types", &self.input_types)
322 .field("is_distinct", &self.is_distinct)
323 .finish()
324 }
325}
326
327impl DataFusionAggregateFactory {
328 #[must_use]
330 pub fn new(
331 udf: Arc<AggregateUDF>,
332 column_indices: Vec<usize>,
333 input_types: Vec<DataType>,
334 ) -> Self {
335 Self {
336 udf,
337 column_indices,
338 input_types,
339 is_distinct: false,
340 }
341 }
342
343 #[must_use]
345 pub fn with_distinct(mut self, distinct: bool) -> Self {
346 self.is_distinct = distinct;
347 self
348 }
349
350 #[must_use]
352 pub fn name(&self) -> &str {
353 self.udf.name()
354 }
355
356 const COL_NAMES: [&str; 8] = [
358 "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
359 ];
360
361 fn col_name(i: usize) -> &'static str {
363 Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
364 }
365
366 fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
368 let return_type = self
369 .udf
370 .return_type(&self.input_types)
371 .unwrap_or(DataType::Float64);
372 let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
373 let schema = Schema::new(
374 self.input_types
375 .iter()
376 .enumerate()
377 .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
378 .collect::<Vec<_>>(),
379 );
380 let expr_fields: Vec<FieldRef> = self
381 .input_types
382 .iter()
383 .enumerate()
384 .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
385 .collect();
386 let args = AccumulatorArgs {
387 return_field,
388 schema: &schema,
389 ignore_nulls: false,
390 order_bys: &[],
391 is_reversed: false,
392 name: self.udf.name(),
393 is_distinct: self.is_distinct,
394 exprs: &[],
395 expr_fields: &expr_fields,
396 };
397 self.udf
398 .accumulator(args)
399 .expect("Failed to create DataFusion accumulator")
400 }
401}
402
403impl DataFusionAggregateFactory {
404 #[must_use]
409 pub fn create_accumulator_with_factory(self: &Arc<Self>) -> Box<dyn DynAccumulator> {
410 let inner = self.create_df_accumulator();
411 Box::new(DataFusionAccumulatorAdapter::new(
412 inner,
413 self.column_indices.clone(),
414 self.input_types.clone(),
415 self.udf.name().to_string(),
416 Arc::clone(self),
417 ))
418 }
419}
420
421impl DynAggregatorFactory for DataFusionAggregateFactory {
422 fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
423 let factory_arc = Arc::new(DataFusionAggregateFactory {
426 udf: Arc::clone(&self.udf),
427 column_indices: self.column_indices.clone(),
428 input_types: self.input_types.clone(),
429 is_distinct: self.is_distinct,
430 });
431 let inner = self.create_df_accumulator();
432 Box::new(DataFusionAccumulatorAdapter::new(
433 inner,
434 self.column_indices.clone(),
435 self.input_types.clone(),
436 self.udf.name().to_string(),
437 factory_arc,
438 ))
439 }
440
441 fn result_field(&self) -> Field {
442 let return_type = self
443 .udf
444 .return_type(&self.input_types)
445 .unwrap_or(DataType::Float64);
446 Field::new(self.udf.name(), return_type, true)
447 }
448
449 fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
450 Box::new(DataFusionAggregateFactory {
451 udf: Arc::clone(&self.udf),
452 column_indices: self.column_indices.clone(),
453 input_types: self.input_types.clone(),
454 is_distinct: self.is_distinct,
455 })
456 }
457
458 fn type_tag(&self) -> &'static str {
459 "datafusion_factory"
460 }
461}
462
463#[must_use]
469pub fn lookup_aggregate_udf(
470 ctx: &datafusion::prelude::SessionContext,
471 name: &str,
472) -> Option<Arc<AggregateUDF>> {
473 let normalized = name.to_lowercase();
474 ctx.udaf(&normalized).ok()
475}
476
477#[must_use]
481pub fn create_aggregate_factory(
482 ctx: &datafusion::prelude::SessionContext,
483 name: &str,
484 column_indices: Vec<usize>,
485 input_types: Vec<DataType>,
486) -> Option<DataFusionAggregateFactory> {
487 lookup_aggregate_udf(ctx, name)
488 .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
489}
490
491#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::datafusion::create_session_context;
497 use arrow_array::{Float64Array, Int64Array, RecordBatch};
498
499 fn float_event(ts: i64, values: Vec<f64>) -> Event {
500 let schema = Arc::new(Schema::new(vec![Field::new(
501 "value",
502 DataType::Float64,
503 false,
504 )]));
505 let batch =
506 RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
507 Event::new(ts, batch)
508 }
509
510 fn int_event(ts: i64, values: Vec<i64>) -> Event {
511 let schema = Arc::new(Schema::new(vec![Field::new(
512 "value",
513 DataType::Int64,
514 false,
515 )]));
516 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
517 Event::new(ts, batch)
518 }
519
520 fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
521 let schema = Arc::new(Schema::new(vec![
522 Field::new("x", DataType::Float64, false),
523 Field::new("y", DataType::Float64, false),
524 ]));
525 let batch = RecordBatch::try_new(
526 schema,
527 vec![
528 Arc::new(Float64Array::from(col0)),
529 Arc::new(Float64Array::from(col1)),
530 ],
531 )
532 .unwrap();
533 Event::new(ts, batch)
534 }
535
536 #[test]
539 fn test_scalar_value_to_result_int64() {
540 let sv = ScalarValue::Int64(Some(42));
541 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
542 }
543
544 #[test]
545 fn test_scalar_value_to_result_float64() {
546 let sv = ScalarValue::Float64(Some(3.125));
547 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
548 }
549
550 #[test]
551 fn test_scalar_value_to_result_uint64() {
552 let sv = ScalarValue::UInt64(Some(100));
553 assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
554 }
555
556 #[test]
557 fn test_scalar_value_to_result_null_int64() {
558 let sv = ScalarValue::Int64(None);
559 assert_eq!(
560 scalar_value_to_result(&sv),
561 ScalarResult::OptionalInt64(None)
562 );
563 }
564
565 #[test]
566 fn test_scalar_value_to_result_null_float64() {
567 let sv = ScalarValue::Float64(None);
568 assert_eq!(
569 scalar_value_to_result(&sv),
570 ScalarResult::OptionalFloat64(None)
571 );
572 }
573
574 #[test]
575 fn test_scalar_value_to_result_smaller_ints() {
576 assert_eq!(
577 scalar_value_to_result(&ScalarValue::Int8(Some(8))),
578 ScalarResult::Int64(8)
579 );
580 assert_eq!(
581 scalar_value_to_result(&ScalarValue::Int16(Some(16))),
582 ScalarResult::Int64(16)
583 );
584 assert_eq!(
585 scalar_value_to_result(&ScalarValue::Int32(Some(32))),
586 ScalarResult::Int64(32)
587 );
588 assert_eq!(
589 scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
590 ScalarResult::UInt64(8)
591 );
592 }
593
594 #[test]
595 fn test_scalar_value_to_result_float32() {
596 let sv = ScalarValue::Float32(Some(2.5));
597 assert_eq!(
598 scalar_value_to_result(&sv),
599 ScalarResult::Float64(f64::from(2.5f32))
600 );
601 }
602
603 #[test]
604 fn test_scalar_value_to_result_unsupported() {
605 let sv = ScalarValue::Utf8(Some("hello".to_string()));
606 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
607 }
608
609 #[test]
610 fn test_result_to_scalar_value_roundtrip() {
611 let exact_cases = vec![
613 ScalarResult::Int64(42),
614 ScalarResult::Float64(3.125),
615 ScalarResult::UInt64(100),
616 ];
617 for sr in &exact_cases {
618 let sv = result_to_scalar_value(sr);
619 let back = scalar_value_to_result(&sv);
620 assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
621 }
622
623 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
626 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
627
628 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
629 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
630
631 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
633 assert_eq!(
634 scalar_value_to_result(&sv),
635 ScalarResult::OptionalInt64(None)
636 );
637
638 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
639 assert_eq!(
640 scalar_value_to_result(&sv),
641 ScalarResult::OptionalFloat64(None)
642 );
643
644 let sv = result_to_scalar_value(&ScalarResult::Null);
646 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
647 }
648
649 #[test]
652 fn test_factory_count() {
653 let ctx = create_session_context();
654 let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
655 assert!(factory.is_some(), "count should be a recognized aggregate");
656 assert_eq!(factory.unwrap().name(), "count");
657 }
658
659 #[test]
660 fn test_factory_sum() {
661 let ctx = create_session_context();
662 let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
663 assert!(factory.is_some());
664 assert_eq!(factory.unwrap().name(), "sum");
665 }
666
667 #[test]
668 fn test_factory_avg() {
669 let ctx = create_session_context();
670 let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
671 assert!(factory.is_some());
672 }
673
674 #[test]
675 fn test_factory_stddev() {
676 let ctx = create_session_context();
677 let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
678 assert!(
679 factory.is_some(),
680 "stddev should be available in DataFusion"
681 );
682 }
683
684 #[test]
685 fn test_factory_unknown() {
686 let ctx = create_session_context();
687 let factory = create_aggregate_factory(
688 &ctx,
689 "nonexistent_aggregate_xyz",
690 vec![0],
691 vec![DataType::Int64],
692 );
693 assert!(factory.is_none());
694 }
695
696 #[test]
697 fn test_factory_result_field() {
698 let ctx = create_session_context();
699 let factory =
700 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
701 let field = factory.result_field();
702 assert_eq!(field.name(), "sum");
703 assert_eq!(field.data_type(), &DataType::Float64);
704 }
705
706 #[test]
707 fn test_factory_clone_box() {
708 let ctx = create_session_context();
709 let factory =
710 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
711 let cloned = factory.clone_box();
712 assert_eq!(cloned.type_tag(), "datafusion_factory");
713 }
714
715 #[test]
718 fn test_adapter_count_basic() {
719 let ctx = create_session_context();
720 let factory =
721 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
722 let mut acc = factory.create_accumulator();
723
724 let result = acc.result_scalar();
725 assert!(
726 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
727 "Expected 0, got {result:?}"
728 );
729
730 acc.add_event(&int_event(1000, vec![10, 20, 30]));
731 let result = acc.result_scalar();
732 assert!(
733 matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
734 "Expected 3, got {result:?}"
735 );
736
737 acc.add_event(&int_event(2000, vec![40, 50]));
738 let result = acc.result_scalar();
739 assert!(
740 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
741 "Expected 5, got {result:?}"
742 );
743 }
744
745 #[test]
746 fn test_adapter_sum_float64() {
747 let ctx = create_session_context();
748 let factory =
749 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
750 let mut acc = factory.create_accumulator();
751
752 acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
753 assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
754 }
755
756 #[test]
757 fn test_adapter_avg_float64() {
758 let ctx = create_session_context();
759 let factory =
760 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
761 let mut acc = factory.create_accumulator();
762
763 acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
764 assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
765 }
766
767 #[test]
768 fn test_adapter_min_float64() {
769 let ctx = create_session_context();
770 let factory =
771 create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
772 let mut acc = factory.create_accumulator();
773
774 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
775 assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
776 }
777
778 #[test]
779 fn test_adapter_max_float64() {
780 let ctx = create_session_context();
781 let factory =
782 create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
783 let mut acc = factory.create_accumulator();
784
785 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
786 assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
787 }
788
789 #[test]
790 fn test_adapter_sum_int64() {
791 let ctx = create_session_context();
792 let factory =
793 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
794 let mut acc = factory.create_accumulator();
795
796 acc.add_event(&int_event(1000, vec![10, 20, 30]));
797 assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
798 }
799
800 #[test]
801 fn test_adapter_type_tag() {
802 let ctx = create_session_context();
803 let factory =
804 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
805 let acc = factory.create_accumulator();
806 assert_eq!(acc.type_tag(), "datafusion_adapter");
807 }
808
809 #[test]
810 fn test_adapter_result_field() {
811 let ctx = create_session_context();
812 let factory =
813 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
814 let mut acc = factory.create_accumulator();
815 acc.add_event(&float_event(1000, vec![1.0]));
816 assert_eq!(acc.result_field().name(), "sum");
817 }
818
819 #[test]
822 fn test_adapter_merge_sum() {
823 let ctx = create_session_context();
824 let factory =
825 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
826
827 let mut acc1 = factory.create_accumulator();
828 acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
829
830 let mut acc2 = factory.create_accumulator();
831 acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
832
833 acc1.merge_dyn(acc2.as_ref());
834 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
835 }
836
837 #[test]
838 fn test_adapter_merge_count() {
839 let ctx = create_session_context();
840 let factory =
841 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
842
843 let mut acc1 = factory.create_accumulator();
844 acc1.add_event(&int_event(1000, vec![1, 2, 3]));
845
846 let mut acc2 = factory.create_accumulator();
847 acc2.add_event(&int_event(2000, vec![4, 5]));
848
849 acc1.merge_dyn(acc2.as_ref());
850 let result = acc1.result_scalar();
851 assert!(
852 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
853 "Expected 5 after merge, got {result:?}"
854 );
855 }
856
857 #[test]
858 fn test_adapter_merge_avg() {
859 let ctx = create_session_context();
860 let factory =
861 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
862
863 let mut acc1 = factory.create_accumulator();
864 acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
865
866 let mut acc2 = factory.create_accumulator();
867 acc2.add_event(&float_event(2000, vec![30.0]));
868
869 acc1.merge_dyn(acc2.as_ref());
870 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
871 }
872
873 #[test]
874 fn test_adapter_merge_empty() {
875 let ctx = create_session_context();
876 let factory =
877 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
878
879 let mut acc1 = factory.create_accumulator();
880 acc1.add_event(&float_event(1000, vec![5.0]));
881
882 let acc2 = factory.create_accumulator();
883 acc1.merge_dyn(acc2.as_ref());
884 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
885 }
886
887 #[test]
890 fn test_adapter_stddev() {
891 let ctx = create_session_context();
892 let factory =
893 create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
894 let mut acc = factory.create_accumulator();
895
896 acc.add_event(&float_event(
897 1000,
898 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
899 ));
900 let result = acc.result_scalar();
901 if let ScalarResult::Float64(v) = result {
902 assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
903 } else {
904 panic!("Expected Float64 result, got {result:?}");
905 }
906 }
907
908 #[test]
909 fn test_adapter_variance() {
910 let ctx = create_session_context();
911 if let Some(factory) =
912 create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
913 {
914 let mut acc = factory.create_accumulator();
915 acc.add_event(&float_event(
916 1000,
917 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
918 ));
919 if let ScalarResult::Float64(v) = acc.result_scalar() {
920 assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
921 }
922 }
923 }
924
925 #[test]
926 fn test_adapter_median() {
927 let ctx = create_session_context();
928 if let Some(factory) =
929 create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
930 {
931 let mut acc = factory.create_accumulator();
932 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
933 assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
934 }
935 }
936
937 #[test]
940 fn test_adapter_serialize() {
941 let ctx = create_session_context();
942 let factory =
943 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
944 let mut acc = factory.create_accumulator();
945 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
946 assert!(!acc.serialize().is_empty());
947 }
948
949 #[test]
950 fn test_adapter_serialize_empty() {
951 let ctx = create_session_context();
952 let factory =
953 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
954 let acc = factory.create_accumulator();
955 assert!(!acc.serialize().is_empty());
956 }
957
958 #[test]
961 fn test_lookup_common_aggregates() {
962 let ctx = create_session_context();
963 for name in &["count", "sum", "min", "max", "avg"] {
964 assert!(
965 lookup_aggregate_udf(&ctx, name).is_some(),
966 "Expected '{name}' to be a recognized aggregate"
967 );
968 }
969 }
970
971 #[test]
972 fn test_lookup_statistical_aggregates() {
973 let ctx = create_session_context();
974 for name in &["stddev", "stddev_pop", "median"] {
975 let _ = lookup_aggregate_udf(&ctx, name);
977 }
978 }
979
980 #[test]
981 fn test_lookup_case_insensitive() {
982 let ctx = create_session_context();
983 assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
984 assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
985 assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
986 }
987
988 #[test]
991 fn test_adapter_multi_column_covar() {
992 let ctx = create_session_context();
993 if let Some(factory) = create_aggregate_factory(
994 &ctx,
995 "covar_samp",
996 vec![0, 1],
997 vec![DataType::Float64, DataType::Float64],
998 ) {
999 let mut acc = factory.create_accumulator();
1000 acc.add_event(&two_col_float_event(
1001 1000,
1002 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1003 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1004 ));
1005 if let ScalarResult::Float64(v) = acc.result_scalar() {
1006 assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
1007 }
1008 }
1009 }
1010
1011 #[test]
1014 fn test_create_aggregate_factory_api() {
1015 let ctx = create_session_context();
1016 let factory =
1017 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1018 let acc = factory.create_accumulator();
1019 assert_eq!(acc.type_tag(), "datafusion_adapter");
1020 }
1021
1022 #[test]
1023 fn test_factory_creates_independent_accumulators() {
1024 let ctx = create_session_context();
1025 let factory =
1026 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1027
1028 let mut acc1 = factory.create_accumulator();
1029 let mut acc2 = factory.create_accumulator();
1030
1031 acc1.add_event(&float_event(1000, vec![10.0]));
1032 acc2.add_event(&float_event(2000, vec![20.0]));
1033
1034 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
1035 assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
1036 }
1037
1038 #[test]
1039 fn test_adapter_function_name() {
1040 let ctx = create_session_context();
1041 let factory =
1042 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1043 let acc = factory.create_accumulator();
1044 let adapter = acc
1045 .as_any()
1046 .downcast_ref::<DataFusionAccumulatorAdapter>()
1047 .expect("should be adapter");
1048 assert_eq!(adapter.function_name(), "sum");
1049 }
1050
1051 #[test]
1052 fn test_clone_box_does_not_panic() {
1053 let ctx = create_session_context();
1054 let factory =
1055 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1056 let mut acc = factory.create_accumulator();
1057 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
1058
1059 let cloned = acc.clone_box();
1061 assert_eq!(cloned.result_scalar(), ScalarResult::Float64(6.0));
1062 }
1063
1064 #[test]
1065 fn test_clone_box_empty_accumulator() {
1066 let ctx = create_session_context();
1067 let factory =
1068 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1069 let acc = factory.create_accumulator();
1070
1071 let cloned = acc.clone_box();
1073 let result = cloned.result_scalar();
1074 assert!(
1075 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
1076 "Expected 0, got {result:?}"
1077 );
1078 }
1079
1080 #[test]
1081 fn test_distinct_factory() {
1082 let ctx = create_session_context();
1083 let udf = lookup_aggregate_udf(&ctx, "count").unwrap();
1084 let factory = DataFusionAggregateFactory::new(udf, vec![0], vec![DataType::Int64])
1085 .with_distinct(true);
1086 assert!(factory.is_distinct);
1087 let _acc = factory.create_accumulator();
1089 }
1090}