Skip to main content

laminar_sql/datafusion/
aggregate_bridge.rs

1//! DataFusion Aggregate Bridge
2//!
3//! Bridges DataFusion's 50+ built-in aggregate functions into LaminarDB's
4//! `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
5//! reimplementing statistical functions (STDDEV, VARIANCE, PERCENTILE, etc.)
6//! that DataFusion already provides.
7//!
8//! # Architecture
9//!
10//! ```text
11//! DataFusion World                 LaminarDB World
12//! ┌───────────────────┐           ┌──────────────────────┐
13//! │ AggregateUDF      │           │ DynAggregatorFactory │
14//! │   └─▶ Accumulator │──bridge──▶│   └─▶ DynAccumulator │
15//! │       (ScalarValue)│           │       (ScalarResult) │
16//! └───────────────────┘           └──────────────────────┘
17//! ```
18//!
19//! # Ring Architecture
20//!
21//! This bridge is Ring 1 (allocates, uses dynamic dispatch). Ring 0 workloads
22//! continue to use the hand-written static-dispatch aggregators.
23
24use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37// Type Conversion: ScalarValue <-> ScalarResult
38
39/// Converts a DataFusion [`ScalarValue`] to a LaminarDB [`ScalarResult`].
40///
41/// Handles the common numeric types. Non-numeric types map to [`ScalarResult::Null`].
42#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44    match sv {
45        ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46        ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47        ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48        ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49            ScalarResult::OptionalFloat64(None)
50        }
51        ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52        // Widen smaller int types
53        ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54        ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55        ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56        ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57        ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58        ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59        // Widen smaller float types
60        ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61        _ => ScalarResult::Null,
62    }
63}
64
65/// Converts a [`ScalarResult`] to a DataFusion [`ScalarValue`].
66#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68    match sr {
69        ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70        ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71        ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72        ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73        ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74        ScalarResult::Null => ScalarValue::Null,
75    }
76}
77
78// DataFusion Accumulator Adapter
79
80/// Adapts a DataFusion [`datafusion_expr::Accumulator`] into LaminarDB's
81/// [`DynAccumulator`] trait.
82///
83/// Uses `RefCell` for interior mutability since DataFusion's `evaluate()`
84/// and `state()` require `&mut self` but LaminarDB's `result_scalar()`
85/// and `serialize()` take `&self`.
86pub struct DataFusionAccumulatorAdapter {
87    /// The wrapped DataFusion accumulator (RefCell for interior mutability)
88    inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89    /// Column indices to extract from events
90    column_indices: Vec<usize>,
91    /// Input types (for creating arrays during merge)
92    input_types: Vec<DataType>,
93    /// Function name (for type_tag/debug)
94    function_name: String,
95}
96
97// SAFETY: DataFusion accumulators are Send. RefCell is Send when T is Send.
98// The adapter is only accessed from a single thread (Ring 1 processing).
99unsafe impl Send for DataFusionAccumulatorAdapter {}
100
101impl std::fmt::Debug for DataFusionAccumulatorAdapter {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("DataFusionAccumulatorAdapter")
104            .field("function_name", &self.function_name)
105            .field("column_indices", &self.column_indices)
106            .field("input_types", &self.input_types)
107            .finish_non_exhaustive()
108    }
109}
110
111impl DataFusionAccumulatorAdapter {
112    /// Creates a new adapter wrapping a DataFusion accumulator.
113    #[must_use]
114    pub fn new(
115        inner: Box<dyn datafusion_expr::Accumulator>,
116        column_indices: Vec<usize>,
117        input_types: Vec<DataType>,
118        function_name: String,
119    ) -> Self {
120        Self {
121            inner: RefCell::new(inner),
122            column_indices,
123            input_types,
124            function_name,
125        }
126    }
127
128    /// Returns the wrapped function name.
129    #[must_use]
130    pub fn function_name(&self) -> &str {
131        &self.function_name
132    }
133
134    /// Extracts the relevant columns from a `RecordBatch`.
135    fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
136        self.column_indices
137            .iter()
138            .enumerate()
139            .map(|(arg_idx, &col_idx)| {
140                if col_idx < batch.num_columns() {
141                    Arc::clone(batch.column(col_idx))
142                } else {
143                    let dt = self
144                        .input_types
145                        .get(arg_idx)
146                        .cloned()
147                        .unwrap_or(DataType::Int64);
148                    arrow_array::new_null_array(&dt, batch.num_rows())
149                }
150            })
151            .collect()
152    }
153}
154
155impl DynAccumulator for DataFusionAccumulatorAdapter {
156    fn add_event(&mut self, event: &Event) {
157        let columns = self.extract_columns(&event.data);
158        let _ = self.inner.borrow_mut().update_batch(&columns);
159    }
160
161    fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
162        let other = other
163            .as_any()
164            .downcast_ref::<DataFusionAccumulatorAdapter>()
165            .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
166
167        if let Ok(state_values) = other.inner.borrow_mut().state() {
168            let state_arrays: Vec<ArrayRef> = state_values
169                .iter()
170                .filter_map(|sv| sv.to_array().ok())
171                .collect();
172            if !state_arrays.is_empty() {
173                let _ = self.inner.borrow_mut().merge_batch(&state_arrays);
174            }
175        }
176    }
177
178    fn result_scalar(&self) -> ScalarResult {
179        match self.inner.borrow_mut().evaluate() {
180            Ok(sv) => scalar_value_to_result(&sv),
181            Err(_) => ScalarResult::Null,
182        }
183    }
184
185    fn is_empty(&self) -> bool {
186        self.inner.borrow().size() <= std::mem::size_of::<Self>()
187    }
188
189    fn clone_box(&self) -> Box<dyn DynAccumulator> {
190        panic!(
191            "clone_box not supported for DataFusion adapter '{}'; \
192             use the factory to create new accumulators",
193            self.function_name
194        )
195    }
196
197    #[allow(clippy::cast_possible_truncation)] // Wire format uses fixed-width integers
198    fn serialize(&self) -> Vec<u8> {
199        match self.inner.borrow_mut().state() {
200            Ok(state_values) => {
201                let mut buf = Vec::new();
202                let count = state_values.len() as u32;
203                buf.extend_from_slice(&count.to_le_bytes());
204                for sv in &state_values {
205                    let bytes = sv.to_string();
206                    let len = bytes.len() as u32;
207                    buf.extend_from_slice(&len.to_le_bytes());
208                    buf.extend_from_slice(bytes.as_bytes());
209                }
210                buf
211            }
212            Err(_) => Vec::new(),
213        }
214    }
215
216    fn result_field(&self) -> Field {
217        let result = self.result_scalar();
218        let dt = result.data_type();
219        let dt = if dt == DataType::Null {
220            DataType::Float64
221        } else {
222            dt
223        };
224        Field::new(&self.function_name, dt, true)
225    }
226
227    fn type_tag(&self) -> &'static str {
228        "datafusion_adapter"
229    }
230
231    fn as_any(&self) -> &dyn std::any::Any {
232        self
233    }
234}
235
236// DataFusion Aggregate Factory
237
238/// Factory for creating [`DataFusionAccumulatorAdapter`] instances.
239///
240/// Wraps a DataFusion [`AggregateUDF`] and provides the [`DynAggregatorFactory`]
241/// interface for use with `CompositeAggregator`.
242pub struct DataFusionAggregateFactory {
243    /// The DataFusion aggregate UDF
244    udf: Arc<AggregateUDF>,
245    /// Column indices to extract from events
246    column_indices: Vec<usize>,
247    /// Input types for the aggregate
248    input_types: Vec<DataType>,
249}
250
251impl std::fmt::Debug for DataFusionAggregateFactory {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("DataFusionAggregateFactory")
254            .field("name", &self.udf.name())
255            .field("column_indices", &self.column_indices)
256            .field("input_types", &self.input_types)
257            .finish()
258    }
259}
260
261impl DataFusionAggregateFactory {
262    /// Creates a new factory for the given DataFusion aggregate UDF.
263    #[must_use]
264    pub fn new(
265        udf: Arc<AggregateUDF>,
266        column_indices: Vec<usize>,
267        input_types: Vec<DataType>,
268    ) -> Self {
269        Self {
270            udf,
271            column_indices,
272            input_types,
273        }
274    }
275
276    /// Returns the name of the wrapped aggregate function.
277    #[must_use]
278    pub fn name(&self) -> &str {
279        self.udf.name()
280    }
281
282    /// Pre-defined column names to avoid `format!()` per accumulator creation.
283    const COL_NAMES: [&str; 8] = [
284        "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
285    ];
286
287    /// Returns a cached column name for the given index.
288    fn col_name(i: usize) -> &'static str {
289        Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
290    }
291
292    /// Creates a DataFusion accumulator from the UDF.
293    fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
294        let return_type = self
295            .udf
296            .return_type(&self.input_types)
297            .unwrap_or(DataType::Float64);
298        let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
299        let schema = Schema::new(
300            self.input_types
301                .iter()
302                .enumerate()
303                .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
304                .collect::<Vec<_>>(),
305        );
306        let expr_fields: Vec<FieldRef> = self
307            .input_types
308            .iter()
309            .enumerate()
310            .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
311            .collect();
312        let args = AccumulatorArgs {
313            return_field,
314            schema: &schema,
315            ignore_nulls: false,
316            order_bys: &[],
317            is_reversed: false,
318            name: self.udf.name(),
319            is_distinct: false,
320            exprs: &[],
321            expr_fields: &expr_fields,
322        };
323        self.udf
324            .accumulator(args)
325            .expect("Failed to create DataFusion accumulator")
326    }
327}
328
329impl DynAggregatorFactory for DataFusionAggregateFactory {
330    fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
331        let inner = self.create_df_accumulator();
332        Box::new(DataFusionAccumulatorAdapter::new(
333            inner,
334            self.column_indices.clone(),
335            self.input_types.clone(),
336            self.udf.name().to_string(),
337        ))
338    }
339
340    fn result_field(&self) -> Field {
341        let return_type = self
342            .udf
343            .return_type(&self.input_types)
344            .unwrap_or(DataType::Float64);
345        Field::new(self.udf.name(), return_type, true)
346    }
347
348    fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
349        Box::new(DataFusionAggregateFactory {
350            udf: Arc::clone(&self.udf),
351            column_indices: self.column_indices.clone(),
352            input_types: self.input_types.clone(),
353        })
354    }
355
356    fn type_tag(&self) -> &'static str {
357        "datafusion_factory"
358    }
359}
360
361// Built-in Aggregate Lookup
362
363/// Looks up a DataFusion built-in aggregate function by name.
364///
365/// Returns `None` if the function is not a recognized DataFusion aggregate.
366#[must_use]
367pub fn lookup_aggregate_udf(
368    ctx: &datafusion::prelude::SessionContext,
369    name: &str,
370) -> Option<Arc<AggregateUDF>> {
371    let normalized = name.to_lowercase();
372    ctx.udaf(&normalized).ok()
373}
374
375/// Creates a [`DataFusionAggregateFactory`] for a named built-in aggregate.
376///
377/// Returns `None` if the function name is not recognized.
378#[must_use]
379pub fn create_aggregate_factory(
380    ctx: &datafusion::prelude::SessionContext,
381    name: &str,
382    column_indices: Vec<usize>,
383    input_types: Vec<DataType>,
384) -> Option<DataFusionAggregateFactory> {
385    lookup_aggregate_udf(ctx, name)
386        .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
387}
388
389// Tests
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use arrow_array::{Float64Array, Int64Array, RecordBatch};
395    use datafusion::prelude::SessionContext;
396
397    fn float_event(ts: i64, values: Vec<f64>) -> Event {
398        let schema = Arc::new(Schema::new(vec![Field::new(
399            "value",
400            DataType::Float64,
401            false,
402        )]));
403        let batch =
404            RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
405        Event::new(ts, batch)
406    }
407
408    fn int_event(ts: i64, values: Vec<i64>) -> Event {
409        let schema = Arc::new(Schema::new(vec![Field::new(
410            "value",
411            DataType::Int64,
412            false,
413        )]));
414        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
415        Event::new(ts, batch)
416    }
417
418    fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
419        let schema = Arc::new(Schema::new(vec![
420            Field::new("x", DataType::Float64, false),
421            Field::new("y", DataType::Float64, false),
422        ]));
423        let batch = RecordBatch::try_new(
424            schema,
425            vec![
426                Arc::new(Float64Array::from(col0)),
427                Arc::new(Float64Array::from(col1)),
428            ],
429        )
430        .unwrap();
431        Event::new(ts, batch)
432    }
433
434    // ── ScalarValue Conversion Tests ────────────────────────────────────
435
436    #[test]
437    fn test_scalar_value_to_result_int64() {
438        let sv = ScalarValue::Int64(Some(42));
439        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
440    }
441
442    #[test]
443    fn test_scalar_value_to_result_float64() {
444        let sv = ScalarValue::Float64(Some(3.125));
445        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
446    }
447
448    #[test]
449    fn test_scalar_value_to_result_uint64() {
450        let sv = ScalarValue::UInt64(Some(100));
451        assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
452    }
453
454    #[test]
455    fn test_scalar_value_to_result_null_int64() {
456        let sv = ScalarValue::Int64(None);
457        assert_eq!(
458            scalar_value_to_result(&sv),
459            ScalarResult::OptionalInt64(None)
460        );
461    }
462
463    #[test]
464    fn test_scalar_value_to_result_null_float64() {
465        let sv = ScalarValue::Float64(None);
466        assert_eq!(
467            scalar_value_to_result(&sv),
468            ScalarResult::OptionalFloat64(None)
469        );
470    }
471
472    #[test]
473    fn test_scalar_value_to_result_smaller_ints() {
474        assert_eq!(
475            scalar_value_to_result(&ScalarValue::Int8(Some(8))),
476            ScalarResult::Int64(8)
477        );
478        assert_eq!(
479            scalar_value_to_result(&ScalarValue::Int16(Some(16))),
480            ScalarResult::Int64(16)
481        );
482        assert_eq!(
483            scalar_value_to_result(&ScalarValue::Int32(Some(32))),
484            ScalarResult::Int64(32)
485        );
486        assert_eq!(
487            scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
488            ScalarResult::UInt64(8)
489        );
490    }
491
492    #[test]
493    fn test_scalar_value_to_result_float32() {
494        let sv = ScalarValue::Float32(Some(2.5));
495        assert_eq!(
496            scalar_value_to_result(&sv),
497            ScalarResult::Float64(f64::from(2.5f32))
498        );
499    }
500
501    #[test]
502    fn test_scalar_value_to_result_unsupported() {
503        let sv = ScalarValue::Utf8(Some("hello".to_string()));
504        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
505    }
506
507    #[test]
508    fn test_result_to_scalar_value_roundtrip() {
509        // Exact roundtrip for non-optional variants
510        let exact_cases = vec![
511            ScalarResult::Int64(42),
512            ScalarResult::Float64(3.125),
513            ScalarResult::UInt64(100),
514        ];
515        for sr in &exact_cases {
516            let sv = result_to_scalar_value(sr);
517            let back = scalar_value_to_result(&sv);
518            assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
519        }
520
521        // Optional(Some(v)) normalizes to non-optional through ScalarValue
522        // because ScalarValue::Int64(Some(7)) maps back to ScalarResult::Int64(7)
523        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
524        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
525
526        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
527        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
528
529        // Optional None roundtrips back to OptionalNone (ScalarValue preserves type)
530        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
531        assert_eq!(
532            scalar_value_to_result(&sv),
533            ScalarResult::OptionalInt64(None)
534        );
535
536        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
537        assert_eq!(
538            scalar_value_to_result(&sv),
539            ScalarResult::OptionalFloat64(None)
540        );
541
542        // Null roundtrips correctly
543        let sv = result_to_scalar_value(&ScalarResult::Null);
544        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
545    }
546
547    // ── Factory Tests ───────────────────────────────────────────────────
548
549    #[test]
550    fn test_factory_count() {
551        let ctx = SessionContext::new();
552        let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
553        assert!(factory.is_some(), "count should be a recognized aggregate");
554        assert_eq!(factory.unwrap().name(), "count");
555    }
556
557    #[test]
558    fn test_factory_sum() {
559        let ctx = SessionContext::new();
560        let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
561        assert!(factory.is_some());
562        assert_eq!(factory.unwrap().name(), "sum");
563    }
564
565    #[test]
566    fn test_factory_avg() {
567        let ctx = SessionContext::new();
568        let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
569        assert!(factory.is_some());
570    }
571
572    #[test]
573    fn test_factory_stddev() {
574        let ctx = SessionContext::new();
575        let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
576        assert!(
577            factory.is_some(),
578            "stddev should be available in DataFusion"
579        );
580    }
581
582    #[test]
583    fn test_factory_unknown() {
584        let ctx = SessionContext::new();
585        let factory = create_aggregate_factory(
586            &ctx,
587            "nonexistent_aggregate_xyz",
588            vec![0],
589            vec![DataType::Int64],
590        );
591        assert!(factory.is_none());
592    }
593
594    #[test]
595    fn test_factory_result_field() {
596        let ctx = SessionContext::new();
597        let factory =
598            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
599        let field = factory.result_field();
600        assert_eq!(field.name(), "sum");
601        assert_eq!(field.data_type(), &DataType::Float64);
602    }
603
604    #[test]
605    fn test_factory_clone_box() {
606        let ctx = SessionContext::new();
607        let factory =
608            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
609        let cloned = factory.clone_box();
610        assert_eq!(cloned.type_tag(), "datafusion_factory");
611    }
612
613    // ── Adapter Basics ──────────────────────────────────────────────────
614
615    #[test]
616    fn test_adapter_count_basic() {
617        let ctx = SessionContext::new();
618        let factory =
619            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
620        let mut acc = factory.create_accumulator();
621
622        let result = acc.result_scalar();
623        assert!(
624            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
625            "Expected 0, got {result:?}"
626        );
627
628        acc.add_event(&int_event(1000, vec![10, 20, 30]));
629        let result = acc.result_scalar();
630        assert!(
631            matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
632            "Expected 3, got {result:?}"
633        );
634
635        acc.add_event(&int_event(2000, vec![40, 50]));
636        let result = acc.result_scalar();
637        assert!(
638            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
639            "Expected 5, got {result:?}"
640        );
641    }
642
643    #[test]
644    fn test_adapter_sum_float64() {
645        let ctx = SessionContext::new();
646        let factory =
647            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
648        let mut acc = factory.create_accumulator();
649
650        acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
651        assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
652    }
653
654    #[test]
655    fn test_adapter_avg_float64() {
656        let ctx = SessionContext::new();
657        let factory =
658            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
659        let mut acc = factory.create_accumulator();
660
661        acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
662        assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
663    }
664
665    #[test]
666    fn test_adapter_min_float64() {
667        let ctx = SessionContext::new();
668        let factory =
669            create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
670        let mut acc = factory.create_accumulator();
671
672        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
673        assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
674    }
675
676    #[test]
677    fn test_adapter_max_float64() {
678        let ctx = SessionContext::new();
679        let factory =
680            create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
681        let mut acc = factory.create_accumulator();
682
683        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
684        assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
685    }
686
687    #[test]
688    fn test_adapter_sum_int64() {
689        let ctx = SessionContext::new();
690        let factory =
691            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
692        let mut acc = factory.create_accumulator();
693
694        acc.add_event(&int_event(1000, vec![10, 20, 30]));
695        assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
696    }
697
698    #[test]
699    fn test_adapter_type_tag() {
700        let ctx = SessionContext::new();
701        let factory =
702            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
703        let acc = factory.create_accumulator();
704        assert_eq!(acc.type_tag(), "datafusion_adapter");
705    }
706
707    #[test]
708    fn test_adapter_result_field() {
709        let ctx = SessionContext::new();
710        let factory =
711            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
712        let mut acc = factory.create_accumulator();
713        acc.add_event(&float_event(1000, vec![1.0]));
714        assert_eq!(acc.result_field().name(), "sum");
715    }
716
717    // ── Merge Tests ─────────────────────────────────────────────────────
718
719    #[test]
720    fn test_adapter_merge_sum() {
721        let ctx = SessionContext::new();
722        let factory =
723            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
724
725        let mut acc1 = factory.create_accumulator();
726        acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
727
728        let mut acc2 = factory.create_accumulator();
729        acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
730
731        acc1.merge_dyn(acc2.as_ref());
732        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
733    }
734
735    #[test]
736    fn test_adapter_merge_count() {
737        let ctx = SessionContext::new();
738        let factory =
739            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
740
741        let mut acc1 = factory.create_accumulator();
742        acc1.add_event(&int_event(1000, vec![1, 2, 3]));
743
744        let mut acc2 = factory.create_accumulator();
745        acc2.add_event(&int_event(2000, vec![4, 5]));
746
747        acc1.merge_dyn(acc2.as_ref());
748        let result = acc1.result_scalar();
749        assert!(
750            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
751            "Expected 5 after merge, got {result:?}"
752        );
753    }
754
755    #[test]
756    fn test_adapter_merge_avg() {
757        let ctx = SessionContext::new();
758        let factory =
759            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
760
761        let mut acc1 = factory.create_accumulator();
762        acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
763
764        let mut acc2 = factory.create_accumulator();
765        acc2.add_event(&float_event(2000, vec![30.0]));
766
767        acc1.merge_dyn(acc2.as_ref());
768        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
769    }
770
771    #[test]
772    fn test_adapter_merge_empty() {
773        let ctx = SessionContext::new();
774        let factory =
775            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
776
777        let mut acc1 = factory.create_accumulator();
778        acc1.add_event(&float_event(1000, vec![5.0]));
779
780        let acc2 = factory.create_accumulator();
781        acc1.merge_dyn(acc2.as_ref());
782        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
783    }
784
785    // ── Built-in Aggregate Pass-Through Tests ───────────────────────────
786
787    #[test]
788    fn test_adapter_stddev() {
789        let ctx = SessionContext::new();
790        let factory =
791            create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
792        let mut acc = factory.create_accumulator();
793
794        acc.add_event(&float_event(
795            1000,
796            vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
797        ));
798        let result = acc.result_scalar();
799        if let ScalarResult::Float64(v) = result {
800            assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
801        } else {
802            panic!("Expected Float64 result, got {result:?}");
803        }
804    }
805
806    #[test]
807    fn test_adapter_variance() {
808        let ctx = SessionContext::new();
809        if let Some(factory) =
810            create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
811        {
812            let mut acc = factory.create_accumulator();
813            acc.add_event(&float_event(
814                1000,
815                vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
816            ));
817            if let ScalarResult::Float64(v) = acc.result_scalar() {
818                assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
819            }
820        }
821    }
822
823    #[test]
824    fn test_adapter_median() {
825        let ctx = SessionContext::new();
826        if let Some(factory) =
827            create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
828        {
829            let mut acc = factory.create_accumulator();
830            acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
831            assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
832        }
833    }
834
835    // ── Serialize Tests ─────────────────────────────────────────────────
836
837    #[test]
838    fn test_adapter_serialize() {
839        let ctx = SessionContext::new();
840        let factory =
841            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
842        let mut acc = factory.create_accumulator();
843        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
844        assert!(!acc.serialize().is_empty());
845    }
846
847    #[test]
848    fn test_adapter_serialize_empty() {
849        let ctx = SessionContext::new();
850        let factory =
851            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
852        let acc = factory.create_accumulator();
853        assert!(!acc.serialize().is_empty());
854    }
855
856    // ── Lookup Tests ────────────────────────────────────────────────────
857
858    #[test]
859    fn test_lookup_common_aggregates() {
860        let ctx = SessionContext::new();
861        for name in &["count", "sum", "min", "max", "avg"] {
862            assert!(
863                lookup_aggregate_udf(&ctx, name).is_some(),
864                "Expected '{name}' to be a recognized aggregate"
865            );
866        }
867    }
868
869    #[test]
870    fn test_lookup_statistical_aggregates() {
871        let ctx = SessionContext::new();
872        for name in &["stddev", "stddev_pop", "median"] {
873            // Just verify lookup doesn't panic
874            let _ = lookup_aggregate_udf(&ctx, name);
875        }
876    }
877
878    #[test]
879    fn test_lookup_case_insensitive() {
880        let ctx = SessionContext::new();
881        assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
882        assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
883        assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
884    }
885
886    // ── Multi-column Tests ──────────────────────────────────────────────
887
888    #[test]
889    fn test_adapter_multi_column_covar() {
890        let ctx = SessionContext::new();
891        if let Some(factory) = create_aggregate_factory(
892            &ctx,
893            "covar_samp",
894            vec![0, 1],
895            vec![DataType::Float64, DataType::Float64],
896        ) {
897            let mut acc = factory.create_accumulator();
898            acc.add_event(&two_col_float_event(
899                1000,
900                vec![1.0, 2.0, 3.0, 4.0, 5.0],
901                vec![1.0, 2.0, 3.0, 4.0, 5.0],
902            ));
903            if let ScalarResult::Float64(v) = acc.result_scalar() {
904                assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
905            }
906        }
907    }
908
909    // ── Registration Tests ──────────────────────────────────────────────
910
911    #[test]
912    fn test_create_aggregate_factory_api() {
913        let ctx = SessionContext::new();
914        let factory =
915            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
916        let acc = factory.create_accumulator();
917        assert_eq!(acc.type_tag(), "datafusion_adapter");
918    }
919
920    #[test]
921    fn test_factory_creates_independent_accumulators() {
922        let ctx = SessionContext::new();
923        let factory =
924            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
925
926        let mut acc1 = factory.create_accumulator();
927        let mut acc2 = factory.create_accumulator();
928
929        acc1.add_event(&float_event(1000, vec![10.0]));
930        acc2.add_event(&float_event(2000, vec![20.0]));
931
932        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
933        assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
934    }
935
936    #[test]
937    fn test_adapter_function_name() {
938        let ctx = SessionContext::new();
939        let factory =
940            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
941        let acc = factory.create_accumulator();
942        let adapter = acc
943            .as_any()
944            .downcast_ref::<DataFusionAccumulatorAdapter>()
945            .expect("should be adapter");
946        assert_eq!(adapter.function_name(), "sum");
947    }
948}