Skip to main content

laminar_sql/datafusion/
aggregate_bridge.rs

1//! DataFusion Aggregate Bridge
2//!
3//! Bridges DataFusion's 50+ built-in aggregate functions into LaminarDB's
4//! `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
5//! reimplementing statistical functions (STDDEV, VARIANCE, PERCENTILE, etc.)
6//! that DataFusion already provides.
7//!
8//! # Architecture
9//!
10//! ```text
11//! DataFusion World                 LaminarDB World
12//! ┌───────────────────┐           ┌──────────────────────┐
13//! │ AggregateUDF      │           │ DynAggregatorFactory │
14//! │   └─▶ Accumulator │──bridge──▶│   └─▶ DynAccumulator │
15//! │       (ScalarValue)│           │       (ScalarResult) │
16//! └───────────────────┘           └──────────────────────┘
17//! ```
18//!
19//! # Ring Architecture
20//!
21//! This bridge is Ring 1 (allocates, uses dynamic dispatch). Ring 0 workloads
22//! continue to use the hand-written static-dispatch aggregators.
23
24use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::{DataFusionError, ScalarValue};
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37// Type Conversion: ScalarValue <-> ScalarResult
38
39/// Converts a DataFusion [`ScalarValue`] to a LaminarDB [`ScalarResult`].
40///
41/// Handles the common numeric types. Non-numeric types map to [`ScalarResult::Null`].
42#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44    match sv {
45        ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46        ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47        ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48        ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49            ScalarResult::OptionalFloat64(None)
50        }
51        ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52        // Widen smaller int types
53        ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54        ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55        ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56        ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57        ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58        ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59        // Widen smaller float types
60        ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61        _ => ScalarResult::Null,
62    }
63}
64
65/// Converts a [`ScalarResult`] to a DataFusion [`ScalarValue`].
66#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68    match sr {
69        ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70        ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71        ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72        ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73        ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74        ScalarResult::Null => ScalarValue::Null,
75    }
76}
77
78// DataFusion Accumulator Adapter
79
80/// Adapts a DataFusion [`datafusion_expr::Accumulator`] into LaminarDB's
81/// [`DynAccumulator`] trait.
82///
83/// Uses `RefCell` for interior mutability since DataFusion's `evaluate()`
84/// and `state()` require `&mut self` but LaminarDB's `result_scalar()`
85/// and `serialize()` take `&self`.
86pub struct DataFusionAccumulatorAdapter {
87    /// The wrapped DataFusion accumulator (RefCell for interior mutability)
88    inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89    /// Column indices to extract from events
90    column_indices: Vec<usize>,
91    /// Input types (for creating arrays during merge)
92    input_types: Vec<DataType>,
93    /// Function name (for type_tag/debug)
94    function_name: String,
95    /// Factory for creating fresh accumulators (enables clone_box)
96    factory: Arc<DataFusionAggregateFactory>,
97}
98
99// SAFETY: DataFusion accumulators are Send. RefCell is Send when T is Send.
100// The adapter is only accessed from a single thread (Ring 1 processing).
101unsafe impl Send for DataFusionAccumulatorAdapter {}
102
103impl std::fmt::Debug for DataFusionAccumulatorAdapter {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("DataFusionAccumulatorAdapter")
106            .field("function_name", &self.function_name)
107            .field("column_indices", &self.column_indices)
108            .field("input_types", &self.input_types)
109            .finish_non_exhaustive()
110    }
111}
112
113impl DataFusionAccumulatorAdapter {
114    /// Creates a new adapter wrapping a DataFusion accumulator.
115    #[must_use]
116    pub fn new(
117        inner: Box<dyn datafusion_expr::Accumulator>,
118        column_indices: Vec<usize>,
119        input_types: Vec<DataType>,
120        function_name: String,
121        factory: Arc<DataFusionAggregateFactory>,
122    ) -> Self {
123        Self {
124            inner: RefCell::new(inner),
125            column_indices,
126            input_types,
127            function_name,
128            factory,
129        }
130    }
131
132    /// Returns the wrapped function name.
133    #[must_use]
134    pub fn function_name(&self) -> &str {
135        &self.function_name
136    }
137
138    /// Extracts the relevant columns from a `RecordBatch`.
139    fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
140        self.column_indices
141            .iter()
142            .enumerate()
143            .map(|(arg_idx, &col_idx)| {
144                if col_idx < batch.num_columns() {
145                    Arc::clone(batch.column(col_idx))
146                } else {
147                    let dt = self
148                        .input_types
149                        .get(arg_idx)
150                        .cloned()
151                        .unwrap_or(DataType::Int64);
152                    arrow_array::new_null_array(&dt, batch.num_rows())
153                }
154            })
155            .collect()
156    }
157}
158
159impl DynAccumulator for DataFusionAccumulatorAdapter {
160    fn add_event(&mut self, event: &Event) {
161        let columns = self.extract_columns(&event.data);
162        if let Err(e) = self.inner.borrow_mut().update_batch(&columns) {
163            tracing::warn!(
164                func = %self.function_name,
165                error = %e,
166                "Accumulator update_batch failed"
167            );
168        }
169    }
170
171    fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
172        let Some(other) = other
173            .as_any()
174            .downcast_ref::<DataFusionAccumulatorAdapter>()
175        else {
176            tracing::error!("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
177            return;
178        };
179
180        match other.inner.borrow_mut().state() {
181            Ok(state_values) => {
182                let mut failed_conversions = 0u32;
183                let state_arrays: Vec<ArrayRef> = state_values
184                    .iter()
185                    .filter_map(|sv| {
186                        if let Ok(arr) = sv.to_array() {
187                            Some(arr)
188                        } else {
189                            failed_conversions += 1;
190                            None
191                        }
192                    })
193                    .collect();
194                if failed_conversions > 0 {
195                    tracing::warn!(
196                        func = %self.function_name,
197                        count = failed_conversions,
198                        "ScalarValue to_array conversions failed during merge"
199                    );
200                }
201                if !state_arrays.is_empty() {
202                    if let Err(e) = self.inner.borrow_mut().merge_batch(&state_arrays) {
203                        tracing::warn!(
204                            func = %self.function_name,
205                            error = %e,
206                            "Accumulator merge_batch failed"
207                        );
208                    }
209                }
210            }
211            Err(e) => {
212                tracing::warn!(
213                    func = %self.function_name,
214                    error = %e,
215                    "Failed to extract state for merge"
216                );
217            }
218        }
219    }
220
221    fn result_scalar(&self) -> ScalarResult {
222        match self.inner.borrow_mut().evaluate() {
223            Ok(sv) => scalar_value_to_result(&sv),
224            Err(_) => ScalarResult::Null,
225        }
226    }
227
228    fn is_empty(&self) -> bool {
229        self.inner.borrow().size() <= std::mem::size_of::<Self>()
230    }
231
232    fn clone_box(&self) -> Box<dyn DynAccumulator> {
233        let create = || {
234            self.factory.create_df_accumulator().unwrap_or_else(|e| {
235                panic!(
236                    "failed to create DataFusion accumulator '{}': {e}",
237                    self.function_name
238                );
239            })
240        };
241        let new_inner = create();
242        // Merge current state into the fresh accumulator
243        if let Ok(state_values) = self.inner.borrow_mut().state() {
244            let state_arrays: Vec<ArrayRef> = state_values
245                .iter()
246                .filter_map(|sv| sv.to_array().ok())
247                .collect();
248            if !state_arrays.is_empty() {
249                let mut new_acc = new_inner;
250                if new_acc.merge_batch(&state_arrays).is_ok() {
251                    return Box::new(DataFusionAccumulatorAdapter {
252                        inner: RefCell::new(new_acc),
253                        column_indices: self.column_indices.clone(),
254                        input_types: self.input_types.clone(),
255                        function_name: self.function_name.clone(),
256                        factory: Arc::clone(&self.factory),
257                    });
258                }
259            }
260        }
261        // Fallback: return a fresh empty accumulator
262        Box::new(DataFusionAccumulatorAdapter {
263            inner: RefCell::new(create()),
264            column_indices: self.column_indices.clone(),
265            input_types: self.input_types.clone(),
266            function_name: self.function_name.clone(),
267            factory: Arc::clone(&self.factory),
268        })
269    }
270
271    #[allow(clippy::cast_possible_truncation)] // Wire format uses fixed-width integers
272    fn serialize(&self) -> Vec<u8> {
273        match self.inner.borrow_mut().state() {
274            Ok(state_values) => {
275                let mut buf = Vec::new();
276                let count = state_values.len() as u32;
277                buf.extend_from_slice(&count.to_le_bytes());
278                for sv in &state_values {
279                    let bytes = sv.to_string();
280                    let len = bytes.len() as u32;
281                    buf.extend_from_slice(&len.to_le_bytes());
282                    buf.extend_from_slice(bytes.as_bytes());
283                }
284                buf
285            }
286            Err(_) => Vec::new(),
287        }
288    }
289
290    fn result_field(&self) -> Field {
291        let result = self.result_scalar();
292        let dt = result.data_type();
293        let dt = if dt == DataType::Null {
294            DataType::Float64
295        } else {
296            dt
297        };
298        Field::new(&self.function_name, dt, true)
299    }
300
301    fn type_tag(&self) -> &'static str {
302        "datafusion_adapter"
303    }
304
305    fn as_any(&self) -> &dyn std::any::Any {
306        self
307    }
308}
309
310// DataFusion Aggregate Factory
311
312/// Factory for creating [`DataFusionAccumulatorAdapter`] instances.
313///
314/// Wraps a DataFusion [`AggregateUDF`] and provides the [`DynAggregatorFactory`]
315/// interface for use with `CompositeAggregator`.
316pub struct DataFusionAggregateFactory {
317    /// The DataFusion aggregate UDF
318    udf: Arc<AggregateUDF>,
319    /// Column indices to extract from events
320    column_indices: Vec<usize>,
321    /// Input types for the aggregate
322    input_types: Vec<DataType>,
323    /// Whether this is a DISTINCT aggregate
324    is_distinct: bool,
325}
326
327impl std::fmt::Debug for DataFusionAggregateFactory {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        f.debug_struct("DataFusionAggregateFactory")
330            .field("name", &self.udf.name())
331            .field("column_indices", &self.column_indices)
332            .field("input_types", &self.input_types)
333            .field("is_distinct", &self.is_distinct)
334            .finish()
335    }
336}
337
338impl DataFusionAggregateFactory {
339    /// Creates a new factory for the given DataFusion aggregate UDF.
340    #[must_use]
341    pub fn new(
342        udf: Arc<AggregateUDF>,
343        column_indices: Vec<usize>,
344        input_types: Vec<DataType>,
345    ) -> Self {
346        Self {
347            udf,
348            column_indices,
349            input_types,
350            is_distinct: false,
351        }
352    }
353
354    /// Creates a new factory with the DISTINCT flag set.
355    #[must_use]
356    pub fn with_distinct(mut self, distinct: bool) -> Self {
357        self.is_distinct = distinct;
358        self
359    }
360
361    /// Returns the name of the wrapped aggregate function.
362    #[must_use]
363    pub fn name(&self) -> &str {
364        self.udf.name()
365    }
366
367    /// Pre-defined column names to avoid `format!()` per accumulator creation.
368    const COL_NAMES: [&str; 8] = [
369        "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
370    ];
371
372    /// Returns a cached column name for the given index.
373    fn col_name(i: usize) -> &'static str {
374        Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
375    }
376
377    /// Creates a DataFusion accumulator from the UDF.
378    fn create_df_accumulator(
379        &self,
380    ) -> std::result::Result<Box<dyn datafusion_expr::Accumulator>, DataFusionError> {
381        let return_type = self
382            .udf
383            .return_type(&self.input_types)
384            .unwrap_or(DataType::Float64);
385        let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
386        let schema = Schema::new(
387            self.input_types
388                .iter()
389                .enumerate()
390                .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
391                .collect::<Vec<_>>(),
392        );
393        let expr_fields: Vec<FieldRef> = self
394            .input_types
395            .iter()
396            .enumerate()
397            .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
398            .collect();
399        let args = AccumulatorArgs {
400            return_field,
401            schema: &schema,
402            ignore_nulls: false,
403            order_bys: &[],
404            is_reversed: false,
405            name: self.udf.name(),
406            is_distinct: self.is_distinct,
407            exprs: &[],
408            expr_fields: &expr_fields,
409        };
410        self.udf.accumulator(args)
411    }
412}
413
414impl DataFusionAggregateFactory {
415    /// Creates an accumulator adapter with a back-reference to this factory.
416    ///
417    /// The factory must be wrapped in an `Arc` for the adapter to support
418    /// `clone_box()`.
419    /// # Errors
420    ///
421    /// Returns `DataFusionError` if the underlying UDF fails to create an accumulator.
422    pub fn create_accumulator_with_factory(
423        self: &Arc<Self>,
424    ) -> std::result::Result<Box<dyn DynAccumulator>, DataFusionError> {
425        let inner = self.create_df_accumulator()?;
426        Ok(Box::new(DataFusionAccumulatorAdapter::new(
427            inner,
428            self.column_indices.clone(),
429            self.input_types.clone(),
430            self.udf.name().to_string(),
431            Arc::clone(self),
432        )))
433    }
434}
435
436impl DynAggregatorFactory for DataFusionAggregateFactory {
437    fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
438        // Without an Arc<Self>, we create a temporary Arc for the adapter.
439        // This is fine — the adapter only uses it for clone_box().
440        let factory_arc = Arc::new(DataFusionAggregateFactory {
441            udf: Arc::clone(&self.udf),
442            column_indices: self.column_indices.clone(),
443            input_types: self.input_types.clone(),
444            is_distinct: self.is_distinct,
445        });
446        let inner = self.create_df_accumulator().unwrap_or_else(|e| {
447            panic!(
448                "failed to create DataFusion accumulator '{}': {e}",
449                self.udf.name()
450            );
451        });
452        Box::new(DataFusionAccumulatorAdapter::new(
453            inner,
454            self.column_indices.clone(),
455            self.input_types.clone(),
456            self.udf.name().to_string(),
457            factory_arc,
458        ))
459    }
460
461    fn result_field(&self) -> Field {
462        let return_type = self
463            .udf
464            .return_type(&self.input_types)
465            .unwrap_or(DataType::Float64);
466        Field::new(self.udf.name(), return_type, true)
467    }
468
469    fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
470        Box::new(DataFusionAggregateFactory {
471            udf: Arc::clone(&self.udf),
472            column_indices: self.column_indices.clone(),
473            input_types: self.input_types.clone(),
474            is_distinct: self.is_distinct,
475        })
476    }
477
478    fn type_tag(&self) -> &'static str {
479        "datafusion_factory"
480    }
481}
482
483// Built-in Aggregate Lookup
484
485/// Looks up a DataFusion built-in aggregate function by name.
486///
487/// Returns `None` if the function is not a recognized DataFusion aggregate.
488#[must_use]
489pub fn lookup_aggregate_udf(
490    ctx: &datafusion::prelude::SessionContext,
491    name: &str,
492) -> Option<Arc<AggregateUDF>> {
493    let normalized = name.to_lowercase();
494    ctx.udaf(&normalized).ok()
495}
496
497/// Creates a [`DataFusionAggregateFactory`] for a named built-in aggregate.
498///
499/// Returns `None` if the function name is not recognized.
500#[must_use]
501pub fn create_aggregate_factory(
502    ctx: &datafusion::prelude::SessionContext,
503    name: &str,
504    column_indices: Vec<usize>,
505    input_types: Vec<DataType>,
506) -> Option<DataFusionAggregateFactory> {
507    lookup_aggregate_udf(ctx, name)
508        .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
509}
510
511// Tests
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::datafusion::create_session_context;
517    use arrow_array::{Float64Array, Int64Array, RecordBatch};
518
519    fn float_event(ts: i64, values: Vec<f64>) -> Event {
520        let schema = Arc::new(Schema::new(vec![Field::new(
521            "value",
522            DataType::Float64,
523            false,
524        )]));
525        let batch =
526            RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
527        Event::new(ts, batch)
528    }
529
530    fn int_event(ts: i64, values: Vec<i64>) -> Event {
531        let schema = Arc::new(Schema::new(vec![Field::new(
532            "value",
533            DataType::Int64,
534            false,
535        )]));
536        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
537        Event::new(ts, batch)
538    }
539
540    fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
541        let schema = Arc::new(Schema::new(vec![
542            Field::new("x", DataType::Float64, false),
543            Field::new("y", DataType::Float64, false),
544        ]));
545        let batch = RecordBatch::try_new(
546            schema,
547            vec![
548                Arc::new(Float64Array::from(col0)),
549                Arc::new(Float64Array::from(col1)),
550            ],
551        )
552        .unwrap();
553        Event::new(ts, batch)
554    }
555
556    // ── ScalarValue Conversion Tests ────────────────────────────────────
557
558    #[test]
559    fn test_scalar_value_to_result_int64() {
560        let sv = ScalarValue::Int64(Some(42));
561        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
562    }
563
564    #[test]
565    fn test_scalar_value_to_result_float64() {
566        let sv = ScalarValue::Float64(Some(3.125));
567        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
568    }
569
570    #[test]
571    fn test_scalar_value_to_result_uint64() {
572        let sv = ScalarValue::UInt64(Some(100));
573        assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
574    }
575
576    #[test]
577    fn test_scalar_value_to_result_null_int64() {
578        let sv = ScalarValue::Int64(None);
579        assert_eq!(
580            scalar_value_to_result(&sv),
581            ScalarResult::OptionalInt64(None)
582        );
583    }
584
585    #[test]
586    fn test_scalar_value_to_result_null_float64() {
587        let sv = ScalarValue::Float64(None);
588        assert_eq!(
589            scalar_value_to_result(&sv),
590            ScalarResult::OptionalFloat64(None)
591        );
592    }
593
594    #[test]
595    fn test_scalar_value_to_result_smaller_ints() {
596        assert_eq!(
597            scalar_value_to_result(&ScalarValue::Int8(Some(8))),
598            ScalarResult::Int64(8)
599        );
600        assert_eq!(
601            scalar_value_to_result(&ScalarValue::Int16(Some(16))),
602            ScalarResult::Int64(16)
603        );
604        assert_eq!(
605            scalar_value_to_result(&ScalarValue::Int32(Some(32))),
606            ScalarResult::Int64(32)
607        );
608        assert_eq!(
609            scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
610            ScalarResult::UInt64(8)
611        );
612    }
613
614    #[test]
615    fn test_scalar_value_to_result_float32() {
616        let sv = ScalarValue::Float32(Some(2.5));
617        assert_eq!(
618            scalar_value_to_result(&sv),
619            ScalarResult::Float64(f64::from(2.5f32))
620        );
621    }
622
623    #[test]
624    fn test_scalar_value_to_result_unsupported() {
625        let sv = ScalarValue::Utf8(Some("hello".to_string()));
626        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
627    }
628
629    #[test]
630    fn test_result_to_scalar_value_roundtrip() {
631        // Exact roundtrip for non-optional variants
632        let exact_cases = vec![
633            ScalarResult::Int64(42),
634            ScalarResult::Float64(3.125),
635            ScalarResult::UInt64(100),
636        ];
637        for sr in &exact_cases {
638            let sv = result_to_scalar_value(sr);
639            let back = scalar_value_to_result(&sv);
640            assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
641        }
642
643        // Optional(Some(v)) normalizes to non-optional through ScalarValue
644        // because ScalarValue::Int64(Some(7)) maps back to ScalarResult::Int64(7)
645        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
646        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
647
648        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
649        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
650
651        // Optional None roundtrips back to OptionalNone (ScalarValue preserves type)
652        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
653        assert_eq!(
654            scalar_value_to_result(&sv),
655            ScalarResult::OptionalInt64(None)
656        );
657
658        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
659        assert_eq!(
660            scalar_value_to_result(&sv),
661            ScalarResult::OptionalFloat64(None)
662        );
663
664        // Null roundtrips correctly
665        let sv = result_to_scalar_value(&ScalarResult::Null);
666        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
667    }
668
669    // ── Factory Tests ───────────────────────────────────────────────────
670
671    #[test]
672    fn test_factory_count() {
673        let ctx = create_session_context();
674        let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
675        assert!(factory.is_some(), "count should be a recognized aggregate");
676        assert_eq!(factory.unwrap().name(), "count");
677    }
678
679    #[test]
680    fn test_factory_sum() {
681        let ctx = create_session_context();
682        let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
683        assert!(factory.is_some());
684        assert_eq!(factory.unwrap().name(), "sum");
685    }
686
687    #[test]
688    fn test_factory_avg() {
689        let ctx = create_session_context();
690        let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
691        assert!(factory.is_some());
692    }
693
694    #[test]
695    fn test_factory_stddev() {
696        let ctx = create_session_context();
697        let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
698        assert!(
699            factory.is_some(),
700            "stddev should be available in DataFusion"
701        );
702    }
703
704    #[test]
705    fn test_factory_unknown() {
706        let ctx = create_session_context();
707        let factory = create_aggregate_factory(
708            &ctx,
709            "nonexistent_aggregate_xyz",
710            vec![0],
711            vec![DataType::Int64],
712        );
713        assert!(factory.is_none());
714    }
715
716    #[test]
717    fn test_factory_result_field() {
718        let ctx = create_session_context();
719        let factory =
720            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
721        let field = factory.result_field();
722        assert_eq!(field.name(), "sum");
723        assert_eq!(field.data_type(), &DataType::Float64);
724    }
725
726    #[test]
727    fn test_factory_clone_box() {
728        let ctx = create_session_context();
729        let factory =
730            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
731        let cloned = factory.clone_box();
732        assert_eq!(cloned.type_tag(), "datafusion_factory");
733    }
734
735    // ── Adapter Basics ──────────────────────────────────────────────────
736
737    #[test]
738    fn test_adapter_count_basic() {
739        let ctx = create_session_context();
740        let factory =
741            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
742        let mut acc = factory.create_accumulator();
743
744        let result = acc.result_scalar();
745        assert!(
746            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
747            "Expected 0, got {result:?}"
748        );
749
750        acc.add_event(&int_event(1000, vec![10, 20, 30]));
751        let result = acc.result_scalar();
752        assert!(
753            matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
754            "Expected 3, got {result:?}"
755        );
756
757        acc.add_event(&int_event(2000, vec![40, 50]));
758        let result = acc.result_scalar();
759        assert!(
760            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
761            "Expected 5, got {result:?}"
762        );
763    }
764
765    #[test]
766    fn test_adapter_sum_float64() {
767        let ctx = create_session_context();
768        let factory =
769            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
770        let mut acc = factory.create_accumulator();
771
772        acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
773        assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
774    }
775
776    #[test]
777    fn test_adapter_avg_float64() {
778        let ctx = create_session_context();
779        let factory =
780            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
781        let mut acc = factory.create_accumulator();
782
783        acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
784        assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
785    }
786
787    #[test]
788    fn test_adapter_min_float64() {
789        let ctx = create_session_context();
790        let factory =
791            create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
792        let mut acc = factory.create_accumulator();
793
794        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
795        assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
796    }
797
798    #[test]
799    fn test_adapter_max_float64() {
800        let ctx = create_session_context();
801        let factory =
802            create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
803        let mut acc = factory.create_accumulator();
804
805        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
806        assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
807    }
808
809    #[test]
810    fn test_adapter_sum_int64() {
811        let ctx = create_session_context();
812        let factory =
813            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
814        let mut acc = factory.create_accumulator();
815
816        acc.add_event(&int_event(1000, vec![10, 20, 30]));
817        assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
818    }
819
820    #[test]
821    fn test_adapter_type_tag() {
822        let ctx = create_session_context();
823        let factory =
824            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
825        let acc = factory.create_accumulator();
826        assert_eq!(acc.type_tag(), "datafusion_adapter");
827    }
828
829    #[test]
830    fn test_adapter_result_field() {
831        let ctx = create_session_context();
832        let factory =
833            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
834        let mut acc = factory.create_accumulator();
835        acc.add_event(&float_event(1000, vec![1.0]));
836        assert_eq!(acc.result_field().name(), "sum");
837    }
838
839    // ── Merge Tests ─────────────────────────────────────────────────────
840
841    #[test]
842    fn test_adapter_merge_sum() {
843        let ctx = create_session_context();
844        let factory =
845            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
846
847        let mut acc1 = factory.create_accumulator();
848        acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
849
850        let mut acc2 = factory.create_accumulator();
851        acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
852
853        acc1.merge_dyn(acc2.as_ref());
854        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
855    }
856
857    #[test]
858    fn test_adapter_merge_count() {
859        let ctx = create_session_context();
860        let factory =
861            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
862
863        let mut acc1 = factory.create_accumulator();
864        acc1.add_event(&int_event(1000, vec![1, 2, 3]));
865
866        let mut acc2 = factory.create_accumulator();
867        acc2.add_event(&int_event(2000, vec![4, 5]));
868
869        acc1.merge_dyn(acc2.as_ref());
870        let result = acc1.result_scalar();
871        assert!(
872            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
873            "Expected 5 after merge, got {result:?}"
874        );
875    }
876
877    #[test]
878    fn test_adapter_merge_avg() {
879        let ctx = create_session_context();
880        let factory =
881            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
882
883        let mut acc1 = factory.create_accumulator();
884        acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
885
886        let mut acc2 = factory.create_accumulator();
887        acc2.add_event(&float_event(2000, vec![30.0]));
888
889        acc1.merge_dyn(acc2.as_ref());
890        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
891    }
892
893    #[test]
894    fn test_adapter_merge_empty() {
895        let ctx = create_session_context();
896        let factory =
897            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
898
899        let mut acc1 = factory.create_accumulator();
900        acc1.add_event(&float_event(1000, vec![5.0]));
901
902        let acc2 = factory.create_accumulator();
903        acc1.merge_dyn(acc2.as_ref());
904        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
905    }
906
907    // ── Built-in Aggregate Pass-Through Tests ───────────────────────────
908
909    #[test]
910    fn test_adapter_stddev() {
911        let ctx = create_session_context();
912        let factory =
913            create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
914        let mut acc = factory.create_accumulator();
915
916        acc.add_event(&float_event(
917            1000,
918            vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
919        ));
920        let result = acc.result_scalar();
921        if let ScalarResult::Float64(v) = result {
922            assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
923        } else {
924            panic!("Expected Float64 result, got {result:?}");
925        }
926    }
927
928    #[test]
929    fn test_adapter_variance() {
930        let ctx = create_session_context();
931        if let Some(factory) =
932            create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
933        {
934            let mut acc = factory.create_accumulator();
935            acc.add_event(&float_event(
936                1000,
937                vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
938            ));
939            if let ScalarResult::Float64(v) = acc.result_scalar() {
940                assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
941            }
942        }
943    }
944
945    #[test]
946    fn test_adapter_median() {
947        let ctx = create_session_context();
948        if let Some(factory) =
949            create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
950        {
951            let mut acc = factory.create_accumulator();
952            acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
953            assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
954        }
955    }
956
957    // ── Serialize Tests ─────────────────────────────────────────────────
958
959    #[test]
960    fn test_adapter_serialize() {
961        let ctx = create_session_context();
962        let factory =
963            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
964        let mut acc = factory.create_accumulator();
965        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
966        assert!(!acc.serialize().is_empty());
967    }
968
969    #[test]
970    fn test_adapter_serialize_empty() {
971        let ctx = create_session_context();
972        let factory =
973            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
974        let acc = factory.create_accumulator();
975        assert!(!acc.serialize().is_empty());
976    }
977
978    // ── Lookup Tests ────────────────────────────────────────────────────
979
980    #[test]
981    fn test_lookup_common_aggregates() {
982        let ctx = create_session_context();
983        for name in &["count", "sum", "min", "max", "avg"] {
984            assert!(
985                lookup_aggregate_udf(&ctx, name).is_some(),
986                "Expected '{name}' to be a recognized aggregate"
987            );
988        }
989    }
990
991    #[test]
992    fn test_lookup_statistical_aggregates() {
993        let ctx = create_session_context();
994        for name in &["stddev", "stddev_pop", "median"] {
995            // Just verify lookup doesn't panic
996            let _ = lookup_aggregate_udf(&ctx, name);
997        }
998    }
999
1000    #[test]
1001    fn test_lookup_case_insensitive() {
1002        let ctx = create_session_context();
1003        assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
1004        assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
1005        assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
1006    }
1007
1008    // ── Multi-column Tests ──────────────────────────────────────────────
1009
1010    #[test]
1011    fn test_adapter_multi_column_covar() {
1012        let ctx = create_session_context();
1013        if let Some(factory) = create_aggregate_factory(
1014            &ctx,
1015            "covar_samp",
1016            vec![0, 1],
1017            vec![DataType::Float64, DataType::Float64],
1018        ) {
1019            let mut acc = factory.create_accumulator();
1020            acc.add_event(&two_col_float_event(
1021                1000,
1022                vec![1.0, 2.0, 3.0, 4.0, 5.0],
1023                vec![1.0, 2.0, 3.0, 4.0, 5.0],
1024            ));
1025            if let ScalarResult::Float64(v) = acc.result_scalar() {
1026                assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
1027            }
1028        }
1029    }
1030
1031    // ── Registration Tests ──────────────────────────────────────────────
1032
1033    #[test]
1034    fn test_create_aggregate_factory_api() {
1035        let ctx = create_session_context();
1036        let factory =
1037            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1038        let acc = factory.create_accumulator();
1039        assert_eq!(acc.type_tag(), "datafusion_adapter");
1040    }
1041
1042    #[test]
1043    fn test_factory_creates_independent_accumulators() {
1044        let ctx = create_session_context();
1045        let factory =
1046            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1047
1048        let mut acc1 = factory.create_accumulator();
1049        let mut acc2 = factory.create_accumulator();
1050
1051        acc1.add_event(&float_event(1000, vec![10.0]));
1052        acc2.add_event(&float_event(2000, vec![20.0]));
1053
1054        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
1055        assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
1056    }
1057
1058    #[test]
1059    fn test_adapter_function_name() {
1060        let ctx = create_session_context();
1061        let factory =
1062            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1063        let acc = factory.create_accumulator();
1064        let adapter = acc
1065            .as_any()
1066            .downcast_ref::<DataFusionAccumulatorAdapter>()
1067            .expect("should be adapter");
1068        assert_eq!(adapter.function_name(), "sum");
1069    }
1070
1071    #[test]
1072    fn test_clone_box_does_not_panic() {
1073        let ctx = create_session_context();
1074        let factory =
1075            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1076        let mut acc = factory.create_accumulator();
1077        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
1078
1079        // clone_box should not panic and should preserve state
1080        let cloned = acc.clone_box();
1081        assert_eq!(cloned.result_scalar(), ScalarResult::Float64(6.0));
1082    }
1083
1084    #[test]
1085    fn test_clone_box_empty_accumulator() {
1086        let ctx = create_session_context();
1087        let factory =
1088            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1089        let acc = factory.create_accumulator();
1090
1091        // clone_box on empty accumulator should work
1092        let cloned = acc.clone_box();
1093        let result = cloned.result_scalar();
1094        assert!(
1095            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
1096            "Expected 0, got {result:?}"
1097        );
1098    }
1099
1100    #[test]
1101    fn test_distinct_factory() {
1102        let ctx = create_session_context();
1103        let udf = lookup_aggregate_udf(&ctx, "count").unwrap();
1104        let factory = DataFusionAggregateFactory::new(udf, vec![0], vec![DataType::Int64])
1105            .with_distinct(true);
1106        assert!(factory.is_distinct);
1107        // Should create accumulator successfully
1108        let _acc = factory.create_accumulator();
1109    }
1110}