Skip to main content

uni_plugin_custom/
aggregate.rs

1// Rust guideline compliant
2//! `DeclaredAggregateFn` — an [`AggregatePluginFn`] that evaluates three
3//! parsed Cypher expression bodies (init / update / finalize) over the
4//! [`crate::eval::eval_expr`] interpreter.
5//!
6//! Structurally mirrors [`crate::scalar::DeclaredScalarFn`]: the
7//! `uni.plugin.declareAggregate` procedure parses each body once at
8//! declare time, wraps the result in [`DeclaredAggregateFn`], and
9//! registers a synthetic [`uni_plugin::Plugin`] through
10//! [`install_aggregate_into_registry`] under a per-namespace plugin id.
11//!
12//! # State model
13//!
14//! Each per-group accumulator (returned by
15//! [`AggregatePluginFn::create_accumulator`]) carries a single
16//! [`uni_common::Value`] in `state`. The `state` is bound under the
17//! `$state` parameter when evaluating `update_expr` (per row) and
18//! `finalize_expr` (once at group end). `init_expr` runs once with no
19//! bindings on first row (or on `evaluate` for empty groups).
20//!
21//! # Partial aggregation
22//!
23//! M9 declared aggregates ship without distributed-aggregation support:
24//! `AggSignature.state_fields` is empty, `supports_partial = false`,
25//! and `merge_batch` errors out. Encoding `uni_common::Value` into a
26//! transport-stable Arrow representation is a separate lane.
27
28use std::collections::HashMap;
29use std::sync::Arc;
30
31use arrow_array::{Array, ArrayRef};
32use arrow_schema::DataType;
33use datafusion::logical_expr::Volatility;
34use datafusion::scalar::ScalarValue;
35use semver::Version;
36use uni_common::Value;
37use uni_cypher::ast::Expr;
38use uni_cypher::parse_expression;
39use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
40use uni_plugin::traits::scalar::ArgType;
41use uni_plugin::{
42    AbiRange, Capability, CapabilitySet, Determinism, FnError, Plugin, PluginError, PluginId,
43    PluginManifest, PluginRegistrar, PluginRegistry, ProvidedSurfaces, QName, Scope, SideEffects,
44};
45
46use crate::decode::{
47    array_value_at, declared_plugin_id, eval_err_to_fn, local_part, map_plugin_error, stringify,
48    type_str_to_arrow,
49};
50use crate::eval::eval_expr;
51use crate::{CustomError, DeclaredPlugin};
52
53/// Parameter name under which the accumulator's running state is bound
54/// when evaluating `update_expr` / `finalize_expr`.
55const STATE_PARAM: &str = "state";
56
57/// A Cypher-declared aggregate function.
58///
59/// Holds three pre-parsed [`Expr`] bodies (`init`, `update`,
60/// `finalize`), the positional argument names of the `update` body,
61/// the declared return type, and a precomputed [`AggSignature`].
62pub struct DeclaredAggregateFn {
63    init_expr: Arc<Expr>,
64    update_expr: Arc<Expr>,
65    finalize_expr: Arc<Expr>,
66    arg_names: Vec<String>,
67    return_dt: DataType,
68    signature: AggSignature,
69}
70
71impl std::fmt::Debug for DeclaredAggregateFn {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("DeclaredAggregateFn")
74            .field("arg_names", &self.arg_names)
75            .field("return_type", &self.return_dt)
76            .finish_non_exhaustive()
77    }
78}
79
80impl DeclaredAggregateFn {
81    /// Construct a declared aggregate from pre-parsed Cypher bodies.
82    #[must_use]
83    pub fn new(
84        init_expr: Expr,
85        update_expr: Expr,
86        finalize_expr: Expr,
87        arg_names: Vec<String>,
88        return_dt: DataType,
89    ) -> Self {
90        let signature = Self::build_signature(return_dt.clone(), &arg_names);
91        Self {
92            init_expr: Arc::new(init_expr),
93            update_expr: Arc::new(update_expr),
94            finalize_expr: Arc::new(finalize_expr),
95            arg_names,
96            return_dt,
97            signature,
98        }
99    }
100
101    /// Build a default [`AggSignature`] for a declared aggregate.
102    ///
103    /// All `update` args are declared `Utf8` (the M9 declared-scalar
104    /// convention — promotions happen at row-decode time). The returned
105    /// signature disables partial aggregation: `state_fields` is empty
106    /// and `supports_partial = false`.
107    #[must_use]
108    pub fn build_signature(returns: DataType, arg_names: &[String]) -> AggSignature {
109        AggSignature {
110            args: arg_names
111                .iter()
112                .map(|_| ArgType::Primitive(DataType::Utf8))
113                .collect(),
114            returns: ArgType::Primitive(returns),
115            state_fields: Vec::new(),
116            volatility: Volatility::Volatile,
117            supports_partial: false,
118        }
119    }
120
121    /// The configured return [`DataType`].
122    #[must_use]
123    pub fn return_dt(&self) -> &DataType {
124        &self.return_dt
125    }
126}
127
128impl AggregatePluginFn for DeclaredAggregateFn {
129    fn signature(&self) -> &AggSignature {
130        &self.signature
131    }
132
133    fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
134        Box::new(DeclaredAccumulator {
135            init_expr: Arc::clone(&self.init_expr),
136            update_expr: Arc::clone(&self.update_expr),
137            finalize_expr: Arc::clone(&self.finalize_expr),
138            arg_names: self.arg_names.clone(),
139            return_dt: self.return_dt.clone(),
140            state: None,
141        })
142    }
143}
144
145/// Per-group accumulator backed by the [`crate::eval`] interpreter.
146#[derive(Debug)]
147struct DeclaredAccumulator {
148    init_expr: Arc<Expr>,
149    update_expr: Arc<Expr>,
150    finalize_expr: Arc<Expr>,
151    arg_names: Vec<String>,
152    return_dt: DataType,
153    state: Option<Value>,
154}
155
156impl DeclaredAccumulator {
157    /// Run `init_expr` if state hasn't been initialized yet.
158    fn ensure_state(&mut self) -> Result<(), FnError> {
159        if self.state.is_none() {
160            let bindings: HashMap<String, Value> = HashMap::new();
161            let v = eval_expr(&self.init_expr, &bindings).map_err(eval_err_to_fn)?;
162            self.state = Some(v);
163        }
164        Ok(())
165    }
166}
167
168impl PluginAccumulator for DeclaredAccumulator {
169    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
170        if values.len() != self.arg_names.len() {
171            return Err(FnError::new(
172                FnError::CODE_TYPE_COERCION,
173                format!(
174                    "declared aggregate expected {} args, got {}",
175                    self.arg_names.len(),
176                    values.len()
177                ),
178            ));
179        }
180        self.ensure_state()?;
181        let rows = values.first().map_or(0, |a| a.len());
182        for row in 0..rows {
183            let mut bindings: HashMap<String, Value> = HashMap::with_capacity(values.len() + 1);
184            // `clone()` is unavoidable here — `eval_expr` takes a HashMap
185            // by reference and we replace `state` after each row.
186            let st = self.state.clone().unwrap_or(Value::Null);
187            bindings.insert(STATE_PARAM.to_owned(), st);
188            for (i, col) in values.iter().enumerate() {
189                bindings.insert(self.arg_names[i].clone(), array_value_at(col, row)?);
190            }
191            let next = eval_expr(&self.update_expr, &bindings).map_err(eval_err_to_fn)?;
192            self.state = Some(next);
193        }
194        Ok(())
195    }
196
197    fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<(), FnError> {
198        Err(FnError::new(
199            FnError::CODE_TYPE_COERCION,
200            "declared aggregates do not support partial / distributed aggregation".to_owned(),
201        ))
202    }
203
204    fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
205        // Empty state matches `AggSignature.state_fields == vec![]`.
206        Ok(Vec::new())
207    }
208
209    fn evaluate(&self) -> Result<ScalarValue, FnError> {
210        // For empty groups, `update_batch` was never called; evaluate
211        // `init_expr` on the fly so `finalize_expr` still sees a state.
212        let st = match &self.state {
213            Some(v) => v.clone(),
214            None => eval_expr(&self.init_expr, &HashMap::new()).map_err(eval_err_to_fn)?,
215        };
216        let mut bindings: HashMap<String, Value> = HashMap::with_capacity(1);
217        bindings.insert(STATE_PARAM.to_owned(), st);
218        let out = eval_expr(&self.finalize_expr, &bindings).map_err(eval_err_to_fn)?;
219        value_to_scalar(&out, &self.return_dt)
220    }
221
222    fn size(&self) -> usize {
223        std::mem::size_of::<Self>()
224    }
225}
226
227/// Convert a [`uni_common::Value`] to a [`ScalarValue`] of the requested
228/// Arrow type.
229///
230/// # Errors
231///
232/// Returns [`FnError`] when the value cannot be coerced to `target`.
233pub(crate) fn value_to_scalar(v: &Value, target: &DataType) -> Result<ScalarValue, FnError> {
234    match (target, v) {
235        (DataType::Utf8, Value::Null) => Ok(ScalarValue::Utf8(None)),
236        (DataType::Int64, Value::Null) => Ok(ScalarValue::Int64(None)),
237        (DataType::Float64, Value::Null) => Ok(ScalarValue::Float64(None)),
238        (DataType::Boolean, Value::Null) => Ok(ScalarValue::Boolean(None)),
239        (DataType::Utf8, Value::String(s)) => Ok(ScalarValue::Utf8(Some(s.clone()))),
240        (DataType::Utf8, other) => Ok(ScalarValue::Utf8(Some(stringify(other)))),
241        (DataType::Int64, Value::Int(i)) => Ok(ScalarValue::Int64(Some(*i))),
242        #[expect(
243            clippy::cast_possible_truncation,
244            reason = "explicit narrowing on user request"
245        )]
246        (DataType::Int64, Value::Float(f)) => Ok(ScalarValue::Int64(Some(*f as i64))),
247        (DataType::Int64, Value::Bool(b)) => Ok(ScalarValue::Int64(Some(i64::from(*b)))),
248        (DataType::Float64, Value::Float(f)) => Ok(ScalarValue::Float64(Some(*f))),
249        #[expect(
250            clippy::cast_precision_loss,
251            reason = "i64→f64 widening at user request"
252        )]
253        (DataType::Float64, Value::Int(i)) => Ok(ScalarValue::Float64(Some(*i as f64))),
254        (DataType::Boolean, Value::Bool(b)) => Ok(ScalarValue::Boolean(Some(*b))),
255        (dt, other) => Err(FnError::new(
256            FnError::CODE_TYPE_COERCION,
257            format!("declared aggregate cannot coerce {other:?} to {dt:?}"),
258        )),
259    }
260}
261
262// ---------------------------------------------------------------
263// Synthesis / registry installation
264// ---------------------------------------------------------------
265
266/// Compile a declared-aggregate record into a [`DeclaredAggregateFn`]
267/// and register it into `registry` under a synthetic plugin id derived
268/// from the qname's namespace.
269///
270/// `record.signature_json` must contain `{init, update, finalize,
271/// return_type, arg_names}` keys (as encoded by
272/// `DeclareAggregateProcedure::invoke`).
273///
274/// # Errors
275///
276/// * [`CustomError::BodyParse`] — `signature_json` is malformed or any
277///   of the three Cypher bodies fails to parse.
278/// * [`CustomError::NativeShadow`] — the qname is already registered as
279///   a native aggregate in `registry`.
280/// * [`CustomError::Registration`] — other registrar failures.
281pub fn install_aggregate_into_registry(
282    registry: &Arc<PluginRegistry>,
283    record: &DeclaredPlugin,
284) -> Result<(), CustomError> {
285    let sig_meta: serde_json::Value = serde_json::from_str(&record.signature_json)
286        .map_err(|e| CustomError::BodyParse(format!("signature_json: {e}")))?;
287    let init_src = sig_meta
288        .get("init")
289        .and_then(|v| v.as_str())
290        .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `init`".to_owned()))?;
291    let update_src = sig_meta
292        .get("update")
293        .and_then(|v| v.as_str())
294        .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `update`".to_owned()))?;
295    let finalize_src = sig_meta
296        .get("finalize")
297        .and_then(|v| v.as_str())
298        .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `finalize`".to_owned()))?;
299    let return_type_str = sig_meta
300        .get("return_type")
301        .and_then(|v| v.as_str())
302        .unwrap_or("float");
303    let arg_names: Vec<String> = sig_meta
304        .get("arg_names")
305        .and_then(|v| v.as_array())
306        .map(|arr| {
307            arr.iter()
308                .filter_map(|v| v.as_str().map(str::to_owned))
309                .collect()
310        })
311        .unwrap_or_default();
312
313    let return_dt = type_str_to_arrow(return_type_str).ok_or_else(|| {
314        CustomError::BodyParse(format!("unknown return type `{return_type_str}`"))
315    })?;
316
317    let init =
318        parse_expression(init_src).map_err(|e| CustomError::BodyParse(format!("init: {e:?}")))?;
319    let update = parse_expression(update_src)
320        .map_err(|e| CustomError::BodyParse(format!("update: {e:?}")))?;
321    let finalize = parse_expression(finalize_src)
322        .map_err(|e| CustomError::BodyParse(format!("finalize: {e:?}")))?;
323
324    let agg = DeclaredAggregateFn::new(init, update, finalize, arg_names, return_dt);
325    let signature = agg.signature().clone();
326
327    let qname = QName::new(
328        declared_plugin_id(&record.qname),
329        local_part(&record.qname).to_ascii_lowercase(),
330    );
331    let plugin = SyntheticAggregatePlugin {
332        plugin_id: PluginId::new(declared_plugin_id(&record.qname)),
333        qname: qname.clone(),
334        signature,
335        function: Arc::new(agg) as Arc<dyn AggregatePluginFn>,
336    };
337    let manifest = plugin.manifest_owned();
338    let caps = manifest.capabilities.clone();
339    let mut r = PluginRegistrar::new(manifest.id, &caps, registry);
340    plugin
341        .register(&mut r)
342        .map_err(|e| map_plugin_error(e, &record.qname))?;
343    r.commit_to_registry()
344        .map_err(|e| map_plugin_error(e, &record.qname))?;
345    // Publish the qname to the Cypher planner's plugin-aggregate hint
346    // set so `RETURN myAgg(x)` routes through aggregate translation
347    // instead of scalar UDF resolution.
348    uni_cypher::register_plugin_aggregate(format!("{}.{}", qname.namespace(), qname.local()));
349    Ok(())
350}
351
352/// Synthetic [`Plugin`] wrapping a single declared aggregate.
353struct SyntheticAggregatePlugin {
354    plugin_id: PluginId,
355    qname: QName,
356    signature: AggSignature,
357    function: Arc<dyn AggregatePluginFn>,
358}
359
360impl std::fmt::Debug for SyntheticAggregatePlugin {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        f.debug_struct("SyntheticAggregatePlugin")
363            .field("plugin_id", &self.plugin_id)
364            .field("qname", &self.qname)
365            .finish_non_exhaustive()
366    }
367}
368
369impl SyntheticAggregatePlugin {
370    fn manifest_owned(&self) -> PluginManifest {
371        PluginManifest {
372            id: self.plugin_id.clone(),
373            version: Version::new(0, 0, 1),
374            abi: AbiRange::parse("^1").expect("manifest ABI range is valid"),
375            depends_on: vec![],
376            capabilities: CapabilitySet::from_iter_of([Capability::AggregateFn]),
377            determinism: Determinism::Pure,
378            side_effects: SideEffects::ReadOnly,
379            scope: Scope::Instance,
380            hash: None,
381            signature: None,
382            provides: ProvidedSurfaces::default(),
383            docs: "Declared aggregate function (apoc.custom analogue).".to_owned(),
384            metadata: std::collections::BTreeMap::new(),
385        }
386    }
387}
388
389impl Plugin for SyntheticAggregatePlugin {
390    fn manifest(&self) -> &PluginManifest {
391        // M-UNSAFE: no `unsafe` used — `Box::leak` is the safe API.
392        // Mirrors `SyntheticScalarPlugin::manifest_cell` in lib.rs; the
393        // leak is bounded by declared-plugin count.
394        Box::leak(Box::new(self.manifest_owned()))
395    }
396
397    fn register(&self, r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
398        r.aggregate_fn(
399            self.qname.clone(),
400            self.signature.clone(),
401            Arc::clone(&self.function),
402        )?;
403        Ok(())
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use arrow_array::Int64Array;
410
411    use super::*;
412
413    fn parse(src: &str) -> Expr {
414        parse_expression(src).expect("parse")
415    }
416
417    fn build_int_sum_squares() -> DeclaredAggregateFn {
418        DeclaredAggregateFn::new(
419            parse("0"),
420            parse("$state + ($x * $x)"),
421            parse("$state"),
422            vec!["x".to_owned()],
423            DataType::Int64,
424        )
425    }
426
427    #[test]
428    fn accumulator_handles_empty_group() {
429        let agg = build_int_sum_squares();
430        let acc = agg.create_accumulator();
431        let out = acc.evaluate().expect("evaluate");
432        assert_eq!(out, ScalarValue::Int64(Some(0)));
433    }
434
435    #[test]
436    fn accumulator_runs_init_only_once() {
437        let agg = build_int_sum_squares();
438        let mut acc = agg.create_accumulator();
439        let col: ArrayRef = Arc::new(Int64Array::from(vec![1_i64, 2, 3]));
440        acc.update_batch(&[col]).expect("update");
441        let out = acc.evaluate().expect("evaluate");
442        // 1 + 4 + 9 = 14
443        assert_eq!(out, ScalarValue::Int64(Some(14)));
444    }
445
446    #[test]
447    fn merge_batch_is_rejected() {
448        let agg = build_int_sum_squares();
449        let mut acc = agg.create_accumulator();
450        let col: ArrayRef = Arc::new(Int64Array::from(vec![1_i64]));
451        let err = acc.merge_batch(&[col]).unwrap_err();
452        assert_eq!(err.code, FnError::CODE_TYPE_COERCION);
453    }
454
455    #[test]
456    fn signature_default_disables_partial() {
457        let agg = build_int_sum_squares();
458        let sig = agg.signature();
459        assert!(!sig.supports_partial);
460        assert!(sig.state_fields.is_empty());
461    }
462
463    #[test]
464    fn value_to_scalar_coerces_int_to_float() {
465        let sv = value_to_scalar(&Value::Int(7), &DataType::Float64).unwrap();
466        assert_eq!(sv, ScalarValue::Float64(Some(7.0)));
467    }
468}