Skip to main content

laminar_sql/datafusion/
aggregate_bridge.rs

1//! DataFusion Aggregate Bridge
2//!
3//! Bridges DataFusion's 50+ built-in aggregate functions into LaminarDB's
4//! `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
5//! reimplementing statistical functions (STDDEV, VARIANCE, PERCENTILE, etc.)
6//! that DataFusion already provides.
7//!
8//! # Architecture
9//!
10//! ```text
11//! DataFusion World                 LaminarDB World
12//! ┌───────────────────┐           ┌──────────────────────┐
13//! │ AggregateUDF      │           │ DynAggregatorFactory │
14//! │   └─▶ Accumulator │──bridge──▶│   └─▶ DynAccumulator │
15//! │       (ScalarValue)│           │       (ScalarResult) │
16//! └───────────────────┘           └──────────────────────┘
17//! ```
18//!
19//! # Ring Architecture
20//!
21//! This bridge is Ring 1 (allocates, uses dynamic dispatch). Ring 0 workloads
22//! continue to use the hand-written static-dispatch aggregators.
23
24use std::cell::RefCell;
25use std::sync::Arc;
26
27use arrow_array::ArrayRef;
28use arrow_schema::{DataType, Field, FieldRef, Schema};
29use datafusion::execution::FunctionRegistry;
30use datafusion_common::ScalarValue;
31use datafusion_expr::function::AccumulatorArgs;
32use datafusion_expr::AggregateUDF;
33
34use laminar_core::operator::window::{DynAccumulator, DynAggregatorFactory, ScalarResult};
35use laminar_core::operator::Event;
36
37// Type Conversion: ScalarValue <-> ScalarResult
38
39/// Converts a DataFusion [`ScalarValue`] to a LaminarDB [`ScalarResult`].
40///
41/// Handles the common numeric types. Non-numeric types map to [`ScalarResult::Null`].
42#[must_use]
43pub fn scalar_value_to_result(sv: &ScalarValue) -> ScalarResult {
44    match sv {
45        ScalarValue::Int64(Some(v)) => ScalarResult::Int64(*v),
46        ScalarValue::Int64(None) => ScalarResult::OptionalInt64(None),
47        ScalarValue::Float64(Some(v)) => ScalarResult::Float64(*v),
48        ScalarValue::Float64(None) | ScalarValue::Float32(None) => {
49            ScalarResult::OptionalFloat64(None)
50        }
51        ScalarValue::UInt64(Some(v)) => ScalarResult::UInt64(*v),
52        // Widen smaller int types
53        ScalarValue::Int8(Some(v)) => ScalarResult::Int64(i64::from(*v)),
54        ScalarValue::Int16(Some(v)) => ScalarResult::Int64(i64::from(*v)),
55        ScalarValue::Int32(Some(v)) => ScalarResult::Int64(i64::from(*v)),
56        ScalarValue::UInt8(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
57        ScalarValue::UInt16(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
58        ScalarValue::UInt32(Some(v)) => ScalarResult::UInt64(u64::from(*v)),
59        // Widen smaller float types
60        ScalarValue::Float32(Some(v)) => ScalarResult::Float64(f64::from(*v)),
61        _ => ScalarResult::Null,
62    }
63}
64
65/// Converts a [`ScalarResult`] to a DataFusion [`ScalarValue`].
66#[must_use]
67pub fn result_to_scalar_value(sr: &ScalarResult) -> ScalarValue {
68    match sr {
69        ScalarResult::Int64(v) => ScalarValue::Int64(Some(*v)),
70        ScalarResult::Float64(v) => ScalarValue::Float64(Some(*v)),
71        ScalarResult::UInt64(v) => ScalarValue::UInt64(Some(*v)),
72        ScalarResult::OptionalInt64(v) => ScalarValue::Int64(*v),
73        ScalarResult::OptionalFloat64(v) => ScalarValue::Float64(*v),
74        ScalarResult::Null => ScalarValue::Null,
75    }
76}
77
78// DataFusion Accumulator Adapter
79
80/// Adapts a DataFusion [`datafusion_expr::Accumulator`] into LaminarDB's
81/// [`DynAccumulator`] trait.
82///
83/// Uses `RefCell` for interior mutability since DataFusion's `evaluate()`
84/// and `state()` require `&mut self` but LaminarDB's `result_scalar()`
85/// and `serialize()` take `&self`.
86pub struct DataFusionAccumulatorAdapter {
87    /// The wrapped DataFusion accumulator (RefCell for interior mutability)
88    inner: RefCell<Box<dyn datafusion_expr::Accumulator>>,
89    /// Column indices to extract from events
90    column_indices: Vec<usize>,
91    /// Input types (for creating arrays during merge)
92    input_types: Vec<DataType>,
93    /// Function name (for type_tag/debug)
94    function_name: String,
95    /// Factory for creating fresh accumulators (enables clone_box)
96    factory: Arc<DataFusionAggregateFactory>,
97}
98
99// SAFETY: DataFusion accumulators are Send. RefCell is Send when T is Send.
100// The adapter is only accessed from a single thread (Ring 1 processing).
101unsafe impl Send for DataFusionAccumulatorAdapter {}
102
103impl std::fmt::Debug for DataFusionAccumulatorAdapter {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("DataFusionAccumulatorAdapter")
106            .field("function_name", &self.function_name)
107            .field("column_indices", &self.column_indices)
108            .field("input_types", &self.input_types)
109            .finish_non_exhaustive()
110    }
111}
112
113impl DataFusionAccumulatorAdapter {
114    /// Creates a new adapter wrapping a DataFusion accumulator.
115    #[must_use]
116    pub fn new(
117        inner: Box<dyn datafusion_expr::Accumulator>,
118        column_indices: Vec<usize>,
119        input_types: Vec<DataType>,
120        function_name: String,
121        factory: Arc<DataFusionAggregateFactory>,
122    ) -> Self {
123        Self {
124            inner: RefCell::new(inner),
125            column_indices,
126            input_types,
127            function_name,
128            factory,
129        }
130    }
131
132    /// Returns the wrapped function name.
133    #[must_use]
134    pub fn function_name(&self) -> &str {
135        &self.function_name
136    }
137
138    /// Extracts the relevant columns from a `RecordBatch`.
139    fn extract_columns(&self, batch: &arrow_array::RecordBatch) -> Vec<ArrayRef> {
140        self.column_indices
141            .iter()
142            .enumerate()
143            .map(|(arg_idx, &col_idx)| {
144                if col_idx < batch.num_columns() {
145                    Arc::clone(batch.column(col_idx))
146                } else {
147                    let dt = self
148                        .input_types
149                        .get(arg_idx)
150                        .cloned()
151                        .unwrap_or(DataType::Int64);
152                    arrow_array::new_null_array(&dt, batch.num_rows())
153                }
154            })
155            .collect()
156    }
157}
158
159impl DynAccumulator for DataFusionAccumulatorAdapter {
160    fn add_event(&mut self, event: &Event) {
161        let columns = self.extract_columns(&event.data);
162        if let Err(e) = self.inner.borrow_mut().update_batch(&columns) {
163            tracing::warn!(
164                func = %self.function_name,
165                error = %e,
166                "Accumulator update_batch failed"
167            );
168        }
169    }
170
171    fn merge_dyn(&mut self, other: &dyn DynAccumulator) {
172        let other = other
173            .as_any()
174            .downcast_ref::<DataFusionAccumulatorAdapter>()
175            .expect("merge_dyn: type mismatch, expected DataFusionAccumulatorAdapter");
176
177        match other.inner.borrow_mut().state() {
178            Ok(state_values) => {
179                let mut failed_conversions = 0u32;
180                let state_arrays: Vec<ArrayRef> = state_values
181                    .iter()
182                    .filter_map(|sv| {
183                        if let Ok(arr) = sv.to_array() {
184                            Some(arr)
185                        } else {
186                            failed_conversions += 1;
187                            None
188                        }
189                    })
190                    .collect();
191                if failed_conversions > 0 {
192                    tracing::warn!(
193                        func = %self.function_name,
194                        count = failed_conversions,
195                        "ScalarValue to_array conversions failed during merge"
196                    );
197                }
198                if !state_arrays.is_empty() {
199                    if let Err(e) = self.inner.borrow_mut().merge_batch(&state_arrays) {
200                        tracing::warn!(
201                            func = %self.function_name,
202                            error = %e,
203                            "Accumulator merge_batch failed"
204                        );
205                    }
206                }
207            }
208            Err(e) => {
209                tracing::warn!(
210                    func = %self.function_name,
211                    error = %e,
212                    "Failed to extract state for merge"
213                );
214            }
215        }
216    }
217
218    fn result_scalar(&self) -> ScalarResult {
219        match self.inner.borrow_mut().evaluate() {
220            Ok(sv) => scalar_value_to_result(&sv),
221            Err(_) => ScalarResult::Null,
222        }
223    }
224
225    fn is_empty(&self) -> bool {
226        self.inner.borrow().size() <= std::mem::size_of::<Self>()
227    }
228
229    fn clone_box(&self) -> Box<dyn DynAccumulator> {
230        let new_inner = self.factory.create_df_accumulator();
231        // Merge current state into the fresh accumulator
232        if let Ok(state_values) = self.inner.borrow_mut().state() {
233            let state_arrays: Vec<ArrayRef> = state_values
234                .iter()
235                .filter_map(|sv| sv.to_array().ok())
236                .collect();
237            if !state_arrays.is_empty() {
238                let mut new_acc = new_inner;
239                if new_acc.merge_batch(&state_arrays).is_ok() {
240                    return Box::new(DataFusionAccumulatorAdapter {
241                        inner: RefCell::new(new_acc),
242                        column_indices: self.column_indices.clone(),
243                        input_types: self.input_types.clone(),
244                        function_name: self.function_name.clone(),
245                        factory: Arc::clone(&self.factory),
246                    });
247                }
248            }
249        }
250        // Fallback: return a fresh empty accumulator
251        Box::new(DataFusionAccumulatorAdapter {
252            inner: RefCell::new(self.factory.create_df_accumulator()),
253            column_indices: self.column_indices.clone(),
254            input_types: self.input_types.clone(),
255            function_name: self.function_name.clone(),
256            factory: Arc::clone(&self.factory),
257        })
258    }
259
260    #[allow(clippy::cast_possible_truncation)] // Wire format uses fixed-width integers
261    fn serialize(&self) -> Vec<u8> {
262        match self.inner.borrow_mut().state() {
263            Ok(state_values) => {
264                let mut buf = Vec::new();
265                let count = state_values.len() as u32;
266                buf.extend_from_slice(&count.to_le_bytes());
267                for sv in &state_values {
268                    let bytes = sv.to_string();
269                    let len = bytes.len() as u32;
270                    buf.extend_from_slice(&len.to_le_bytes());
271                    buf.extend_from_slice(bytes.as_bytes());
272                }
273                buf
274            }
275            Err(_) => Vec::new(),
276        }
277    }
278
279    fn result_field(&self) -> Field {
280        let result = self.result_scalar();
281        let dt = result.data_type();
282        let dt = if dt == DataType::Null {
283            DataType::Float64
284        } else {
285            dt
286        };
287        Field::new(&self.function_name, dt, true)
288    }
289
290    fn type_tag(&self) -> &'static str {
291        "datafusion_adapter"
292    }
293
294    fn as_any(&self) -> &dyn std::any::Any {
295        self
296    }
297}
298
299// DataFusion Aggregate Factory
300
301/// Factory for creating [`DataFusionAccumulatorAdapter`] instances.
302///
303/// Wraps a DataFusion [`AggregateUDF`] and provides the [`DynAggregatorFactory`]
304/// interface for use with `CompositeAggregator`.
305pub struct DataFusionAggregateFactory {
306    /// The DataFusion aggregate UDF
307    udf: Arc<AggregateUDF>,
308    /// Column indices to extract from events
309    column_indices: Vec<usize>,
310    /// Input types for the aggregate
311    input_types: Vec<DataType>,
312    /// Whether this is a DISTINCT aggregate
313    is_distinct: bool,
314}
315
316impl std::fmt::Debug for DataFusionAggregateFactory {
317    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
318        f.debug_struct("DataFusionAggregateFactory")
319            .field("name", &self.udf.name())
320            .field("column_indices", &self.column_indices)
321            .field("input_types", &self.input_types)
322            .field("is_distinct", &self.is_distinct)
323            .finish()
324    }
325}
326
327impl DataFusionAggregateFactory {
328    /// Creates a new factory for the given DataFusion aggregate UDF.
329    #[must_use]
330    pub fn new(
331        udf: Arc<AggregateUDF>,
332        column_indices: Vec<usize>,
333        input_types: Vec<DataType>,
334    ) -> Self {
335        Self {
336            udf,
337            column_indices,
338            input_types,
339            is_distinct: false,
340        }
341    }
342
343    /// Creates a new factory with the DISTINCT flag set.
344    #[must_use]
345    pub fn with_distinct(mut self, distinct: bool) -> Self {
346        self.is_distinct = distinct;
347        self
348    }
349
350    /// Returns the name of the wrapped aggregate function.
351    #[must_use]
352    pub fn name(&self) -> &str {
353        self.udf.name()
354    }
355
356    /// Pre-defined column names to avoid `format!()` per accumulator creation.
357    const COL_NAMES: [&str; 8] = [
358        "col_0", "col_1", "col_2", "col_3", "col_4", "col_5", "col_6", "col_7",
359    ];
360
361    /// Returns a cached column name for the given index.
362    fn col_name(i: usize) -> &'static str {
363        Self::COL_NAMES.get(i).copied().unwrap_or("col_n")
364    }
365
366    /// Creates a DataFusion accumulator from the UDF.
367    fn create_df_accumulator(&self) -> Box<dyn datafusion_expr::Accumulator> {
368        let return_type = self
369            .udf
370            .return_type(&self.input_types)
371            .unwrap_or(DataType::Float64);
372        let return_field: FieldRef = Arc::new(Field::new(self.udf.name(), return_type, true));
373        let schema = Schema::new(
374            self.input_types
375                .iter()
376                .enumerate()
377                .map(|(i, dt)| Field::new(Self::col_name(i), dt.clone(), true))
378                .collect::<Vec<_>>(),
379        );
380        let expr_fields: Vec<FieldRef> = self
381            .input_types
382            .iter()
383            .enumerate()
384            .map(|(i, dt)| Arc::new(Field::new(Self::col_name(i), dt.clone(), true)) as FieldRef)
385            .collect();
386        let args = AccumulatorArgs {
387            return_field,
388            schema: &schema,
389            ignore_nulls: false,
390            order_bys: &[],
391            is_reversed: false,
392            name: self.udf.name(),
393            is_distinct: self.is_distinct,
394            exprs: &[],
395            expr_fields: &expr_fields,
396        };
397        self.udf
398            .accumulator(args)
399            .expect("Failed to create DataFusion accumulator")
400    }
401}
402
403impl DataFusionAggregateFactory {
404    /// Creates an accumulator adapter with a back-reference to this factory.
405    ///
406    /// The factory must be wrapped in an `Arc` for the adapter to support
407    /// `clone_box()`.
408    #[must_use]
409    pub fn create_accumulator_with_factory(self: &Arc<Self>) -> Box<dyn DynAccumulator> {
410        let inner = self.create_df_accumulator();
411        Box::new(DataFusionAccumulatorAdapter::new(
412            inner,
413            self.column_indices.clone(),
414            self.input_types.clone(),
415            self.udf.name().to_string(),
416            Arc::clone(self),
417        ))
418    }
419}
420
421impl DynAggregatorFactory for DataFusionAggregateFactory {
422    fn create_accumulator(&self) -> Box<dyn DynAccumulator> {
423        // Without an Arc<Self>, we create a temporary Arc for the adapter.
424        // This is fine — the adapter only uses it for clone_box().
425        let factory_arc = Arc::new(DataFusionAggregateFactory {
426            udf: Arc::clone(&self.udf),
427            column_indices: self.column_indices.clone(),
428            input_types: self.input_types.clone(),
429            is_distinct: self.is_distinct,
430        });
431        let inner = self.create_df_accumulator();
432        Box::new(DataFusionAccumulatorAdapter::new(
433            inner,
434            self.column_indices.clone(),
435            self.input_types.clone(),
436            self.udf.name().to_string(),
437            factory_arc,
438        ))
439    }
440
441    fn result_field(&self) -> Field {
442        let return_type = self
443            .udf
444            .return_type(&self.input_types)
445            .unwrap_or(DataType::Float64);
446        Field::new(self.udf.name(), return_type, true)
447    }
448
449    fn clone_box(&self) -> Box<dyn DynAggregatorFactory> {
450        Box::new(DataFusionAggregateFactory {
451            udf: Arc::clone(&self.udf),
452            column_indices: self.column_indices.clone(),
453            input_types: self.input_types.clone(),
454            is_distinct: self.is_distinct,
455        })
456    }
457
458    fn type_tag(&self) -> &'static str {
459        "datafusion_factory"
460    }
461}
462
463// Built-in Aggregate Lookup
464
465/// Looks up a DataFusion built-in aggregate function by name.
466///
467/// Returns `None` if the function is not a recognized DataFusion aggregate.
468#[must_use]
469pub fn lookup_aggregate_udf(
470    ctx: &datafusion::prelude::SessionContext,
471    name: &str,
472) -> Option<Arc<AggregateUDF>> {
473    let normalized = name.to_lowercase();
474    ctx.udaf(&normalized).ok()
475}
476
477/// Creates a [`DataFusionAggregateFactory`] for a named built-in aggregate.
478///
479/// Returns `None` if the function name is not recognized.
480#[must_use]
481pub fn create_aggregate_factory(
482    ctx: &datafusion::prelude::SessionContext,
483    name: &str,
484    column_indices: Vec<usize>,
485    input_types: Vec<DataType>,
486) -> Option<DataFusionAggregateFactory> {
487    lookup_aggregate_udf(ctx, name)
488        .map(|udf| DataFusionAggregateFactory::new(udf, column_indices, input_types))
489}
490
491// Tests
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::datafusion::create_session_context;
497    use arrow_array::{Float64Array, Int64Array, RecordBatch};
498
499    fn float_event(ts: i64, values: Vec<f64>) -> Event {
500        let schema = Arc::new(Schema::new(vec![Field::new(
501            "value",
502            DataType::Float64,
503            false,
504        )]));
505        let batch =
506            RecordBatch::try_new(schema, vec![Arc::new(Float64Array::from(values))]).unwrap();
507        Event::new(ts, batch)
508    }
509
510    fn int_event(ts: i64, values: Vec<i64>) -> Event {
511        let schema = Arc::new(Schema::new(vec![Field::new(
512            "value",
513            DataType::Int64,
514            false,
515        )]));
516        let batch = RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]).unwrap();
517        Event::new(ts, batch)
518    }
519
520    fn two_col_float_event(ts: i64, col0: Vec<f64>, col1: Vec<f64>) -> Event {
521        let schema = Arc::new(Schema::new(vec![
522            Field::new("x", DataType::Float64, false),
523            Field::new("y", DataType::Float64, false),
524        ]));
525        let batch = RecordBatch::try_new(
526            schema,
527            vec![
528                Arc::new(Float64Array::from(col0)),
529                Arc::new(Float64Array::from(col1)),
530            ],
531        )
532        .unwrap();
533        Event::new(ts, batch)
534    }
535
536    // ── ScalarValue Conversion Tests ────────────────────────────────────
537
538    #[test]
539    fn test_scalar_value_to_result_int64() {
540        let sv = ScalarValue::Int64(Some(42));
541        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(42));
542    }
543
544    #[test]
545    fn test_scalar_value_to_result_float64() {
546        let sv = ScalarValue::Float64(Some(3.125));
547        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(3.125));
548    }
549
550    #[test]
551    fn test_scalar_value_to_result_uint64() {
552        let sv = ScalarValue::UInt64(Some(100));
553        assert_eq!(scalar_value_to_result(&sv), ScalarResult::UInt64(100));
554    }
555
556    #[test]
557    fn test_scalar_value_to_result_null_int64() {
558        let sv = ScalarValue::Int64(None);
559        assert_eq!(
560            scalar_value_to_result(&sv),
561            ScalarResult::OptionalInt64(None)
562        );
563    }
564
565    #[test]
566    fn test_scalar_value_to_result_null_float64() {
567        let sv = ScalarValue::Float64(None);
568        assert_eq!(
569            scalar_value_to_result(&sv),
570            ScalarResult::OptionalFloat64(None)
571        );
572    }
573
574    #[test]
575    fn test_scalar_value_to_result_smaller_ints() {
576        assert_eq!(
577            scalar_value_to_result(&ScalarValue::Int8(Some(8))),
578            ScalarResult::Int64(8)
579        );
580        assert_eq!(
581            scalar_value_to_result(&ScalarValue::Int16(Some(16))),
582            ScalarResult::Int64(16)
583        );
584        assert_eq!(
585            scalar_value_to_result(&ScalarValue::Int32(Some(32))),
586            ScalarResult::Int64(32)
587        );
588        assert_eq!(
589            scalar_value_to_result(&ScalarValue::UInt8(Some(8))),
590            ScalarResult::UInt64(8)
591        );
592    }
593
594    #[test]
595    fn test_scalar_value_to_result_float32() {
596        let sv = ScalarValue::Float32(Some(2.5));
597        assert_eq!(
598            scalar_value_to_result(&sv),
599            ScalarResult::Float64(f64::from(2.5f32))
600        );
601    }
602
603    #[test]
604    fn test_scalar_value_to_result_unsupported() {
605        let sv = ScalarValue::Utf8(Some("hello".to_string()));
606        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
607    }
608
609    #[test]
610    fn test_result_to_scalar_value_roundtrip() {
611        // Exact roundtrip for non-optional variants
612        let exact_cases = vec![
613            ScalarResult::Int64(42),
614            ScalarResult::Float64(3.125),
615            ScalarResult::UInt64(100),
616        ];
617        for sr in &exact_cases {
618            let sv = result_to_scalar_value(sr);
619            let back = scalar_value_to_result(&sv);
620            assert_eq!(&back, sr, "Roundtrip failed for {sr:?}");
621        }
622
623        // Optional(Some(v)) normalizes to non-optional through ScalarValue
624        // because ScalarValue::Int64(Some(7)) maps back to ScalarResult::Int64(7)
625        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(Some(7)));
626        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Int64(7));
627
628        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(Some(2.72)));
629        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Float64(2.72));
630
631        // Optional None roundtrips back to OptionalNone (ScalarValue preserves type)
632        let sv = result_to_scalar_value(&ScalarResult::OptionalInt64(None));
633        assert_eq!(
634            scalar_value_to_result(&sv),
635            ScalarResult::OptionalInt64(None)
636        );
637
638        let sv = result_to_scalar_value(&ScalarResult::OptionalFloat64(None));
639        assert_eq!(
640            scalar_value_to_result(&sv),
641            ScalarResult::OptionalFloat64(None)
642        );
643
644        // Null roundtrips correctly
645        let sv = result_to_scalar_value(&ScalarResult::Null);
646        assert_eq!(scalar_value_to_result(&sv), ScalarResult::Null);
647    }
648
649    // ── Factory Tests ───────────────────────────────────────────────────
650
651    #[test]
652    fn test_factory_count() {
653        let ctx = create_session_context();
654        let factory = create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]);
655        assert!(factory.is_some(), "count should be a recognized aggregate");
656        assert_eq!(factory.unwrap().name(), "count");
657    }
658
659    #[test]
660    fn test_factory_sum() {
661        let ctx = create_session_context();
662        let factory = create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]);
663        assert!(factory.is_some());
664        assert_eq!(factory.unwrap().name(), "sum");
665    }
666
667    #[test]
668    fn test_factory_avg() {
669        let ctx = create_session_context();
670        let factory = create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]);
671        assert!(factory.is_some());
672    }
673
674    #[test]
675    fn test_factory_stddev() {
676        let ctx = create_session_context();
677        let factory = create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]);
678        assert!(
679            factory.is_some(),
680            "stddev should be available in DataFusion"
681        );
682    }
683
684    #[test]
685    fn test_factory_unknown() {
686        let ctx = create_session_context();
687        let factory = create_aggregate_factory(
688            &ctx,
689            "nonexistent_aggregate_xyz",
690            vec![0],
691            vec![DataType::Int64],
692        );
693        assert!(factory.is_none());
694    }
695
696    #[test]
697    fn test_factory_result_field() {
698        let ctx = create_session_context();
699        let factory =
700            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
701        let field = factory.result_field();
702        assert_eq!(field.name(), "sum");
703        assert_eq!(field.data_type(), &DataType::Float64);
704    }
705
706    #[test]
707    fn test_factory_clone_box() {
708        let ctx = create_session_context();
709        let factory =
710            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
711        let cloned = factory.clone_box();
712        assert_eq!(cloned.type_tag(), "datafusion_factory");
713    }
714
715    // ── Adapter Basics ──────────────────────────────────────────────────
716
717    #[test]
718    fn test_adapter_count_basic() {
719        let ctx = create_session_context();
720        let factory =
721            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
722        let mut acc = factory.create_accumulator();
723
724        let result = acc.result_scalar();
725        assert!(
726            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
727            "Expected 0, got {result:?}"
728        );
729
730        acc.add_event(&int_event(1000, vec![10, 20, 30]));
731        let result = acc.result_scalar();
732        assert!(
733            matches!(result, ScalarResult::Int64(3) | ScalarResult::UInt64(3)),
734            "Expected 3, got {result:?}"
735        );
736
737        acc.add_event(&int_event(2000, vec![40, 50]));
738        let result = acc.result_scalar();
739        assert!(
740            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
741            "Expected 5, got {result:?}"
742        );
743    }
744
745    #[test]
746    fn test_adapter_sum_float64() {
747        let ctx = create_session_context();
748        let factory =
749            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
750        let mut acc = factory.create_accumulator();
751
752        acc.add_event(&float_event(1000, vec![1.5, 2.5, 3.0]));
753        assert_eq!(acc.result_scalar(), ScalarResult::Float64(7.0));
754    }
755
756    #[test]
757    fn test_adapter_avg_float64() {
758        let ctx = create_session_context();
759        let factory =
760            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
761        let mut acc = factory.create_accumulator();
762
763        acc.add_event(&float_event(1000, vec![10.0, 20.0, 30.0]));
764        assert_eq!(acc.result_scalar(), ScalarResult::Float64(20.0));
765    }
766
767    #[test]
768    fn test_adapter_min_float64() {
769        let ctx = create_session_context();
770        let factory =
771            create_aggregate_factory(&ctx, "min", vec![0], vec![DataType::Float64]).unwrap();
772        let mut acc = factory.create_accumulator();
773
774        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
775        assert_eq!(acc.result_scalar(), ScalarResult::Float64(10.0));
776    }
777
778    #[test]
779    fn test_adapter_max_float64() {
780        let ctx = create_session_context();
781        let factory =
782            create_aggregate_factory(&ctx, "max", vec![0], vec![DataType::Float64]).unwrap();
783        let mut acc = factory.create_accumulator();
784
785        acc.add_event(&float_event(1000, vec![30.0, 10.0, 20.0]));
786        assert_eq!(acc.result_scalar(), ScalarResult::Float64(30.0));
787    }
788
789    #[test]
790    fn test_adapter_sum_int64() {
791        let ctx = create_session_context();
792        let factory =
793            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Int64]).unwrap();
794        let mut acc = factory.create_accumulator();
795
796        acc.add_event(&int_event(1000, vec![10, 20, 30]));
797        assert_eq!(acc.result_scalar(), ScalarResult::Int64(60));
798    }
799
800    #[test]
801    fn test_adapter_type_tag() {
802        let ctx = create_session_context();
803        let factory =
804            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
805        let acc = factory.create_accumulator();
806        assert_eq!(acc.type_tag(), "datafusion_adapter");
807    }
808
809    #[test]
810    fn test_adapter_result_field() {
811        let ctx = create_session_context();
812        let factory =
813            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
814        let mut acc = factory.create_accumulator();
815        acc.add_event(&float_event(1000, vec![1.0]));
816        assert_eq!(acc.result_field().name(), "sum");
817    }
818
819    // ── Merge Tests ─────────────────────────────────────────────────────
820
821    #[test]
822    fn test_adapter_merge_sum() {
823        let ctx = create_session_context();
824        let factory =
825            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
826
827        let mut acc1 = factory.create_accumulator();
828        acc1.add_event(&float_event(1000, vec![1.0, 2.0]));
829
830        let mut acc2 = factory.create_accumulator();
831        acc2.add_event(&float_event(2000, vec![3.0, 4.0]));
832
833        acc1.merge_dyn(acc2.as_ref());
834        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
835    }
836
837    #[test]
838    fn test_adapter_merge_count() {
839        let ctx = create_session_context();
840        let factory =
841            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
842
843        let mut acc1 = factory.create_accumulator();
844        acc1.add_event(&int_event(1000, vec![1, 2, 3]));
845
846        let mut acc2 = factory.create_accumulator();
847        acc2.add_event(&int_event(2000, vec![4, 5]));
848
849        acc1.merge_dyn(acc2.as_ref());
850        let result = acc1.result_scalar();
851        assert!(
852            matches!(result, ScalarResult::Int64(5) | ScalarResult::UInt64(5)),
853            "Expected 5 after merge, got {result:?}"
854        );
855    }
856
857    #[test]
858    fn test_adapter_merge_avg() {
859        let ctx = create_session_context();
860        let factory =
861            create_aggregate_factory(&ctx, "avg", vec![0], vec![DataType::Float64]).unwrap();
862
863        let mut acc1 = factory.create_accumulator();
864        acc1.add_event(&float_event(1000, vec![10.0, 20.0]));
865
866        let mut acc2 = factory.create_accumulator();
867        acc2.add_event(&float_event(2000, vec![30.0]));
868
869        acc1.merge_dyn(acc2.as_ref());
870        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(20.0));
871    }
872
873    #[test]
874    fn test_adapter_merge_empty() {
875        let ctx = create_session_context();
876        let factory =
877            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
878
879        let mut acc1 = factory.create_accumulator();
880        acc1.add_event(&float_event(1000, vec![5.0]));
881
882        let acc2 = factory.create_accumulator();
883        acc1.merge_dyn(acc2.as_ref());
884        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(5.0));
885    }
886
887    // ── Built-in Aggregate Pass-Through Tests ───────────────────────────
888
889    #[test]
890    fn test_adapter_stddev() {
891        let ctx = create_session_context();
892        let factory =
893            create_aggregate_factory(&ctx, "stddev", vec![0], vec![DataType::Float64]).unwrap();
894        let mut acc = factory.create_accumulator();
895
896        acc.add_event(&float_event(
897            1000,
898            vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
899        ));
900        let result = acc.result_scalar();
901        if let ScalarResult::Float64(v) = result {
902            assert!((v - 2.138).abs() < 0.01, "Expected ~2.138, got {v}");
903        } else {
904            panic!("Expected Float64 result, got {result:?}");
905        }
906    }
907
908    #[test]
909    fn test_adapter_variance() {
910        let ctx = create_session_context();
911        if let Some(factory) =
912            create_aggregate_factory(&ctx, "var_samp", vec![0], vec![DataType::Float64])
913        {
914            let mut acc = factory.create_accumulator();
915            acc.add_event(&float_event(
916                1000,
917                vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0],
918            ));
919            if let ScalarResult::Float64(v) = acc.result_scalar() {
920                assert!((v - 4.571).abs() < 0.01, "Expected ~4.571, got {v}");
921            }
922        }
923    }
924
925    #[test]
926    fn test_adapter_median() {
927        let ctx = create_session_context();
928        if let Some(factory) =
929            create_aggregate_factory(&ctx, "median", vec![0], vec![DataType::Float64])
930        {
931            let mut acc = factory.create_accumulator();
932            acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0, 4.0, 5.0]));
933            assert_eq!(acc.result_scalar(), ScalarResult::Float64(3.0));
934        }
935    }
936
937    // ── Serialize Tests ─────────────────────────────────────────────────
938
939    #[test]
940    fn test_adapter_serialize() {
941        let ctx = create_session_context();
942        let factory =
943            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
944        let mut acc = factory.create_accumulator();
945        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
946        assert!(!acc.serialize().is_empty());
947    }
948
949    #[test]
950    fn test_adapter_serialize_empty() {
951        let ctx = create_session_context();
952        let factory =
953            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
954        let acc = factory.create_accumulator();
955        assert!(!acc.serialize().is_empty());
956    }
957
958    // ── Lookup Tests ────────────────────────────────────────────────────
959
960    #[test]
961    fn test_lookup_common_aggregates() {
962        let ctx = create_session_context();
963        for name in &["count", "sum", "min", "max", "avg"] {
964            assert!(
965                lookup_aggregate_udf(&ctx, name).is_some(),
966                "Expected '{name}' to be a recognized aggregate"
967            );
968        }
969    }
970
971    #[test]
972    fn test_lookup_statistical_aggregates() {
973        let ctx = create_session_context();
974        for name in &["stddev", "stddev_pop", "median"] {
975            // Just verify lookup doesn't panic
976            let _ = lookup_aggregate_udf(&ctx, name);
977        }
978    }
979
980    #[test]
981    fn test_lookup_case_insensitive() {
982        let ctx = create_session_context();
983        assert!(lookup_aggregate_udf(&ctx, "COUNT").is_some());
984        assert!(lookup_aggregate_udf(&ctx, "Sum").is_some());
985        assert!(lookup_aggregate_udf(&ctx, "AVG").is_some());
986    }
987
988    // ── Multi-column Tests ──────────────────────────────────────────────
989
990    #[test]
991    fn test_adapter_multi_column_covar() {
992        let ctx = create_session_context();
993        if let Some(factory) = create_aggregate_factory(
994            &ctx,
995            "covar_samp",
996            vec![0, 1],
997            vec![DataType::Float64, DataType::Float64],
998        ) {
999            let mut acc = factory.create_accumulator();
1000            acc.add_event(&two_col_float_event(
1001                1000,
1002                vec![1.0, 2.0, 3.0, 4.0, 5.0],
1003                vec![1.0, 2.0, 3.0, 4.0, 5.0],
1004            ));
1005            if let ScalarResult::Float64(v) = acc.result_scalar() {
1006                assert!((v - 2.5).abs() < 0.01, "Expected covar ~2.5, got {v}");
1007            }
1008        }
1009    }
1010
1011    // ── Registration Tests ──────────────────────────────────────────────
1012
1013    #[test]
1014    fn test_create_aggregate_factory_api() {
1015        let ctx = create_session_context();
1016        let factory =
1017            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1018        let acc = factory.create_accumulator();
1019        assert_eq!(acc.type_tag(), "datafusion_adapter");
1020    }
1021
1022    #[test]
1023    fn test_factory_creates_independent_accumulators() {
1024        let ctx = create_session_context();
1025        let factory =
1026            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1027
1028        let mut acc1 = factory.create_accumulator();
1029        let mut acc2 = factory.create_accumulator();
1030
1031        acc1.add_event(&float_event(1000, vec![10.0]));
1032        acc2.add_event(&float_event(2000, vec![20.0]));
1033
1034        assert_eq!(acc1.result_scalar(), ScalarResult::Float64(10.0));
1035        assert_eq!(acc2.result_scalar(), ScalarResult::Float64(20.0));
1036    }
1037
1038    #[test]
1039    fn test_adapter_function_name() {
1040        let ctx = create_session_context();
1041        let factory =
1042            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1043        let acc = factory.create_accumulator();
1044        let adapter = acc
1045            .as_any()
1046            .downcast_ref::<DataFusionAccumulatorAdapter>()
1047            .expect("should be adapter");
1048        assert_eq!(adapter.function_name(), "sum");
1049    }
1050
1051    #[test]
1052    fn test_clone_box_does_not_panic() {
1053        let ctx = create_session_context();
1054        let factory =
1055            create_aggregate_factory(&ctx, "sum", vec![0], vec![DataType::Float64]).unwrap();
1056        let mut acc = factory.create_accumulator();
1057        acc.add_event(&float_event(1000, vec![1.0, 2.0, 3.0]));
1058
1059        // clone_box should not panic and should preserve state
1060        let cloned = acc.clone_box();
1061        assert_eq!(cloned.result_scalar(), ScalarResult::Float64(6.0));
1062    }
1063
1064    #[test]
1065    fn test_clone_box_empty_accumulator() {
1066        let ctx = create_session_context();
1067        let factory =
1068            create_aggregate_factory(&ctx, "count", vec![0], vec![DataType::Int64]).unwrap();
1069        let acc = factory.create_accumulator();
1070
1071        // clone_box on empty accumulator should work
1072        let cloned = acc.clone_box();
1073        let result = cloned.result_scalar();
1074        assert!(
1075            matches!(result, ScalarResult::Int64(0) | ScalarResult::UInt64(0)),
1076            "Expected 0, got {result:?}"
1077        );
1078    }
1079
1080    #[test]
1081    fn test_distinct_factory() {
1082        let ctx = create_session_context();
1083        let udf = lookup_aggregate_udf(&ctx, "count").unwrap();
1084        let factory = DataFusionAggregateFactory::new(udf, vec![0], vec![DataType::Int64])
1085            .with_distinct(true);
1086        assert!(factory.is_distinct);
1087        // Should create accumulator successfully
1088        let _acc = factory.create_accumulator();
1089    }
1090}