lance_index/scalar/
json.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::{
5    collections::HashMap,
6    ops::Bound,
7    sync::{Arc, Mutex},
8};
9
10use arrow_array::{Array, LargeBinaryArray, RecordBatch, StructArray, UInt8Array};
11use arrow_schema::{DataType, Field, Field as ArrowField, Schema};
12use async_trait::async_trait;
13use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
14use datafusion::{
15    execution::SendableRecordBatchStream,
16    physical_plan::{projection::ProjectionExec, ExecutionPlan},
17};
18use datafusion_common::{config::ConfigOptions, ScalarValue};
19use datafusion_expr::{Expr, Operator, ScalarUDF};
20use datafusion_physical_expr::{
21    expressions::{Column, Literal},
22    PhysicalExpr, ScalarFunctionExpr,
23};
24use deepsize::DeepSizeOf;
25use futures::StreamExt;
26use lance_datafusion::exec::{get_session_context, LanceExecutionOptions, OneShotExec};
27use lance_datafusion::udf::json::JsonbType;
28use prost::Message;
29use roaring::RoaringBitmap;
30use serde::{Deserialize, Serialize};
31use snafu::location;
32
33use lance_core::{cache::LanceCache, error::LanceOptionExt, Error, Result, ROW_ID};
34
35use crate::{
36    frag_reuse::FragReuseIndex,
37    metrics::MetricsCollector,
38    registry::IndexPluginRegistry,
39    scalar::{
40        expression::{IndexedExpression, ScalarIndexExpr, ScalarIndexSearch, ScalarQueryParser},
41        registry::{ScalarIndexPlugin, TrainingCriteria, TrainingRequest, VALUE_COLUMN_NAME},
42        AnyQuery, CreatedIndex, IndexStore, ScalarIndex, SearchResult, UpdateCriteria,
43    },
44    Index, IndexType,
45};
46
47const JSON_INDEX_VERSION: u32 = 0;
48
49/// A JSON index that indexes a field in a JSON column
50///
51/// The underlying index can be any other type of scalar index
52#[derive(Debug)]
53pub struct JsonIndex {
54    target_index: Arc<dyn ScalarIndex>,
55    path: String,
56}
57
58impl JsonIndex {
59    pub fn new(target_index: Arc<dyn ScalarIndex>, path: String) -> Self {
60        Self { target_index, path }
61    }
62}
63
64impl DeepSizeOf for JsonIndex {
65    fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
66        self.target_index.deep_size_of_children(context) + self.path.deep_size_of_children(context)
67    }
68}
69
70#[async_trait]
71impl Index for JsonIndex {
72    fn as_any(&self) -> &dyn std::any::Any {
73        self
74    }
75
76    fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
77        self
78    }
79
80    fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn crate::vector::VectorIndex>> {
81        unimplemented!()
82    }
83
84    fn index_type(&self) -> IndexType {
85        // TODO: This causes the index to appear as btree in list_indices call.  Need better logic
86        // in list_indices to use details instead of index_type.
87        IndexType::Scalar
88    }
89
90    async fn prewarm(&self) -> Result<()> {
91        self.target_index.prewarm().await
92    }
93
94    fn statistics(&self) -> Result<serde_json::Value> {
95        todo!()
96    }
97
98    async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
99        self.target_index.calculate_included_frags().await
100    }
101}
102
103#[async_trait]
104impl ScalarIndex for JsonIndex {
105    async fn search(
106        &self,
107        query: &dyn AnyQuery,
108        metrics: &dyn MetricsCollector,
109    ) -> Result<SearchResult> {
110        let query = query.as_any().downcast_ref::<JsonQuery>().unwrap();
111        self.target_index
112            .search(query.target_query.as_ref(), metrics)
113            .await
114    }
115
116    fn can_remap(&self) -> bool {
117        self.target_index.can_remap()
118    }
119
120    async fn remap(
121        &self,
122        mapping: &HashMap<u64, Option<u64>>,
123        dest_store: &dyn IndexStore,
124    ) -> Result<CreatedIndex> {
125        let target_created = self.target_index.remap(mapping, dest_store).await?;
126        let json_details = crate::pb::JsonIndexDetails {
127            path: self.path.clone(),
128            target_details: Some(target_created.index_details),
129        };
130        Ok(CreatedIndex {
131            index_details: prost_types::Any::from_msg(&json_details)?,
132            // TODO: We should store the target index version in the details
133            index_version: JSON_INDEX_VERSION,
134        })
135    }
136
137    async fn update(
138        &self,
139        new_data: SendableRecordBatchStream,
140        dest_store: &dyn IndexStore,
141    ) -> Result<CreatedIndex> {
142        let target_created = self.target_index.update(new_data, dest_store).await?;
143        let json_details = crate::pb::JsonIndexDetails {
144            path: self.path.clone(),
145            target_details: Some(target_created.index_details),
146        };
147        Ok(CreatedIndex {
148            index_details: prost_types::Any::from_msg(&json_details)?,
149            // TODO: We should store the target index version in the details
150            index_version: JSON_INDEX_VERSION,
151        })
152    }
153
154    fn update_criteria(&self) -> UpdateCriteria {
155        self.target_index.update_criteria()
156    }
157
158    fn derive_index_params(&self) -> Result<super::ScalarIndexParams> {
159        self.target_index.derive_index_params()
160    }
161}
162
163/// Parameters for a [`JsonIndex`]
164#[derive(Debug, Serialize, Deserialize)]
165pub struct JsonIndexParameters {
166    target_index_type: String,
167    target_index_parameters: Option<String>,
168    path: String,
169}
170
171// TODO: Do we really need to wrap the query or could we just return the target query directly?
172//
173// I think the only thing we really gain is a different format impl (e.g. it shows up as a json query
174// in the explain plan) but I don't know if that helps the user much.
175#[derive(Debug, Clone)]
176pub struct JsonQuery {
177    target_query: Arc<dyn AnyQuery>,
178    path: String,
179}
180
181impl JsonQuery {
182    pub fn new(target_query: Arc<dyn AnyQuery>, path: String) -> Self {
183        Self { target_query, path }
184    }
185}
186
187impl PartialEq for JsonQuery {
188    fn eq(&self, other: &Self) -> bool {
189        self.target_query.dyn_eq(other.target_query.as_ref()) && self.path == other.path
190    }
191}
192
193impl AnyQuery for JsonQuery {
194    fn as_any(&self) -> &dyn std::any::Any {
195        self
196    }
197
198    fn format(&self, col: &str) -> String {
199        format!("Json({}->{})", self.target_query.format(col), self.path)
200    }
201
202    fn to_expr(&self, _col: String) -> Expr {
203        todo!()
204    }
205
206    fn dyn_eq(&self, other: &dyn AnyQuery) -> bool {
207        match other.as_any().downcast_ref::<Self>() {
208            Some(o) => self == o,
209            None => false,
210        }
211    }
212}
213
214#[derive(Debug)]
215pub struct JsonQueryParser {
216    path: String,
217    target_parser: Box<dyn ScalarQueryParser>,
218}
219
220impl JsonQueryParser {
221    pub fn new(path: String, target_parser: Box<dyn ScalarQueryParser>) -> Self {
222        Self {
223            path,
224            target_parser,
225        }
226    }
227
228    fn wrap_search(&self, target_expr: IndexedExpression) -> IndexedExpression {
229        if let Some(scalar_query) = target_expr.scalar_query {
230            let scalar_query = match scalar_query {
231                ScalarIndexExpr::Query(ScalarIndexSearch {
232                    column,
233                    index_name,
234                    query,
235                    needs_recheck,
236                }) => ScalarIndexExpr::Query(ScalarIndexSearch {
237                    column,
238                    index_name,
239                    query: Arc::new(JsonQuery::new(query, self.path.clone())),
240                    needs_recheck,
241                }),
242                // This code path should only be hit on leaf expr
243                _ => unreachable!(),
244            };
245            IndexedExpression {
246                scalar_query: Some(scalar_query),
247                refine_expr: target_expr.refine_expr,
248            }
249        } else {
250            target_expr
251        }
252    }
253}
254
255impl ScalarQueryParser for JsonQueryParser {
256    fn visit_between(
257        &self,
258        column: &str,
259        low: &Bound<ScalarValue>,
260        high: &Bound<ScalarValue>,
261    ) -> Option<IndexedExpression> {
262        self.target_parser
263            .visit_between(column, low, high)
264            .map(|target_expr| self.wrap_search(target_expr))
265    }
266    fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option<IndexedExpression> {
267        self.target_parser
268            .visit_in_list(column, in_list)
269            .map(|target_expr| self.wrap_search(target_expr))
270    }
271    fn visit_is_bool(&self, column: &str, value: bool) -> Option<IndexedExpression> {
272        self.target_parser
273            .visit_is_bool(column, value)
274            .map(|target_expr| self.wrap_search(target_expr))
275    }
276    fn visit_is_null(&self, column: &str) -> Option<IndexedExpression> {
277        self.target_parser
278            .visit_is_null(column)
279            .map(|target_expr| self.wrap_search(target_expr))
280    }
281    fn visit_comparison(
282        &self,
283        column: &str,
284        value: &ScalarValue,
285        op: &Operator,
286    ) -> Option<IndexedExpression> {
287        self.target_parser
288            .visit_comparison(column, value, op)
289            .map(|target_expr| self.wrap_search(target_expr))
290    }
291    fn visit_scalar_function(
292        &self,
293        column: &str,
294        data_type: &DataType,
295        func: &ScalarUDF,
296        args: &[Expr],
297    ) -> Option<IndexedExpression> {
298        self.target_parser
299            .visit_scalar_function(column, data_type, func, args)
300            .map(|target_expr| self.wrap_search(target_expr))
301    }
302
303    // TODO: maybe we should address it by https://github.com/lance-format/lance/issues/4624
304    fn is_valid_reference(&self, func: &Expr, _data_type: &DataType) -> Option<DataType> {
305        match func {
306            Expr::ScalarFunction(udf) => {
307                // Support multiple JSON extraction functions
308                let json_functions = [
309                    "json_extract",
310                    "json_get",
311                    "json_get_int",
312                    "json_get_float",
313                    "json_get_bool",
314                    "json_get_string",
315                ];
316                if !json_functions.contains(&udf.name()) {
317                    return None;
318                }
319                if udf.args.len() != 2 {
320                    return None;
321                }
322                // We already know index 0 is a column reference to the column so we just need to
323                // ensure that index 1 matches our path
324                match &udf.args[1] {
325                    Expr::Literal(ScalarValue::Utf8(Some(path)), _) => {
326                        if path == &self.path {
327                            // Return the appropriate type based on the function
328                            match udf.name() {
329                                "json_get_int" => Some(DataType::Int64),
330                                "json_get_float" => Some(DataType::Float64),
331                                "json_get_bool" => Some(DataType::Boolean),
332                                "json_get_string" | "json_extract" => Some(DataType::Utf8),
333                                _ => None,
334                            }
335                        } else {
336                            None
337                        }
338                    }
339                    _ => None,
340                }
341            }
342            _ => None,
343        }
344    }
345}
346
347pub struct JsonTrainingRequest {
348    parameters: JsonIndexParameters,
349    target_request: Box<dyn TrainingRequest>,
350}
351
352impl JsonTrainingRequest {
353    pub fn new(parameters: JsonIndexParameters, target_request: Box<dyn TrainingRequest>) -> Self {
354        Self {
355            parameters,
356            target_request,
357        }
358    }
359}
360
361impl TrainingRequest for JsonTrainingRequest {
362    fn as_any(&self) -> &dyn std::any::Any {
363        self
364    }
365
366    fn criteria(&self) -> &TrainingCriteria {
367        self.target_request.criteria()
368    }
369}
370
371/// Plugin implementation for a [`JsonIndex`]
372#[derive(Default)]
373pub struct JsonIndexPlugin {
374    registry: Mutex<Option<Arc<IndexPluginRegistry>>>,
375}
376
377impl std::fmt::Debug for JsonIndexPlugin {
378    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379        write!(f, "JsonIndexPlugin")
380    }
381}
382
383impl JsonIndexPlugin {
384    fn registry(&self) -> Result<Arc<IndexPluginRegistry>> {
385        Ok(self.registry.lock().unwrap().as_ref().expect_ok()?.clone())
386    }
387
388    /// Extract JSON with type information using the new UDF
389    async fn extract_json_with_type_info(
390        data: SendableRecordBatchStream,
391        path: String,
392    ) -> Result<(SendableRecordBatchStream, DataType)> {
393        let input = Arc::new(OneShotExec::new(data));
394        let input_schema = input.schema();
395        let value_column_idx = input_schema
396            .column_with_name(VALUE_COLUMN_NAME)
397            .expect_ok()?
398            .0;
399        let row_id_column_idx = input_schema.column_with_name(ROW_ID).expect_ok()?.0;
400
401        // Call json_extract_with_type UDF
402        let exprs = vec![
403            (
404                Arc::new(ScalarFunctionExpr::try_new(
405                    Arc::new(lance_datafusion::udf::json::json_extract_with_type_udf()),
406                    vec![
407                        Arc::new(Column::new(VALUE_COLUMN_NAME, value_column_idx)),
408                        Arc::new(Literal::new(ScalarValue::Utf8(Some(path)))),
409                    ],
410                    &input_schema,
411                    Arc::new(ConfigOptions::default()),
412                )?) as Arc<dyn PhysicalExpr>,
413                "json_result".to_string(),
414            ),
415            (
416                Arc::new(Column::new(ROW_ID, row_id_column_idx)) as Arc<dyn PhysicalExpr>,
417                ROW_ID.to_string(),
418            ),
419        ];
420
421        let project = ProjectionExec::try_new(exprs, input)?;
422        let ctx = get_session_context(&LanceExecutionOptions::default());
423        let mut stream = project.execute(0, ctx.task_ctx())?;
424
425        // Collect batches and determine type from first non-null value
426        let mut all_batches = Vec::new();
427        let mut inferred_type: Option<DataType> = None;
428
429        while let Some(batch_result) = stream.next().await {
430            let batch = batch_result?;
431
432            // Determine type from first non-null value if not yet set
433            if inferred_type.is_none() {
434                if let Some(json_result_column) = batch.column_by_name("json_result") {
435                    if let Some(struct_array) =
436                        json_result_column.as_any().downcast_ref::<StructArray>()
437                    {
438                        if let Some(type_array) = struct_array.column_by_name("type_tag") {
439                            if let Some(uint8_array) =
440                                type_array.as_any().downcast_ref::<UInt8Array>()
441                            {
442                                // Find first non-null value to determine type
443                                for i in 0..uint8_array.len() {
444                                    if !uint8_array.is_null(i) {
445                                        let type_tag = uint8_array.value(i);
446                                        let jsonb_type =
447                                            JsonbType::from_u8(type_tag).ok_or_else(|| {
448                                                Error::InvalidInput {
449                                                    source: format!(
450                                                        "Invalid type tag: {}",
451                                                        type_tag
452                                                    )
453                                                    .into(),
454                                                    location: location!(),
455                                                }
456                                            })?;
457
458                                        // Map JsonbType to Arrow DataType
459                                        inferred_type = Some(match jsonb_type {
460                                            JsonbType::Null => continue, // Skip null values
461                                            JsonbType::Boolean => DataType::Boolean,
462                                            JsonbType::Int64 => DataType::Int64,
463                                            JsonbType::Float64 => DataType::Float64,
464                                            JsonbType::String => DataType::Utf8,
465                                            JsonbType::Array => DataType::LargeBinary,
466                                            JsonbType::Object => DataType::LargeBinary,
467                                        });
468                                        break;
469                                    }
470                                }
471                            }
472                        }
473                    }
474                }
475            }
476
477            all_batches.push(batch);
478        }
479
480        // If no type was inferred (all nulls), default to String
481        let inferred_type = inferred_type.unwrap_or(DataType::Utf8);
482
483        // Recreate stream from collected batches
484        let schema =
485            all_batches
486                .first()
487                .map(|b| b.schema())
488                .ok_or_else(|| Error::InvalidInput {
489                    source: "No batches in stream".into(),
490                    location: location!(),
491                })?;
492
493        let recreated_stream = Box::pin(RecordBatchStreamAdapter::new(
494            schema,
495            futures::stream::iter(all_batches.into_iter().map(Ok)),
496        )) as SendableRecordBatchStream;
497
498        Ok((recreated_stream, inferred_type))
499    }
500
501    /// Convert the stream with JSONB values and type tags to properly typed values
502    async fn convert_stream_by_type(
503        data: SendableRecordBatchStream,
504        target_type: DataType,
505    ) -> Result<SendableRecordBatchStream> {
506        let input = Arc::new(OneShotExec::new(data));
507        let _input_schema = input.schema();
508        let ctx = get_session_context(&LanceExecutionOptions::default());
509        let mut stream = input.execute(0, ctx.task_ctx())?;
510
511        let mut converted_batches = Vec::new();
512
513        while let Some(batch_result) = stream.next().await {
514            let batch = batch_result?;
515
516            // Extract the struct column containing value and type_tag
517            let json_result_column =
518                batch
519                    .column_by_name("json_result")
520                    .ok_or_else(|| Error::InvalidInput {
521                        source: "Missing json_result column".into(),
522                        location: location!(),
523                    })?;
524
525            let struct_array = json_result_column
526                .as_any()
527                .downcast_ref::<StructArray>()
528                .ok_or_else(|| Error::InvalidInput {
529                    source: "json_result is not a struct".into(),
530                    location: location!(),
531                })?;
532
533            let value_array =
534                struct_array
535                    .column_by_name("value")
536                    .ok_or_else(|| Error::InvalidInput {
537                        source: "Missing value column in struct".into(),
538                        location: location!(),
539                    })?;
540
541            let binary_array = value_array
542                .as_any()
543                .downcast_ref::<LargeBinaryArray>()
544                .ok_or_else(|| Error::InvalidInput {
545                    source: "value is not LargeBinary".into(),
546                    location: location!(),
547                })?;
548
549            // Convert based on target type using serde deserialization
550            let converted_array: Arc<dyn Array> = match target_type {
551                DataType::Boolean => {
552                    let mut builder =
553                        arrow_array::builder::BooleanBuilder::with_capacity(binary_array.len());
554                    for i in 0..binary_array.len() {
555                        if binary_array.is_null(i) {
556                            builder.append_null();
557                        } else if let Some(bytes) = binary_array.value(i).into() {
558                            let raw_jsonb = jsonb::RawJsonb::new(bytes);
559                            // Try to deserialize directly to bool
560                            match jsonb::from_raw_jsonb::<bool>(&raw_jsonb) {
561                                Ok(bool_val) => builder.append_value(bool_val),
562                                Err(e) => {
563                                    return Err(Error::InvalidInput {
564                                        source: format!(
565                                            "Failed to deserialize JSONB to bool at index {}: {}",
566                                            i, e
567                                        )
568                                        .into(),
569                                        location: location!(),
570                                    });
571                                }
572                            }
573                        } else {
574                            builder.append_null();
575                        }
576                    }
577                    Arc::new(builder.finish())
578                }
579                DataType::Int64 => {
580                    let mut builder =
581                        arrow_array::builder::Int64Builder::with_capacity(binary_array.len());
582                    for i in 0..binary_array.len() {
583                        if binary_array.is_null(i) {
584                            builder.append_null();
585                        } else if let Some(bytes) = binary_array.value(i).into() {
586                            let raw_jsonb = jsonb::RawJsonb::new(bytes);
587                            // Try to deserialize directly to i64
588                            match jsonb::from_raw_jsonb::<i64>(&raw_jsonb) {
589                                Ok(int_val) => builder.append_value(int_val),
590                                Err(e) => {
591                                    return Err(Error::InvalidInput {
592                                        source: format!(
593                                            "Failed to deserialize JSONB to i64 at index {}: {}",
594                                            i, e
595                                        )
596                                        .into(),
597                                        location: location!(),
598                                    });
599                                }
600                            }
601                        } else {
602                            builder.append_null();
603                        }
604                    }
605                    Arc::new(builder.finish())
606                }
607                DataType::Float64 => {
608                    let mut builder =
609                        arrow_array::builder::Float64Builder::with_capacity(binary_array.len());
610                    for i in 0..binary_array.len() {
611                        if binary_array.is_null(i) {
612                            builder.append_null();
613                        } else if let Some(bytes) = binary_array.value(i).into() {
614                            let raw_jsonb = jsonb::RawJsonb::new(bytes);
615                            // Try to deserialize directly to f64 (serde handles int->float conversion)
616                            match jsonb::from_raw_jsonb::<f64>(&raw_jsonb) {
617                                Ok(float_val) => builder.append_value(float_val),
618                                Err(e) => {
619                                    return Err(Error::InvalidInput {
620                                        source: format!(
621                                            "Failed to deserialize JSONB to f64 at index {}: {}",
622                                            i, e
623                                        )
624                                        .into(),
625                                        location: location!(),
626                                    });
627                                }
628                            }
629                        } else {
630                            builder.append_null();
631                        }
632                    }
633                    Arc::new(builder.finish())
634                }
635                DataType::Utf8 => {
636                    let mut builder = arrow_array::builder::StringBuilder::with_capacity(
637                        binary_array.len(),
638                        1024,
639                    );
640                    for i in 0..binary_array.len() {
641                        if binary_array.is_null(i) {
642                            builder.append_null();
643                        } else if let Some(bytes) = binary_array.value(i).into() {
644                            let raw_jsonb = jsonb::RawJsonb::new(bytes);
645                            // Try to deserialize to String, or use to_string() for any type
646                            match jsonb::from_raw_jsonb::<String>(&raw_jsonb) {
647                                Ok(str_val) => builder.append_value(&str_val),
648                                Err(_) => {
649                                    // For non-string types, convert to string representation
650                                    builder.append_value(raw_jsonb.to_string());
651                                }
652                            }
653                        } else {
654                            builder.append_null();
655                        }
656                    }
657                    Arc::new(builder.finish())
658                }
659                DataType::LargeBinary => {
660                    // Keep as binary for array/object types
661                    value_array.clone()
662                }
663                _ => {
664                    return Err(Error::InvalidInput {
665                        source: format!("Unsupported target type: {:?}", target_type).into(),
666                        location: location!(),
667                    });
668                }
669            };
670
671            // Get row_id column
672            let row_id_column = batch
673                .column_by_name(ROW_ID)
674                .ok_or_else(|| Error::InvalidInput {
675                    source: "Missing row_id column".into(),
676                    location: location!(),
677                })?
678                .clone();
679
680            // Create new batch with converted values
681            let new_schema = Arc::new(Schema::new(vec![
682                ArrowField::new(VALUE_COLUMN_NAME, target_type.clone(), true),
683                ArrowField::new(ROW_ID, DataType::UInt64, false),
684            ]));
685
686            let new_batch =
687                RecordBatch::try_new(new_schema.clone(), vec![converted_array, row_id_column])?;
688
689            converted_batches.push(new_batch);
690        }
691
692        // Create stream from converted batches
693        let schema = converted_batches
694            .first()
695            .map(|b| b.schema())
696            .ok_or_else(|| Error::InvalidInput {
697                source: "No batches to convert".into(),
698                location: location!(),
699            })?;
700
701        Ok(Box::pin(RecordBatchStreamAdapter::new(
702            schema,
703            futures::stream::iter(converted_batches.into_iter().map(Ok)),
704        )))
705    }
706}
707
708#[async_trait]
709impl ScalarIndexPlugin for JsonIndexPlugin {
710    fn name(&self) -> &str {
711        "Json"
712    }
713
714    fn new_training_request(
715        &self,
716        params: &str,
717        field: &Field,
718    ) -> Result<Box<dyn TrainingRequest>> {
719        if !matches!(field.data_type(), DataType::Binary | DataType::LargeBinary) {
720            return Err(Error::InvalidInput {
721                source: "A JSON index can only be created on a Binary or LargeBinary field.".into(),
722                location: location!(),
723            });
724        }
725
726        // Initially use Utf8, will be refined during training with type inference
727        let target_type = DataType::Utf8;
728
729        let params = serde_json::from_str::<JsonIndexParameters>(params)?;
730        let registry = self.registry()?;
731        let target_plugin = registry.get_plugin_by_name(&params.target_index_type)?;
732        let target_request = target_plugin.new_training_request(
733            params.target_index_parameters.as_deref().unwrap_or("{}"),
734            &Field::new("", target_type, true),
735        )?;
736
737        Ok(Box::new(JsonTrainingRequest::new(params, target_request)))
738    }
739
740    fn provides_exact_answer(&self) -> bool {
741        // TODO: Need to lookup target plugin via details to figure this out correctly
742        true
743    }
744
745    fn attach_registry(&self, registry: Arc<IndexPluginRegistry>) {
746        let mut reg_ref = self.registry.lock().unwrap();
747        *reg_ref = Some(registry);
748    }
749
750    fn version(&self) -> u32 {
751        JSON_INDEX_VERSION
752    }
753
754    fn new_query_parser(
755        &self,
756        index_name: String,
757        index_details: &prost_types::Any,
758    ) -> Option<Box<dyn ScalarQueryParser>> {
759        // TODO: Allow return Result here
760        let registry = self.registry().unwrap();
761        let json_details =
762            crate::pb::JsonIndexDetails::decode(index_details.value.as_slice()).unwrap();
763        let target_details = json_details.target_details.as_ref().expect_ok().unwrap();
764        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
765        // TODO: Use something like ${index_name}_${path} for the index name?  Don't have access to path here tho
766        let target_parser = target_plugin.new_query_parser(index_name, index_details)?;
767        Some(Box::new(JsonQueryParser::new(
768            json_details.path.clone(),
769            target_parser,
770        )) as Box<dyn ScalarQueryParser>)
771    }
772
773    async fn train_index(
774        &self,
775        data: SendableRecordBatchStream,
776        index_store: &dyn IndexStore,
777        request: Box<dyn TrainingRequest>,
778        fragment_ids: Option<Vec<u32>>,
779    ) -> Result<CreatedIndex> {
780        let request = (request as Box<dyn std::any::Any>)
781            .downcast::<JsonTrainingRequest>()
782            .unwrap();
783        let path = request.parameters.path.clone();
784
785        // Extract JSON with type information
786        let (data_stream, inferred_type) =
787            Self::extract_json_with_type_info(data, path.clone()).await?;
788
789        // Convert the stream to properly typed values based on inferred type
790        let converted_stream =
791            Self::convert_stream_by_type(data_stream, inferred_type.clone()).await?;
792
793        // Update the target request with inferred type
794        let registry = self.registry()?;
795        let target_plugin = registry.get_plugin_by_name(&request.parameters.target_index_type)?;
796
797        // Create a new training request with the inferred type
798        let target_request = target_plugin.new_training_request(
799            request
800                .parameters
801                .target_index_parameters
802                .as_deref()
803                .unwrap_or("{}"),
804            &Field::new("", inferred_type, true),
805        )?;
806
807        let target_index = target_plugin
808            .train_index(converted_stream, index_store, target_request, fragment_ids)
809            .await?;
810
811        let index_details = crate::pb::JsonIndexDetails {
812            path,
813            target_details: Some(target_index.index_details),
814        };
815        Ok(CreatedIndex {
816            index_details: prost_types::Any::from_msg(&index_details)?,
817            index_version: JSON_INDEX_VERSION,
818        })
819    }
820
821    async fn load_index(
822        &self,
823        index_store: Arc<dyn IndexStore>,
824        index_details: &prost_types::Any,
825        frag_reuse_index: Option<Arc<FragReuseIndex>>,
826        cache: &LanceCache,
827    ) -> Result<Arc<dyn ScalarIndex>> {
828        let registry = self.registry().unwrap();
829        let json_details = crate::pb::JsonIndexDetails::decode(index_details.value.as_slice())?;
830        let target_details = json_details.target_details.as_ref().expect_ok()?;
831        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
832        let target_index = target_plugin
833            .load_index(index_store, target_details, frag_reuse_index, cache)
834            .await?;
835        Ok(Arc::new(JsonIndex::new(target_index, json_details.path)))
836    }
837
838    fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
839        let registry = self.registry().unwrap();
840        let json_details = crate::pb::JsonIndexDetails::decode(details.value.as_slice())?;
841        let target_details = json_details.target_details.as_ref().expect_ok()?;
842        let target_plugin = registry.get_plugin_by_details(target_details).unwrap();
843        let target_details_json = target_plugin.details_as_json(target_details)?;
844        Ok(serde_json::json!({
845            "path": json_details.path,
846            "target_details": target_details_json,
847        }))
848    }
849}
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854    use arrow_array::{ArrayRef, RecordBatch};
855    use arrow_schema::{DataType, Field, Schema};
856    use std::sync::Arc;
857
858    // Note: The old test_detect_json_value_type test has been removed as we now use
859    // JSONB's inherent type information instead of string-based type detection
860
861    #[tokio::test]
862    async fn test_json_extract_with_type_info() {
863        use arrow_array::{LargeBinaryArray, UInt64Array};
864        use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
865        use futures::stream;
866
867        // Create test JSONB data
868        let json_data = vec![
869            r#"{"name": "Alice", "age": 30, "active": true}"#,
870            r#"{"name": "Bob", "age": 25, "active": false}"#,
871            r#"{"name": "Charlie", "age": 35, "active": true}"#,
872        ];
873
874        // Convert JSON strings to JSONB binary format
875        let mut jsonb_values = Vec::new();
876        for json_str in &json_data {
877            let owned_jsonb: jsonb::OwnedJsonb = json_str.parse().unwrap();
878            jsonb_values.push(Some(owned_jsonb.to_vec()));
879        }
880
881        // Create test batch with JSONB data
882        let schema = Arc::new(Schema::new(vec![
883            Field::new(VALUE_COLUMN_NAME, DataType::LargeBinary, true),
884            Field::new(ROW_ID, DataType::UInt64, false),
885        ]));
886
887        let jsonb_array = LargeBinaryArray::from(
888            jsonb_values
889                .iter()
890                .map(|v| v.as_deref())
891                .collect::<Vec<_>>(),
892        );
893        let row_ids = UInt64Array::from(vec![1, 2, 3]);
894
895        let batch = RecordBatch::try_new(
896            schema.clone(),
897            vec![
898                Arc::new(jsonb_array) as ArrayRef,
899                Arc::new(row_ids) as ArrayRef,
900            ],
901        )
902        .unwrap();
903
904        let stream = Box::pin(RecordBatchStreamAdapter::new(
905            schema.clone(),
906            stream::iter(vec![Ok(batch)]),
907        )) as SendableRecordBatchStream;
908
909        // Test type inference for integer field
910        let (_result_stream, inferred_type) =
911            JsonIndexPlugin::extract_json_with_type_info(stream, "$.age".to_string())
912                .await
913                .unwrap();
914
915        assert_eq!(inferred_type, DataType::Int64);
916
917        // Create new test stream for boolean field
918        let batch2 = RecordBatch::try_new(
919            schema.clone(),
920            vec![
921                Arc::new(LargeBinaryArray::from(vec![
922                    json_data[0]
923                        .parse::<jsonb::OwnedJsonb>()
924                        .ok()
925                        .map(|j| j.to_vec())
926                        .as_deref(),
927                    json_data[1]
928                        .parse::<jsonb::OwnedJsonb>()
929                        .ok()
930                        .map(|j| j.to_vec())
931                        .as_deref(),
932                    json_data[2]
933                        .parse::<jsonb::OwnedJsonb>()
934                        .ok()
935                        .map(|j| j.to_vec())
936                        .as_deref(),
937                ])) as ArrayRef,
938                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
939            ],
940        )
941        .unwrap();
942
943        let stream2 = Box::pin(RecordBatchStreamAdapter::new(
944            schema.clone(),
945            stream::iter(vec![Ok(batch2)]),
946        )) as SendableRecordBatchStream;
947
948        // Test type inference for boolean field
949        let (_, inferred_type) =
950            JsonIndexPlugin::extract_json_with_type_info(stream2, "$.active".to_string())
951                .await
952                .unwrap();
953
954        assert_eq!(inferred_type, DataType::Boolean);
955
956        // Create test stream for string field
957        let batch3 = RecordBatch::try_new(
958            schema.clone(),
959            vec![
960                Arc::new(LargeBinaryArray::from(vec![
961                    json_data[0]
962                        .parse::<jsonb::OwnedJsonb>()
963                        .ok()
964                        .map(|j| j.to_vec())
965                        .as_deref(),
966                    json_data[1]
967                        .parse::<jsonb::OwnedJsonb>()
968                        .ok()
969                        .map(|j| j.to_vec())
970                        .as_deref(),
971                    json_data[2]
972                        .parse::<jsonb::OwnedJsonb>()
973                        .ok()
974                        .map(|j| j.to_vec())
975                        .as_deref(),
976                ])) as ArrayRef,
977                Arc::new(UInt64Array::from(vec![1, 2, 3])) as ArrayRef,
978            ],
979        )
980        .unwrap();
981
982        let stream3 = Box::pin(RecordBatchStreamAdapter::new(
983            schema,
984            stream::iter(vec![Ok(batch3)]),
985        )) as SendableRecordBatchStream;
986
987        // Test type inference for string field
988        let (_, inferred_type) =
989            JsonIndexPlugin::extract_json_with_type_info(stream3, "$.name".to_string())
990                .await
991                .unwrap();
992
993        assert_eq!(inferred_type, DataType::Utf8);
994    }
995}