Skip to main content

laminar_sql/datafusion/
aggregate_bridge.rs

1//! F075: DataFusion Aggregate Bridge
2//!
3//! Bridges DataFusion's 50+ built-in aggregate functions into LaminarDB's
4//! `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
5//! reimplementing statistical functions (STDDEV, VARIANCE, PERCENTILE, etc.)
6//! that DataFusion already provides.
7//!
8//! # Architecture
9//!
10//! ```text
11//! DataFusion World                 LaminarDB World
12//! ┌───────────────────┐           ┌──────────────────────┐
13//! │ AggregateUDF      │           │ DynAggregatorFactory │
14//! │   └─▶ Accumulator │──bridge──▶│   └─▶ DynAccumulator │
15//! │       (ScalarValue)│           │       (ScalarResult) │
16//! └───────────────────┘           └──────────────────────┘
17//! ```
18//!
19//! # Ring Architecture
20//!
21//! This bridge is Ring 1 (allocates, uses dynamic dispatch). Ring 0 workloads
22//! continue to use the hand-written static-dispatch aggregators.
23
24use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37// Type Conversion: ScalarValue <-> ScalarResult
38
39/// Converts a DataFusion [`ScalarValue`] to a LaminarDB [`ScalarResult`].
40///
41/// Handles the common numeric types. Non-numeric types map to [`ScalarResult::Null`].
42#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44    match sv {
45        ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46        ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47        ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48        ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49            ScalarResult::OptionalFloat64(None)
50        }
51        ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52        // Widen smaller int types
53        ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54        ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55        ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56        ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57        ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58        ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59        // Widen smaller float types
60        ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61        _ => ScalarResult::Null,
62    }
63}
64
65/// Converts a [`ScalarResult`] to a DataFusion [`ScalarValue`].
66#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68    match sr {
69        ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70        ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71        ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72        ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73        ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74        ScalarResult::Null => ScalarValue::Null,
75    }
76}
77
78// DataFusion Accumulator Adapter
79
80/// Adapts a DataFusion [`datafusion_expr::Accumulator`] into LaminarDB's
81/// [`DynAccumulator`] trait.
82///
83/// Uses `RefCell` for interior mutability since DataFusion's `evaluate()`
84/// and `state()` require `&mut self` but LaminarDB's `result_scalar()`
85/// and `serialize()` take `&self`.
86pub struct DataFusionAccumulatorAdapter {
87    /// The wrapped DataFusion accumulator (RefCell for interior mutability)
88    inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89    /// Column indices to extract from events
90    column_indices: Vec<usize>,
91    /// Input types (for creating arrays during merge)
92    input_types: Vec<DataType>,
93    /// Function name (for type_tag/debug)
94    function_name: String,
95}
96
97// SAFETY: DataFusion accumulators are Send. RefCell is Send when T is Send.
98// The adapter is only accessed from a single thread (Ring 1 processing).
99unsafe impl Send for DataFusionAccumulatorAdapter {}
100
101impl std::fmt::Debug for DataFusionAccumulatorAdapter {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("DataFusionAccumulatorAdapter")
104            .field("function_name", &self.function_name)
105            .field("column_indices", &self.column_indices)
106            .field("input_types", &self.input_types)
107            .finish_non_exhaustive()
108    }
109}
110
111impl DataFusionAccumulatorAdapter {
112    /// Creates a new adapter wrapping a DataFusion accumulator.
113    #[must_use]
114    pub fn new(
115        inner: Box<dyn datafusion_expr::Accumulator>,
116        column_indices: Vec<usize>,
117        input_types: Vec<DataType>,
118        function_name: String,
119    ) -> Self {
120        Self {
121            inner: RefCell::new(inner),
122            column_indices,
123            input_types,
124            function_name,
125        }
126    }
127
128    /// Returns the wrapped function name.
129    #[must_use]
130    pub fn function_name(&self) -> &str {
131        &self.function_name
132    }
133
134    /// Extracts the relevant columns from a `RecordBatch`.
135    fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
136        self.column_indices
137            .iter()
138            .enumerate()
139            .map(|(arg_idx, &col_idx)| {
140                if col_idx < batch.num_columns() {
141                    Arc::clone(batch.column(col_idx))
142                } else {
143                    let dt = self
144                        .input_types
145                        .get(arg_idx)
146                        .cloned()
147                        .unwrap_or(DataType::Int64);
148                    arrow_array::new_null_array(&dt, batch.num_rows())
149                }
150            })
151            .collect()
152    }
153}
154
155impl DynAccumulator for DataFusionAccumulatorAdapter {
156    fn add_event(&mut self, event: &Event) {
157        let columns = self.extract_columns(&event.data);
158        let _ = self.inner.borrow_mut().update_batch(&columns);
159    }
160
161    fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
162        let other = other
163            .as_any()
164            .downcast_ref::<DataFusionAccumulatorAdapter>()
165            .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
166
167        if let Ok(state_values) = other.inner.borrow_mut().state() {
168            let state_arrays: Vec<ArrayRef> = state_values
169                .iter()
170                .filter_map(|sv| sv.to_array().ok())
171                .collect();
172            if !state_arrays.is_empty() {
173                let _ = self.inner.borrow_mut().merge_batch(&state_arrays);
174            }
175        }
176    }
177
178    fn result_scalar(&self) -> ScalarResult {
179        match self.inner.borrow_mut().evaluate() {
180            Ok(sv) => scalar_value_to_result(&sv),
181            Err(_) => ScalarResult::Null,
182        }
183    }
184
185    fn is_empty(&self) -> bool {
186        self.inner.borrow().size() <= std::mem::size_of::<Self>()
187    }
188
189    fn clone_box(&self) -> Box<dyn DynAccumulator> {
190        panic!(
191            "clone_box not supported for DataFusion adapter '{}'; \
192             use the factory to create new accumulators",
193            self.function_name
194        )
195    }
196
197    #[allow(clippy::cast_possible_truncation)] // Wire format uses fixed-width integers
198    fn serialize(&self) -> Vec<u8> {
199        match self.inner.borrow_mut().state() {
200            Ok(state_values) => {
201                let mut buf = Vec::new();
202                let count = state_values.len() as u32;
203                buf.extend_from_slice(&count.to_le_bytes());
204                for sv in &state_values {
205                    let bytes = sv.to_string();
206                    let len = bytes.len() as u32;
207                    buf.extend_from_slice(&len.to_le_bytes());
208                    buf.extend_from_slice(bytes.as_bytes());
209                }
210                buf
211            }
212            Err(_) => Vec::new(),
213        }
214    }
215
216    fn result_field(&self) -> Field {
217        let result = self.result_scalar();
218        let dt = result.data_type();
219        let dt = if dt == DataType::Null {
220            DataType::Float64
221        } else {
222            dt
223        };
224        Field::new(&self.function_name, dt, true)
225    }
226
227    fn type_tag(&self) -> &'static str {
228        "datafusion_adapter"
229    }
230
231    fn as_any(&self) -> &dyn std::any::Any {
232        self
233    }
234}
235
236// DataFusion Aggregate Factory
237
238/// Factory for creating [`DataFusionAccumulatorAdapter`] instances.
239///
240/// Wraps a DataFusion [`AggregateUDF`] and provides the [`DynAggregatorFactory`]
241/// interface for use with `CompositeAggregator`.
242pub struct DataFusionAggregateFactory {
243    /// The DataFusion aggregate UDF
244    udf: Arc<AggregateUDF>,
245    /// Column indices to extract from events
246    column_indices: Vec<usize>,
247    /// Input types for the aggregate
248    input_types: Vec<DataType>,
249}
250
251impl std::fmt::Debug for DataFusionAggregateFactory {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("DataFusionAggregateFactory")
254            .field("name", &self.udf.name())
255            .field("column_indices", &self.column_indices)
256            .field("input_types", &self.input_types)
257            .finish()
258    }
259}
260
261impl DataFusionAggregateFactory {
262    /// Creates a new factory for the given DataFusion aggregate UDF.
263    #[must_use]
264    pub fn new(
265        udf: Arc<AggregateUDF>,
266        column_indices: Vec<usize>,
267        input_types: Vec<DataType>,
268    ) -> Self {
269        Self {
270            udf,
271            column_indices,
272            input_types,
273        }
274    }
275
276    /// Returns the name of the wrapped aggregate function.
277    #[must_use]
278    pub fn name(&self) -> &str {
279        self.udf.name()
280    }
281
282    /// Creates a DataFusion accumulator from the UDF.
283    fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
284        let return_type = self
285            .udf
286            .return_type(&self.input_types)
287            .unwrap_or(DataType::Float64);
288        let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
289        let schema = Schema::new(
290            self.input_types
291                .iter()
292                .enumerate()
293                .map(|(i, dt)| Field::new(format!("col_{i}"), dt.clone(), true))
294                .collect::<Vec<_>>(),
295        );
296        let expr_fields: Vec<FieldRef> = self
297            .input_types
298            .iter()
299            .enumerate()
300            .map(|(i, dt)| Arc::new(Field::new(format!("col_{i}"), dt.clone(), true)) as FieldRef)
301            .collect();
302        let args = AccumulatorArgs {
303            return_field,
304            schema: &schema,
305            ignore_nulls: false,
306            order_bys: &[],
307            is_reversed: false,
308            name: self.udf.name(),
309            is_distinct: false,
310            exprs: &[],
311            expr_fields: &expr_fields,
312        };
313        self.udf
314            .accumulator(args)
315            .expect("Failed to create DataFusion accumulator")
316    }
317}
318
319impl DynAggregatorFactory for DataFusionAggregateFactory {
320    fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
321        let inner = self.create_df_accumulator();
322        Box::new(DataFusionAccumulatorAdapter::new(
323            inner,
324            self.column_indices.clone(),
325            self.input_types.clone(),
326            self.udf.name().to_string(),
327        ))
328    }
329
330    fn result_field(&self) -> Field {
331        let return_type = self
332            .udf
333            .return_type(&self.input_types)
334            .unwrap_or(DataType::Float64);
335        Field::new(self.udf.name(), return_type, true)
336    }
337
338    fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
339        Box::new(DataFusionAggregateFactory {
340            udf: Arc::clone(&self.udf),
341            column_indices: self.column_indices.clone(),
342            input_types: self.input_types.clone(),
343        })
344    }
345
346    fn type_tag(&self) -> &'static str {
347        "datafusion_factory"
348    }
349}
350
351// Built-in Aggregate Lookup
352
353/// Looks up a DataFusion built-in aggregate function by name.
354///
355/// Returns `None` if the function is not a recognized DataFusion aggregate.
356#[must_use]
357pub fn lookup_aggregate_udf(
358    ctx: &datafusion::prelude::SessionContext,
359    name: &str,
360) -> Option<Arc<AggregateUDF>> {
361    let normalized = name.to_lowercase();
362    ctx.udaf(&normalized).ok()
363}
364
365/// Creates a [`DataFusionAggregateFactory`] for a named built-in aggregate.
366///
367/// Returns `None` if the function name is not recognized.
368#[must_use]
369pub fn create_aggregate_factory(
370    ctx: &datafusion::prelude::SessionContext,
371    name: &str,
372    column_indices: Vec<usize>,
373    input_types: Vec<DataType>,
374) -> Option<DataFusionAggregateFactory> {
375    lookup_aggregate_udf(ctx, name)
376        .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
377}
378
379// Tests
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use arrow_array::{Float64Array, Int64Array, RecordBatch};
385    use datafusion::prelude::SessionContext;
386
387    fn float_event(ts: i64, values: Vec<f64>) -> Event {
388        let schema = Arc::new(Schema::new(vec![Field::new(
389            "value",
390            DataType::Float64,
391            false,
392        )]));
393        let batch =
394            RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
395        Event::new(ts, batch)
396    }
397
398    fn int_event(ts: i64, values: Vec<i64>) -> Event {
399        let schema = Arc::new(Schema::new(vec![Field::new(
400            "value",
401            DataType::Int64,
402            false,
403        )]));
404        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
405        Event::new(ts, batch)
406    }
407
408    fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
409        let schema = Arc::new(Schema::new(vec![
410            Field::new("x", DataType::Float64, false),
411            Field::new("y", DataType::Float64, false),
412        ]));
413        let batch = RecordBatch::try_new(
414            schema,
415            vec![
416                Arc::new(Float64Array::from(col0)),
417                Arc::new(Float64Array::from(col1)),
418            ],
419        )
420        .unwrap();
421        Event::new(ts, batch)
422    }
423
424    // ── ScalarValue Conversion Tests ────────────────────────────────────
425
426    #[test]
427    fn test_scalar_value_to_result_int64() {
428        let sv = ScalarValue::Int64(Some(42));
429        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
430    }
431
432    #[test]
433    fn test_scalar_value_to_result_float64() {
434        let sv = ScalarValue::Float64(Some(3.125));
435        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
436    }
437
438    #[test]
439    fn test_scalar_value_to_result_uint64() {
440        let sv = ScalarValue::UInt64(Some(100));
441        assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
442    }
443
444    #[test]
445    fn test_scalar_value_to_result_null_int64() {
446        let sv = ScalarValue::Int64(None);
447        assert_eq!(
448            scalar_value_to_result(&sv),
449            ScalarResult::OptionalInt64(None)
450        );
451    }
452
453    #[test]
454    fn test_scalar_value_to_result_null_float64() {
455        let sv = ScalarValue::Float64(None);
456        assert_eq!(
457            scalar_value_to_result(&sv),
458            ScalarResult::OptionalFloat64(None)
459        );
460    }
461
462    #[test]
463    fn test_scalar_value_to_result_smaller_ints() {
464        assert_eq!(
465            scalar_value_to_result(&ScalarValue::Int8(Some(8))),
466            ScalarResult::Int64(8)
467        );
468        assert_eq!(
469            scalar_value_to_result(&ScalarValue::Int16(Some(16))),
470            ScalarResult::Int64(16)
471        );
472        assert_eq!(
473            scalar_value_to_result(&ScalarValue::Int32(Some(32))),
474            ScalarResult::Int64(32)
475        );
476        assert_eq!(
477            scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
478            ScalarResult::UInt64(8)
479        );
480    }
481
482    #[test]
483    fn test_scalar_value_to_result_float32() {
484        let sv = ScalarValue::Float32(Some(2.5));
485        assert_eq!(
486            scalar_value_to_result(&sv),
487            ScalarResult::Float64(f64::from(2.5f32))
488        );
489    }
490
491    #[test]
492    fn test_scalar_value_to_result_unsupported() {
493        let sv = ScalarValue::Utf8(Some("hello".to_string()));
494        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
495    }
496
497    #[test]
498    fn test_result_to_scalar_value_roundtrip() {
499        // Exact roundtrip for non-optional variants
500        let exact_cases = vec![
501            ScalarResult::Int64(42),
502            ScalarResult::Float64(3.125),
503            ScalarResult::UInt64(100),
504        ];
505        for sr in &exact_cases {
506            let sv = result_to_scalar_value(sr);
507            let back = scalar_value_to_result(&sv);
508            assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
509        }
510
511        // Optional(Some(v)) normalizes to non-optional through ScalarValue
512        // because ScalarValue::Int64(Some(7)) maps back to ScalarResult::Int64(7)
513        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
514        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
515
516        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
517        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
518
519        // Optional None roundtrips back to OptionalNone (ScalarValue preserves type)
520        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
521        assert_eq!(
522            scalar_value_to_result(&sv),
523            ScalarResult::OptionalInt64(None)
524        );
525
526        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
527        assert_eq!(
528            scalar_value_to_result(&sv),
529            ScalarResult::OptionalFloat64(None)
530        );
531
532        // Null roundtrips correctly
533        let sv = result_to_scalar_value(&ScalarResult::Null);
534        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
535    }
536
537    // ── Factory Tests ───────────────────────────────────────────────────
538
539    #[test]
540    fn test_factory_count() {
541        let ctx = SessionContext::new();
542        let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
543        assert!(factory.is_some(), "count should be a recognized aggregate");
544        assert_eq!(factory.unwrap().name(), "count");
545    }
546
547    #[test]
548    fn test_factory_sum() {
549        let ctx = SessionContext::new();
550        let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
551        assert!(factory.is_some());
552        assert_eq!(factory.unwrap().name(), "sum");
553    }
554
555    #[test]
556    fn test_factory_avg() {
557        let ctx = SessionContext::new();
558        let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
559        assert!(factory.is_some());
560    }
561
562    #[test]
563    fn test_factory_stddev() {
564        let ctx = SessionContext::new();
565        let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
566        assert!(
567            factory.is_some(),
568            "stddev should be available in DataFusion"
569        );
570    }
571
572    #[test]
573    fn test_factory_unknown() {
574        let ctx = SessionContext::new();
575        let factory = create_aggregate_factory(
576            &ctx,
577            "nonexistent_aggregate_xyz",
578            vec![0],
579            vec![DataType::Int64],
580        );
581        assert!(factory.is_none());
582    }
583
584    #[test]
585    fn test_factory_result_field() {
586        let ctx = SessionContext::new();
587        let factory =
588            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
589        let field = factory.result_field();
590        assert_eq!(field.name(), "sum");
591        assert_eq!(field.data_type(), &DataType::Float64);
592    }
593
594    #[test]
595    fn test_factory_clone_box() {
596        let ctx = SessionContext::new();
597        let factory =
598            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
599        let cloned = factory.clone_box();
600        assert_eq!(cloned.type_tag(), "datafusion_factory");
601    }
602
603    // ── Adapter Basics ──────────────────────────────────────────────────
604
605    #[test]
606    fn test_adapter_count_basic() {
607        let ctx = SessionContext::new();
608        let factory =
609            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
610        let mut acc = factory.create_accumulator();
611
612        let result = acc.result_scalar();
613        assert!(
614            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
615            "Expected 0, got {result:?}"
616        );
617
618        acc.add_event(&int_event(1000, vec![10, 20, 30]));
619        let result = acc.result_scalar();
620        assert!(
621            matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
622            "Expected 3, got {result:?}"
623        );
624
625        acc.add_event(&int_event(2000, vec![40, 50]));
626        let result = acc.result_scalar();
627        assert!(
628            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
629            "Expected 5, got {result:?}"
630        );
631    }
632
633    #[test]
634    fn test_adapter_sum_float64() {
635        let ctx = SessionContext::new();
636        let factory =
637            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
638        let mut acc = factory.create_accumulator();
639
640        acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
641        assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
642    }
643
644    #[test]
645    fn test_adapter_avg_float64() {
646        let ctx = SessionContext::new();
647        let factory =
648            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
649        let mut acc = factory.create_accumulator();
650
651        acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
652        assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
653    }
654
655    #[test]
656    fn test_adapter_min_float64() {
657        let ctx = SessionContext::new();
658        let factory =
659            create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
660        let mut acc = factory.create_accumulator();
661
662        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
663        assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
664    }
665
666    #[test]
667    fn test_adapter_max_float64() {
668        let ctx = SessionContext::new();
669        let factory =
670            create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
671        let mut acc = factory.create_accumulator();
672
673        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
674        assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
675    }
676
677    #[test]
678    fn test_adapter_sum_int64() {
679        let ctx = SessionContext::new();
680        let factory =
681            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
682        let mut acc = factory.create_accumulator();
683
684        acc.add_event(&int_event(1000, vec![10, 20, 30]));
685        assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
686    }
687
688    #[test]
689    fn test_adapter_type_tag() {
690        let ctx = SessionContext::new();
691        let factory =
692            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
693        let acc = factory.create_accumulator();
694        assert_eq!(acc.type_tag(), "datafusion_adapter");
695    }
696
697    #[test]
698    fn test_adapter_result_field() {
699        let ctx = SessionContext::new();
700        let factory =
701            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
702        let mut acc = factory.create_accumulator();
703        acc.add_event(&float_event(1000, vec![1.0]));
704        assert_eq!(acc.result_field().name(), "sum");
705    }
706
707    // ── Merge Tests ─────────────────────────────────────────────────────
708
709    #[test]
710    fn test_adapter_merge_sum() {
711        let ctx = SessionContext::new();
712        let factory =
713            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
714
715        let mut acc1 = factory.create_accumulator();
716        acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
717
718        let mut acc2 = factory.create_accumulator();
719        acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
720
721        acc1.merge_dyn(acc2.as_ref());
722        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
723    }
724
725    #[test]
726    fn test_adapter_merge_count() {
727        let ctx = SessionContext::new();
728        let factory =
729            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
730
731        let mut acc1 = factory.create_accumulator();
732        acc1.add_event(&int_event(1000, vec![1, 2, 3]));
733
734        let mut acc2 = factory.create_accumulator();
735        acc2.add_event(&int_event(2000, vec![4, 5]));
736
737        acc1.merge_dyn(acc2.as_ref());
738        let result = acc1.result_scalar();
739        assert!(
740            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
741            "Expected 5 after merge, got {result:?}"
742        );
743    }
744
745    #[test]
746    fn test_adapter_merge_avg() {
747        let ctx = SessionContext::new();
748        let factory =
749            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
750
751        let mut acc1 = factory.create_accumulator();
752        acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
753
754        let mut acc2 = factory.create_accumulator();
755        acc2.add_event(&float_event(2000, vec![30.0]));
756
757        acc1.merge_dyn(acc2.as_ref());
758        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
759    }
760
761    #[test]
762    fn test_adapter_merge_empty() {
763        let ctx = SessionContext::new();
764        let factory =
765            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
766
767        let mut acc1 = factory.create_accumulator();
768        acc1.add_event(&float_event(1000, vec![5.0]));
769
770        let acc2 = factory.create_accumulator();
771        acc1.merge_dyn(acc2.as_ref());
772        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
773    }
774
775    // ── Built-in Aggregate Pass-Through Tests ───────────────────────────
776
777    #[test]
778    fn test_adapter_stddev() {
779        let ctx = SessionContext::new();
780        let factory =
781            create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
782        let mut acc = factory.create_accumulator();
783
784        acc.add_event(&float_event(
785            1000,
786            vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
787        ));
788        let result = acc.result_scalar();
789        if let ScalarResult::Float64(v) = result {
790            assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
791        } else {
792            panic!("Expected Float64 result, got {result:?}");
793        }
794    }
795
796    #[test]
797    fn test_adapter_variance() {
798        let ctx = SessionContext::new();
799        if let Some(factory) =
800            create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
801        {
802            let mut acc = factory.create_accumulator();
803            acc.add_event(&float_event(
804                1000,
805                vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
806            ));
807            if let ScalarResult::Float64(v) = acc.result_scalar() {
808                assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
809            }
810        }
811    }
812
813    #[test]
814    fn test_adapter_median() {
815        let ctx = SessionContext::new();
816        if let Some(factory) =
817            create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
818        {
819            let mut acc = factory.create_accumulator();
820            acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
821            assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
822        }
823    }
824
825    // ── Serialize Tests ─────────────────────────────────────────────────
826
827    #[test]
828    fn test_adapter_serialize() {
829        let ctx = SessionContext::new();
830        let factory =
831            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
832        let mut acc = factory.create_accumulator();
833        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
834        assert!(!acc.serialize().is_empty());
835    }
836
837    #[test]
838    fn test_adapter_serialize_empty() {
839        let ctx = SessionContext::new();
840        let factory =
841            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
842        let acc = factory.create_accumulator();
843        assert!(!acc.serialize().is_empty());
844    }
845
846    // ── Lookup Tests ────────────────────────────────────────────────────
847
848    #[test]
849    fn test_lookup_common_aggregates() {
850        let ctx = SessionContext::new();
851        for name in &["count", "sum", "min", "max", "avg"] {
852            assert!(
853                lookup_aggregate_udf(&ctx, name).is_some(),
854                "Expected '{name}' to be a recognized aggregate"
855            );
856        }
857    }
858
859    #[test]
860    fn test_lookup_statistical_aggregates() {
861        let ctx = SessionContext::new();
862        for name in &["stddev", "stddev_pop", "median"] {
863            // Just verify lookup doesn't panic
864            let _ = lookup_aggregate_udf(&ctx, name);
865        }
866    }
867
868    #[test]
869    fn test_lookup_case_insensitive() {
870        let ctx = SessionContext::new();
871        assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
872        assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
873        assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
874    }
875
876    // ── Multi-column Tests ──────────────────────────────────────────────
877
878    #[test]
879    fn test_adapter_multi_column_covar() {
880        let ctx = SessionContext::new();
881        if let Some(factory) = create_aggregate_factory(
882            &ctx,
883            "covar_samp",
884            vec![0, 1],
885            vec![DataType::Float64, DataType::Float64],
886        ) {
887            let mut acc = factory.create_accumulator();
888            acc.add_event(&two_col_float_event(
889                1000,
890                vec![1.0, 2.0, 3.0, 4.0, 5.0],
891                vec![1.0, 2.0, 3.0, 4.0, 5.0],
892            ));
893            if let ScalarResult::Float64(v) = acc.result_scalar() {
894                assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
895            }
896        }
897    }
898
899    // ── Registration Tests ──────────────────────────────────────────────
900
901    #[test]
902    fn test_create_aggregate_factory_api() {
903        let ctx = SessionContext::new();
904        let factory =
905            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
906        let acc = factory.create_accumulator();
907        assert_eq!(acc.type_tag(), "datafusion_adapter");
908    }
909
910    #[test]
911    fn test_factory_creates_independent_accumulators() {
912        let ctx = SessionContext::new();
913        let factory =
914            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
915
916        let mut acc1 = factory.create_accumulator();
917        let mut acc2 = factory.create_accumulator();
918
919        acc1.add_event(&float_event(1000, vec![10.0]));
920        acc2.add_event(&float_event(2000, vec![20.0]));
921
922        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
923        assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
924    }
925
926    #[test]
927    fn test_adapter_function_name() {
928        let ctx = SessionContext::new();
929        let factory =
930            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
931        let acc = factory.create_accumulator();
932        let adapter = acc
933            .as_any()
934            .downcast_ref::<DataFusionAccumulatorAdapter>()
935            .expect("should be adapter");
936        assert_eq!(adapter.function_name(), "sum");
937    }
938}