Skip to main content

uni_plugin/traits/
aggregate.rs

1//! Cypher aggregate plugin functions.
2//!
3//! Aggregates accumulate state across many input rows and produce a single
4//! result. The trait splits the user-facing signature ([`AggregatePluginFn`])
5//! from the per-group state machine ([`PluginAccumulator`]), matching
6//! DataFusion's `AggregateUDFImpl` / `Accumulator` split so plugin
7//! aggregates can run inside DataFusion's partial-aggregation flow.
8
9use arrow_array::ArrayRef;
10use arrow_schema::Field;
11use datafusion::logical_expr::Volatility;
12use datafusion::scalar::ScalarValue;
13
14use crate::errors::FnError;
15use crate::traits::scalar::ArgType;
16
17/// A Cypher aggregate function plugin.
18pub trait AggregatePluginFn: Send + Sync {
19    /// Static signature.
20    fn signature(&self) -> &AggSignature;
21
22    /// Construct a fresh per-group accumulator.
23    fn create_accumulator(&self) -> Box<dyn PluginAccumulator>;
24}
25
26/// Per-group state machine for an aggregate function.
27///
28/// One `PluginAccumulator` instance is created per group. The host calls
29/// `update_batch` repeatedly with the group's rows, then `evaluate` for the
30/// final value. For distributed aggregation, the host calls `state` on
31/// partial accumulators and `merge_batch` on the final accumulator.
32pub trait PluginAccumulator: Send {
33    /// Ingest a batch of input rows into the accumulator.
34    ///
35    /// `values[i]` is the `i`-th argument's column, all of equal length.
36    ///
37    /// # Errors
38    ///
39    /// Returns [`FnError`] if the input cannot be accumulated (type
40    /// mismatch, resource exhaustion).
41    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError>;
42
43    /// Merge per-partition partial states into this accumulator.
44    ///
45    /// `states[i]` is the `i`-th state field across partial accumulators.
46    ///
47    /// # Errors
48    ///
49    /// Returns [`FnError`] if the merge cannot proceed.
50    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError>;
51
52    /// Return the current accumulator state as scalar values, for transport
53    /// across the partial / final aggregation boundary.
54    ///
55    /// # Errors
56    ///
57    /// Returns [`FnError`] if the state cannot be serialized.
58    fn state(&self) -> Result<Vec<ScalarValue>, FnError>;
59
60    /// Produce the final aggregate value.
61    ///
62    /// # Errors
63    ///
64    /// Returns [`FnError`] if the final value cannot be computed (e.g.,
65    /// undefined for empty input and the aggregate forbids it).
66    fn evaluate(&self) -> Result<ScalarValue, FnError>;
67
68    /// Approximate in-memory size, in bytes — used for memory accounting.
69    fn size(&self) -> usize;
70}
71
72/// Static signature of an aggregate function plugin.
73#[derive(Clone, Debug)]
74pub struct AggSignature {
75    /// Argument types, in declaration order.
76    pub args: Vec<ArgType>,
77    /// Final return type.
78    pub returns: ArgType,
79    /// Schema of the per-partition partial state.
80    pub state_fields: Vec<Field>,
81    /// DataFusion volatility.
82    pub volatility: Volatility,
83    /// `true` if this aggregate supports partial aggregation (the common
84    /// case). `false` aggregates only run in a single physical pass.
85    pub supports_partial: bool,
86}
87
88impl AggSignature {
89    /// Convenience constructor for partial-aggregation-capable signatures.
90    #[must_use]
91    pub fn new(
92        args: Vec<ArgType>,
93        returns: ArgType,
94        state_fields: Vec<Field>,
95        volatility: Volatility,
96    ) -> Self {
97        Self {
98            args,
99            returns,
100            state_fields,
101            volatility,
102            supports_partial: true,
103        }
104    }
105}