1use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::{DataFusionError, ScalarValue};
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44 match sv {
45 ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46 ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47 ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48 ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49 ScalarResult::OptionalFloat64(None)
50 }
51 ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52 ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54 ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55 ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56 ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57 ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58 ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59 ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61 _ => ScalarResult::Null,
62 }
63}
64
65#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68 match sr {
69 ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70 ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71 ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72 ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73 ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74 ScalarResult::Null => ScalarValue::Null,
75 }
76}
77
78pub struct DataFusionAccumulatorAdapter {
87 inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89 column_indices: Vec<usize>,
91 input_types: Vec<DataType>,
93 function_name: String,
95 factory: Arc<DataFusionAggregateFactory>,
97}
98
99unsafe impl Send for DataFusionAccumulatorAdapter {}
102
103impl std::fmt::Debug for DataFusionAccumulatorAdapter {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("DataFusionAccumulatorAdapter")
106 .field("function_name", &self.function_name)
107 .field("column_indices", &self.column_indices)
108 .field("input_types", &self.input_types)
109 .finish_non_exhaustive()
110 }
111}
112
113impl DataFusionAccumulatorAdapter {
114 #[must_use]
116 pub fn new(
117 inner: Box<dyn datafusion_expr::Accumulator>,
118 column_indices: Vec<usize>,
119 input_types: Vec<DataType>,
120 function_name: String,
121 factory: Arc<DataFusionAggregateFactory>,
122 ) -> Self {
123 Self {
124 inner: RefCell::new(inner),
125 column_indices,
126 input_types,
127 function_name,
128 factory,
129 }
130 }
131
132 #[must_use]
134 pub fn function_name(&self) -> &str {
135 &self.function_name
136 }
137
138 fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
140 self.column_indices
141 .iter()
142 .enumerate()
143 .map(|(arg_idx, &col_idx)| {
144 if col_idx < batch.num_columns() {
145 Arc::clone(batch.column(col_idx))
146 } else {
147 let dt = self
148 .input_types
149 .get(arg_idx)
150 .cloned()
151 .unwrap_or(DataType::Int64);
152 arrow_array::new_null_array(&dt, batch.num_rows())
153 }
154 })
155 .collect()
156 }
157}
158
159impl DynAccumulator for DataFusionAccumulatorAdapter {
160 fn add_event(&mut self, event: &Event) {
161 let columns = self.extract_columns(&event.data);
162 if let Err(e) = self.inner.borrow_mut().update_batch(&columns) {
163 tracing::warn!(
164 func = %self.function_name,
165 error = %e,
166 "Accumulator update_batch failed"
167 );
168 }
169 }
170
171 fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
172 let Some(other) = other
173 .as_any()
174 .downcast_ref::<DataFusionAccumulatorAdapter>()
175 else {
176 tracing::error!("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
177 return;
178 };
179
180 match other.inner.borrow_mut().state() {
181 Ok(state_values) => {
182 let mut failed_conversions = 0u32;
183 let state_arrays: Vec<ArrayRef> = state_values
184 .iter()
185 .filter_map(|sv| {
186 if let Ok(arr) = sv.to_array() {
187 Some(arr)
188 } else {
189 failed_conversions += 1;
190 None
191 }
192 })
193 .collect();
194 if failed_conversions > 0 {
195 tracing::warn!(
196 func = %self.function_name,
197 count = failed_conversions,
198 "ScalarValue to_array conversions failed during merge"
199 );
200 }
201 if !state_arrays.is_empty() {
202 if let Err(e) = self.inner.borrow_mut().merge_batch(&state_arrays) {
203 tracing::warn!(
204 func = %self.function_name,
205 error = %e,
206 "Accumulator merge_batch failed"
207 );
208 }
209 }
210 }
211 Err(e) => {
212 tracing::warn!(
213 func = %self.function_name,
214 error = %e,
215 "Failed to extract state for merge"
216 );
217 }
218 }
219 }
220
221 fn result_scalar(&self) -> ScalarResult {
222 match self.inner.borrow_mut().evaluate() {
223 Ok(sv) => scalar_value_to_result(&sv),
224 Err(_) => ScalarResult::Null,
225 }
226 }
227
228 fn is_empty(&self) -> bool {
229 self.inner.borrow().size() <= std::mem::size_of::<Self>()
230 }
231
232 fn clone_box(&self) -> Box<dyn DynAccumulator> {
233 let create = || {
234 self.factory.create_df_accumulator().unwrap_or_else(|e| {
235 panic!(
236 "failed to create DataFusion accumulator '{}': {e}",
237 self.function_name
238 );
239 })
240 };
241 let new_inner = create();
242 if let Ok(state_values) = self.inner.borrow_mut().state() {
244 let state_arrays: Vec<ArrayRef> = state_values
245 .iter()
246 .filter_map(|sv| sv.to_array().ok())
247 .collect();
248 if !state_arrays.is_empty() {
249 let mut new_acc = new_inner;
250 if new_acc.merge_batch(&state_arrays).is_ok() {
251 return Box::new(DataFusionAccumulatorAdapter {
252 inner: RefCell::new(new_acc),
253 column_indices: self.column_indices.clone(),
254 input_types: self.input_types.clone(),
255 function_name: self.function_name.clone(),
256 factory: Arc::clone(&self.factory),
257 });
258 }
259 }
260 }
261 Box::new(DataFusionAccumulatorAdapter {
263 inner: RefCell::new(create()),
264 column_indices: self.column_indices.clone(),
265 input_types: self.input_types.clone(),
266 function_name: self.function_name.clone(),
267 factory: Arc::clone(&self.factory),
268 })
269 }
270
271 #[allow(clippy::cast_possible_truncation)] fn serialize(&self) -> Vec<u8> {
273 match self.inner.borrow_mut().state() {
274 Ok(state_values) => {
275 let mut buf = Vec::new();
276 let count = state_values.len() as u32;
277 buf.extend_from_slice(&count.to_le_bytes());
278 for sv in &state_values {
279 let bytes = sv.to_string();
280 let len = bytes.len() as u32;
281 buf.extend_from_slice(&len.to_le_bytes());
282 buf.extend_from_slice(bytes.as_bytes());
283 }
284 buf
285 }
286 Err(_) => Vec::new(),
287 }
288 }
289
290 fn result_field(&self) -> Field {
291 let result = self.result_scalar();
292 let dt = result.data_type();
293 let dt = if dt == DataType::Null {
294 DataType::Float64
295 } else {
296 dt
297 };
298 Field::new(&self.function_name, dt, true)
299 }
300
301 fn type_tag(&self) -> &'static str {
302 "datafusion_adapter"
303 }
304
305 fn as_any(&self) -> &dyn std::any::Any {
306 self
307 }
308}
309
310pub struct DataFusionAggregateFactory {
317 udf: Arc<AggregateUDF>,
319 column_indices: Vec<usize>,
321 input_types: Vec<DataType>,
323 is_distinct: bool,
325}
326
327impl std::fmt::Debug for DataFusionAggregateFactory {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 f.debug_struct("DataFusionAggregateFactory")
330 .field("name", &self.udf.name())
331 .field("column_indices", &self.column_indices)
332 .field("input_types", &self.input_types)
333 .field("is_distinct", &self.is_distinct)
334 .finish()
335 }
336}
337
338impl DataFusionAggregateFactory {
339 #[must_use]
341 pub fn new(
342 udf: Arc<AggregateUDF>,
343 column_indices: Vec<usize>,
344 input_types: Vec<DataType>,
345 ) -> Self {
346 Self {
347 udf,
348 column_indices,
349 input_types,
350 is_distinct: false,
351 }
352 }
353
354 #[must_use]
356 pub fn with_distinct(mut self, distinct: bool) -> Self {
357 self.is_distinct = distinct;
358 self
359 }
360
361 #[must_use]
363 pub fn name(&self) -> &str {
364 self.udf.name()
365 }
366
367 const COL_NAMES: [&str; 8] = [
369 "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
370 ];
371
372 fn col_name(i: usize) -> &'static str {
374 Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
375 }
376
377 fn create_df_accumulator(
379 &self,
380 ) -> std::result::Result<Box<dyn datafusion_expr::Accumulator>, DataFusionError> {
381 let return_type = self
382 .udf
383 .return_type(&self.input_types)
384 .unwrap_or(DataType::Float64);
385 let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
386 let schema = Schema::new(
387 self.input_types
388 .iter()
389 .enumerate()
390 .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
391 .collect::<Vec<_>>(),
392 );
393 let expr_fields: Vec<FieldRef> = self
394 .input_types
395 .iter()
396 .enumerate()
397 .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
398 .collect();
399 let args = AccumulatorArgs {
400 return_field,
401 schema: &schema,
402 ignore_nulls: false,
403 order_bys: &[],
404 is_reversed: false,
405 name: self.udf.name(),
406 is_distinct: self.is_distinct,
407 exprs: &[],
408 expr_fields: &expr_fields,
409 };
410 self.udf.accumulator(args)
411 }
412}
413
414impl DataFusionAggregateFactory {
415 pub fn create_accumulator_with_factory(
423 self: &Arc<Self>,
424 ) -> std::result::Result<Box<dyn DynAccumulator>, DataFusionError> {
425 let inner = self.create_df_accumulator()?;
426 Ok(Box::new(DataFusionAccumulatorAdapter::new(
427 inner,
428 self.column_indices.clone(),
429 self.input_types.clone(),
430 self.udf.name().to_string(),
431 Arc::clone(self),
432 )))
433 }
434}
435
436impl DynAggregatorFactory for DataFusionAggregateFactory {
437 fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
438 let factory_arc = Arc::new(DataFusionAggregateFactory {
441 udf: Arc::clone(&self.udf),
442 column_indices: self.column_indices.clone(),
443 input_types: self.input_types.clone(),
444 is_distinct: self.is_distinct,
445 });
446 let inner = self.create_df_accumulator().unwrap_or_else(|e| {
447 panic!(
448 "failed to create DataFusion accumulator '{}': {e}",
449 self.udf.name()
450 );
451 });
452 Box::new(DataFusionAccumulatorAdapter::new(
453 inner,
454 self.column_indices.clone(),
455 self.input_types.clone(),
456 self.udf.name().to_string(),
457 factory_arc,
458 ))
459 }
460
461 fn result_field(&self) -> Field {
462 let return_type = self
463 .udf
464 .return_type(&self.input_types)
465 .unwrap_or(DataType::Float64);
466 Field::new(self.udf.name(), return_type, true)
467 }
468
469 fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
470 Box::new(DataFusionAggregateFactory {
471 udf: Arc::clone(&self.udf),
472 column_indices: self.column_indices.clone(),
473 input_types: self.input_types.clone(),
474 is_distinct: self.is_distinct,
475 })
476 }
477
478 fn type_tag(&self) -> &'static str {
479 "datafusion_factory"
480 }
481}
482
483#[must_use]
489pub fn lookup_aggregate_udf(
490 ctx: &datafusion::prelude::SessionContext,
491 name: &str,
492) -> Option<Arc<AggregateUDF>> {
493 let normalized = name.to_lowercase();
494 ctx.udaf(&normalized).ok()
495}
496
497#[must_use]
501pub fn create_aggregate_factory(
502 ctx: &datafusion::prelude::SessionContext,
503 name: &str,
504 column_indices: Vec<usize>,
505 input_types: Vec<DataType>,
506) -> Option<DataFusionAggregateFactory> {
507 lookup_aggregate_udf(ctx, name)
508 .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
509}
510
511#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::datafusion::create_session_context;
517 use arrow_array::{Float64Array, Int64Array, RecordBatch};
518
519 fn float_event(ts: i64, values: Vec<f64>) -> Event {
520 let schema = Arc::new(Schema::new(vec![Field::new(
521 "value",
522 DataType::Float64,
523 false,
524 )]));
525 let batch =
526 RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
527 Event::new(ts, batch)
528 }
529
530 fn int_event(ts: i64, values: Vec<i64>) -> Event {
531 let schema = Arc::new(Schema::new(vec![Field::new(
532 "value",
533 DataType::Int64,
534 false,
535 )]));
536 let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
537 Event::new(ts, batch)
538 }
539
540 fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
541 let schema = Arc::new(Schema::new(vec![
542 Field::new("x", DataType::Float64, false),
543 Field::new("y", DataType::Float64, false),
544 ]));
545 let batch = RecordBatch::try_new(
546 schema,
547 vec![
548 Arc::new(Float64Array::from(col0)),
549 Arc::new(Float64Array::from(col1)),
550 ],
551 )
552 .unwrap();
553 Event::new(ts, batch)
554 }
555
556 #[test]
559 fn test_scalar_value_to_result_int64() {
560 let sv = ScalarValue::Int64(Some(42));
561 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
562 }
563
564 #[test]
565 fn test_scalar_value_to_result_float64() {
566 let sv = ScalarValue::Float64(Some(3.125));
567 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
568 }
569
570 #[test]
571 fn test_scalar_value_to_result_uint64() {
572 let sv = ScalarValue::UInt64(Some(100));
573 assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
574 }
575
576 #[test]
577 fn test_scalar_value_to_result_null_int64() {
578 let sv = ScalarValue::Int64(None);
579 assert_eq!(
580 scalar_value_to_result(&sv),
581 ScalarResult::OptionalInt64(None)
582 );
583 }
584
585 #[test]
586 fn test_scalar_value_to_result_null_float64() {
587 let sv = ScalarValue::Float64(None);
588 assert_eq!(
589 scalar_value_to_result(&sv),
590 ScalarResult::OptionalFloat64(None)
591 );
592 }
593
594 #[test]
595 fn test_scalar_value_to_result_smaller_ints() {
596 assert_eq!(
597 scalar_value_to_result(&ScalarValue::Int8(Some(8))),
598 ScalarResult::Int64(8)
599 );
600 assert_eq!(
601 scalar_value_to_result(&ScalarValue::Int16(Some(16))),
602 ScalarResult::Int64(16)
603 );
604 assert_eq!(
605 scalar_value_to_result(&ScalarValue::Int32(Some(32))),
606 ScalarResult::Int64(32)
607 );
608 assert_eq!(
609 scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
610 ScalarResult::UInt64(8)
611 );
612 }
613
614 #[test]
615 fn test_scalar_value_to_result_float32() {
616 let sv = ScalarValue::Float32(Some(2.5));
617 assert_eq!(
618 scalar_value_to_result(&sv),
619 ScalarResult::Float64(f64::from(2.5f32))
620 );
621 }
622
623 #[test]
624 fn test_scalar_value_to_result_unsupported() {
625 let sv = ScalarValue::Utf8(Some("hello".to_string()));
626 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
627 }
628
629 #[test]
630 fn test_result_to_scalar_value_roundtrip() {
631 let exact_cases = vec![
633 ScalarResult::Int64(42),
634 ScalarResult::Float64(3.125),
635 ScalarResult::UInt64(100),
636 ];
637 for sr in &exact_cases {
638 let sv = result_to_scalar_value(sr);
639 let back = scalar_value_to_result(&sv);
640 assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
641 }
642
643 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
646 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
647
648 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
649 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
650
651 let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
653 assert_eq!(
654 scalar_value_to_result(&sv),
655 ScalarResult::OptionalInt64(None)
656 );
657
658 let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
659 assert_eq!(
660 scalar_value_to_result(&sv),
661 ScalarResult::OptionalFloat64(None)
662 );
663
664 let sv = result_to_scalar_value(&ScalarResult::Null);
666 assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
667 }
668
669 #[test]
672 fn test_factory_count() {
673 let ctx = create_session_context();
674 let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
675 assert!(factory.is_some(), "count should be a recognized aggregate");
676 assert_eq!(factory.unwrap().name(), "count");
677 }
678
679 #[test]
680 fn test_factory_sum() {
681 let ctx = create_session_context();
682 let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
683 assert!(factory.is_some());
684 assert_eq!(factory.unwrap().name(), "sum");
685 }
686
687 #[test]
688 fn test_factory_avg() {
689 let ctx = create_session_context();
690 let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
691 assert!(factory.is_some());
692 }
693
694 #[test]
695 fn test_factory_stddev() {
696 let ctx = create_session_context();
697 let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
698 assert!(
699 factory.is_some(),
700 "stddev should be available in DataFusion"
701 );
702 }
703
704 #[test]
705 fn test_factory_unknown() {
706 let ctx = create_session_context();
707 let factory = create_aggregate_factory(
708 &ctx,
709 "nonexistent_aggregate_xyz",
710 vec![0],
711 vec![DataType::Int64],
712 );
713 assert!(factory.is_none());
714 }
715
716 #[test]
717 fn test_factory_result_field() {
718 let ctx = create_session_context();
719 let factory =
720 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
721 let field = factory.result_field();
722 assert_eq!(field.name(), "sum");
723 assert_eq!(field.data_type(), &DataType::Float64);
724 }
725
726 #[test]
727 fn test_factory_clone_box() {
728 let ctx = create_session_context();
729 let factory =
730 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
731 let cloned = factory.clone_box();
732 assert_eq!(cloned.type_tag(), "datafusion_factory");
733 }
734
735 #[test]
738 fn test_adapter_count_basic() {
739 let ctx = create_session_context();
740 let factory =
741 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
742 let mut acc = factory.create_accumulator();
743
744 let result = acc.result_scalar();
745 assert!(
746 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
747 "Expected 0, got {result:?}"
748 );
749
750 acc.add_event(&int_event(1000, vec![10, 20, 30]));
751 let result = acc.result_scalar();
752 assert!(
753 matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
754 "Expected 3, got {result:?}"
755 );
756
757 acc.add_event(&int_event(2000, vec![40, 50]));
758 let result = acc.result_scalar();
759 assert!(
760 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
761 "Expected 5, got {result:?}"
762 );
763 }
764
765 #[test]
766 fn test_adapter_sum_float64() {
767 let ctx = create_session_context();
768 let factory =
769 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
770 let mut acc = factory.create_accumulator();
771
772 acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
773 assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
774 }
775
776 #[test]
777 fn test_adapter_avg_float64() {
778 let ctx = create_session_context();
779 let factory =
780 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
781 let mut acc = factory.create_accumulator();
782
783 acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
784 assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
785 }
786
787 #[test]
788 fn test_adapter_min_float64() {
789 let ctx = create_session_context();
790 let factory =
791 create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
792 let mut acc = factory.create_accumulator();
793
794 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
795 assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
796 }
797
798 #[test]
799 fn test_adapter_max_float64() {
800 let ctx = create_session_context();
801 let factory =
802 create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
803 let mut acc = factory.create_accumulator();
804
805 acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
806 assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
807 }
808
809 #[test]
810 fn test_adapter_sum_int64() {
811 let ctx = create_session_context();
812 let factory =
813 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
814 let mut acc = factory.create_accumulator();
815
816 acc.add_event(&int_event(1000, vec![10, 20, 30]));
817 assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
818 }
819
820 #[test]
821 fn test_adapter_type_tag() {
822 let ctx = create_session_context();
823 let factory =
824 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
825 let acc = factory.create_accumulator();
826 assert_eq!(acc.type_tag(), "datafusion_adapter");
827 }
828
829 #[test]
830 fn test_adapter_result_field() {
831 let ctx = create_session_context();
832 let factory =
833 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
834 let mut acc = factory.create_accumulator();
835 acc.add_event(&float_event(1000, vec![1.0]));
836 assert_eq!(acc.result_field().name(), "sum");
837 }
838
839 #[test]
842 fn test_adapter_merge_sum() {
843 let ctx = create_session_context();
844 let factory =
845 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
846
847 let mut acc1 = factory.create_accumulator();
848 acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
849
850 let mut acc2 = factory.create_accumulator();
851 acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
852
853 acc1.merge_dyn(acc2.as_ref());
854 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
855 }
856
857 #[test]
858 fn test_adapter_merge_count() {
859 let ctx = create_session_context();
860 let factory =
861 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
862
863 let mut acc1 = factory.create_accumulator();
864 acc1.add_event(&int_event(1000, vec![1, 2, 3]));
865
866 let mut acc2 = factory.create_accumulator();
867 acc2.add_event(&int_event(2000, vec![4, 5]));
868
869 acc1.merge_dyn(acc2.as_ref());
870 let result = acc1.result_scalar();
871 assert!(
872 matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
873 "Expected 5 after merge, got {result:?}"
874 );
875 }
876
877 #[test]
878 fn test_adapter_merge_avg() {
879 let ctx = create_session_context();
880 let factory =
881 create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
882
883 let mut acc1 = factory.create_accumulator();
884 acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
885
886 let mut acc2 = factory.create_accumulator();
887 acc2.add_event(&float_event(2000, vec![30.0]));
888
889 acc1.merge_dyn(acc2.as_ref());
890 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
891 }
892
893 #[test]
894 fn test_adapter_merge_empty() {
895 let ctx = create_session_context();
896 let factory =
897 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
898
899 let mut acc1 = factory.create_accumulator();
900 acc1.add_event(&float_event(1000, vec![5.0]));
901
902 let acc2 = factory.create_accumulator();
903 acc1.merge_dyn(acc2.as_ref());
904 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
905 }
906
907 #[test]
910 fn test_adapter_stddev() {
911 let ctx = create_session_context();
912 let factory =
913 create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
914 let mut acc = factory.create_accumulator();
915
916 acc.add_event(&float_event(
917 1000,
918 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
919 ));
920 let result = acc.result_scalar();
921 if let ScalarResult::Float64(v) = result {
922 assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
923 } else {
924 panic!("Expected Float64 result, got {result:?}");
925 }
926 }
927
928 #[test]
929 fn test_adapter_variance() {
930 let ctx = create_session_context();
931 if let Some(factory) =
932 create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
933 {
934 let mut acc = factory.create_accumulator();
935 acc.add_event(&float_event(
936 1000,
937 vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
938 ));
939 if let ScalarResult::Float64(v) = acc.result_scalar() {
940 assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
941 }
942 }
943 }
944
945 #[test]
946 fn test_adapter_median() {
947 let ctx = create_session_context();
948 if let Some(factory) =
949 create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
950 {
951 let mut acc = factory.create_accumulator();
952 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
953 assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
954 }
955 }
956
957 #[test]
960 fn test_adapter_serialize() {
961 let ctx = create_session_context();
962 let factory =
963 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
964 let mut acc = factory.create_accumulator();
965 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
966 assert!(!acc.serialize().is_empty());
967 }
968
969 #[test]
970 fn test_adapter_serialize_empty() {
971 let ctx = create_session_context();
972 let factory =
973 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
974 let acc = factory.create_accumulator();
975 assert!(!acc.serialize().is_empty());
976 }
977
978 #[test]
981 fn test_lookup_common_aggregates() {
982 let ctx = create_session_context();
983 for name in &["count", "sum", "min", "max", "avg"] {
984 assert!(
985 lookup_aggregate_udf(&ctx, name).is_some(),
986 "Expected '{name}' to be a recognized aggregate"
987 );
988 }
989 }
990
991 #[test]
992 fn test_lookup_statistical_aggregates() {
993 let ctx = create_session_context();
994 for name in &["stddev", "stddev_pop", "median"] {
995 let _ = lookup_aggregate_udf(&ctx, name);
997 }
998 }
999
1000 #[test]
1001 fn test_lookup_case_insensitive() {
1002 let ctx = create_session_context();
1003 assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
1004 assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
1005 assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
1006 }
1007
1008 #[test]
1011 fn test_adapter_multi_column_covar() {
1012 let ctx = create_session_context();
1013 if let Some(factory) = create_aggregate_factory(
1014 &ctx,
1015 "covar_samp",
1016 vec![0, 1],
1017 vec![DataType::Float64, DataType::Float64],
1018 ) {
1019 let mut acc = factory.create_accumulator();
1020 acc.add_event(&two_col_float_event(
1021 1000,
1022 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1023 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1024 ));
1025 if let ScalarResult::Float64(v) = acc.result_scalar() {
1026 assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
1027 }
1028 }
1029 }
1030
1031 #[test]
1034 fn test_create_aggregate_factory_api() {
1035 let ctx = create_session_context();
1036 let factory =
1037 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1038 let acc = factory.create_accumulator();
1039 assert_eq!(acc.type_tag(), "datafusion_adapter");
1040 }
1041
1042 #[test]
1043 fn test_factory_creates_independent_accumulators() {
1044 let ctx = create_session_context();
1045 let factory =
1046 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1047
1048 let mut acc1 = factory.create_accumulator();
1049 let mut acc2 = factory.create_accumulator();
1050
1051 acc1.add_event(&float_event(1000, vec![10.0]));
1052 acc2.add_event(&float_event(2000, vec![20.0]));
1053
1054 assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
1055 assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
1056 }
1057
1058 #[test]
1059 fn test_adapter_function_name() {
1060 let ctx = create_session_context();
1061 let factory =
1062 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1063 let acc = factory.create_accumulator();
1064 let adapter = acc
1065 .as_any()
1066 .downcast_ref::<DataFusionAccumulatorAdapter>()
1067 .expect("should be adapter");
1068 assert_eq!(adapter.function_name(), "sum");
1069 }
1070
1071 #[test]
1072 fn test_clone_box_does_not_panic() {
1073 let ctx = create_session_context();
1074 let factory =
1075 create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1076 let mut acc = factory.create_accumulator();
1077 acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
1078
1079 let cloned = acc.clone_box();
1081 assert_eq!(cloned.result_scalar(), ScalarResult::Float64(6.0));
1082 }
1083
1084 #[test]
1085 fn test_clone_box_empty_accumulator() {
1086 let ctx = create_session_context();
1087 let factory =
1088 create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1089 let acc = factory.create_accumulator();
1090
1091 let cloned = acc.clone_box();
1093 let result = cloned.result_scalar();
1094 assert!(
1095 matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
1096 "Expected 0, got {result:?}"
1097 );
1098 }
1099
1100 #[test]
1101 fn test_distinct_factory() {
1102 let ctx = create_session_context();
1103 let udf = lookup_aggregate_udf(&ctx, "count").unwrap();
1104 let factory = DataFusionAggregateFactory::new(udf, vec![0], vec![DataType::Int64])
1105 .with_distinct(true);
1106 assert!(factory.is_distinct);
1107 let _acc = factory.create_accumulator();
1109 }
1110}