Skip to main content

laminar_sql/datafusion/
json_udaf.rs

1//! PostgreSQL-compatible JSON aggregate UDFs (F-SCHEMA-011).
2//!
3//! - [`JsonAgg`] — `json_agg(expr) -> jsonb` — collects values into a JSON array
4//! - [`JsonObjectAgg`] — `json_object_agg(key, value) -> jsonb` — collects
5//!   key-value pairs into a JSON object
6
7use std::any::Any;
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10
11use arrow::datatypes::DataType;
12use arrow_array::{Array, ArrayRef, LargeBinaryArray, StringArray};
13use arrow_schema::Field;
14use datafusion_common::{Result, ScalarValue};
15use datafusion_expr::function::AccumulatorArgs;
16use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility};
17
18use super::json_types;
19
20// ══════════════════════════════════════════════════════════════════
21// json_agg(expression) -> jsonb
22// ══════════════════════════════════════════════════════════════════
23
24/// `json_agg(expression) -> jsonb`
25///
26/// Collects all values of an expression into a JSON array.
27/// Executes in Ring 1 via the DataFusion aggregate bridge.
28#[derive(Debug)]
29pub struct JsonAgg {
30    signature: Signature,
31}
32
33impl JsonAgg {
34    /// Creates a new `json_agg` UDAF.
35    #[must_use]
36    pub fn new() -> Self {
37        Self {
38            signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
39        }
40    }
41}
42
43impl Default for JsonAgg {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl PartialEq for JsonAgg {
50    fn eq(&self, _other: &Self) -> bool {
51        true
52    }
53}
54
55impl Eq for JsonAgg {}
56
57impl Hash for JsonAgg {
58    fn hash<H: Hasher>(&self, state: &mut H) {
59        "json_agg".hash(state);
60    }
61}
62
63impl AggregateUDFImpl for JsonAgg {
64    fn as_any(&self) -> &dyn Any {
65        self
66    }
67
68    fn name(&self) -> &'static str {
69        "json_agg"
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77        Ok(DataType::LargeBinary)
78    }
79
80    fn state_fields(
81        &self,
82        _args: datafusion_expr::function::StateFieldsArgs,
83    ) -> Result<Vec<Arc<Field>>> {
84        // State is a single LargeBinary holding concatenated JSONB values
85        Ok(vec![Arc::new(Field::new(
86            "json_agg_state",
87            DataType::LargeBinary,
88            true,
89        ))])
90    }
91
92    fn accumulator(&self, _args: AccumulatorArgs<'_>) -> Result<Box<dyn Accumulator>> {
93        Ok(Box::new(JsonAggAccumulator::new()))
94    }
95}
96
97/// Accumulator for `json_agg`.
98///
99/// Maintains a `Vec` of JSON values. On `evaluate`, serializes to JSONB array.
100#[derive(Debug)]
101struct JsonAggAccumulator {
102    values: Vec<serde_json::Value>,
103}
104
105impl JsonAggAccumulator {
106    fn new() -> Self {
107        Self { values: Vec::new() }
108    }
109}
110
111impl Accumulator for JsonAggAccumulator {
112    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
113        let arr = &values[0];
114        for i in 0..arr.len() {
115            if arr.is_null(i) {
116                self.values.push(serde_json::Value::Null);
117            } else {
118                self.values.push(array_value_to_json(arr, i));
119            }
120        }
121        Ok(())
122    }
123
124    fn evaluate(&mut self) -> Result<ScalarValue> {
125        let json_arr = serde_json::Value::Array(self.values.clone());
126        let bytes = json_types::encode_jsonb(&json_arr);
127        Ok(ScalarValue::LargeBinary(Some(bytes)))
128    }
129
130    fn size(&self) -> usize {
131        std::mem::size_of::<Self>()
132            + self.values.capacity() * std::mem::size_of::<serde_json::Value>()
133    }
134
135    fn state(&mut self) -> Result<Vec<ScalarValue>> {
136        // Serialize current state as a JSONB array
137        let json_arr = serde_json::Value::Array(self.values.clone());
138        let bytes = json_types::encode_jsonb(&json_arr);
139        Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
140    }
141
142    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
143        let arr = states[0]
144            .as_any()
145            .downcast_ref::<LargeBinaryArray>()
146            .ok_or_else(|| {
147                datafusion_common::DataFusionError::Internal(
148                    "json_agg: merge state must be LargeBinary".into(),
149                )
150            })?;
151        for i in 0..arr.len() {
152            if !arr.is_null(i) {
153                let bytes = arr.value(i);
154                // Decode JSONB array and merge elements
155                if let Some(json_str) = json_types::jsonb_to_text(bytes) {
156                    if let Ok(serde_json::Value::Array(elems)) =
157                        serde_json::from_str::<serde_json::Value>(&json_str)
158                    {
159                        self.values.extend(elems);
160                    }
161                }
162            }
163        }
164        Ok(())
165    }
166}
167
168// ══════════════════════════════════════════════════════════════════
169// json_object_agg(key, value) -> jsonb
170// ══════════════════════════════════════════════════════════════════
171
172/// `json_object_agg(key, value) -> jsonb`
173///
174/// Collects key-value pairs into a JSON object. Duplicate keys use
175/// last-value-wins semantics (consistent with PostgreSQL).
176#[derive(Debug)]
177pub struct JsonObjectAgg {
178    signature: Signature,
179}
180
181impl JsonObjectAgg {
182    /// Creates a new `json_object_agg` UDAF.
183    #[must_use]
184    pub fn new() -> Self {
185        Self {
186            signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
187        }
188    }
189}
190
191impl Default for JsonObjectAgg {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197impl PartialEq for JsonObjectAgg {
198    fn eq(&self, _other: &Self) -> bool {
199        true
200    }
201}
202
203impl Eq for JsonObjectAgg {}
204
205impl Hash for JsonObjectAgg {
206    fn hash<H: Hasher>(&self, state: &mut H) {
207        "json_object_agg".hash(state);
208    }
209}
210
211impl AggregateUDFImpl for JsonObjectAgg {
212    fn as_any(&self) -> &dyn Any {
213        self
214    }
215
216    fn name(&self) -> &'static str {
217        "json_object_agg"
218    }
219
220    fn signature(&self) -> &Signature {
221        &self.signature
222    }
223
224    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
225        Ok(DataType::LargeBinary)
226    }
227
228    fn state_fields(
229        &self,
230        _args: datafusion_expr::function::StateFieldsArgs,
231    ) -> Result<Vec<Arc<Field>>> {
232        Ok(vec![Arc::new(Field::new(
233            "json_object_agg_state",
234            DataType::LargeBinary,
235            true,
236        ))])
237    }
238
239    fn accumulator(&self, _args: AccumulatorArgs<'_>) -> Result<Box<dyn Accumulator>> {
240        Ok(Box::new(JsonObjectAggAccumulator::new()))
241    }
242}
243
244/// Accumulator for `json_object_agg`.
245///
246/// Maintains an ordered map of key-value pairs.
247#[derive(Debug)]
248struct JsonObjectAggAccumulator {
249    entries: serde_json::Map<String, serde_json::Value>,
250}
251
252impl JsonObjectAggAccumulator {
253    fn new() -> Self {
254        Self {
255            entries: serde_json::Map::new(),
256        }
257    }
258}
259
260impl Accumulator for JsonObjectAggAccumulator {
261    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
262        let key_arr = &values[0];
263        let val_arr = &values[1];
264
265        for i in 0..key_arr.len() {
266            if key_arr.is_null(i) {
267                continue; // Skip null keys (PostgreSQL behavior)
268            }
269            let key = array_value_to_string(key_arr, i)?;
270            let val = if val_arr.is_null(i) {
271                serde_json::Value::Null
272            } else {
273                array_value_to_json(val_arr, i)
274            };
275            self.entries.insert(key, val); // last-value-wins
276        }
277        Ok(())
278    }
279
280    fn evaluate(&mut self) -> Result<ScalarValue> {
281        let obj = serde_json::Value::Object(self.entries.clone());
282        let bytes = json_types::encode_jsonb(&obj);
283        Ok(ScalarValue::LargeBinary(Some(bytes)))
284    }
285
286    fn size(&self) -> usize {
287        std::mem::size_of::<Self>() + self.entries.len() * 64 // rough estimate
288    }
289
290    fn state(&mut self) -> Result<Vec<ScalarValue>> {
291        let obj = serde_json::Value::Object(self.entries.clone());
292        let bytes = json_types::encode_jsonb(&obj);
293        Ok(vec![ScalarValue::LargeBinary(Some(bytes))])
294    }
295
296    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
297        let arr = states[0]
298            .as_any()
299            .downcast_ref::<LargeBinaryArray>()
300            .ok_or_else(|| {
301                datafusion_common::DataFusionError::Internal(
302                    "json_object_agg: merge state must be LargeBinary".into(),
303                )
304            })?;
305        for i in 0..arr.len() {
306            if !arr.is_null(i) {
307                let bytes = arr.value(i);
308                if let Some(json_str) = json_types::jsonb_to_text(bytes) {
309                    if let Ok(serde_json::Value::Object(map)) =
310                        serde_json::from_str::<serde_json::Value>(&json_str)
311                    {
312                        for (k, v) in map {
313                            self.entries.insert(k, v);
314                        }
315                    }
316                }
317            }
318        }
319        Ok(())
320    }
321}
322
323// ── Helpers ──────────────────────────────────────────────────────
324
325/// Convert an Arrow array element to a `serde_json::Value`.
326fn array_value_to_json(arr: &ArrayRef, row: usize) -> serde_json::Value {
327    if arr.is_null(row) {
328        return serde_json::Value::Null;
329    }
330    if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
331        return serde_json::Value::String(a.value(row).to_owned());
332    }
333    if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Int64Array>() {
334        return serde_json::Value::Number(a.value(row).into());
335    }
336    if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Int32Array>() {
337        return serde_json::Value::Number(i64::from(a.value(row)).into());
338    }
339    if let Some(a) = arr.as_any().downcast_ref::<arrow_array::Float64Array>() {
340        if let Some(n) = serde_json::Number::from_f64(a.value(row)) {
341            return serde_json::Value::Number(n);
342        }
343        return serde_json::Value::Null;
344    }
345    if let Some(a) = arr.as_any().downcast_ref::<arrow_array::BooleanArray>() {
346        return serde_json::Value::Bool(a.value(row));
347    }
348    // Fallback
349    let scalar = ScalarValue::try_from_array(arr, row).ok();
350    match scalar {
351        Some(s) => serde_json::Value::String(s.to_string()),
352        None => serde_json::Value::Null,
353    }
354}
355
356/// Extract a string key from an Arrow array.
357fn array_value_to_string(arr: &ArrayRef, row: usize) -> Result<String> {
358    if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
359        return Ok(a.value(row).to_owned());
360    }
361    // Fallback via ScalarValue display
362    let sv = ScalarValue::try_from_array(arr, row)?;
363    Ok(sv.to_string())
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use arrow_array::Int64Array;
370
371    fn make_string_array(vals: &[&str]) -> StringArray {
372        StringArray::from(vals.to_vec())
373    }
374
375    #[test]
376    fn test_json_agg_basic() {
377        let mut acc = JsonAggAccumulator::new();
378        let vals = Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef;
379        acc.update_batch(&[vals]).unwrap();
380        let result = acc.evaluate().unwrap();
381        match result {
382            ScalarValue::LargeBinary(Some(bytes)) => {
383                assert_eq!(json_types::jsonb_type_name(&bytes), Some("array"));
384                let e0 = json_types::jsonb_array_get(&bytes, 0).unwrap();
385                assert_eq!(json_types::jsonb_to_text(e0), Some("1".to_owned()));
386                let e2 = json_types::jsonb_array_get(&bytes, 2).unwrap();
387                assert_eq!(json_types::jsonb_to_text(e2), Some("3".to_owned()));
388            }
389            other => panic!("Expected LargeBinary, got {other:?}"),
390        }
391    }
392
393    #[test]
394    fn test_json_agg_strings() {
395        let mut acc = JsonAggAccumulator::new();
396        let vals = Arc::new(make_string_array(&["a", "b", "c"])) as ArrayRef;
397        acc.update_batch(&[vals]).unwrap();
398        let result = acc.evaluate().unwrap();
399        match result {
400            ScalarValue::LargeBinary(Some(bytes)) => {
401                let e0 = json_types::jsonb_array_get(&bytes, 0).unwrap();
402                assert_eq!(json_types::jsonb_to_text(e0), Some("a".to_owned()));
403            }
404            other => panic!("Expected LargeBinary, got {other:?}"),
405        }
406    }
407
408    #[test]
409    fn test_json_agg_multiple_batches() {
410        let mut acc = JsonAggAccumulator::new();
411        let v1 = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
412        let v2 = Arc::new(Int64Array::from(vec![3])) as ArrayRef;
413        acc.update_batch(&[v1]).unwrap();
414        acc.update_batch(&[v2]).unwrap();
415        let result = acc.evaluate().unwrap();
416        match result {
417            ScalarValue::LargeBinary(Some(bytes)) => {
418                // Should have 3 elements total
419                let text = json_types::jsonb_to_text(&bytes).unwrap();
420                assert_eq!(text, "[1,2,3]");
421            }
422            other => panic!("Expected LargeBinary, got {other:?}"),
423        }
424    }
425
426    #[test]
427    fn test_json_object_agg_basic() {
428        let mut acc = JsonObjectAggAccumulator::new();
429        let keys = Arc::new(make_string_array(&["a", "b", "c"])) as ArrayRef;
430        let vals = Arc::new(Int64Array::from(vec![1, 2, 3])) as ArrayRef;
431        acc.update_batch(&[keys, vals]).unwrap();
432        let result = acc.evaluate().unwrap();
433        match result {
434            ScalarValue::LargeBinary(Some(bytes)) => {
435                assert_eq!(json_types::jsonb_type_name(&bytes), Some("object"));
436                let a = json_types::jsonb_get_field(&bytes, "a").unwrap();
437                assert_eq!(json_types::jsonb_to_text(a), Some("1".to_owned()));
438                let c = json_types::jsonb_get_field(&bytes, "c").unwrap();
439                assert_eq!(json_types::jsonb_to_text(c), Some("3".to_owned()));
440            }
441            other => panic!("Expected LargeBinary, got {other:?}"),
442        }
443    }
444
445    #[test]
446    fn test_json_object_agg_last_value_wins() {
447        let mut acc = JsonObjectAggAccumulator::new();
448        let keys = Arc::new(make_string_array(&["a", "a"])) as ArrayRef;
449        let vals = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
450        acc.update_batch(&[keys, vals]).unwrap();
451        let result = acc.evaluate().unwrap();
452        match result {
453            ScalarValue::LargeBinary(Some(bytes)) => {
454                let a = json_types::jsonb_get_field(&bytes, "a").unwrap();
455                assert_eq!(json_types::jsonb_to_text(a), Some("2".to_owned()));
456            }
457            other => panic!("Expected LargeBinary, got {other:?}"),
458        }
459    }
460
461    #[test]
462    fn test_json_agg_state_merge() {
463        let mut acc1 = JsonAggAccumulator::new();
464        let v1 = Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef;
465        acc1.update_batch(&[v1]).unwrap();
466        let state = acc1.state().unwrap();
467
468        let mut acc2 = JsonAggAccumulator::new();
469        let v2 = Arc::new(Int64Array::from(vec![3])) as ArrayRef;
470        acc2.update_batch(&[v2]).unwrap();
471
472        // Merge state from acc1 into acc2
473        let state_arr: ArrayRef = match &state[0] {
474            ScalarValue::LargeBinary(Some(b)) => {
475                Arc::new(LargeBinaryArray::from_iter_values(vec![b.as_slice()]))
476            }
477            _ => panic!("expected LargeBinary state"),
478        };
479        acc2.merge_batch(&[state_arr]).unwrap();
480
481        let result = acc2.evaluate().unwrap();
482        match result {
483            ScalarValue::LargeBinary(Some(bytes)) => {
484                let text = json_types::jsonb_to_text(&bytes).unwrap();
485                assert_eq!(text, "[3,1,2]");
486            }
487            other => panic!("Expected LargeBinary, got {other:?}"),
488        }
489    }
490
491    #[test]
492    fn test_udaf_registration() {
493        let json_agg = datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new());
494        assert_eq!(json_agg.name(), "json_agg");
495
496        let json_obj_agg = datafusion_expr::AggregateUDF::new_from_impl(JsonObjectAgg::new());
497        assert_eq!(json_obj_agg.name(), "json_object_agg");
498    }
499}