uni_plugin/traits/aggregate.rs
1//! Cypher aggregate plugin functions.
2//!
3//! Aggregates accumulate state across many input rows and produce a single
4//! result. The trait splits the user-facing signature ([`AggregatePluginFn`])
5//! from the per-group state machine ([`PluginAccumulator`]), matching
6//! DataFusion's `AggregateUDFImpl` / `Accumulator` split so plugin
7//! aggregates can run inside DataFusion's partial-aggregation flow.
8
9use arrow_array::ArrayRef;
10use arrow_schema::Field;
11use datafusion::logical_expr::Volatility;
12use datafusion::scalar::ScalarValue;
13
14use crate::errors::FnError;
15use crate::traits::scalar::ArgType;
16
17/// A Cypher aggregate function plugin.
18pub trait AggregatePluginFn: Send + Sync {
19 /// Static signature.
20 fn signature(&self) -> &AggSignature;
21
22 /// Construct a fresh per-group accumulator.
23 fn create_accumulator(&self) -> Box<dyn PluginAccumulator>;
24}
25
26/// Per-group state machine for an aggregate function.
27///
28/// One `PluginAccumulator` instance is created per group. The host calls
29/// `update_batch` repeatedly with the group's rows, then `evaluate` for the
30/// final value. For distributed aggregation, the host calls `state` on
31/// partial accumulators and `merge_batch` on the final accumulator.
32pub trait PluginAccumulator: Send {
33 /// Ingest a batch of input rows into the accumulator.
34 ///
35 /// `values[i]` is the `i`-th argument's column, all of equal length.
36 ///
37 /// # Errors
38 ///
39 /// Returns [`FnError`] if the input cannot be accumulated (type
40 /// mismatch, resource exhaustion).
41 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError>;
42
43 /// Merge per-partition partial states into this accumulator.
44 ///
45 /// `states[i]` is the `i`-th state field across partial accumulators.
46 ///
47 /// # Errors
48 ///
49 /// Returns [`FnError`] if the merge cannot proceed.
50 fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<(), FnError>;
51
52 /// Return the current accumulator state as scalar values, for transport
53 /// across the partial / final aggregation boundary.
54 ///
55 /// # Errors
56 ///
57 /// Returns [`FnError`] if the state cannot be serialized.
58 fn state(&self) -> Result<Vec<ScalarValue>, FnError>;
59
60 /// Produce the final aggregate value.
61 ///
62 /// # Errors
63 ///
64 /// Returns [`FnError`] if the final value cannot be computed (e.g.,
65 /// undefined for empty input and the aggregate forbids it).
66 fn evaluate(&self) -> Result<ScalarValue, FnError>;
67
68 /// Approximate in-memory size, in bytes — used for memory accounting.
69 fn size(&self) -> usize;
70}
71
72/// Static signature of an aggregate function plugin.
73#[derive(Clone, Debug)]
74pub struct AggSignature {
75 /// Argument types, in declaration order.
76 pub args: Vec<ArgType>,
77 /// Final return type.
78 pub returns: ArgType,
79 /// Schema of the per-partition partial state.
80 pub state_fields: Vec<Field>,
81 /// DataFusion volatility.
82 pub volatility: Volatility,
83 /// `true` if this aggregate supports partial aggregation (the common
84 /// case). `false` aggregates only run in a single physical pass.
85 pub supports_partial: bool,
86}
87
88impl AggSignature {
89 /// Convenience constructor for partial-aggregation-capable signatures.
90 #[must_use]
91 pub fn new(
92 args: Vec<ArgType>,
93 returns: ArgType,
94 state_fields: Vec<Field>,
95 volatility: Volatility,
96 ) -> Self {
97 Self {
98 args,
99 returns,
100 state_fields,
101 volatility,
102 supports_partial: true,
103 }
104 }
105}