1use std::{ops::Range, sync::Arc};
5
6use arrow_array::{Array, ArrayRef, ListArray, MapArray};
7use arrow_schema::DataType;
8use futures::future::BoxFuture;
9use lance_arrow::deepcopy::deep_copy_nulls;
10use lance_arrow::list::ListArrayExt;
11use lance_core::{Error, Result};
12
13use crate::{
14 decoder::{
15 DecodedArray, FilterExpression, ScheduledScanLine, SchedulerContext,
16 StructuralDecodeArrayTask, StructuralFieldDecoder, StructuralFieldScheduler,
17 StructuralSchedulingJob,
18 },
19 encoder::{EncodeTask, FieldEncoder, OutOfLineBuffers},
20 repdef::RepDefBuilder,
21};
22
23pub struct MapStructuralEncoder {
29 keep_original_array: bool,
30 child: Box<dyn FieldEncoder>,
31}
32
33impl MapStructuralEncoder {
34 pub fn new(keep_original_array: bool, child: Box<dyn FieldEncoder>) -> Self {
35 Self {
36 keep_original_array,
37 child,
38 }
39 }
40}
41
42impl FieldEncoder for MapStructuralEncoder {
43 fn maybe_encode(
44 &mut self,
45 array: ArrayRef,
46 external_buffers: &mut OutOfLineBuffers,
47 mut repdef: RepDefBuilder,
48 row_number: u64,
49 num_rows: u64,
50 ) -> Result<Vec<EncodeTask>> {
51 let map_array = array
52 .as_any()
53 .downcast_ref::<MapArray>()
54 .expect("MapEncoder used for non-map data");
55
56 let has_garbage_values = if self.keep_original_array {
58 repdef.add_offsets(map_array.offsets().clone(), array.nulls().cloned())
59 } else {
60 repdef.add_offsets(map_array.offsets().clone(), deep_copy_nulls(array.nulls()))
61 };
62
63 let list_array: ListArray = map_array.clone().into();
65 let entries = if has_garbage_values {
66 list_array.filter_garbage_nulls().trimmed_values()
67 } else {
68 list_array.trimmed_values()
69 };
70
71 self.child
72 .maybe_encode(entries, external_buffers, repdef, row_number, num_rows)
73 }
74
75 fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>> {
76 self.child.flush(external_buffers)
77 }
78
79 fn num_columns(&self) -> u32 {
80 self.child.num_columns()
81 }
82
83 fn finish(
84 &mut self,
85 external_buffers: &mut OutOfLineBuffers,
86 ) -> BoxFuture<'_, Result<Vec<crate::encoder::EncodedColumn>>> {
87 self.child.finish(external_buffers)
88 }
89}
90
91#[derive(Debug)]
92pub struct StructuralMapScheduler {
93 child: Box<dyn StructuralFieldScheduler>,
94}
95
96impl StructuralMapScheduler {
97 pub fn new(child: Box<dyn StructuralFieldScheduler>) -> Self {
98 Self { child }
99 }
100}
101
102impl StructuralFieldScheduler for StructuralMapScheduler {
103 fn schedule_ranges<'a>(
104 &'a self,
105 ranges: &[Range<u64>],
106 filter: &FilterExpression,
107 ) -> Result<Box<dyn StructuralSchedulingJob + 'a>> {
108 let child = self.child.schedule_ranges(ranges, filter)?;
109
110 Ok(Box::new(StructuralMapSchedulingJob::new(child)))
111 }
112
113 fn initialize<'a>(
114 &'a mut self,
115 filter: &'a FilterExpression,
116 context: &'a SchedulerContext,
117 ) -> BoxFuture<'a, Result<()>> {
118 self.child.initialize(filter, context)
119 }
120}
121
122#[derive(Debug)]
127struct StructuralMapSchedulingJob<'a> {
128 child: Box<dyn StructuralSchedulingJob + 'a>,
129}
130
131impl<'a> StructuralMapSchedulingJob<'a> {
132 fn new(child: Box<dyn StructuralSchedulingJob + 'a>) -> Self {
133 Self { child }
134 }
135}
136
137impl StructuralSchedulingJob for StructuralMapSchedulingJob<'_> {
138 fn schedule_next(&mut self, context: &mut SchedulerContext) -> Result<Vec<ScheduledScanLine>> {
139 self.child.schedule_next(context)
140 }
141}
142
143#[derive(Debug)]
144pub struct StructuralMapDecoder {
145 child: Box<dyn StructuralFieldDecoder>,
146 data_type: DataType,
147}
148
149impl StructuralMapDecoder {
150 pub fn new(child: Box<dyn StructuralFieldDecoder>, data_type: DataType) -> Self {
151 Self { child, data_type }
152 }
153}
154
155impl StructuralFieldDecoder for StructuralMapDecoder {
156 fn accept_page(&mut self, child: crate::decoder::LoadedPageShard) -> Result<()> {
157 self.child.accept_page(child)
158 }
159
160 fn drain(&mut self, num_rows: u64) -> Result<Box<dyn StructuralDecodeArrayTask>> {
161 let child_task = self.child.drain(num_rows)?;
162 Ok(Box::new(StructuralMapDecodeTask::new(
163 child_task,
164 self.data_type.clone(),
165 )))
166 }
167
168 fn data_type(&self) -> &DataType {
169 &self.data_type
170 }
171}
172
173#[derive(Debug)]
174struct StructuralMapDecodeTask {
175 child_task: Box<dyn StructuralDecodeArrayTask>,
176 data_type: DataType,
177}
178
179impl StructuralMapDecodeTask {
180 fn new(child_task: Box<dyn StructuralDecodeArrayTask>, data_type: DataType) -> Self {
181 Self {
182 child_task,
183 data_type,
184 }
185 }
186}
187
188impl StructuralDecodeArrayTask for StructuralMapDecodeTask {
189 fn decode(self: Box<Self>) -> Result<DecodedArray> {
190 let DecodedArray { array, mut repdef } = self.child_task.decode()?;
191
192 let (offsets, validity) = repdef.unravel_offsets::<i32>()?;
194
195 let (entries_field, keys_sorted) = match &self.data_type {
197 DataType::Map(field, keys_sorted) => {
198 if *keys_sorted {
199 return Err(Error::not_supported_source(
200 "Map type decoder does not support keys_sorted=true now"
201 .to_string()
202 .into(),
203 ));
204 }
205 (field.clone(), *keys_sorted)
206 }
207 _ => {
208 return Err(Error::schema(
209 "Map decoder did not have a map field".to_string(),
210 ));
211 }
212 };
213
214 let entries = array
216 .as_any()
217 .downcast_ref::<arrow_array::StructArray>()
218 .ok_or_else(|| Error::schema("Map entries should be a StructArray".to_string()))?
219 .clone();
220
221 let map_array = MapArray::new(entries_field, offsets, entries, validity, keys_sorted);
223
224 Ok(DecodedArray {
225 array: Arc::new(map_array),
226 repdef,
227 })
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use std::{collections::HashMap, sync::Arc};
234
235 use arrow_array::{
236 Array, Int32Array, MapArray, StringArray, StructArray,
237 builder::{Int32Builder, MapBuilder, StringBuilder},
238 };
239 use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
240 use arrow_schema::{DataType, Field, Fields};
241
242 use crate::encoder::{ColumnIndexSequence, EncodingOptions, default_encoding_strategy};
243 use crate::{
244 testing::{TestCases, check_round_trip_encoding_of_data},
245 version::LanceFileVersion,
246 };
247 use arrow_schema::Field as ArrowField;
248 use lance_core::datatypes::Field as LanceField;
249
250 fn make_map_type(key_type: DataType, value_type: DataType) -> DataType {
251 let entries = Field::new(
253 "entries",
254 DataType::Struct(Fields::from(vec![
255 Field::new("keys", key_type, false),
256 Field::new("values", value_type, true),
257 ])),
258 false,
259 );
260 DataType::Map(Arc::new(entries), false)
261 }
262
263 #[test_log::test(tokio::test)]
264 async fn test_simple_map() {
265 let string_builder = StringBuilder::new();
267 let int_builder = Int32Builder::new();
268 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
269
270 map_builder.keys().append_value("key1");
272 map_builder.values().append_value(10);
273 map_builder.keys().append_value("key2");
274 map_builder.values().append_value(20);
275 map_builder.append(true).unwrap();
276
277 map_builder.keys().append_value("key3");
279 map_builder.values().append_value(30);
280 map_builder.append(true).unwrap();
281
282 let map_array = map_builder.finish();
283
284 let test_cases = TestCases::default()
285 .with_range(0..2)
286 .with_min_file_version(LanceFileVersion::V2_2);
287
288 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
289 .await;
290 }
291
292 #[test_log::test(tokio::test)]
293 async fn test_empty_maps() {
294 let string_builder = StringBuilder::new();
296 let int_builder = Int32Builder::new();
297 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
298
299 map_builder.keys().append_value("a");
301 map_builder.values().append_value(1);
302 map_builder.append(true).unwrap();
303
304 map_builder.append(true).unwrap();
306
307 map_builder.append(false).unwrap();
309
310 map_builder.append(true).unwrap();
312
313 let map_array = map_builder.finish();
314
315 let test_cases = TestCases::default()
316 .with_range(0..4)
317 .with_indices(vec![1])
318 .with_indices(vec![2])
319 .with_min_file_version(LanceFileVersion::V2_2);
320
321 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
322 .await;
323 }
324
325 #[test_log::test(tokio::test)]
326 async fn test_map_with_null_values() {
327 let string_builder = StringBuilder::new();
329 let int_builder = Int32Builder::new();
330 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
331
332 map_builder.keys().append_value("key1");
334 map_builder.values().append_value(10);
335 map_builder.keys().append_value("key2");
336 map_builder.values().append_null();
337 map_builder.append(true).unwrap();
338
339 map_builder.keys().append_value("key3");
341 map_builder.values().append_null();
342 map_builder.append(true).unwrap();
343
344 let map_array = map_builder.finish();
345
346 let test_cases = TestCases::default()
347 .with_range(0..2)
348 .with_indices(vec![0])
349 .with_indices(vec![1])
350 .with_min_file_version(LanceFileVersion::V2_2);
351
352 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
353 .await;
354 }
355
356 #[test_log::test(tokio::test)]
357 async fn test_map_in_struct() {
358 let string_key_builder = StringBuilder::new();
362 let string_val_builder = StringBuilder::new();
363 let mut map_builder = MapBuilder::new(None, string_key_builder, string_val_builder);
364
365 map_builder.keys().append_value("name");
367 map_builder.values().append_value("Alice");
368 map_builder.keys().append_value("city");
369 map_builder.values().append_value("NYC");
370 map_builder.append(true).unwrap();
371
372 map_builder.keys().append_value("name");
374 map_builder.values().append_value("Bob");
375 map_builder.append(true).unwrap();
376
377 map_builder.append(false).unwrap();
379
380 let map_array = Arc::new(map_builder.finish());
381 let id_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
382
383 let struct_array = StructArray::new(
384 Fields::from(vec![
385 Field::new("id", DataType::Int32, false),
386 Field::new(
387 "properties",
388 make_map_type(DataType::Utf8, DataType::Utf8),
389 true,
390 ),
391 ]),
392 vec![id_array, map_array],
393 None,
394 );
395
396 let test_cases = TestCases::default()
397 .with_range(0..3)
398 .with_indices(vec![0, 2])
399 .with_min_file_version(LanceFileVersion::V2_2);
400
401 check_round_trip_encoding_of_data(
402 vec![Arc::new(struct_array)],
403 &test_cases,
404 HashMap::new(),
405 )
406 .await;
407 }
408
409 #[test_log::test(tokio::test)]
410 async fn test_map_in_nullable_struct() {
411 let entries_fields = Fields::from(vec![
414 Field::new("keys", DataType::Utf8, false),
415 Field::new("values", DataType::Int32, true),
416 ]);
417 let entries_field = Arc::new(Field::new(
418 "entries",
419 DataType::Struct(entries_fields.clone()),
420 false,
421 ));
422 let map_entries = StructArray::new(
423 entries_fields,
424 vec![
425 Arc::new(StringArray::from(vec!["a", "garbage", "b"])),
426 Arc::new(Int32Array::from(vec![1, 999, 2])),
427 ],
428 None,
429 );
430 let map_array: Arc<dyn Array> = Arc::new(MapArray::new(
432 entries_field,
433 OffsetBuffer::new(ScalarBuffer::from(vec![0, 1, 2, 3])),
434 map_entries,
435 None, false,
437 ));
438
439 let struct_array = StructArray::new(
440 Fields::from(vec![
441 Field::new("id", DataType::Int32, true),
442 Field::new("props", map_array.data_type().clone(), true),
443 ]),
444 vec![
445 Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
446 map_array,
447 ],
448 Some(NullBuffer::from(vec![true, false, true])), );
450
451 let test_cases = TestCases::default()
452 .with_range(0..3)
453 .with_min_file_version(LanceFileVersion::V2_2);
454
455 check_round_trip_encoding_of_data(
456 vec![Arc::new(struct_array)],
457 &test_cases,
458 HashMap::new(),
459 )
460 .await;
461 }
462
463 #[test_log::test(tokio::test)]
464 async fn test_list_of_maps() {
465 use arrow_array::builder::ListBuilder;
467
468 let string_builder = StringBuilder::new();
469 let int_builder = Int32Builder::new();
470 let map_builder = MapBuilder::new(None, string_builder, int_builder);
471 let mut list_builder = ListBuilder::new(map_builder);
472
473 list_builder.values().keys().append_value("a");
475 list_builder.values().values().append_value(1);
476 list_builder.values().append(true).unwrap();
477
478 list_builder.values().keys().append_value("b");
479 list_builder.values().values().append_value(2);
480 list_builder.values().append(true).unwrap();
481
482 list_builder.append(true);
483
484 list_builder.values().keys().append_value("c");
486 list_builder.values().values().append_value(3);
487 list_builder.values().append(true).unwrap();
488
489 list_builder.append(true);
490
491 list_builder.append(true);
493
494 let list_array = list_builder.finish();
495
496 let test_cases = TestCases::default()
497 .with_range(0..3)
498 .with_indices(vec![0, 2])
499 .with_min_file_version(LanceFileVersion::V2_2);
500
501 check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, HashMap::new())
502 .await;
503 }
504
505 #[test_log::test(tokio::test)]
506 async fn test_nested_map() {
507 let inner_string_builder = StringBuilder::new();
512 let inner_int_builder = Int32Builder::new();
513 let mut inner_map_builder1 = MapBuilder::new(None, inner_string_builder, inner_int_builder);
514
515 inner_map_builder1.keys().append_value("x");
517 inner_map_builder1.values().append_value(10);
518 inner_map_builder1.append(true).unwrap();
519
520 inner_map_builder1.keys().append_value("y");
522 inner_map_builder1.values().append_value(20);
523 inner_map_builder1.keys().append_value("z");
524 inner_map_builder1.values().append_value(30);
525 inner_map_builder1.append(true).unwrap();
526
527 let inner_maps = Arc::new(inner_map_builder1.finish());
528
529 let outer_keys = Arc::new(StringArray::from(vec!["key1", "key2"]));
531
532 let entries_struct = StructArray::new(
534 Fields::from(vec![
535 Field::new("key", DataType::Utf8, false),
536 Field::new(
537 "value",
538 make_map_type(DataType::Utf8, DataType::Int32),
539 true,
540 ),
541 ]),
542 vec![outer_keys, inner_maps],
543 None,
544 );
545
546 let offsets = OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, 2]));
547 let entries_field = Field::new("entries", entries_struct.data_type().clone(), false);
548
549 let outer_map = MapArray::new(
550 Arc::new(entries_field),
551 offsets,
552 entries_struct,
553 None,
554 false,
555 );
556
557 let test_cases = TestCases::default()
558 .with_range(0..1)
559 .with_min_file_version(LanceFileVersion::V2_2);
560
561 check_round_trip_encoding_of_data(vec![Arc::new(outer_map)], &test_cases, HashMap::new())
562 .await;
563 }
564
565 #[test_log::test(tokio::test)]
566 async fn test_map_different_key_types() {
567 let int_builder = Int32Builder::new();
569 let string_builder = StringBuilder::new();
570 let mut map_builder = MapBuilder::new(None, int_builder, string_builder);
571
572 map_builder.keys().append_value(1);
574 map_builder.values().append_value("one");
575 map_builder.keys().append_value(2);
576 map_builder.values().append_value("two");
577 map_builder.append(true).unwrap();
578
579 map_builder.keys().append_value(3);
581 map_builder.values().append_value("three");
582 map_builder.append(true).unwrap();
583
584 let map_array = map_builder.finish();
585
586 let test_cases = TestCases::default()
587 .with_range(0..2)
588 .with_indices(vec![0, 1])
589 .with_min_file_version(LanceFileVersion::V2_2);
590
591 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
592 .await;
593 }
594
595 #[test_log::test(tokio::test)]
596 async fn test_map_with_extreme_sizes() {
597 let string_builder = StringBuilder::new();
599 let int_builder = Int32Builder::new();
600 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
601
602 for i in 0..100 {
604 map_builder.keys().append_value(format!("key{}", i));
605 map_builder.values().append_value(i);
606 }
607 map_builder.append(true).unwrap();
608
609 map_builder.append(true).unwrap();
611
612 let map_array = map_builder.finish();
613
614 let test_cases = TestCases::default()
615 .with_range(0..2)
616 .with_min_file_version(LanceFileVersion::V2_2);
617
618 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
619 .await;
620 }
621
622 #[test_log::test(tokio::test)]
623 async fn test_map_all_null() {
624 let string_builder = StringBuilder::new();
626 let int_builder = Int32Builder::new();
627 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
628
629 map_builder.append(false).unwrap(); map_builder.append(false).unwrap(); let map_array = map_builder.finish();
634
635 let test_cases = TestCases::default()
636 .with_range(0..2)
637 .with_min_file_version(LanceFileVersion::V2_2);
638
639 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
640 .await;
641 }
642
643 #[test_log::test(tokio::test)]
644 async fn test_map_encoder_keep_original_array_scenarios() {
645 let string_builder = StringBuilder::new();
648 let int_builder = Int32Builder::new();
649 let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
650
651 map_builder.keys().append_value("key1");
654 map_builder.values().append_value(10);
655 map_builder.keys().append_value("key2");
656 map_builder.values().append_null();
657 map_builder.append(true).unwrap();
658
659 map_builder.append(false).unwrap();
661
662 map_builder.keys().append_value("key3");
664 map_builder.values().append_value(30);
665 map_builder.append(true).unwrap();
666
667 let map_array = map_builder.finish();
668
669 let test_cases = TestCases::default()
670 .with_range(0..3)
671 .with_indices(vec![0, 1, 2])
672 .with_min_file_version(LanceFileVersion::V2_2);
673
674 check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
677 .await;
678 }
679
680 #[test]
681 fn test_map_not_supported_write_in_v2_1() {
682 let map_arrow_field = ArrowField::new(
684 "map_field",
685 make_map_type(DataType::Utf8, DataType::Int32),
686 true,
687 );
688 let map_field = LanceField::try_from(&map_arrow_field).unwrap();
689
690 let encoder_strategy = default_encoding_strategy(LanceFileVersion::V2_1);
692 let mut column_index = ColumnIndexSequence::default();
693 let options = EncodingOptions::default();
694
695 let encoder_result = encoder_strategy.create_field_encoder(
696 encoder_strategy.as_ref(),
697 &map_field,
698 &mut column_index,
699 &options,
700 );
701
702 assert!(
703 encoder_result.is_err(),
704 "Map type should not be supported in V2_1 for encoder"
705 );
706 let Err(encoder_err) = encoder_result else {
707 panic!("Expected error but got Ok")
708 };
709
710 let encoder_err_msg = format!("{}", encoder_err);
711 assert!(
712 encoder_err_msg.contains("2.2"),
713 "Encoder error message should mention version 2.2, got: {}",
714 encoder_err_msg
715 );
716 assert!(
717 encoder_err_msg.contains("Map data type"),
718 "Encoder error message should mention Map data type, got: {}",
719 encoder_err_msg
720 );
721 }
722}