Skip to main content

lance_index/scalar/
json.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::HashMap,
6    ops::Bound,
7    sync::{Arc, Mutex},
8};
9
10use arrow_array::{Array, LargeBinaryArray, RecordBatch, StructArray, UInt8Array};
11use arrow_schema::{DataType, Field, Field as ArrowField, Schema};
12use async_trait::async_trait;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use datafusion::{
15    execution::SendableRecordBatchStream,
16    physical_plan::{ExecutionPlan, projection::ProjectionExec},
17};
18use datafusion_common::{ScalarValue, config::ConfigOptions};
19use datafusion_expr::{Expr, Operator, ScalarUDF};
20use datafusion_physical_expr::{
21    PhysicalExpr, ScalarFunctionExpr,
22    expressions::{Column, Literal},
23};
24use deepsize::DeepSizeOf;
25use futures::StreamExt;
26use lance_datafusion::exec::{LanceExecutionOptions, OneShotExec, get_session_context};
27use lance_datafusion::udf::json::JsonbType;
28use prost::Message;
29use roaring::RoaringBitmap;
30use serde::{Deserialize, Serialize};
31
32use lance_core::{Error, ROW_ID, Result, cache::LanceCache, error::LanceOptionExt};
33
34use crate::{
35    Index, IndexType,
36    frag_reuse::FragReuseIndex,
37    metrics::MetricsCollector,
38    registry::IndexPluginRegistry,
39    scalar::{
40        AnyQuery, CreatedIndex, IndexStore, ScalarIndex, SearchResult, UpdateCriteria,
41        expression::{IndexedExpression, ScalarIndexExpr, ScalarIndexSearch, ScalarQueryParser},
42        registry::{ScalarIndexPlugin, TrainingCriteria, TrainingRequest, VALUE_COLUMN_NAME},
43    },
44};
45
46const JSON_INDEX_VERSION: u32 = 0;
47
48/// A JSON index that indexes a field in a JSON column
49///
50/// The underlying index can be any other type of scalar index
51#[derive(Debug)]
52pub struct JsonIndex {
53    target_index: Arc<dyn ScalarIndex>,
54    path: String,
55}
56
57impl JsonIndex {
58    pub fn new(target_index: Arc<dyn ScalarIndex>, path: String) -> Self {
59        Self { target_index, path }
60    }
61}
62
63impl DeepSizeOf for JsonIndex {
64    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
65        self.target_index.deep_size_of_children(context) + self.path.deep_size_of_children(context)
66    }
67}
68
69#[async_trait]
70impl Index for JsonIndex {
71    fn as_any(&self) -> &dyn std::any::Any {
72        self
73    }
74
75    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
76        self
77    }
78
79    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
80        unimplemented!()
81    }
82
83    fn index_type(&self) -> IndexType {
84        // TODO: This causes the index to appear as btree in list_indices call.  Need better logic
85        // in list_indices to use details instead of index_type.
86        IndexType::Scalar
87    }
88
89    async fn prewarm(&self) -> Result<()> {
90        self.target_index.prewarm().await
91    }
92
93    fn statistics(&self) -> Result<serde_json::Value> {
94        todo!()
95    }
96
97    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
98        self.target_index.calculate_included_frags().await
99    }
100}
101
102#[async_trait]
103impl ScalarIndex for JsonIndex {
104    async fn search(
105        &self,
106        query: &dyn AnyQuery,
107        metrics: &dyn MetricsCollector,
108    ) -> Result<SearchResult> {
109        let query = query.as_any().downcast_ref::<JsonQuery>().unwrap();
110        self.target_index
111            .search(query.target_query.as_ref(), metrics)
112            .await
113    }
114
115    fn can_remap(&self) -> bool {
116        self.target_index.can_remap()
117    }
118
119    async fn remap(
120        &self,
121        mapping: &HashMap<u64, Option<u64>>,
122        dest_store: &dyn IndexStore,
123    ) -> Result<CreatedIndex> {
124        let target_created = self.target_index.remap(mapping, dest_store).await?;
125        let json_details = crate::pb::JsonIndexDetails {
126            path: self.path.clone(),
127            target_details: Some(target_created.index_details),
128        };
129        Ok(CreatedIndex {
130            index_details: prost_types::Any::from_msg(&json_details)?,
131            // TODO: We should store the target index version in the details
132            index_version: JSON_INDEX_VERSION,
133            files: Some(dest_store.list_files_with_sizes().await?),
134        })
135    }
136
137    async fn update(
138        &self,
139        new_data: SendableRecordBatchStream,
140        dest_store: &dyn IndexStore,
141        old_data_filter: Option<super::OldIndexDataFilter>,
142    ) -> Result<CreatedIndex> {
143        let target_created = self
144            .target_index
145            .update(new_data, dest_store, old_data_filter)
146            .await?;
147        let json_details = crate::pb::JsonIndexDetails {
148            path: self.path.clone(),
149            target_details: Some(target_created.index_details),
150        };
151        Ok(CreatedIndex {
152            index_details: prost_types::Any::from_msg(&json_details)?,
153            // TODO: We should store the target index version in the details
154            index_version: JSON_INDEX_VERSION,
155            files: Some(dest_store.list_files_with_sizes().await?),
156        })
157    }
158
159    fn update_criteria(&self) -> UpdateCriteria {
160        self.target_index.update_criteria()
161    }
162
163    fn derive_index_params(&self) -> Result<super::ScalarIndexParams> {
164        self.target_index.derive_index_params()
165    }
166}
167
168/// Parameters for a [`JsonIndex`]
169#[derive(Debug, Serialize, Deserialize)]
170pub struct JsonIndexParameters {
171    target_index_type: String,
172    target_index_parameters: Option<String>,
173    path: String,
174}
175
176// TODO: Do we really need to wrap the query or could we just return the target query directly?
177//
178// I think the only thing we really gain is a different format impl (e.g. it shows up as a json query
179// in the explain plan) but I don't know if that helps the user much.
180#[derive(Debug, Clone)]
181pub struct JsonQuery {
182    target_query: Arc<dyn AnyQuery>,
183    path: String,
184}
185
186impl JsonQuery {
187    pub fn new(target_query: Arc<dyn AnyQuery>, path: String) -> Self {
188        Self { target_query, path }
189    }
190}
191
192impl PartialEq for JsonQuery {
193    fn eq(&self, other: &Self) -> bool {
194        self.target_query.dyn_eq(other.target_query.as_ref()) && self.path == other.path
195    }
196}
197
198impl AnyQuery for JsonQuery {
199    fn as_any(&self) -> &dyn std::any::Any {
200        self
201    }
202
203    fn format(&self, col: &str) -> String {
204        format!("Json({}->{})", self.target_query.format(col), self.path)
205    }
206
207    fn to_expr(&self, _col: String) -> Expr {
208        todo!()
209    }
210
211    fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
212        match other.as_any().downcast_ref::<Self>() {
213            Some(o) => self == o,
214            None => false,
215        }
216    }
217}
218
219#[derive(Debug)]
220pub struct JsonQueryParser {
221    path: String,
222    target_parser: Box<dyn ScalarQueryParser>,
223}
224
225impl JsonQueryParser {
226    pub fn new(path: String, target_parser: Box<dyn ScalarQueryParser>) -> Self {
227        Self {
228            path,
229            target_parser,
230        }
231    }
232
233    fn wrap_search(&self, target_expr: IndexedExpression) -> IndexedExpression {
234        if let Some(scalar_query) = target_expr.scalar_query {
235            let scalar_query = match scalar_query {
236                ScalarIndexExpr::Query(ScalarIndexSearch {
237                    column,
238                    index_name,
239                    query,
240                    needs_recheck,
241                }) => ScalarIndexExpr::Query(ScalarIndexSearch {
242                    column,
243                    index_name,
244                    query: Arc::new(JsonQuery::new(query, self.path.clone())),
245                    needs_recheck,
246                }),
247                // This code path should only be hit on leaf expr
248                _ => unreachable!(),
249            };
250            IndexedExpression {
251                scalar_query: Some(scalar_query),
252                refine_expr: target_expr.refine_expr,
253            }
254        } else {
255            target_expr
256        }
257    }
258}
259
260impl ScalarQueryParser for JsonQueryParser {
261    fn visit_between(
262        &self,
263        column: &str,
264        low: &Bound<ScalarValue>,
265        high: &Bound<ScalarValue>,
266    ) -> Option<IndexedExpression> {
267        self.target_parser
268            .visit_between(column, low, high)
269            .map(|target_expr| self.wrap_search(target_expr))
270    }
271    fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
272        self.target_parser
273            .visit_in_list(column, in_list)
274            .map(|target_expr| self.wrap_search(target_expr))
275    }
276    fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
277        self.target_parser
278            .visit_is_bool(column, value)
279            .map(|target_expr| self.wrap_search(target_expr))
280    }
281    fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
282        self.target_parser
283            .visit_is_null(column)
284            .map(|target_expr| self.wrap_search(target_expr))
285    }
286    fn visit_comparison(
287        &self,
288        column: &str,
289        value: &ScalarValue,
290        op: &Operator,
291    ) -> Option<IndexedExpression> {
292        self.target_parser
293            .visit_comparison(column, value, op)
294            .map(|target_expr| self.wrap_search(target_expr))
295    }
296    fn visit_scalar_function(
297        &self,
298        column: &str,
299        data_type: &DataType,
300        func: &ScalarUDF,
301        args: &[Expr],
302    ) -> Option<IndexedExpression> {
303        self.target_parser
304            .visit_scalar_function(column, data_type, func, args)
305            .map(|target_expr| self.wrap_search(target_expr))
306    }
307
308    // TODO: maybe we should address it by https://github.com/lance-format/lance/issues/4624
309    fn is_valid_reference(&self, func: &Expr, _data_type: &DataType) -> Option<DataType> {
310        match func {
311            Expr::ScalarFunction(udf) => {
312                // Support multiple JSON extraction functions
313                let json_functions = [
314                    "json_extract",
315                    "json_get",
316                    "json_get_int",
317                    "json_get_float",
318                    "json_get_bool",
319                    "json_get_string",
320                ];
321                if !json_functions.contains(&udf.name()) {
322                    return None;
323                }
324                if udf.args.len() != 2 {
325                    return None;
326                }
327                // We already know index 0 is a column reference to the column so we just need to
328                // ensure that index 1 matches our path
329                match &udf.args[1] {
330                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
331                        if path == &self.path {
332                            // Return the appropriate type based on the function
333                            match udf.name() {
334                                "json_get_int" => Some(DataType::Int64),
335                                "json_get_float" => Some(DataType::Float64),
336                                "json_get_bool" => Some(DataType::Boolean),
337                                "json_get_string" | "json_extract" => Some(DataType::Utf8),
338                                _ => None,
339                            }
340                        } else {
341                            None
342                        }
343                    }
344                    _ => None,
345                }
346            }
347            _ => None,
348        }
349    }
350}
351
352pub struct JsonTrainingRequest {
353    parameters: JsonIndexParameters,
354    target_request: Box<dyn TrainingRequest>,
355}
356
357impl JsonTrainingRequest {
358    pub fn new(parameters: JsonIndexParameters, target_request: Box<dyn TrainingRequest>) -> Self {
359        Self {
360            parameters,
361            target_request,
362        }
363    }
364}
365
366impl TrainingRequest for JsonTrainingRequest {
367    fn as_any(&self) -> &dyn std::any::Any {
368        self
369    }
370
371    fn criteria(&self) -> &TrainingCriteria {
372        self.target_request.criteria()
373    }
374}
375
376/// Plugin implementation for a [`JsonIndex`]
377#[derive(Default)]
378pub struct JsonIndexPlugin {
379    registry: Mutex<Option<Arc<IndexPluginRegistry>>>,
380}
381
382impl std::fmt::Debug for JsonIndexPlugin {
383    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
384        write!(f, "JsonIndexPlugin")
385    }
386}
387
388impl JsonIndexPlugin {
389    fn registry(&self) -> Result<Arc<IndexPluginRegistry>> {
390        Ok(self.registry.lock().unwrap().as_ref().expect_ok()?.clone())
391    }
392
393    /// Extract JSON with type information using the new UDF
394    async fn extract_json_with_type_info(
395        data: SendableRecordBatchStream,
396        path: String,
397    ) -> Result<(SendableRecordBatchStream, DataType)> {
398        let input = Arc::new(OneShotExec::new(data));
399        let input_schema = input.schema();
400        let value_column_idx = input_schema
401            .column_with_name(VALUE_COLUMN_NAME)
402            .expect_ok()?
403            .0;
404        let row_id_column_idx = input_schema.column_with_name(ROW_ID).expect_ok()?.0;
405
406        // Call json_extract_with_type UDF
407        let exprs = vec![
408            (
409                Arc::new(ScalarFunctionExpr::try_new(
410                    Arc::new(lance_datafusion::udf::json::json_extract_with_type_udf()),
411                    vec![
412                        Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_idx)),
413                        Arc::new(Literal::new(ScalarValue::Utf8(Some(path)))),
414                    ],
415                    &input_schema,
416                    Arc::new(ConfigOptions::default()),
417                )?) as Arc<dyn PhysicalExpr>,
418                "json_result".to_string(),
419            ),
420            (
421                Arc::new(Column::new(ROW_ID, row_id_column_idx)) as Arc<dyn PhysicalExpr>,
422                ROW_ID.to_string(),
423            ),
424        ];
425
426        let project = ProjectionExec::try_new(exprs, input)?;
427        let ctx = get_session_context(&LanceExecutionOptions::default());
428        let mut stream = project.execute(0, ctx.task_ctx())?;
429
430        // Collect batches and determine type from first non-null value
431        let mut all_batches = Vec::new();
432        let mut inferred_type: Option<DataType> = None;
433
434        while let Some(batch_result) = stream.next().await {
435            let batch = batch_result?;
436
437            // Determine type from first non-null value if not yet set
438            if inferred_type.is_none()
439                && let Some(json_result_column) = batch.column_by_name("json_result")
440                && let Some(struct_array) =
441                    json_result_column.as_any().downcast_ref::<StructArray>()
442                && let Some(type_array) = struct_array.column_by_name("type_tag")
443                && let Some(uint8_array) = type_array.as_any().downcast_ref::<UInt8Array>()
444            {
445                // Find first non-null value to determine type
446                for i in 0..uint8_array.len() {
447                    if !uint8_array.is_null(i) {
448                        let type_tag = uint8_array.value(i);
449                        let jsonb_type = JsonbType::from_u8(type_tag).ok_or_else(|| {
450                            Error::invalid_input_source(
451                                format!("Invalid type tag: {}", type_tag).into(),
452                            )
453                        })?;
454
455                        // Map JsonbType to Arrow DataType
456                        inferred_type = Some(match jsonb_type {
457                            JsonbType::Null => continue, // Skip null values
458                            JsonbType::Boolean => DataType::Boolean,
459                            JsonbType::Int64 => DataType::Int64,
460                            JsonbType::Float64 => DataType::Float64,
461                            JsonbType::String => DataType::Utf8,
462                            JsonbType::Array => DataType::LargeBinary,
463                            JsonbType::Object => DataType::LargeBinary,
464                        });
465                        break;
466                    }
467                }
468            }
469
470            all_batches.push(batch);
471        }
472
473        // If no type was inferred (all nulls), default to String
474        let inferred_type = inferred_type.unwrap_or(DataType::Utf8);
475
476        // Recreate stream from collected batches
477        let schema = all_batches
478            .first()
479            .map(|b| b.schema())
480            .ok_or_else(|| Error::invalid_input_source("No batches in stream".into()))?;
481
482        let recreated_stream = Box::pin(RecordBatchStreamAdapter::new(
483            schema,
484            futures::stream::iter(all_batches.into_iter().map(Ok)),
485        )) as SendableRecordBatchStream;
486
487        Ok((recreated_stream, inferred_type))
488    }
489
490    /// Convert the stream with JSONB values and type tags to properly typed values
491    async fn convert_stream_by_type(
492        data: SendableRecordBatchStream,
493        target_type: DataType,
494    ) -> Result<SendableRecordBatchStream> {
495        let input = Arc::new(OneShotExec::new(data));
496        let _input_schema = input.schema();
497        let ctx = get_session_context(&LanceExecutionOptions::default());
498        let mut stream = input.execute(0, ctx.task_ctx())?;
499
500        let mut converted_batches = Vec::new();
501
502        while let Some(batch_result) = stream.next().await {
503            let batch = batch_result?;
504
505            // Extract the struct column containing value and type_tag
506            let json_result_column = batch
507                .column_by_name("json_result")
508                .ok_or_else(|| Error::invalid_input_source("Missing json_result column".into()))?;
509
510            let struct_array = json_result_column
511                .as_any()
512                .downcast_ref::<StructArray>()
513                .ok_or_else(|| Error::invalid_input_source("json_result is not a struct".into()))?;
514
515            let value_array = struct_array.column_by_name("value").ok_or_else(|| {
516                Error::invalid_input_source("Missing value column in struct".into())
517            })?;
518
519            let binary_array = value_array
520                .as_any()
521                .downcast_ref::<LargeBinaryArray>()
522                .ok_or_else(|| Error::invalid_input_source("value is not LargeBinary".into()))?;
523
524            // Convert based on target type using serde deserialization
525            let converted_array: Arc<dyn Array> =
526                match target_type {
527                    DataType::Boolean => {
528                        let mut builder =
529                            arrow_array::builder::BooleanBuilder::with_capacity(binary_array.len());
530                        for i in 0..binary_array.len() {
531                            if binary_array.is_null(i) {
532                                builder.append_null();
533                            } else if let Some(bytes) = binary_array.value(i).into() {
534                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
535                                // Try to deserialize directly to bool
536                                match jsonb::from_raw_jsonb::<bool>(&raw_jsonb) {
537                                    Ok(bool_val) => builder.append_value(bool_val),
538                                    Err(e) => {
539                                        return Err(Error::invalid_input_source(format!(
540                                        "Failed to deserialize JSONB to bool at index {}: {}",
541                                        i, e
542                                    )
543                                    .into()));
544                                    }
545                                }
546                            } else {
547                                builder.append_null();
548                            }
549                        }
550                        Arc::new(builder.finish())
551                    }
552                    DataType::Int64 => {
553                        let mut builder =
554                            arrow_array::builder::Int64Builder::with_capacity(binary_array.len());
555                        for i in 0..binary_array.len() {
556                            if binary_array.is_null(i) {
557                                builder.append_null();
558                            } else if let Some(bytes) = binary_array.value(i).into() {
559                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
560                                // Try to deserialize directly to i64
561                                match jsonb::from_raw_jsonb::<i64>(&raw_jsonb) {
562                                    Ok(int_val) => builder.append_value(int_val),
563                                    Err(e) => {
564                                        return Err(Error::invalid_input_source(format!(
565                                        "Failed to deserialize JSONB to i64 at index {}: {}",
566                                        i, e
567                                    )
568                                    .into()));
569                                    }
570                                }
571                            } else {
572                                builder.append_null();
573                            }
574                        }
575                        Arc::new(builder.finish())
576                    }
577                    DataType::Float64 => {
578                        let mut builder =
579                            arrow_array::builder::Float64Builder::with_capacity(binary_array.len());
580                        for i in 0..binary_array.len() {
581                            if binary_array.is_null(i) {
582                                builder.append_null();
583                            } else if let Some(bytes) = binary_array.value(i).into() {
584                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
585                                // Try to deserialize directly to f64 (serde handles int->float conversion)
586                                match jsonb::from_raw_jsonb::<f64>(&raw_jsonb) {
587                                    Ok(float_val) => builder.append_value(float_val),
588                                    Err(e) => {
589                                        return Err(Error::invalid_input_source(format!(
590                                        "Failed to deserialize JSONB to f64 at index {}: {}",
591                                        i, e
592                                    )
593                                    .into()));
594                                    }
595                                }
596                            } else {
597                                builder.append_null();
598                            }
599                        }
600                        Arc::new(builder.finish())
601                    }
602                    DataType::Utf8 => {
603                        let mut builder = arrow_array::builder::StringBuilder::with_capacity(
604                            binary_array.len(),
605                            1024,
606                        );
607                        for i in 0..binary_array.len() {
608                            if binary_array.is_null(i) {
609                                builder.append_null();
610                            } else if let Some(bytes) = binary_array.value(i).into() {
611                                let raw_jsonb = jsonb::RawJsonb::new(bytes);
612                                // Try to deserialize to String, or use to_string() for any type
613                                match jsonb::from_raw_jsonb::<String>(&raw_jsonb) {
614                                    Ok(str_val) => builder.append_value(&str_val),
615                                    Err(_) => {
616                                        // For non-string types, convert to string representation
617                                        builder.append_value(raw_jsonb.to_string());
618                                    }
619                                }
620                            } else {
621                                builder.append_null();
622                            }
623                        }
624                        Arc::new(builder.finish())
625                    }
626                    DataType::LargeBinary => {
627                        // Keep as binary for array/object types
628                        value_array.clone()
629                    }
630                    _ => {
631                        return Err(Error::invalid_input_source(
632                            format!("Unsupported target type: {:?}", target_type).into(),
633                        ));
634                    }
635                };
636
637            // Get row_id column
638            let row_id_column = batch
639                .column_by_name(ROW_ID)
640                .ok_or_else(|| Error::invalid_input_source("Missing row_id column".into()))?
641                .clone();
642
643            // Create new batch with converted values
644            let new_schema = Arc::new(Schema::new(vec![
645                ArrowField::new(VALUE_COLUMN_NAME, target_type.clone(), true),
646                ArrowField::new(ROW_ID, DataType::UInt64, false),
647            ]));
648
649            let new_batch =
650                RecordBatch::try_new(new_schema.clone(), vec![converted_array, row_id_column])?;
651
652            converted_batches.push(new_batch);
653        }
654
655        // Create stream from converted batches
656        let schema = converted_batches
657            .first()
658            .map(|b| b.schema())
659            .ok_or_else(|| Error::invalid_input_source("No batches to convert".into()))?;
660
661        Ok(Box::pin(RecordBatchStreamAdapter::new(
662            schema,
663            futures::stream::iter(converted_batches.into_iter().map(Ok)),
664        )))
665    }
666}
667
668#[async_trait]
669impl ScalarIndexPlugin for JsonIndexPlugin {
670    fn name(&self) -> &str {
671        "Json"
672    }
673
674    fn new_training_request(
675        &self,
676        params: &str,
677        field: &Field,
678    ) -> Result<Box<dyn TrainingRequest>> {
679        if !matches!(field.data_type(), DataType::Binary | DataType::LargeBinary) {
680            return Err(Error::invalid_input_source(
681                "A JSON index can only be created on a Binary or LargeBinary field.".into(),
682            ));
683        }
684
685        // Initially use Utf8, will be refined during training with type inference
686        let target_type = DataType::Utf8;
687
688        let params = serde_json::from_str::<JsonIndexParameters>(params)?;
689        let registry = self.registry()?;
690        let target_plugin = registry.get_plugin_by_name(&params.target_index_type)?;
691        let target_request = target_plugin.new_training_request(
692            params.target_index_parameters.as_deref().unwrap_or("{}"),
693            &Field::new("", target_type, true),
694        )?;
695
696        Ok(Box::new(JsonTrainingRequest::new(params, target_request)))
697    }
698
699    fn provides_exact_answer(&self) -> bool {
700        // TODO: Need to lookup target plugin via details to figure this out correctly
701        true
702    }
703
704    fn attach_registry(&self, registry: Arc<IndexPluginRegistry>) {
705        let mut reg_ref = self.registry.lock().unwrap();
706        *reg_ref = Some(registry);
707    }
708
709    fn version(&self) -> u32 {
710        JSON_INDEX_VERSION
711    }
712
713    fn new_query_parser(
714        &self,
715        index_name: String,
716        index_details: &prost_types::Any,
717    ) -> Option<Box<dyn ScalarQueryParser>> {
718        // TODO: Allow return Result here
719        let registry = self.registry().unwrap();
720        let json_details =
721            crate::pb::JsonIndexDetails::decode(index_details.value.as_slice()).unwrap();
722        let target_details = json_details.target_details.as_ref().expect_ok().unwrap();
723        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
724        // TODO: Use something like ${index_name}_${path} for the index name?  Don't have access to path here tho
725        let target_parser = target_plugin.new_query_parser(index_name, index_details)?;
726        Some(Box::new(JsonQueryParser::new(
727            json_details.path.clone(),
728            target_parser,
729        )) as Box<dyn ScalarQueryParser>)
730    }
731
732    async fn train_index(
733        &self,
734        data: SendableRecordBatchStream,
735        index_store: &dyn IndexStore,
736        request: Box<dyn TrainingRequest>,
737        fragment_ids: Option<Vec<u32>>,
738        progress: Arc<dyn crate::progress::IndexBuildProgress>,
739    ) -> Result<CreatedIndex> {
740        let request = (request as Box<dyn std::any::Any>)
741            .downcast::<JsonTrainingRequest>()
742            .unwrap();
743        let path = request.parameters.path.clone();
744
745        // Extract JSON with type information
746        let (data_stream, inferred_type) =
747            Self::extract_json_with_type_info(data, path.clone()).await?;
748
749        // Convert the stream to properly typed values based on inferred type
750        let converted_stream =
751            Self::convert_stream_by_type(data_stream, inferred_type.clone()).await?;
752
753        // Update the target request with inferred type
754        let registry = self.registry()?;
755        let target_plugin = registry.get_plugin_by_name(&request.parameters.target_index_type)?;
756
757        // Create a new training request with the inferred type
758        let target_request = target_plugin.new_training_request(
759            request
760                .parameters
761                .target_index_parameters
762                .as_deref()
763                .unwrap_or("{}"),
764            &Field::new("", inferred_type, true),
765        )?;
766
767        let target_index = target_plugin
768            .train_index(
769                converted_stream,
770                index_store,
771                target_request,
772                fragment_ids,
773                progress,
774            )
775            .await?;
776
777        let index_details = crate::pb::JsonIndexDetails {
778            path,
779            target_details: Some(target_index.index_details),
780        };
781        Ok(CreatedIndex {
782            index_details: prost_types::Any::from_msg(&index_details)?,
783            index_version: JSON_INDEX_VERSION,
784            files: Some(index_store.list_files_with_sizes().await?),
785        })
786    }
787
788    async fn load_index(
789        &self,
790        index_store: Arc<dyn IndexStore>,
791        index_details: &prost_types::Any,
792        frag_reuse_index: Option<Arc<FragReuseIndex>>,
793        cache: &LanceCache,
794    ) -> Result<Arc<dyn ScalarIndex>> {
795        let registry = self.registry().unwrap();
796        let json_details = crate::pb::JsonIndexDetails::decode(index_details.value.as_slice())?;
797        let target_details = json_details.target_details.as_ref().expect_ok()?;
798        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
799        let target_index = target_plugin
800            .load_index(index_store, target_details, frag_reuse_index, cache)
801            .await?;
802        Ok(Arc::new(JsonIndex::new(target_index, json_details.path)))
803    }
804
805    fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
806        let registry = self.registry().unwrap();
807        let json_details = crate::pb::JsonIndexDetails::decode(details.value.as_slice())?;
808        let target_details = json_details.target_details.as_ref().expect_ok()?;
809        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
810        let target_details_json = target_plugin.details_as_json(target_details)?;
811        Ok(serde_json::json!({
812            "path": json_details.path,
813            "target_details": target_details_json,
814        }))
815    }
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821    use arrow_array::{ArrayRef, RecordBatch};
822    use arrow_schema::{DataType, Field, Schema};
823    use std::sync::Arc;
824
825    // Note: The old test_detect_json_value_type test has been removed as we now use
826    // JSONB's inherent type information instead of string-based type detection
827
828    #[tokio::test]
829    async fn test_json_extract_with_type_info() {
830        use arrow_array::{LargeBinaryArray, UInt64Array};
831        use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
832        use futures::stream;
833
834        // Create test JSONB data
835        let json_data = vec![
836            r#"{"name": "Alice", "age": 30, "active": true}"#,
837            r#"{"name": "Bob", "age": 25, "active": false}"#,
838            r#"{"name": "Charlie", "age": 35, "active": true}"#,
839        ];
840
841        // Convert JSON strings to JSONB binary format
842        let mut jsonb_values = Vec::new();
843        for json_str in &json_data {
844            let owned_jsonb: jsonb::OwnedJsonb = json_str.parse().unwrap();
845            jsonb_values.push(Some(owned_jsonb.to_vec()));
846        }
847
848        // Create test batch with JSONB data
849        let schema = Arc::new(Schema::new(vec![
850            Field::new(VALUE_COLUMN_NAME, DataType::LargeBinary, true),
851            Field::new(ROW_ID, DataType::UInt64, false),
852        ]));
853
854        let jsonb_array = LargeBinaryArray::from(
855            jsonb_values
856                .iter()
857                .map(|v| v.as_deref())
858                .collect::<Vec<_>>(),
859        );
860        let row_ids = UInt64Array::from(vec![1, 2, 3]);
861
862        let batch = RecordBatch::try_new(
863            schema.clone(),
864            vec![
865                Arc::new(jsonb_array) as ArrayRef,
866                Arc::new(row_ids) as ArrayRef,
867            ],
868        )
869        .unwrap();
870
871        let stream = Box::pin(RecordBatchStreamAdapter::new(
872            schema.clone(),
873            stream::iter(vec![Ok(batch)]),
874        )) as SendableRecordBatchStream;
875
876        // Test type inference for integer field
877        let (_result_stream, inferred_type) =
878            JsonIndexPlugin::extract_json_with_type_info(stream, "$.age".to_string())
879                .await
880                .unwrap();
881
882        assert_eq!(inferred_type, DataType::Int64);
883
884        // Create new test stream for boolean field
885        let batch2 = RecordBatch::try_new(
886            schema.clone(),
887            vec![
888                Arc::new(LargeBinaryArray::from(vec![
889                    json_data[0]
890                        .parse::<jsonb::OwnedJsonb>()
891                        .ok()
892                        .map(|j| j.to_vec())
893                        .as_deref(),
894                    json_data[1]
895                        .parse::<jsonb::OwnedJsonb>()
896                        .ok()
897                        .map(|j| j.to_vec())
898                        .as_deref(),
899                    json_data[2]
900                        .parse::<jsonb::OwnedJsonb>()
901                        .ok()
902                        .map(|j| j.to_vec())
903                        .as_deref(),
904                ])) as ArrayRef,
905                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
906            ],
907        )
908        .unwrap();
909
910        let stream2 = Box::pin(RecordBatchStreamAdapter::new(
911            schema.clone(),
912            stream::iter(vec![Ok(batch2)]),
913        )) as SendableRecordBatchStream;
914
915        // Test type inference for boolean field
916        let (_, inferred_type) =
917            JsonIndexPlugin::extract_json_with_type_info(stream2, "$.active".to_string())
918                .await
919                .unwrap();
920
921        assert_eq!(inferred_type, DataType::Boolean);
922
923        // Create test stream for string field
924        let batch3 = RecordBatch::try_new(
925            schema.clone(),
926            vec![
927                Arc::new(LargeBinaryArray::from(vec![
928                    json_data[0]
929                        .parse::<jsonb::OwnedJsonb>()
930                        .ok()
931                        .map(|j| j.to_vec())
932                        .as_deref(),
933                    json_data[1]
934                        .parse::<jsonb::OwnedJsonb>()
935                        .ok()
936                        .map(|j| j.to_vec())
937                        .as_deref(),
938                    json_data[2]
939                        .parse::<jsonb::OwnedJsonb>()
940                        .ok()
941                        .map(|j| j.to_vec())
942                        .as_deref(),
943                ])) as ArrayRef,
944                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
945            ],
946        )
947        .unwrap();
948
949        let stream3 = Box::pin(RecordBatchStreamAdapter::new(
950            schema,
951            stream::iter(vec![Ok(batch3)]),
952        )) as SendableRecordBatchStream;
953
954        // Test type inference for string field
955        let (_, inferred_type) =
956            JsonIndexPlugin::extract_json_with_type_info(stream3, "$.name".to_string())
957                .await
958                .unwrap();
959
960        assert_eq!(inferred_type, DataType::Utf8);
961    }
962}