Skip to main content

laminar_sql/datafusion/
json_tvf.rs

1//! JSON table-valued functions (F-SCHEMA-012).
2//!
3//! Implements PostgreSQL-compatible JSON TVFs as DataFusion table functions:
4//!
5//! - `jsonb_array_elements(jsonb)` → set of JSONB values
6//! - `jsonb_array_elements_text(jsonb)` → set of text values
7//! - `jsonb_each(jsonb)` → set of (key, value) pairs
8//! - `jsonb_each_text(jsonb)` → set of (key, text_value) pairs
9//! - `jsonb_object_keys(jsonb)` → set of text keys
10//!
11//! Each TVF implements `TableFunctionImpl`, producing a `MemTable`-backed
12//! `TableProvider` that holds the expanded rows.
13
14use std::sync::Arc;
15
16use arrow::datatypes::{DataType, Field, Schema};
17use arrow::record_batch::RecordBatch;
18use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
19use datafusion::catalog::{TableFunctionImpl, TableProvider};
20use datafusion::datasource::MemTable;
21use datafusion_common::{plan_err, Result, ScalarValue};
22use datafusion_expr::Expr;
23
24use super::json_types;
25
26// ── Helpers ─────────────────────────────────────────────────────────
27
28/// Builds a 1-based ordinality vector for `n` elements.
29fn ordinality_vec(n: usize) -> Vec<i64> {
30    (1..=n)
31        .map(|i| i64::try_from(i).unwrap_or(i64::MAX))
32        .collect()
33}
34
35/// JSONB tag constants (mirrored from json_types).
36mod tags {
37    pub const ARRAY: u8 = 0x06;
38    pub const OBJECT: u8 = 0x07;
39}
40
41/// Extracts a literal JSONB binary from an expression.
42fn extract_jsonb_literal(expr: &Expr) -> Result<Option<Vec<u8>>> {
43    match expr {
44        Expr::Literal(ScalarValue::LargeBinary(bytes), _) => Ok(bytes.clone()),
45        Expr::Literal(ScalarValue::Null | ScalarValue::Utf8(None), _) => Ok(None),
46        // Also accept Utf8 string literals (parse as JSON then encode to JSONB)
47        Expr::Literal(ScalarValue::Utf8(Some(s)), _) => {
48            let json_val: serde_json::Value = serde_json::from_str(s).map_err(|e| {
49                datafusion_common::DataFusionError::Plan(format!("invalid JSON literal: {e}"))
50            })?;
51            Ok(Some(json_types::encode_jsonb(&json_val)))
52        }
53        other => plan_err!(
54            "JSON TVF argument must be a JSONB (LargeBinary) or JSON string literal, got {other:?}"
55        ),
56    }
57}
58
59/// Iterates over JSONB array elements, returning each element as a bounded byte vec.
60///
61/// Returns `None` if the input is not a JSONB array.
62fn jsonb_array_elements_iter(data: &[u8]) -> Option<Vec<Vec<u8>>> {
63    if data.is_empty() || data[0] != tags::ARRAY {
64        return None;
65    }
66    if data.len() < 5 {
67        return None;
68    }
69    let count = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
70    // Offset table: count * 4 bytes (each offset is u32 LE, relative to data_start)
71    let offsets_start = 5;
72    let data_start = offsets_start + count * 4;
73    if data.len() < data_start {
74        return None;
75    }
76
77    let mut elements = Vec::with_capacity(count);
78    for i in 0..count {
79        let off_pos = offsets_start + i * 4;
80        let offset = u32::from_le_bytes([
81            data[off_pos],
82            data[off_pos + 1],
83            data[off_pos + 2],
84            data[off_pos + 3],
85        ]) as usize;
86
87        let abs_start = data_start + offset;
88        // Element end: next element's absolute start, or end of data
89        let abs_end = if i + 1 < count {
90            let next_pos = offsets_start + (i + 1) * 4;
91            data_start
92                + u32::from_le_bytes([
93                    data[next_pos],
94                    data[next_pos + 1],
95                    data[next_pos + 2],
96                    data[next_pos + 3],
97                ]) as usize
98        } else {
99            data.len()
100        };
101
102        if abs_start <= abs_end && abs_end <= data.len() {
103            elements.push(data[abs_start..abs_end].to_vec());
104        }
105    }
106    Some(elements)
107}
108
109/// Iterates over JSONB object key-value pairs.
110///
111/// Returns `None` if the input is not a JSONB object.
112fn jsonb_object_entries(data: &[u8]) -> Option<Vec<(String, Vec<u8>)>> {
113    if data.is_empty() || data[0] != tags::OBJECT {
114        return None;
115    }
116    if data.len() < 5 {
117        return None;
118    }
119    let count = u32::from_le_bytes([data[1], data[2], data[3], data[4]]) as usize;
120    // Offset table: count * 8 bytes (key_off u32 + val_off u32, relative to data_start)
121    let offsets_start = 5;
122    let data_start = offsets_start + count * 8;
123    if data.len() < data_start {
124        return None;
125    }
126
127    let mut entries = Vec::with_capacity(count);
128    for i in 0..count {
129        let base = offsets_start + i * 8;
130        let key_off =
131            u32::from_le_bytes([data[base], data[base + 1], data[base + 2], data[base + 3]])
132                as usize;
133        let val_off = u32::from_le_bytes([
134            data[base + 4],
135            data[base + 5],
136            data[base + 6],
137            data[base + 7],
138        ]) as usize;
139
140        // Key at data_start + key_off: u16 LE length + UTF-8 bytes
141        let key_abs = data_start + key_off;
142        if key_abs + 2 > data.len() {
143            continue;
144        }
145        let key_len = u16::from_le_bytes([data[key_abs], data[key_abs + 1]]) as usize;
146        let key_start = key_abs + 2;
147        let key_end = key_start + key_len;
148        if key_end > data.len() {
149            continue;
150        }
151        let key = String::from_utf8_lossy(&data[key_start..key_end]).to_string();
152
153        // Value at data_start + val_off, ending at next key's abs position
154        let val_abs = data_start + val_off;
155        let val_end = if i + 1 < count {
156            let next_base = offsets_start + (i + 1) * 8;
157            data_start
158                + u32::from_le_bytes([
159                    data[next_base],
160                    data[next_base + 1],
161                    data[next_base + 2],
162                    data[next_base + 3],
163                ]) as usize
164        } else {
165            data.len()
166        };
167
168        if val_abs <= val_end && val_end <= data.len() {
169            entries.push((key, data[val_abs..val_end].to_vec()));
170        }
171    }
172    Some(entries)
173}
174
175// ── jsonb_array_elements ─────────────────────────────────────────────
176
177/// `jsonb_array_elements(jsonb) → setof jsonb`
178///
179/// Expands a JSONB array into a set of JSONB values (one row per element).
180#[derive(Debug)]
181pub struct JsonbArrayElementsTvf;
182
183impl TableFunctionImpl for JsonbArrayElementsTvf {
184    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
185        if args.len() != 1 {
186            return plan_err!("jsonb_array_elements requires exactly 1 argument");
187        }
188        let schema = Arc::new(Schema::new(vec![
189            Field::new("value", DataType::LargeBinary, true),
190            Field::new("ordinality", DataType::Int64, false),
191        ]));
192
193        let bytes = extract_jsonb_literal(&args[0])?;
194        let elements = bytes.as_deref().and_then(jsonb_array_elements_iter);
195
196        match elements {
197            Some(elems) if !elems.is_empty() => {
198                let values: Vec<Option<&[u8]>> = elems.iter().map(|e| Some(e.as_slice())).collect();
199                let ordinality = ordinality_vec(elems.len());
200                let batch = RecordBatch::try_new(
201                    Arc::clone(&schema),
202                    vec![
203                        Arc::new(LargeBinaryArray::from(values)),
204                        Arc::new(Int64Array::from(ordinality)),
205                    ],
206                )?;
207                Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
208            }
209            _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
210        }
211    }
212}
213
214// ── jsonb_array_elements_text ────────────────────────────────────────
215
216/// `jsonb_array_elements_text(jsonb) → setof text`
217///
218/// Same as `jsonb_array_elements` but returns each element as text.
219#[derive(Debug)]
220pub struct JsonbArrayElementsTextTvf;
221
222impl TableFunctionImpl for JsonbArrayElementsTextTvf {
223    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
224        if args.len() != 1 {
225            return plan_err!("jsonb_array_elements_text requires exactly 1 argument");
226        }
227        let schema = Arc::new(Schema::new(vec![
228            Field::new("value", DataType::Utf8, true),
229            Field::new("ordinality", DataType::Int64, false),
230        ]));
231
232        let bytes = extract_jsonb_literal(&args[0])?;
233        let elements = bytes.as_deref().and_then(jsonb_array_elements_iter);
234
235        match elements {
236            Some(elems) if !elems.is_empty() => {
237                let texts: Vec<Option<String>> =
238                    elems.iter().map(|e| json_types::jsonb_to_text(e)).collect();
239                let ordinality = ordinality_vec(elems.len());
240                let batch = RecordBatch::try_new(
241                    Arc::clone(&schema),
242                    vec![
243                        Arc::new(StringArray::from(texts)),
244                        Arc::new(Int64Array::from(ordinality)),
245                    ],
246                )?;
247                Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
248            }
249            _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
250        }
251    }
252}
253
254// ── jsonb_each ──────────────────────────────────────────────────────
255
256/// `jsonb_each(jsonb) → setof (key text, value jsonb)`
257///
258/// Expands a JSONB object into a set of key-value pairs.
259#[derive(Debug)]
260pub struct JsonbEachTvf;
261
262impl TableFunctionImpl for JsonbEachTvf {
263    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
264        if args.len() != 1 {
265            return plan_err!("jsonb_each requires exactly 1 argument");
266        }
267        let schema = Arc::new(Schema::new(vec![
268            Field::new("key", DataType::Utf8, false),
269            Field::new("value", DataType::LargeBinary, true),
270            Field::new("ordinality", DataType::Int64, false),
271        ]));
272
273        let bytes = extract_jsonb_literal(&args[0])?;
274        let entries = bytes.as_deref().and_then(jsonb_object_entries);
275
276        match entries {
277            Some(kvs) if !kvs.is_empty() => {
278                let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
279                let values: Vec<Option<&[u8]>> =
280                    kvs.iter().map(|(_, v)| Some(v.as_slice())).collect();
281                let ordinality = ordinality_vec(kvs.len());
282                let batch = RecordBatch::try_new(
283                    Arc::clone(&schema),
284                    vec![
285                        Arc::new(StringArray::from(keys)),
286                        Arc::new(LargeBinaryArray::from(values)),
287                        Arc::new(Int64Array::from(ordinality)),
288                    ],
289                )?;
290                Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
291            }
292            _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
293        }
294    }
295}
296
297// ── jsonb_each_text ─────────────────────────────────────────────────
298
299/// `jsonb_each_text(jsonb) → setof (key text, value text)`
300///
301/// Same as `jsonb_each` but casts each value to text.
302#[derive(Debug)]
303pub struct JsonbEachTextTvf;
304
305impl TableFunctionImpl for JsonbEachTextTvf {
306    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
307        if args.len() != 1 {
308            return plan_err!("jsonb_each_text requires exactly 1 argument");
309        }
310        let schema = Arc::new(Schema::new(vec![
311            Field::new("key", DataType::Utf8, false),
312            Field::new("value", DataType::Utf8, true),
313            Field::new("ordinality", DataType::Int64, false),
314        ]));
315
316        let bytes = extract_jsonb_literal(&args[0])?;
317        let entries = bytes.as_deref().and_then(jsonb_object_entries);
318
319        match entries {
320            Some(kvs) if !kvs.is_empty() => {
321                let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
322                let texts: Vec<Option<String>> = kvs
323                    .iter()
324                    .map(|(_, v)| json_types::jsonb_to_text(v))
325                    .collect();
326                let ordinality = ordinality_vec(kvs.len());
327                let batch = RecordBatch::try_new(
328                    Arc::clone(&schema),
329                    vec![
330                        Arc::new(StringArray::from(keys)),
331                        Arc::new(StringArray::from(texts)),
332                        Arc::new(Int64Array::from(ordinality)),
333                    ],
334                )?;
335                Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
336            }
337            _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
338        }
339    }
340}
341
342// ── jsonb_object_keys ───────────────────────────────────────────────
343
344/// `jsonb_object_keys(jsonb) → setof text`
345///
346/// Returns all keys of a JSONB object as text rows.
347#[derive(Debug)]
348pub struct JsonbObjectKeysTvf;
349
350impl TableFunctionImpl for JsonbObjectKeysTvf {
351    fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> {
352        if args.len() != 1 {
353            return plan_err!("jsonb_object_keys requires exactly 1 argument");
354        }
355        let schema = Arc::new(Schema::new(vec![
356            Field::new("key", DataType::Utf8, false),
357            Field::new("ordinality", DataType::Int64, false),
358        ]));
359
360        let bytes = extract_jsonb_literal(&args[0])?;
361        let entries = bytes.as_deref().and_then(jsonb_object_entries);
362
363        match entries {
364            Some(kvs) if !kvs.is_empty() => {
365                let keys: Vec<&str> = kvs.iter().map(|(k, _)| k.as_str()).collect();
366                let ordinality = ordinality_vec(kvs.len());
367                let batch = RecordBatch::try_new(
368                    Arc::clone(&schema),
369                    vec![
370                        Arc::new(StringArray::from(keys)),
371                        Arc::new(Int64Array::from(ordinality)),
372                    ],
373                )?;
374                Ok(Arc::new(MemTable::try_new(schema, vec![vec![batch]])?))
375            }
376            _ => Ok(Arc::new(MemTable::try_new(schema, vec![vec![]])?)),
377        }
378    }
379}
380
381// ── Registration ────────────────────────────────────────────────────
382
383/// Registers all JSON table-valued functions with the `SessionContext`.
384pub fn register_json_table_functions(ctx: &datafusion::prelude::SessionContext) {
385    ctx.register_udtf("jsonb_array_elements", Arc::new(JsonbArrayElementsTvf));
386    ctx.register_udtf(
387        "jsonb_array_elements_text",
388        Arc::new(JsonbArrayElementsTextTvf),
389    );
390    ctx.register_udtf("jsonb_each", Arc::new(JsonbEachTvf));
391    ctx.register_udtf("jsonb_each_text", Arc::new(JsonbEachTextTvf));
392    ctx.register_udtf("jsonb_object_keys", Arc::new(JsonbObjectKeysTvf));
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::datafusion::create_session_context;
399
400    fn make_jsonb_expr(json_str: &str) -> Expr {
401        let val: serde_json::Value = serde_json::from_str(json_str).unwrap();
402        let bytes = json_types::encode_jsonb(&val);
403        Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), None)
404    }
405
406    // ── jsonb_array_elements tests ──
407
408    #[test]
409    fn test_array_elements_basic() {
410        let tvf = JsonbArrayElementsTvf;
411        let provider = tvf.call(&[make_jsonb_expr("[1, 2, 3]")]).unwrap();
412        let schema = provider.schema();
413        assert_eq!(schema.fields().len(), 2);
414        assert_eq!(schema.field(0).name(), "value");
415        assert_eq!(schema.field(1).name(), "ordinality");
416    }
417
418    #[tokio::test]
419    async fn test_array_elements_via_sql() {
420        let ctx = create_session_context();
421        register_json_table_functions(&ctx);
422
423        let df = ctx
424            .sql("SELECT value, ordinality FROM jsonb_array_elements('[10, 20, 30]')")
425            .await
426            .unwrap();
427        let batches = df.collect().await.unwrap();
428        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
429        assert_eq!(total, 3);
430
431        // Check ordinality
432        let ord = batches[0]
433            .column(1)
434            .as_any()
435            .downcast_ref::<Int64Array>()
436            .unwrap();
437        assert_eq!(ord.value(0), 1);
438        assert_eq!(ord.value(1), 2);
439        assert_eq!(ord.value(2), 3);
440    }
441
442    #[tokio::test]
443    async fn test_array_elements_empty() {
444        let ctx = create_session_context();
445        register_json_table_functions(&ctx);
446
447        let df = ctx
448            .sql("SELECT value FROM jsonb_array_elements('[]')")
449            .await
450            .unwrap();
451        let batches = df.collect().await.unwrap();
452        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
453        assert_eq!(total, 0);
454    }
455
456    #[tokio::test]
457    async fn test_array_elements_not_array() {
458        let ctx = create_session_context();
459        register_json_table_functions(&ctx);
460
461        let df = ctx
462            .sql("SELECT value FROM jsonb_array_elements('{\"a\":1}')")
463            .await
464            .unwrap();
465        let batches = df.collect().await.unwrap();
466        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
467        assert_eq!(total, 0); // graceful: non-array returns 0 rows
468    }
469
470    // ── jsonb_array_elements_text tests ──
471
472    #[tokio::test]
473    async fn test_array_elements_text_strings() {
474        let ctx = create_session_context();
475        register_json_table_functions(&ctx);
476
477        let df = ctx
478            .sql("SELECT value FROM jsonb_array_elements_text('[\"a\", \"b\", \"c\"]')")
479            .await
480            .unwrap();
481        let batches = df.collect().await.unwrap();
482        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
483        assert_eq!(total, 3);
484
485        let vals = batches[0]
486            .column(0)
487            .as_any()
488            .downcast_ref::<StringArray>()
489            .unwrap();
490        assert_eq!(vals.value(0), "a");
491        assert_eq!(vals.value(1), "b");
492        assert_eq!(vals.value(2), "c");
493    }
494
495    #[tokio::test]
496    async fn test_array_elements_text_mixed() {
497        let ctx = create_session_context();
498        register_json_table_functions(&ctx);
499
500        let df = ctx
501            .sql("SELECT value FROM jsonb_array_elements_text('[1, \"hello\", true]')")
502            .await
503            .unwrap();
504        let batches = df.collect().await.unwrap();
505        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
506        assert_eq!(total, 3);
507    }
508
509    // ── jsonb_each tests ──
510
511    #[tokio::test]
512    async fn test_each_basic() {
513        let ctx = create_session_context();
514        register_json_table_functions(&ctx);
515
516        let df = ctx
517            .sql("SELECT key, ordinality FROM jsonb_each('{\"a\":1,\"b\":2}')")
518            .await
519            .unwrap();
520        let batches = df.collect().await.unwrap();
521        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
522        assert_eq!(total, 2);
523    }
524
525    #[tokio::test]
526    async fn test_each_empty() {
527        let ctx = create_session_context();
528        register_json_table_functions(&ctx);
529
530        let df = ctx.sql("SELECT key FROM jsonb_each('{}')").await.unwrap();
531        let batches = df.collect().await.unwrap();
532        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
533        assert_eq!(total, 0);
534    }
535
536    // ── jsonb_each_text tests ──
537
538    #[tokio::test]
539    async fn test_each_text_basic() {
540        let ctx = create_session_context();
541        register_json_table_functions(&ctx);
542
543        let df = ctx
544            .sql("SELECT key, value FROM jsonb_each_text('{\"x\":\"hello\",\"y\":42}')")
545            .await
546            .unwrap();
547        let batches = df.collect().await.unwrap();
548        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
549        assert_eq!(total, 2);
550    }
551
552    // ── jsonb_object_keys tests ──
553
554    #[tokio::test]
555    async fn test_object_keys_basic() {
556        let ctx = create_session_context();
557        register_json_table_functions(&ctx);
558
559        let df = ctx
560            .sql("SELECT key FROM jsonb_object_keys('{\"a\":1,\"b\":2,\"c\":3}')")
561            .await
562            .unwrap();
563        let batches = df.collect().await.unwrap();
564        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
565        assert_eq!(total, 3);
566    }
567
568    #[tokio::test]
569    async fn test_object_keys_empty() {
570        let ctx = create_session_context();
571        register_json_table_functions(&ctx);
572
573        let df = ctx
574            .sql("SELECT key FROM jsonb_object_keys('{}')")
575            .await
576            .unwrap();
577        let batches = df.collect().await.unwrap();
578        let total: usize = batches.iter().map(RecordBatch::num_rows).sum();
579        assert_eq!(total, 0);
580    }
581
582    // ── Registration test ──
583
584    #[test]
585    fn test_registration() {
586        let ctx = create_session_context();
587        register_json_table_functions(&ctx);
588        assert!(ctx.table_function("jsonb_array_elements").is_ok());
589        assert!(ctx.table_function("jsonb_array_elements_text").is_ok());
590        assert!(ctx.table_function("jsonb_each").is_ok());
591        assert!(ctx.table_function("jsonb_each_text").is_ok());
592        assert!(ctx.table_function("jsonb_object_keys").is_ok());
593    }
594}