1use std::collections::HashMap;
29use std::sync::Arc;
30
31use arrow_array::{Array, ArrayRef};
32use arrow_schema::DataType;
33use datafusion::logical_expr::Volatility;
34use datafusion::scalar::ScalarValue;
35use semver::Version;
36use uni_common::Value;
37use uni_cypher::ast::Expr;
38use uni_cypher::parse_expression;
39use uni_plugin::traits::aggregate::{AggSignature, AggregatePluginFn, PluginAccumulator};
40use uni_plugin::traits::scalar::ArgType;
41use uni_plugin::{
42 AbiRange, Capability, CapabilitySet, Determinism, FnError, Plugin, PluginError, PluginId,
43 PluginManifest, PluginRegistrar, PluginRegistry, ProvidedSurfaces, QName, Scope, SideEffects,
44};
45
46use crate::decode::{
47 array_value_at, declared_plugin_id, eval_err_to_fn, local_part, map_plugin_error, stringify,
48 type_str_to_arrow,
49};
50use crate::eval::eval_expr;
51use crate::{CustomError, DeclaredPlugin};
52
53const STATE_PARAM: &str = "state";
56
57pub struct DeclaredAggregateFn {
63 init_expr: Arc<Expr>,
64 update_expr: Arc<Expr>,
65 finalize_expr: Arc<Expr>,
66 arg_names: Vec<String>,
67 return_dt: DataType,
68 signature: AggSignature,
69}
70
71impl std::fmt::Debug for DeclaredAggregateFn {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct("DeclaredAggregateFn")
74 .field("arg_names", &self.arg_names)
75 .field("return_type", &self.return_dt)
76 .finish_non_exhaustive()
77 }
78}
79
80impl DeclaredAggregateFn {
81 #[must_use]
83 pub fn new(
84 init_expr: Expr,
85 update_expr: Expr,
86 finalize_expr: Expr,
87 arg_names: Vec<String>,
88 return_dt: DataType,
89 ) -> Self {
90 let signature = Self::build_signature(return_dt.clone(), &arg_names);
91 Self {
92 init_expr: Arc::new(init_expr),
93 update_expr: Arc::new(update_expr),
94 finalize_expr: Arc::new(finalize_expr),
95 arg_names,
96 return_dt,
97 signature,
98 }
99 }
100
101 #[must_use]
108 pub fn build_signature(returns: DataType, arg_names: &[String]) -> AggSignature {
109 AggSignature {
110 args: arg_names
111 .iter()
112 .map(|_| ArgType::Primitive(DataType::Utf8))
113 .collect(),
114 returns: ArgType::Primitive(returns),
115 state_fields: Vec::new(),
116 volatility: Volatility::Volatile,
117 supports_partial: false,
118 }
119 }
120
121 #[must_use]
123 pub fn return_dt(&self) -> &DataType {
124 &self.return_dt
125 }
126}
127
128impl AggregatePluginFn for DeclaredAggregateFn {
129 fn signature(&self) -> &AggSignature {
130 &self.signature
131 }
132
133 fn create_accumulator(&self) -> Box<dyn PluginAccumulator> {
134 Box::new(DeclaredAccumulator {
135 init_expr: Arc::clone(&self.init_expr),
136 update_expr: Arc::clone(&self.update_expr),
137 finalize_expr: Arc::clone(&self.finalize_expr),
138 arg_names: self.arg_names.clone(),
139 return_dt: self.return_dt.clone(),
140 state: None,
141 })
142 }
143}
144
145#[derive(Debug)]
147struct DeclaredAccumulator {
148 init_expr: Arc<Expr>,
149 update_expr: Arc<Expr>,
150 finalize_expr: Arc<Expr>,
151 arg_names: Vec<String>,
152 return_dt: DataType,
153 state: Option<Value>,
154}
155
156impl DeclaredAccumulator {
157 fn ensure_state(&mut self) -> Result<(), FnError> {
159 if self.state.is_none() {
160 let bindings: HashMap<String, Value> = HashMap::new();
161 let v = eval_expr(&self.init_expr, &bindings).map_err(eval_err_to_fn)?;
162 self.state = Some(v);
163 }
164 Ok(())
165 }
166}
167
168impl PluginAccumulator for DeclaredAccumulator {
169 fn update_batch(&mut self, values: &[ArrayRef]) -> Result<(), FnError> {
170 if values.len() != self.arg_names.len() {
171 return Err(FnError::new(
172 FnError::CODE_TYPE_COERCION,
173 format!(
174 "declared aggregate expected {} args, got {}",
175 self.arg_names.len(),
176 values.len()
177 ),
178 ));
179 }
180 self.ensure_state()?;
181 let rows = values.first().map_or(0, |a| a.len());
182 for row in 0..rows {
183 let mut bindings: HashMap<String, Value> = HashMap::with_capacity(values.len() + 1);
184 let st = self.state.clone().unwrap_or(Value::Null);
187 bindings.insert(STATE_PARAM.to_owned(), st);
188 for (i, col) in values.iter().enumerate() {
189 bindings.insert(self.arg_names[i].clone(), array_value_at(col, row)?);
190 }
191 let next = eval_expr(&self.update_expr, &bindings).map_err(eval_err_to_fn)?;
192 self.state = Some(next);
193 }
194 Ok(())
195 }
196
197 fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<(), FnError> {
198 Err(FnError::new(
199 FnError::CODE_TYPE_COERCION,
200 "declared aggregates do not support partial / distributed aggregation".to_owned(),
201 ))
202 }
203
204 fn state(&self) -> Result<Vec<ScalarValue>, FnError> {
205 Ok(Vec::new())
207 }
208
209 fn evaluate(&self) -> Result<ScalarValue, FnError> {
210 let st = match &self.state {
213 Some(v) => v.clone(),
214 None => eval_expr(&self.init_expr, &HashMap::new()).map_err(eval_err_to_fn)?,
215 };
216 let mut bindings: HashMap<String, Value> = HashMap::with_capacity(1);
217 bindings.insert(STATE_PARAM.to_owned(), st);
218 let out = eval_expr(&self.finalize_expr, &bindings).map_err(eval_err_to_fn)?;
219 value_to_scalar(&out, &self.return_dt)
220 }
221
222 fn size(&self) -> usize {
223 std::mem::size_of::<Self>()
224 }
225}
226
227pub(crate) fn value_to_scalar(v: &Value, target: &DataType) -> Result<ScalarValue, FnError> {
234 match (target, v) {
235 (DataType::Utf8, Value::Null) => Ok(ScalarValue::Utf8(None)),
236 (DataType::Int64, Value::Null) => Ok(ScalarValue::Int64(None)),
237 (DataType::Float64, Value::Null) => Ok(ScalarValue::Float64(None)),
238 (DataType::Boolean, Value::Null) => Ok(ScalarValue::Boolean(None)),
239 (DataType::Utf8, Value::String(s)) => Ok(ScalarValue::Utf8(Some(s.clone()))),
240 (DataType::Utf8, other) => Ok(ScalarValue::Utf8(Some(stringify(other)))),
241 (DataType::Int64, Value::Int(i)) => Ok(ScalarValue::Int64(Some(*i))),
242 #[expect(
243 clippy::cast_possible_truncation,
244 reason = "explicit narrowing on user request"
245 )]
246 (DataType::Int64, Value::Float(f)) => Ok(ScalarValue::Int64(Some(*f as i64))),
247 (DataType::Int64, Value::Bool(b)) => Ok(ScalarValue::Int64(Some(i64::from(*b)))),
248 (DataType::Float64, Value::Float(f)) => Ok(ScalarValue::Float64(Some(*f))),
249 #[expect(
250 clippy::cast_precision_loss,
251 reason = "i64→f64 widening at user request"
252 )]
253 (DataType::Float64, Value::Int(i)) => Ok(ScalarValue::Float64(Some(*i as f64))),
254 (DataType::Boolean, Value::Bool(b)) => Ok(ScalarValue::Boolean(Some(*b))),
255 (dt, other) => Err(FnError::new(
256 FnError::CODE_TYPE_COERCION,
257 format!("declared aggregate cannot coerce {other:?} to {dt:?}"),
258 )),
259 }
260}
261
262pub fn install_aggregate_into_registry(
282 registry: &Arc<PluginRegistry>,
283 record: &DeclaredPlugin,
284) -> Result<(), CustomError> {
285 let sig_meta: serde_json::Value = serde_json::from_str(&record.signature_json)
286 .map_err(|e| CustomError::BodyParse(format!("signature_json: {e}")))?;
287 let init_src = sig_meta
288 .get("init")
289 .and_then(|v| v.as_str())
290 .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `init`".to_owned()))?;
291 let update_src = sig_meta
292 .get("update")
293 .and_then(|v| v.as_str())
294 .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `update`".to_owned()))?;
295 let finalize_src = sig_meta
296 .get("finalize")
297 .and_then(|v| v.as_str())
298 .ok_or_else(|| CustomError::BodyParse("declareAggregate: missing `finalize`".to_owned()))?;
299 let return_type_str = sig_meta
300 .get("return_type")
301 .and_then(|v| v.as_str())
302 .unwrap_or("float");
303 let arg_names: Vec<String> = sig_meta
304 .get("arg_names")
305 .and_then(|v| v.as_array())
306 .map(|arr| {
307 arr.iter()
308 .filter_map(|v| v.as_str().map(str::to_owned))
309 .collect()
310 })
311 .unwrap_or_default();
312
313 let return_dt = type_str_to_arrow(return_type_str).ok_or_else(|| {
314 CustomError::BodyParse(format!("unknown return type `{return_type_str}`"))
315 })?;
316
317 let init =
318 parse_expression(init_src).map_err(|e| CustomError::BodyParse(format!("init: {e:?}")))?;
319 let update = parse_expression(update_src)
320 .map_err(|e| CustomError::BodyParse(format!("update: {e:?}")))?;
321 let finalize = parse_expression(finalize_src)
322 .map_err(|e| CustomError::BodyParse(format!("finalize: {e:?}")))?;
323
324 let agg = DeclaredAggregateFn::new(init, update, finalize, arg_names, return_dt);
325 let signature = agg.signature().clone();
326
327 let qname = QName::new(
328 declared_plugin_id(&record.qname),
329 local_part(&record.qname).to_ascii_lowercase(),
330 );
331 let plugin = SyntheticAggregatePlugin {
332 plugin_id: PluginId::new(declared_plugin_id(&record.qname)),
333 qname: qname.clone(),
334 signature,
335 function: Arc::new(agg) as Arc<dyn AggregatePluginFn>,
336 };
337 let manifest = plugin.manifest_owned();
338 let caps = manifest.capabilities.clone();
339 let mut r = PluginRegistrar::new(manifest.id, &caps, registry);
340 plugin
341 .register(&mut r)
342 .map_err(|e| map_plugin_error(e, &record.qname))?;
343 r.commit_to_registry()
344 .map_err(|e| map_plugin_error(e, &record.qname))?;
345 uni_cypher::register_plugin_aggregate(format!("{}.{}", qname.namespace(), qname.local()));
349 Ok(())
350}
351
352struct SyntheticAggregatePlugin {
354 plugin_id: PluginId,
355 qname: QName,
356 signature: AggSignature,
357 function: Arc<dyn AggregatePluginFn>,
358}
359
360impl std::fmt::Debug for SyntheticAggregatePlugin {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 f.debug_struct("SyntheticAggregatePlugin")
363 .field("plugin_id", &self.plugin_id)
364 .field("qname", &self.qname)
365 .finish_non_exhaustive()
366 }
367}
368
369impl SyntheticAggregatePlugin {
370 fn manifest_owned(&self) -> PluginManifest {
371 PluginManifest {
372 id: self.plugin_id.clone(),
373 version: Version::new(0, 0, 1),
374 abi: AbiRange::parse("^1").expect("manifest ABI range is valid"),
375 depends_on: vec![],
376 capabilities: CapabilitySet::from_iter_of([Capability::AggregateFn]),
377 determinism: Determinism::Pure,
378 side_effects: SideEffects::ReadOnly,
379 scope: Scope::Instance,
380 hash: None,
381 signature: None,
382 provides: ProvidedSurfaces::default(),
383 docs: "Declared aggregate function (apoc.custom analogue).".to_owned(),
384 metadata: std::collections::BTreeMap::new(),
385 }
386 }
387}
388
389impl Plugin for SyntheticAggregatePlugin {
390 fn manifest(&self) -> &PluginManifest {
391 Box::leak(Box::new(self.manifest_owned()))
395 }
396
397 fn register(&self, r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
398 r.aggregate_fn(
399 self.qname.clone(),
400 self.signature.clone(),
401 Arc::clone(&self.function),
402 )?;
403 Ok(())
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use arrow_array::Int64Array;
410
411 use super::*;
412
413 fn parse(src: &str) -> Expr {
414 parse_expression(src).expect("parse")
415 }
416
417 fn build_int_sum_squares() -> DeclaredAggregateFn {
418 DeclaredAggregateFn::new(
419 parse("0"),
420 parse("$state + ($x * $x)"),
421 parse("$state"),
422 vec!["x".to_owned()],
423 DataType::Int64,
424 )
425 }
426
427 #[test]
428 fn accumulator_handles_empty_group() {
429 let agg = build_int_sum_squares();
430 let acc = agg.create_accumulator();
431 let out = acc.evaluate().expect("evaluate");
432 assert_eq!(out, ScalarValue::Int64(Some(0)));
433 }
434
435 #[test]
436 fn accumulator_runs_init_only_once() {
437 let agg = build_int_sum_squares();
438 let mut acc = agg.create_accumulator();
439 let col: ArrayRef = Arc::new(Int64Array::from(vec![1_i64, 2, 3]));
440 acc.update_batch(&[col]).expect("update");
441 let out = acc.evaluate().expect("evaluate");
442 assert_eq!(out, ScalarValue::Int64(Some(14)));
444 }
445
446 #[test]
447 fn merge_batch_is_rejected() {
448 let agg = build_int_sum_squares();
449 let mut acc = agg.create_accumulator();
450 let col: ArrayRef = Arc::new(Int64Array::from(vec![1_i64]));
451 let err = acc.merge_batch(&[col]).unwrap_err();
452 assert_eq!(err.code, FnError::CODE_TYPE_COERCION);
453 }
454
455 #[test]
456 fn signature_default_disables_partial() {
457 let agg = build_int_sum_squares();
458 let sig = agg.signature();
459 assert!(!sig.supports_partial);
460 assert!(sig.state_fields.is_empty());
461 }
462
463 #[test]
464 fn value_to_scalar_coerces_int_to_float() {
465 let sv = value_to_scalar(&Value::Int(7), &DataType::Float64).unwrap();
466 assert_eq!(sv, ScalarValue::Float64(Some(7.0)));
467 }
468}