Skip to main content

lance_index/scalar/
json.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::HashMap,
6    ops::Bound,
7    sync::{Arc, Mutex},
8};
9
10use arrow_array::{Array, LargeBinaryArray, RecordBatch, StructArray, UInt8Array};
11use arrow_schema::{DataType, Field, Field as ArrowField, Schema};
12use async_trait::async_trait;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use datafusion::{
15    execution::SendableRecordBatchStream,
16    physical_plan::{ExecutionPlan, projection::ProjectionExec},
17};
18use datafusion_common::{ScalarValue, config::ConfigOptions};
19use datafusion_expr::{Expr, Operator, ScalarUDF};
20use datafusion_physical_expr::{
21    PhysicalExpr, ScalarFunctionExpr,
22    expressions::{Column, Literal},
23};
24use deepsize::DeepSizeOf;
25use futures::StreamExt;
26use lance_datafusion::exec::{LanceExecutionOptions, OneShotExec, get_session_context};
27use lance_datafusion::udf::json::JsonbType;
28use prost::Message;
29use roaring::RoaringBitmap;
30use serde::{Deserialize, Serialize};
31
32use lance_core::{Error, ROW_ID, Result, cache::LanceCache, error::LanceOptionExt};
33
34use crate::{
35    Index, IndexType,
36    frag_reuse::FragReuseIndex,
37    metrics::MetricsCollector,
38    registry::IndexPluginRegistry,
39    scalar::{
40        AnyQuery, CreatedIndex, IndexStore, ScalarIndex, SearchResult, UpdateCriteria,
41        expression::{IndexedExpression, ScalarIndexExpr, ScalarIndexSearch, ScalarQueryParser},
42        registry::{ScalarIndexPlugin, TrainingCriteria, TrainingRequest, VALUE_COLUMN_NAME},
43    },
44};
45
46const JSON_INDEX_VERSION: u32 = 0;
47
48/// A JSON index that indexes a field in a JSON column
49///
50/// The underlying index can be any other type of scalar index
51#[derive(Debug)]
52pub struct JsonIndex {
53    target_index: Arc<dyn ScalarIndex>,
54    path: String,
55}
56
57impl JsonIndex {
58    pub fn new(target_index: Arc<dyn ScalarIndex>, path: String) -> Self {
59        Self { target_index, path }
60    }
61}
62
63impl DeepSizeOf for JsonIndex {
64    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
65        self.target_index.deep_size_of_children(context) + self.path.deep_size_of_children(context)
66    }
67}
68
69#[async_trait]
70impl Index for JsonIndex {
71    fn as_any(&self) -> &dyn std::any::Any {
72        self
73    }
74
75    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
76        self
77    }
78
79    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
80        unimplemented!()
81    }
82
83    fn index_type(&self) -> IndexType {
84        // TODO: This causes the index to appear as btree in list_indices call.  Need better logic
85        // in list_indices to use details instead of index_type.
86        IndexType::Scalar
87    }
88
89    async fn prewarm(&self) -> Result<()> {
90        self.target_index.prewarm().await
91    }
92
93    fn statistics(&self) -> Result<serde_json::Value> {
94        todo!()
95    }
96
97    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
98        self.target_index.calculate_included_frags().await
99    }
100}
101
102#[async_trait]
103impl ScalarIndex for JsonIndex {
104    async fn search(
105        &self,
106        query: &dyn AnyQuery,
107        metrics: &dyn MetricsCollector,
108    ) -> Result<SearchResult> {
109        let query = query.as_any().downcast_ref::<JsonQuery>().unwrap();
110        self.target_index
111            .search(query.target_query.as_ref(), metrics)
112            .await
113    }
114
115    fn can_remap(&self) -> bool {
116        self.target_index.can_remap()
117    }
118
119    async fn remap(
120        &self,
121        mapping: &HashMap<u64, Option<u64>>,
122        dest_store: &dyn IndexStore,
123    ) -> Result<CreatedIndex> {
124        let target_created = self.target_index.remap(mapping, dest_store).await?;
125        let json_details = crate::pb::JsonIndexDetails {
126            path: self.path.clone(),
127            target_details: Some(target_created.index_details),
128        };
129        Ok(CreatedIndex {
130            index_details: prost_types::Any::from_msg(&json_details)?,
131            // TODO: We should store the target index version in the details
132            index_version: JSON_INDEX_VERSION,
133            files: Some(dest_store.list_files_with_sizes().await?),
134        })
135    }
136
137    async fn update(
138        &self,
139        new_data: SendableRecordBatchStream,
140        dest_store: &dyn IndexStore,
141        old_data_filter: Option<super::OldIndexDataFilter>,
142    ) -> Result<CreatedIndex> {
143        let target_created = self
144            .target_index
145            .update(new_data, dest_store, old_data_filter)
146            .await?;
147        let json_details = crate::pb::JsonIndexDetails {
148            path: self.path.clone(),
149            target_details: Some(target_created.index_details),
150        };
151        Ok(CreatedIndex {
152            index_details: prost_types::Any::from_msg(&json_details)?,
153            // TODO: We should store the target index version in the details
154            index_version: JSON_INDEX_VERSION,
155            files: Some(dest_store.list_files_with_sizes().await?),
156        })
157    }
158
159    fn update_criteria(&self) -> UpdateCriteria {
160        self.target_index.update_criteria()
161    }
162
163    fn derive_index_params(&self) -> Result<super::ScalarIndexParams> {
164        self.target_index.derive_index_params()
165    }
166}
167
168/// Parameters for a [`JsonIndex`]
169#[derive(Debug, Serialize, Deserialize)]
170pub struct JsonIndexParameters {
171    target_index_type: String,
172    target_index_parameters: Option<String>,
173    path: String,
174}
175
176// TODO: Do we really need to wrap the query or could we just return the target query directly?
177//
178// I think the only thing we really gain is a different format impl (e.g. it shows up as a json query
179// in the explain plan) but I don't know if that helps the user much.
180#[derive(Debug, Clone)]
181pub struct JsonQuery {
182    target_query: Arc<dyn AnyQuery>,
183    path: String,
184}
185
186impl JsonQuery {
187    pub fn new(target_query: Arc<dyn AnyQuery>, path: String) -> Self {
188        Self { target_query, path }
189    }
190}
191
192impl PartialEq for JsonQuery {
193    fn eq(&self, other: &Self) -> bool {
194        self.target_query.dyn_eq(other.target_query.as_ref()) && self.path == other.path
195    }
196}
197
198impl AnyQuery for JsonQuery {
199    fn as_any(&self) -> &dyn std::any::Any {
200        self
201    }
202
203    fn format(&self, col: &str) -> String {
204        format!("Json({}->{})", self.target_query.format(col), self.path)
205    }
206
207    fn to_expr(&self, _col: String) -> Expr {
208        todo!()
209    }
210
211    fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
212        match other.as_any().downcast_ref::<Self>() {
213            Some(o) => self == o,
214            None => false,
215        }
216    }
217}
218
219#[derive(Debug)]
220pub struct JsonQueryParser {
221    path: String,
222    target_parser: Box<dyn ScalarQueryParser>,
223}
224
225impl JsonQueryParser {
226    pub fn new(path: String, target_parser: Box<dyn ScalarQueryParser>) -> Self {
227        Self {
228            path,
229            target_parser,
230        }
231    }
232
233    fn wrap_search(&self, target_expr: IndexedExpression) -> IndexedExpression {
234        if let Some(scalar_query) = target_expr.scalar_query {
235            let scalar_query = match scalar_query {
236                ScalarIndexExpr::Query(ScalarIndexSearch {
237                    column,
238                    index_name,
239                    index_type,
240                    query,
241                    needs_recheck,
242                }) => ScalarIndexExpr::Query(ScalarIndexSearch {
243                    column,
244                    index_name,
245                    index_type,
246                    query: Arc::new(JsonQuery::new(query, self.path.clone())),
247                    needs_recheck,
248                }),
249                // This code path should only be hit on leaf expr
250                _ => unreachable!(),
251            };
252            IndexedExpression {
253                scalar_query: Some(scalar_query),
254                refine_expr: target_expr.refine_expr,
255            }
256        } else {
257            target_expr
258        }
259    }
260}
261
262impl ScalarQueryParser for JsonQueryParser {
263    fn visit_between(
264        &self,
265        column: &str,
266        low: &Bound<ScalarValue>,
267        high: &Bound<ScalarValue>,
268    ) -> Option<IndexedExpression> {
269        self.target_parser
270            .visit_between(column, low, high)
271            .map(|target_expr| self.wrap_search(target_expr))
272    }
273    fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
274        self.target_parser
275            .visit_in_list(column, in_list)
276            .map(|target_expr| self.wrap_search(target_expr))
277    }
278    fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
279        self.target_parser
280            .visit_is_bool(column, value)
281            .map(|target_expr| self.wrap_search(target_expr))
282    }
283    fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
284        self.target_parser
285            .visit_is_null(column)
286            .map(|target_expr| self.wrap_search(target_expr))
287    }
288    fn visit_comparison(
289        &self,
290        column: &str,
291        value: &ScalarValue,
292        op: &Operator,
293    ) -> Option<IndexedExpression> {
294        self.target_parser
295            .visit_comparison(column, value, op)
296            .map(|target_expr| self.wrap_search(target_expr))
297    }
298    fn visit_scalar_function(
299        &self,
300        column: &str,
301        data_type: &DataType,
302        func: &ScalarUDF,
303        args: &[Expr],
304    ) -> Option<IndexedExpression> {
305        self.target_parser
306            .visit_scalar_function(column, data_type, func, args)
307            .map(|target_expr| self.wrap_search(target_expr))
308    }
309
310    // TODO: maybe we should address it by https://github.com/lance-format/lance/issues/4624
311    fn is_valid_reference(&self, func: &Expr, _data_type: &DataType) -> Option<DataType> {
312        match func {
313            Expr::ScalarFunction(udf) => {
314                // Support multiple JSON extraction functions
315                let json_functions = [
316                    "json_extract",
317                    "json_get",
318                    "json_get_int",
319                    "json_get_float",
320                    "json_get_bool",
321                    "json_get_string",
322                ];
323                if !json_functions.contains(&udf.name()) {
324                    return None;
325                }
326                if udf.args.len() != 2 {
327                    return None;
328                }
329                // We already know index 0 is a column reference to the column so we just need to
330                // ensure that index 1 matches our path
331                match &udf.args[1] {
332                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
333                        if path == &self.path {
334                            // Return the appropriate type based on the function
335                            match udf.name() {
336                                "json_get_int" => Some(DataType::Int64),
337                                "json_get_float" => Some(DataType::Float64),
338                                "json_get_bool" => Some(DataType::Boolean),
339                                "json_get_string" | "json_extract" => Some(DataType::Utf8),
340                                _ => None,
341                            }
342                        } else {
343                            None
344                        }
345                    }
346                    _ => None,
347                }
348            }
349            _ => None,
350        }
351    }
352}
353
354pub struct JsonTrainingRequest {
355    parameters: JsonIndexParameters,
356    target_request: Box<dyn TrainingRequest>,
357}
358
359impl JsonTrainingRequest {
360    pub fn new(parameters: JsonIndexParameters, target_request: Box<dyn TrainingRequest>) -> Self {
361        Self {
362            parameters,
363            target_request,
364        }
365    }
366}
367
368impl TrainingRequest for JsonTrainingRequest {
369    fn as_any(&self) -> &dyn std::any::Any {
370        self
371    }
372
373    fn criteria(&self) -> &TrainingCriteria {
374        self.target_request.criteria()
375    }
376}
377
378/// Plugin implementation for a [`JsonIndex`]
379#[derive(Default)]
380pub struct JsonIndexPlugin {
381    registry: Mutex<Option<Arc<IndexPluginRegistry>>>,
382}
383
384impl std::fmt::Debug for JsonIndexPlugin {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        write!(f, "JsonIndexPlugin")
387    }
388}
389
390impl JsonIndexPlugin {
391    fn registry(&self) -> Result<Arc<IndexPluginRegistry>> {
392        Ok(self.registry.lock().unwrap().as_ref().expect_ok()?.clone())
393    }
394
395    /// Extract JSON with type information using the new UDF
396    async fn extract_json_with_type_info(
397        data: SendableRecordBatchStream,
398        path: String,
399    ) -> Result<(SendableRecordBatchStream, DataType)> {
400        let input = Arc::new(OneShotExec::new(data));
401        let input_schema = input.schema();
402        let value_column_idx = input_schema
403            .column_with_name(VALUE_COLUMN_NAME)
404            .expect_ok()?
405            .0;
406        let row_id_column_idx = input_schema.column_with_name(ROW_ID).expect_ok()?.0;
407
408        // Call json_extract_with_type UDF
409        let exprs = vec![
410            (
411                Arc::new(ScalarFunctionExpr::try_new(
412                    Arc::new(lance_datafusion::udf::json::json_extract_with_type_udf()),
413                    vec![
414                        Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_idx)),
415                        Arc::new(Literal::new(ScalarValue::Utf8(Some(path)))),
416                    ],
417                    &input_schema,
418                    Arc::new(ConfigOptions::default()),
419                )?) as Arc<dyn PhysicalExpr>,
420                "json_result".to_string(),
421            ),
422            (
423                Arc::new(Column::new(ROW_ID, row_id_column_idx)) as Arc<dyn PhysicalExpr>,
424                ROW_ID.to_string(),
425            ),
426        ];
427
428        let project = ProjectionExec::try_new(exprs, input)?;
429        let ctx = get_session_context(&LanceExecutionOptions::default());
430        let mut stream = project.execute(0, ctx.task_ctx())?;
431
432        // Collect batches and determine type from first non-null value
433        let mut all_batches = Vec::new();
434        let mut inferred_type: Option<DataType> = None;
435
436        while let Some(batch_result) = stream.next().await {
437            let batch = batch_result?;
438
439            // Determine type from first non-null value if not yet set
440            if inferred_type.is_none()
441                && let Some(json_result_column) = batch.column_by_name("json_result")
442                && let Some(struct_array) =
443                    json_result_column.as_any().downcast_ref::<StructArray>()
444                && let Some(type_array) = struct_array.column_by_name("type_tag")
445                && let Some(uint8_array) = type_array.as_any().downcast_ref::<UInt8Array>()
446            {
447                // Find first non-null value to determine type
448                for i in 0..uint8_array.len() {
449                    if !uint8_array.is_null(i) {
450                        let type_tag = uint8_array.value(i);
451                        let jsonb_type = JsonbType::from_u8(type_tag).ok_or_else(|| {
452                            Error::invalid_input_source(
453                                format!("Invalid type tag: {}", type_tag).into(),
454                            )
455                        })?;
456
457                        // Map JsonbType to Arrow DataType
458                        inferred_type = Some(match jsonb_type {
459                            JsonbType::Null => continue, // Skip null values
460                            JsonbType::Boolean => DataType::Boolean,
461                            JsonbType::Int64 => DataType::Int64,
462                            JsonbType::Float64 => DataType::Float64,
463                            JsonbType::String => DataType::Utf8,
464                            JsonbType::Array => DataType::LargeBinary,
465                            JsonbType::Object => DataType::LargeBinary,
466                        });
467                        break;
468                    }
469                }
470            }
471
472            all_batches.push(batch);
473        }
474
475        // If no type was inferred (all nulls), default to String
476        let inferred_type = inferred_type.unwrap_or(DataType::Utf8);
477
478        // Recreate stream from collected batches
479        let schema = all_batches
480            .first()
481            .map(|b| b.schema())
482            .ok_or_else(|| Error::invalid_input_source("No batches in stream".into()))?;
483
484        let recreated_stream = Box::pin(RecordBatchStreamAdapter::new(
485            schema,
486            futures::stream::iter(all_batches.into_iter().map(Ok)),
487        )) as SendableRecordBatchStream;
488
489        Ok((recreated_stream, inferred_type))
490    }
491
492    /// Convert the stream with JSONB values and type tags to properly typed values
493    async fn convert_stream_by_type(
494        data: SendableRecordBatchStream,
495        target_type: DataType,
496    ) -> Result<SendableRecordBatchStream> {
497        let input = Arc::new(OneShotExec::new(data));
498        let _input_schema = input.schema();
499        let ctx = get_session_context(&LanceExecutionOptions::default());
500        let mut stream = input.execute(0, ctx.task_ctx())?;
501
502        let mut converted_batches = Vec::new();
503
504        while let Some(batch_result) = stream.next().await {
505            let batch = batch_result?;
506
507            // Extract the struct column containing value and type_tag
508            let json_result_column = batch
509                .column_by_name("json_result")
510                .ok_or_else(|| Error::invalid_input_source("Missing json_result column".into()))?;
511
512            let struct_array = json_result_column
513                .as_any()
514                .downcast_ref::<StructArray>()
515                .ok_or_else(|| Error::invalid_input_source("json_result is not a struct".into()))?;
516
517            let value_array = struct_array.column_by_name("value").ok_or_else(|| {
518                Error::invalid_input_source("Missing value column in struct".into())
519            })?;
520
521            let binary_array = value_array
522                .as_any()
523                .downcast_ref::<LargeBinaryArray>()
524                .ok_or_else(|| Error::invalid_input_source("value is not LargeBinary".into()))?;
525
526            // Convert based on target type using serde deserialization
527            let converted_array: Arc<dyn Array> =
528                match target_type {
529                    DataType::Boolean => {
530                        let mut builder =
531                            arrow_array::builder::BooleanBuilder::with_capacity(binary_array.len());
532                        for i in 0..binary_array.len() {
533                            if binary_array.is_null(i) {
534                                builder.append_null();
535                            } else if let Some(bytes) = binary_array.value(i).into() {
536                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
537                                // Try to deserialize directly to bool
538                                match jsonb::from_raw_jsonb::<bool>(&raw_jsonb) {
539                                    Ok(bool_val) => builder.append_value(bool_val),
540                                    Err(e) => {
541                                        return Err(Error::invalid_input_source(format!(
542                                        "Failed to deserialize JSONB to bool at index {}: {}",
543                                        i, e
544                                    )
545                                    .into()));
546                                    }
547                                }
548                            } else {
549                                builder.append_null();
550                            }
551                        }
552                        Arc::new(builder.finish())
553                    }
554                    DataType::Int64 => {
555                        let mut builder =
556                            arrow_array::builder::Int64Builder::with_capacity(binary_array.len());
557                        for i in 0..binary_array.len() {
558                            if binary_array.is_null(i) {
559                                builder.append_null();
560                            } else if let Some(bytes) = binary_array.value(i).into() {
561                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
562                                // Try to deserialize directly to i64
563                                match jsonb::from_raw_jsonb::<i64>(&raw_jsonb) {
564                                    Ok(int_val) => builder.append_value(int_val),
565                                    Err(e) => {
566                                        return Err(Error::invalid_input_source(format!(
567                                        "Failed to deserialize JSONB to i64 at index {}: {}",
568                                        i, e
569                                    )
570                                    .into()));
571                                    }
572                                }
573                            } else {
574                                builder.append_null();
575                            }
576                        }
577                        Arc::new(builder.finish())
578                    }
579                    DataType::Float64 => {
580                        let mut builder =
581                            arrow_array::builder::Float64Builder::with_capacity(binary_array.len());
582                        for i in 0..binary_array.len() {
583                            if binary_array.is_null(i) {
584                                builder.append_null();
585                            } else if let Some(bytes) = binary_array.value(i).into() {
586                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
587                                // Try to deserialize directly to f64 (serde handles int->float conversion)
588                                match jsonb::from_raw_jsonb::<f64>(&raw_jsonb) {
589                                    Ok(float_val) => builder.append_value(float_val),
590                                    Err(e) => {
591                                        return Err(Error::invalid_input_source(format!(
592                                        "Failed to deserialize JSONB to f64 at index {}: {}",
593                                        i, e
594                                    )
595                                    .into()));
596                                    }
597                                }
598                            } else {
599                                builder.append_null();
600                            }
601                        }
602                        Arc::new(builder.finish())
603                    }
604                    DataType::Utf8 => {
605                        let mut builder = arrow_array::builder::StringBuilder::with_capacity(
606                            binary_array.len(),
607                            1024,
608                        );
609                        for i in 0..binary_array.len() {
610                            if binary_array.is_null(i) {
611                                builder.append_null();
612                            } else if let Some(bytes) = binary_array.value(i).into() {
613                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
614                                // Try to deserialize to String, or use to_string() for any type
615                                match jsonb::from_raw_jsonb::<String>(&raw_jsonb) {
616                                    Ok(str_val) => builder.append_value(&str_val),
617                                    Err(_) => {
618                                        // For non-string types, convert to string representation
619                                        builder.append_value(raw_jsonb.to_string());
620                                    }
621                                }
622                            } else {
623                                builder.append_null();
624                            }
625                        }
626                        Arc::new(builder.finish())
627                    }
628                    DataType::LargeBinary => {
629                        // Keep as binary for array/object types
630                        value_array.clone()
631                    }
632                    _ => {
633                        return Err(Error::invalid_input_source(
634                            format!("Unsupported target type: {:?}", target_type).into(),
635                        ));
636                    }
637                };
638
639            // Get row_id column
640            let row_id_column = batch
641                .column_by_name(ROW_ID)
642                .ok_or_else(|| Error::invalid_input_source("Missing row_id column".into()))?
643                .clone();
644
645            // Create new batch with converted values
646            let new_schema = Arc::new(Schema::new(vec![
647                ArrowField::new(VALUE_COLUMN_NAME, target_type.clone(), true),
648                ArrowField::new(ROW_ID, DataType::UInt64, false),
649            ]));
650
651            let new_batch =
652                RecordBatch::try_new(new_schema.clone(), vec![converted_array, row_id_column])?;
653
654            converted_batches.push(new_batch);
655        }
656
657        // Create stream from converted batches
658        let schema = converted_batches
659            .first()
660            .map(|b| b.schema())
661            .ok_or_else(|| Error::invalid_input_source("No batches to convert".into()))?;
662
663        Ok(Box::pin(RecordBatchStreamAdapter::new(
664            schema,
665            futures::stream::iter(converted_batches.into_iter().map(Ok)),
666        )))
667    }
668}
669
670#[async_trait]
671impl ScalarIndexPlugin for JsonIndexPlugin {
672    fn name(&self) -> &str {
673        "Json"
674    }
675
676    fn new_training_request(
677        &self,
678        params: &str,
679        field: &Field,
680    ) -> Result<Box<dyn TrainingRequest>> {
681        if !matches!(field.data_type(), DataType::Binary | DataType::LargeBinary) {
682            return Err(Error::invalid_input_source(
683                "A JSON index can only be created on a Binary or LargeBinary field.".into(),
684            ));
685        }
686
687        // Initially use Utf8, will be refined during training with type inference
688        let target_type = DataType::Utf8;
689
690        let params = serde_json::from_str::<JsonIndexParameters>(params)?;
691        let registry = self.registry()?;
692        let target_plugin = registry.get_plugin_by_name(&params.target_index_type)?;
693        let target_request = target_plugin.new_training_request(
694            params.target_index_parameters.as_deref().unwrap_or("{}"),
695            &Field::new("", target_type, true),
696        )?;
697
698        Ok(Box::new(JsonTrainingRequest::new(params, target_request)))
699    }
700
701    fn provides_exact_answer(&self) -> bool {
702        // TODO: Need to lookup target plugin via details to figure this out correctly
703        true
704    }
705
706    fn attach_registry(&self, registry: Arc<IndexPluginRegistry>) {
707        let mut reg_ref = self.registry.lock().unwrap();
708        *reg_ref = Some(registry);
709    }
710
711    fn version(&self) -> u32 {
712        JSON_INDEX_VERSION
713    }
714
715    fn new_query_parser(
716        &self,
717        index_name: String,
718        index_details: &prost_types::Any,
719    ) -> Option<Box<dyn ScalarQueryParser>> {
720        // TODO: Allow return Result here
721        let registry = self.registry().unwrap();
722        let json_details =
723            crate::pb::JsonIndexDetails::decode(index_details.value.as_slice()).unwrap();
724        let target_details = json_details.target_details.as_ref().expect_ok().unwrap();
725        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
726        // TODO: Use something like ${index_name}_${path} for the index name?  Don't have access to path here tho
727        let target_parser = target_plugin.new_query_parser(index_name, index_details)?;
728        Some(Box::new(JsonQueryParser::new(
729            json_details.path.clone(),
730            target_parser,
731        )) as Box<dyn ScalarQueryParser>)
732    }
733
734    async fn train_index(
735        &self,
736        data: SendableRecordBatchStream,
737        index_store: &dyn IndexStore,
738        request: Box<dyn TrainingRequest>,
739        fragment_ids: Option<Vec<u32>>,
740        progress: Arc<dyn crate::progress::IndexBuildProgress>,
741    ) -> Result<CreatedIndex> {
742        let request = (request as Box<dyn std::any::Any>)
743            .downcast::<JsonTrainingRequest>()
744            .unwrap();
745        let path = request.parameters.path.clone();
746
747        // Extract JSON with type information
748        let (data_stream, inferred_type) =
749            Self::extract_json_with_type_info(data, path.clone()).await?;
750
751        // Convert the stream to properly typed values based on inferred type
752        let converted_stream =
753            Self::convert_stream_by_type(data_stream, inferred_type.clone()).await?;
754
755        // Update the target request with inferred type
756        let registry = self.registry()?;
757        let target_plugin = registry.get_plugin_by_name(&request.parameters.target_index_type)?;
758
759        // Create a new training request with the inferred type
760        let target_request = target_plugin.new_training_request(
761            request
762                .parameters
763                .target_index_parameters
764                .as_deref()
765                .unwrap_or("{}"),
766            &Field::new("", inferred_type, true),
767        )?;
768
769        let target_index = target_plugin
770            .train_index(
771                converted_stream,
772                index_store,
773                target_request,
774                fragment_ids,
775                progress,
776            )
777            .await?;
778
779        let index_details = crate::pb::JsonIndexDetails {
780            path,
781            target_details: Some(target_index.index_details),
782        };
783        Ok(CreatedIndex {
784            index_details: prost_types::Any::from_msg(&index_details)?,
785            index_version: JSON_INDEX_VERSION,
786            files: Some(index_store.list_files_with_sizes().await?),
787        })
788    }
789
790    async fn load_index(
791        &self,
792        index_store: Arc<dyn IndexStore>,
793        index_details: &prost_types::Any,
794        frag_reuse_index: Option<Arc<FragReuseIndex>>,
795        cache: &LanceCache,
796    ) -> Result<Arc<dyn ScalarIndex>> {
797        let registry = self.registry().unwrap();
798        let json_details = crate::pb::JsonIndexDetails::decode(index_details.value.as_slice())?;
799        let target_details = json_details.target_details.as_ref().expect_ok()?;
800        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
801        let target_index = target_plugin
802            .load_index(index_store, target_details, frag_reuse_index, cache)
803            .await?;
804        Ok(Arc::new(JsonIndex::new(target_index, json_details.path)))
805    }
806
807    fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
808        let registry = self.registry().unwrap();
809        let json_details = crate::pb::JsonIndexDetails::decode(details.value.as_slice())?;
810        let target_details = json_details.target_details.as_ref().expect_ok()?;
811        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
812        let target_details_json = target_plugin.details_as_json(target_details)?;
813        Ok(serde_json::json!({
814            "path": json_details.path,
815            "target_details": target_details_json,
816        }))
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use arrow_array::{ArrayRef, RecordBatch};
824    use arrow_schema::{DataType, Field, Schema};
825    use std::sync::Arc;
826
827    // Note: The old test_detect_json_value_type test has been removed as we now use
828    // JSONB's inherent type information instead of string-based type detection
829
830    #[tokio::test]
831    async fn test_json_extract_with_type_info() {
832        use arrow_array::{LargeBinaryArray, UInt64Array};
833        use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
834        use futures::stream;
835
836        // Create test JSONB data
837        let json_data = vec![
838            r#"{"name": "Alice", "age": 30, "active": true}"#,
839            r#"{"name": "Bob", "age": 25, "active": false}"#,
840            r#"{"name": "Charlie", "age": 35, "active": true}"#,
841        ];
842
843        // Convert JSON strings to JSONB binary format
844        let mut jsonb_values = Vec::new();
845        for json_str in &json_data {
846            let owned_jsonb: jsonb::OwnedJsonb = json_str.parse().unwrap();
847            jsonb_values.push(Some(owned_jsonb.to_vec()));
848        }
849
850        // Create test batch with JSONB data
851        let schema = Arc::new(Schema::new(vec![
852            Field::new(VALUE_COLUMN_NAME, DataType::LargeBinary, true),
853            Field::new(ROW_ID, DataType::UInt64, false),
854        ]));
855
856        let jsonb_array = LargeBinaryArray::from(
857            jsonb_values
858                .iter()
859                .map(|v| v.as_deref())
860                .collect::<Vec<_>>(),
861        );
862        let row_ids = UInt64Array::from(vec![1, 2, 3]);
863
864        let batch = RecordBatch::try_new(
865            schema.clone(),
866            vec![
867                Arc::new(jsonb_array) as ArrayRef,
868                Arc::new(row_ids) as ArrayRef,
869            ],
870        )
871        .unwrap();
872
873        let stream = Box::pin(RecordBatchStreamAdapter::new(
874            schema.clone(),
875            stream::iter(vec![Ok(batch)]),
876        )) as SendableRecordBatchStream;
877
878        // Test type inference for integer field
879        let (_result_stream, inferred_type) =
880            JsonIndexPlugin::extract_json_with_type_info(stream, "$.age".to_string())
881                .await
882                .unwrap();
883
884        assert_eq!(inferred_type, DataType::Int64);
885
886        // Create new test stream for boolean field
887        let batch2 = RecordBatch::try_new(
888            schema.clone(),
889            vec![
890                Arc::new(LargeBinaryArray::from(vec![
891                    json_data[0]
892                        .parse::<jsonb::OwnedJsonb>()
893                        .ok()
894                        .map(|j| j.to_vec())
895                        .as_deref(),
896                    json_data[1]
897                        .parse::<jsonb::OwnedJsonb>()
898                        .ok()
899                        .map(|j| j.to_vec())
900                        .as_deref(),
901                    json_data[2]
902                        .parse::<jsonb::OwnedJsonb>()
903                        .ok()
904                        .map(|j| j.to_vec())
905                        .as_deref(),
906                ])) as ArrayRef,
907                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
908            ],
909        )
910        .unwrap();
911
912        let stream2 = Box::pin(RecordBatchStreamAdapter::new(
913            schema.clone(),
914            stream::iter(vec![Ok(batch2)]),
915        )) as SendableRecordBatchStream;
916
917        // Test type inference for boolean field
918        let (_, inferred_type) =
919            JsonIndexPlugin::extract_json_with_type_info(stream2, "$.active".to_string())
920                .await
921                .unwrap();
922
923        assert_eq!(inferred_type, DataType::Boolean);
924
925        // Create test stream for string field
926        let batch3 = RecordBatch::try_new(
927            schema.clone(),
928            vec![
929                Arc::new(LargeBinaryArray::from(vec![
930                    json_data[0]
931                        .parse::<jsonb::OwnedJsonb>()
932                        .ok()
933                        .map(|j| j.to_vec())
934                        .as_deref(),
935                    json_data[1]
936                        .parse::<jsonb::OwnedJsonb>()
937                        .ok()
938                        .map(|j| j.to_vec())
939                        .as_deref(),
940                    json_data[2]
941                        .parse::<jsonb::OwnedJsonb>()
942                        .ok()
943                        .map(|j| j.to_vec())
944                        .as_deref(),
945                ])) as ArrayRef,
946                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
947            ],
948        )
949        .unwrap();
950
951        let stream3 = Box::pin(RecordBatchStreamAdapter::new(
952            schema,
953            stream::iter(vec![Ok(batch3)]),
954        )) as SendableRecordBatchStream;
955
956        // Test type inference for string field
957        let (_, inferred_type) =
958            JsonIndexPlugin::extract_json_with_type_info(stream3, "$.name".to_string())
959                .await
960                .unwrap();
961
962        assert_eq!(inferred_type, DataType::Utf8);
963    }
964}