use std::{ops::Range, sync::Arc};
use arrow_array::{Array, ArrayRef, ListArray, MapArray};
use arrow_schema::DataType;
use futures::future::BoxFuture;
use lance_arrow::deepcopy::deep_copy_nulls;
use lance_arrow::list::ListArrayExt;
use lance_core::{Error, Result};
use crate::{
decoder::{
DecodedArray, FilterExpression, ScheduledScanLine, SchedulerContext,
StructuralDecodeArrayTask, StructuralFieldDecoder, StructuralFieldScheduler,
StructuralSchedulingJob,
},
encoder::{EncodeTask, FieldEncoder, OutOfLineBuffers},
repdef::RepDefBuilder,
};
pub struct MapStructuralEncoder {
keep_original_array: bool,
child: Box<dyn FieldEncoder>,
}
impl MapStructuralEncoder {
pub fn new(keep_original_array: bool, child: Box<dyn FieldEncoder>) -> Self {
Self {
keep_original_array,
child,
}
}
}
impl FieldEncoder for MapStructuralEncoder {
fn maybe_encode(
&mut self,
array: ArrayRef,
external_buffers: &mut OutOfLineBuffers,
mut repdef: RepDefBuilder,
row_number: u64,
num_rows: u64,
) -> Result<Vec<EncodeTask>> {
let map_array = array
.as_any()
.downcast_ref::<MapArray>()
.expect("MapEncoder used for non-map data");
let has_garbage_values = if self.keep_original_array {
repdef.add_offsets(map_array.offsets().clone(), array.nulls().cloned())
} else {
repdef.add_offsets(map_array.offsets().clone(), deep_copy_nulls(array.nulls()))
};
let list_array: ListArray = map_array.clone().into();
let entries = if has_garbage_values {
list_array.filter_garbage_nulls().trimmed_values()
} else {
list_array.trimmed_values()
};
self.child
.maybe_encode(entries, external_buffers, repdef, row_number, num_rows)
}
fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>> {
self.child.flush(external_buffers)
}
fn num_columns(&self) -> u32 {
self.child.num_columns()
}
fn finish(
&mut self,
external_buffers: &mut OutOfLineBuffers,
) -> BoxFuture<'_, Result<Vec<crate::encoder::EncodedColumn>>> {
self.child.finish(external_buffers)
}
}
#[derive(Debug)]
pub struct StructuralMapScheduler {
child: Box<dyn StructuralFieldScheduler>,
}
impl StructuralMapScheduler {
pub fn new(child: Box<dyn StructuralFieldScheduler>) -> Self {
Self { child }
}
}
impl StructuralFieldScheduler for StructuralMapScheduler {
fn schedule_ranges<'a>(
&'a self,
ranges: &[Range<u64>],
filter: &FilterExpression,
) -> Result<Box<dyn StructuralSchedulingJob + 'a>> {
let child = self.child.schedule_ranges(ranges, filter)?;
Ok(Box::new(StructuralMapSchedulingJob::new(child)))
}
fn initialize<'a>(
&'a mut self,
filter: &'a FilterExpression,
context: &'a SchedulerContext,
) -> BoxFuture<'a, Result<()>> {
self.child.initialize(filter, context)
}
}
#[derive(Debug)]
struct StructuralMapSchedulingJob<'a> {
child: Box<dyn StructuralSchedulingJob + 'a>,
}
impl<'a> StructuralMapSchedulingJob<'a> {
fn new(child: Box<dyn StructuralSchedulingJob + 'a>) -> Self {
Self { child }
}
}
impl StructuralSchedulingJob for StructuralMapSchedulingJob<'_> {
fn schedule_next(&mut self, context: &mut SchedulerContext) -> Result<Vec<ScheduledScanLine>> {
self.child.schedule_next(context)
}
}
#[derive(Debug)]
pub struct StructuralMapDecoder {
child: Box<dyn StructuralFieldDecoder>,
data_type: DataType,
}
impl StructuralMapDecoder {
pub fn new(child: Box<dyn StructuralFieldDecoder>, data_type: DataType) -> Self {
Self { child, data_type }
}
}
impl StructuralFieldDecoder for StructuralMapDecoder {
fn accept_page(&mut self, child: crate::decoder::LoadedPageShard) -> Result<()> {
self.child.accept_page(child)
}
fn drain(&mut self, num_rows: u64) -> Result<Box<dyn StructuralDecodeArrayTask>> {
let child_task = self.child.drain(num_rows)?;
Ok(Box::new(StructuralMapDecodeTask::new(
child_task,
self.data_type.clone(),
)))
}
fn data_type(&self) -> &DataType {
&self.data_type
}
}
#[derive(Debug)]
struct StructuralMapDecodeTask {
child_task: Box<dyn StructuralDecodeArrayTask>,
data_type: DataType,
}
impl StructuralMapDecodeTask {
fn new(child_task: Box<dyn StructuralDecodeArrayTask>, data_type: DataType) -> Self {
Self {
child_task,
data_type,
}
}
}
impl StructuralDecodeArrayTask for StructuralMapDecodeTask {
fn decode(self: Box<Self>) -> Result<DecodedArray> {
let DecodedArray { array, mut repdef } = self.child_task.decode()?;
let (offsets, validity) = repdef.unravel_offsets::<i32>()?;
let (entries_field, keys_sorted) = match &self.data_type {
DataType::Map(field, keys_sorted) => {
if *keys_sorted {
return Err(Error::not_supported_source(
"Map type decoder does not support keys_sorted=true now"
.to_string()
.into(),
));
}
(field.clone(), *keys_sorted)
}
_ => {
return Err(Error::schema(
"Map decoder did not have a map field".to_string(),
));
}
};
let entries = array
.as_any()
.downcast_ref::<arrow_array::StructArray>()
.ok_or_else(|| Error::schema("Map entries should be a StructArray".to_string()))?
.clone();
let map_array = MapArray::new(entries_field, offsets, entries, validity, keys_sorted);
Ok(DecodedArray {
array: Arc::new(map_array),
repdef,
})
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use arrow_array::{
Array, Int32Array, MapArray, StringArray, StructArray,
builder::{Int32Builder, MapBuilder, StringBuilder},
};
use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Fields};
use crate::encoder::{ColumnIndexSequence, EncodingOptions, default_encoding_strategy};
use crate::{
testing::{TestCases, check_round_trip_encoding_of_data},
version::LanceFileVersion,
};
use arrow_schema::Field as ArrowField;
use lance_core::datatypes::Field as LanceField;
fn make_map_type(key_type: DataType, value_type: DataType) -> DataType {
let entries = Field::new(
"entries",
DataType::Struct(Fields::from(vec![
Field::new("keys", key_type, false),
Field::new("values", value_type, true),
])),
false,
);
DataType::Map(Arc::new(entries), false)
}
#[test_log::test(tokio::test)]
async fn test_simple_map() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.keys().append_value("key1");
map_builder.values().append_value(10);
map_builder.keys().append_value("key2");
map_builder.values().append_value(20);
map_builder.append(true).unwrap();
map_builder.keys().append_value("key3");
map_builder.values().append_value(30);
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..2)
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_empty_maps() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.keys().append_value("a");
map_builder.values().append_value(1);
map_builder.append(true).unwrap();
map_builder.append(true).unwrap();
map_builder.append(false).unwrap();
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..4)
.with_indices(vec![1])
.with_indices(vec![2])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_with_null_values() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.keys().append_value("key1");
map_builder.values().append_value(10);
map_builder.keys().append_value("key2");
map_builder.values().append_null();
map_builder.append(true).unwrap();
map_builder.keys().append_value("key3");
map_builder.values().append_null();
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..2)
.with_indices(vec![0])
.with_indices(vec![1])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_in_struct() {
let string_key_builder = StringBuilder::new();
let string_val_builder = StringBuilder::new();
let mut map_builder = MapBuilder::new(None, string_key_builder, string_val_builder);
map_builder.keys().append_value("name");
map_builder.values().append_value("Alice");
map_builder.keys().append_value("city");
map_builder.values().append_value("NYC");
map_builder.append(true).unwrap();
map_builder.keys().append_value("name");
map_builder.values().append_value("Bob");
map_builder.append(true).unwrap();
map_builder.append(false).unwrap();
let map_array = Arc::new(map_builder.finish());
let id_array = Arc::new(Int32Array::from(vec![1, 2, 3]));
let struct_array = StructArray::new(
Fields::from(vec![
Field::new("id", DataType::Int32, false),
Field::new(
"properties",
make_map_type(DataType::Utf8, DataType::Utf8),
true,
),
]),
vec![id_array, map_array],
None,
);
let test_cases = TestCases::default()
.with_range(0..3)
.with_indices(vec![0, 2])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(
vec![Arc::new(struct_array)],
&test_cases,
HashMap::new(),
)
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_in_nullable_struct() {
let entries_fields = Fields::from(vec![
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Int32, true),
]);
let entries_field = Arc::new(Field::new(
"entries",
DataType::Struct(entries_fields.clone()),
false,
));
let map_entries = StructArray::new(
entries_fields,
vec![
Arc::new(StringArray::from(vec!["a", "garbage", "b"])),
Arc::new(Int32Array::from(vec![1, 999, 2])),
],
None,
);
let map_array: Arc<dyn Array> = Arc::new(MapArray::new(
entries_field,
OffsetBuffer::new(ScalarBuffer::from(vec![0, 1, 2, 3])),
map_entries,
None, false,
));
let struct_array = StructArray::new(
Fields::from(vec![
Field::new("id", DataType::Int32, true),
Field::new("props", map_array.data_type().clone(), true),
]),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
map_array,
],
Some(NullBuffer::from(vec![true, false, true])), );
let test_cases = TestCases::default()
.with_range(0..3)
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(
vec![Arc::new(struct_array)],
&test_cases,
HashMap::new(),
)
.await;
}
#[test_log::test(tokio::test)]
async fn test_list_of_maps() {
use arrow_array::builder::ListBuilder;
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let map_builder = MapBuilder::new(None, string_builder, int_builder);
let mut list_builder = ListBuilder::new(map_builder);
list_builder.values().keys().append_value("a");
list_builder.values().values().append_value(1);
list_builder.values().append(true).unwrap();
list_builder.values().keys().append_value("b");
list_builder.values().values().append_value(2);
list_builder.values().append(true).unwrap();
list_builder.append(true);
list_builder.values().keys().append_value("c");
list_builder.values().values().append_value(3);
list_builder.values().append(true).unwrap();
list_builder.append(true);
list_builder.append(true);
let list_array = list_builder.finish();
let test_cases = TestCases::default()
.with_range(0..3)
.with_indices(vec![0, 2])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_nested_map() {
let inner_string_builder = StringBuilder::new();
let inner_int_builder = Int32Builder::new();
let mut inner_map_builder1 = MapBuilder::new(None, inner_string_builder, inner_int_builder);
inner_map_builder1.keys().append_value("x");
inner_map_builder1.values().append_value(10);
inner_map_builder1.append(true).unwrap();
inner_map_builder1.keys().append_value("y");
inner_map_builder1.values().append_value(20);
inner_map_builder1.keys().append_value("z");
inner_map_builder1.values().append_value(30);
inner_map_builder1.append(true).unwrap();
let inner_maps = Arc::new(inner_map_builder1.finish());
let outer_keys = Arc::new(StringArray::from(vec!["key1", "key2"]));
let entries_struct = StructArray::new(
Fields::from(vec![
Field::new("key", DataType::Utf8, false),
Field::new(
"value",
make_map_type(DataType::Utf8, DataType::Int32),
true,
),
]),
vec![outer_keys, inner_maps],
None,
);
let offsets = OffsetBuffer::new(ScalarBuffer::<i32>::from(vec![0, 2]));
let entries_field = Field::new("entries", entries_struct.data_type().clone(), false);
let outer_map = MapArray::new(
Arc::new(entries_field),
offsets,
entries_struct,
None,
false,
);
let test_cases = TestCases::default()
.with_range(0..1)
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(outer_map)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_different_key_types() {
let int_builder = Int32Builder::new();
let string_builder = StringBuilder::new();
let mut map_builder = MapBuilder::new(None, int_builder, string_builder);
map_builder.keys().append_value(1);
map_builder.values().append_value("one");
map_builder.keys().append_value(2);
map_builder.values().append_value("two");
map_builder.append(true).unwrap();
map_builder.keys().append_value(3);
map_builder.values().append_value("three");
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..2)
.with_indices(vec![0, 1])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_with_extreme_sizes() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
for i in 0..100 {
map_builder.keys().append_value(format!("key{}", i));
map_builder.values().append_value(i);
}
map_builder.append(true).unwrap();
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..2)
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_all_null() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.append(false).unwrap(); map_builder.append(false).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..2)
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test_log::test(tokio::test)]
async fn test_map_encoder_keep_original_array_scenarios() {
let string_builder = StringBuilder::new();
let int_builder = Int32Builder::new();
let mut map_builder = MapBuilder::new(None, string_builder, int_builder);
map_builder.keys().append_value("key1");
map_builder.values().append_value(10);
map_builder.keys().append_value("key2");
map_builder.values().append_null();
map_builder.append(true).unwrap();
map_builder.append(false).unwrap();
map_builder.keys().append_value("key3");
map_builder.values().append_value(30);
map_builder.append(true).unwrap();
let map_array = map_builder.finish();
let test_cases = TestCases::default()
.with_range(0..3)
.with_indices(vec![0, 1, 2])
.with_min_file_version(LanceFileVersion::V2_2);
check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new())
.await;
}
#[test]
fn test_map_not_supported_write_in_v2_1() {
let map_arrow_field = ArrowField::new(
"map_field",
make_map_type(DataType::Utf8, DataType::Int32),
true,
);
let map_field = LanceField::try_from(&map_arrow_field).unwrap();
let encoder_strategy = default_encoding_strategy(LanceFileVersion::V2_1);
let mut column_index = ColumnIndexSequence::default();
let options = EncodingOptions::default();
let encoder_result = encoder_strategy.create_field_encoder(
encoder_strategy.as_ref(),
&map_field,
&mut column_index,
&options,
);
assert!(
encoder_result.is_err(),
"Map type should not be supported in V2_1 for encoder"
);
let Err(encoder_err) = encoder_result else {
panic!("Expected error but got Ok")
};
let encoder_err_msg = format!("{}", encoder_err);
assert!(
encoder_err_msg.contains("2.2"),
"Encoder error message should mention version 2.2, got: {}",
encoder_err_msg
);
assert!(
encoder_err_msg.contains("Map data type"),
"Encoder error message should mention Map data type, got: {}",
encoder_err_msg
);
}
}