lance_index/scalar/inverted/
json.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::{Array, ArrayRef, LargeBinaryArray, RecordBatch};
5use arrow_schema::{DataType, Field, Schema, SchemaRef};
6use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream};
7use futures::Stream;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12/// Transform jsonb stream into json text stream
13pub struct JsonTextStream {
14    inner: SendableRecordBatchStream,
15    jsonb_col: String,
16}
17
18impl JsonTextStream {
19    pub fn new(inner: SendableRecordBatchStream, jsonb_col: String) -> Self {
20        Self { inner, jsonb_col }
21    }
22}
23
24impl Stream for JsonTextStream {
25    type Item = datafusion_common::Result<RecordBatch>;
26
27    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
28        match Pin::new(&mut self.inner).poll_next(cx) {
29            Poll::Ready(Some(Ok(batch))) => {
30                let cols: Vec<ArrayRef> = batch
31                    .schema()
32                    .fields()
33                    .iter()
34                    .enumerate()
35                    .map(|(idx, col)| {
36                        if col.name().as_str() == self.jsonb_col {
37                            Ok(jsonb_to_json(batch.column(idx), &self.jsonb_col)?)
38                        } else {
39                            Ok(batch.column(idx).clone())
40                        }
41                    })
42                    .collect::<lance_core::Result<Vec<ArrayRef>>>()?;
43
44                let new_schema = batch
45                    .schema()
46                    .fields()
47                    .iter()
48                    .map(|col| {
49                        if col.name().as_str() == self.jsonb_col {
50                            Field::new(&self.jsonb_col, DataType::LargeUtf8, true)
51                        } else {
52                            col.as_ref().clone()
53                        }
54                    })
55                    .collect::<Vec<Field>>();
56                let new_schema = Arc::new(Schema::new(new_schema));
57                let mapped = RecordBatch::try_new(new_schema, cols).unwrap();
58                Poll::Ready(Some(Ok(mapped)))
59            }
60            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
61            Poll::Ready(None) => Poll::Ready(None),
62            Poll::Pending => Poll::Pending,
63        }
64    }
65}
66
67impl RecordBatchStream for JsonTextStream {
68    fn schema(&self) -> SchemaRef {
69        Arc::new(Schema::new(vec![Field::new(
70            &self.jsonb_col,
71            DataType::Utf8,
72            true,
73        )]))
74    }
75}
76
77pub fn jsonb_to_json(col: &ArrayRef, col_name: &str) -> lance_core::Result<ArrayRef> {
78    let binary_array = col
79        .as_any()
80        .downcast_ref::<LargeBinaryArray>()
81        .unwrap_or_else(|| panic!("column {} is not a large binary array", col_name));
82    let mut builder =
83        arrow_array::builder::LargeStringBuilder::with_capacity(binary_array.len(), 1024);
84    for i in 0..binary_array.len() {
85        if binary_array.is_null(i) {
86            builder.append_null();
87        } else if let Some(bytes) = binary_array.value(i).into() {
88            let raw_jsonb = jsonb::RawJsonb::new(bytes);
89            let json_text = raw_jsonb.to_string();
90            builder.append_value(json_text);
91        } else {
92            unreachable!("jsonb value is not valid");
93        }
94    }
95    Ok(Arc::new(builder.finish()))
96}
97
98#[cfg(test)]
99mod tests {
100    use crate::scalar::inverted::json::JsonTextStream;
101    use arrow_array::builder::{LargeBinaryBuilder, UInt64Builder};
102    use arrow_array::cast::AsArray;
103    use arrow_array::{ArrayRef, RecordBatch};
104    use arrow_schema::{DataType, Field, Schema};
105    use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
106    use futures::{stream, TryStreamExt};
107    use serde_json::Value;
108    use std::sync::Arc;
109
110    #[tokio::test]
111    async fn test_json_text_stream() {
112        let json_strings = [
113            r#"{"a": 1, "b": "hello"}"#,
114            r#"{"c": [1, 2, 3], "d": {"e": true}}"#,
115            r#"{"f": null}"#,
116        ];
117
118        let mut jsonb_builder = LargeBinaryBuilder::new();
119        let mut rowid_builder = UInt64Builder::new();
120
121        for (i, json_str) in json_strings.iter().enumerate() {
122            let jsonb_bytes = jsonb::parse_value(json_str.as_bytes()).unwrap().to_vec();
123            jsonb_builder.append_value(jsonb_bytes);
124            rowid_builder.append_value(i as u64);
125        }
126
127        let schema = Arc::new(Schema::new(vec![
128            Field::new("json_col", DataType::LargeBinary, true),
129            Field::new("rowid", DataType::UInt64, false),
130        ]));
131
132        let batch = RecordBatch::try_new(
133            schema.clone(),
134            vec![
135                Arc::new(jsonb_builder.finish()) as ArrayRef,
136                Arc::new(rowid_builder.finish()) as ArrayRef,
137            ],
138        )
139        .unwrap();
140
141        let stream = Box::pin(RecordBatchStreamAdapter::new(
142            schema.clone(),
143            stream::once(async { Ok(batch) }),
144        ));
145
146        let json_text_stream = JsonTextStream::new(stream, "json_col".to_string());
147
148        let result_batches: Vec<RecordBatch> = json_text_stream.try_collect().await.unwrap();
149        assert_eq!(result_batches.len(), 1);
150        let result_batch = &result_batches[0];
151
152        let expected_schema = Arc::new(Schema::new(vec![
153            Field::new("json_col", DataType::LargeUtf8, true),
154            Field::new("rowid", DataType::UInt64, false),
155        ]));
156        assert_eq!(result_batch.schema(), expected_schema);
157
158        let json_text_col = result_batch
159            .column_by_name("json_col")
160            .unwrap()
161            .as_string::<i64>();
162
163        for (i, original_json_str) in json_strings.iter().enumerate() {
164            let converted_json_str = json_text_col.value(i);
165            let original_value: Value = serde_json::from_str(original_json_str).unwrap();
166            let converted_value: Value = serde_json::from_str(converted_json_str).unwrap();
167            assert_eq!(original_value, converted_value);
168        }
169    }
170}