use crate::builder::{ArrayBuilder, BufferBuilder};
use crate::{Array, ArrayRef, MapArray, StructArray};
use arrow_buffer::Buffer;
use arrow_buffer::{NullBuffer, NullBufferBuilder};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, FieldRef};
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
pub struct MapBuilder<K: ArrayBuilder, V: ArrayBuilder> {
offsets_builder: BufferBuilder<i32>,
null_buffer_builder: NullBufferBuilder,
field_names: MapFieldNames,
key_builder: K,
value_builder: V,
key_field: Option<FieldRef>,
value_field: Option<FieldRef>,
}
#[derive(Debug, Clone)]
pub struct MapFieldNames {
pub entry: String,
pub key: String,
pub value: String,
}
impl Default for MapFieldNames {
fn default() -> Self {
Self {
entry: "entries".to_string(),
key: "keys".to_string(),
value: "values".to_string(),
}
}
}
impl<K: ArrayBuilder, V: ArrayBuilder> MapBuilder<K, V> {
pub fn new(field_names: Option<MapFieldNames>, key_builder: K, value_builder: V) -> Self {
let capacity = key_builder.len();
Self::with_capacity(field_names, key_builder, value_builder, capacity)
}
pub fn with_capacity(
field_names: Option<MapFieldNames>,
key_builder: K,
value_builder: V,
capacity: usize,
) -> Self {
let mut offsets_builder = BufferBuilder::<i32>::new(capacity + 1);
offsets_builder.append(0);
Self {
offsets_builder,
null_buffer_builder: NullBufferBuilder::new(capacity),
field_names: field_names.unwrap_or_default(),
key_builder,
value_builder,
key_field: None,
value_field: None,
}
}
pub fn with_keys_field(self, field: impl Into<FieldRef>) -> Self {
Self {
key_field: Some(field.into()),
..self
}
}
pub fn with_values_field(self, field: impl Into<FieldRef>) -> Self {
Self {
value_field: Some(field.into()),
..self
}
}
pub fn keys(&mut self) -> &mut K {
&mut self.key_builder
}
pub fn values(&mut self) -> &mut V {
&mut self.value_builder
}
pub fn entries(&mut self) -> (&mut K, &mut V) {
(&mut self.key_builder, &mut self.value_builder)
}
#[inline]
pub fn append(&mut self, is_valid: bool) -> Result<(), ArrowError> {
if self.key_builder.len() != self.value_builder.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Cannot append to a map builder when its keys and values have unequal lengths of {} and {}",
self.key_builder.len(),
self.value_builder.len()
)));
}
self.offsets_builder.append(self.key_builder.len() as i32);
self.null_buffer_builder.append(is_valid);
Ok(())
}
pub fn finish(&mut self) -> MapArray {
let len = self.len();
let keys_arr = self.key_builder.finish();
let values_arr = self.value_builder.finish();
let offset_buffer = self.offsets_builder.finish();
self.offsets_builder.append(0);
let null_bit_buffer = self.null_buffer_builder.finish();
self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len)
}
pub fn finish_cloned(&self) -> MapArray {
let len = self.len();
let keys_arr = self.key_builder.finish_cloned();
let values_arr = self.value_builder.finish_cloned();
let offset_buffer = Buffer::from_slice_ref(self.offsets_builder.as_slice());
let nulls = self.null_buffer_builder.finish_cloned();
self.finish_helper(keys_arr, values_arr, offset_buffer, nulls, len)
}
fn finish_helper(
&self,
keys_arr: Arc<dyn Array>,
values_arr: Arc<dyn Array>,
offset_buffer: Buffer,
nulls: Option<NullBuffer>,
len: usize,
) -> MapArray {
assert!(
keys_arr.null_count() == 0,
"Keys array must have no null values, found {} null value(s)",
keys_arr.null_count()
);
let keys_field = match &self.key_field {
Some(f) => {
assert!(!f.is_nullable(), "Keys field must not be nullable");
f.clone()
}
None => Arc::new(Field::new(
self.field_names.key.as_str(),
keys_arr.data_type().clone(),
false, )),
};
let values_field = match &self.value_field {
Some(f) => f.clone(),
None => Arc::new(Field::new(
self.field_names.value.as_str(),
values_arr.data_type().clone(),
true,
)),
};
let struct_array =
StructArray::from(vec![(keys_field, keys_arr), (values_field, values_arr)]);
let map_field = Arc::new(Field::new(
self.field_names.entry.as_str(),
struct_array.data_type().clone(),
false, ));
let array_data = ArrayData::builder(DataType::Map(map_field, false)) .len(len)
.add_buffer(offset_buffer)
.add_child_data(struct_array.into_data())
.nulls(nulls);
let array_data = unsafe { array_data.build_unchecked() };
MapArray::from(array_data)
}
pub fn validity_slice(&self) -> Option<&[u8]> {
self.null_buffer_builder.as_slice()
}
}
impl<K: ArrayBuilder, V: ArrayBuilder> ArrayBuilder for MapBuilder<K, V> {
fn len(&self) -> usize {
self.null_buffer_builder.len()
}
fn finish(&mut self) -> ArrayRef {
Arc::new(self.finish())
}
fn finish_cloned(&self) -> ArrayRef {
Arc::new(self.finish_cloned())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::{make_builder, Int32Builder, StringBuilder};
use crate::{Int32Array, StringArray};
use std::collections::HashMap;
#[test]
#[should_panic(expected = "Keys array must have no null values, found 1 null value(s)")]
fn test_map_builder_with_null_keys_panics() {
let mut builder = MapBuilder::new(None, StringBuilder::new(), Int32Builder::new());
builder.keys().append_null();
builder.values().append_value(42);
builder.append(true).unwrap();
builder.finish();
}
#[test]
fn test_boxed_map_builder() {
let keys_builder = make_builder(&DataType::Utf8, 5);
let values_builder = make_builder(&DataType::Int32, 5);
let mut builder = MapBuilder::new(None, keys_builder, values_builder);
builder
.keys()
.as_any_mut()
.downcast_mut::<StringBuilder>()
.expect("should be an StringBuilder")
.append_value("1");
builder
.values()
.as_any_mut()
.downcast_mut::<Int32Builder>()
.expect("should be an Int32Builder")
.append_value(42);
builder.append(true).unwrap();
let map_array = builder.finish();
assert_eq!(
map_array
.keys()
.as_any()
.downcast_ref::<StringArray>()
.expect("should be an StringArray")
.value(0),
"1"
);
assert_eq!(
map_array
.values()
.as_any()
.downcast_ref::<Int32Array>()
.expect("should be an Int32Array")
.value(0),
42
);
}
#[test]
fn test_with_values_field() {
let value_field = Arc::new(Field::new("bars", DataType::Int32, false));
let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
.with_values_field(value_field.clone());
builder.keys().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
builder.append(false).unwrap(); builder.keys().append_value(3);
builder.values().append_value(4);
builder.append(true).unwrap();
let map = builder.finish();
assert_eq!(map.len(), 3);
assert_eq!(
map.data_type(),
&DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Arc::new(Field::new("keys", DataType::Int32, false)),
value_field.clone()
]
.into()
),
false,
)),
false
)
);
builder.keys().append_value(5);
builder.values().append_value(6);
builder.append(true).unwrap();
let map = builder.finish();
assert_eq!(map.len(), 1);
assert_eq!(
map.data_type(),
&DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Arc::new(Field::new("keys", DataType::Int32, false)),
value_field
]
.into()
),
false,
)),
false
)
);
}
#[test]
fn test_with_keys_field() {
let mut key_metadata = HashMap::new();
key_metadata.insert("foo".to_string(), "bar".to_string());
let key_field = Arc::new(
Field::new("keys", DataType::Int32, false).with_metadata(key_metadata.clone()),
);
let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
.with_keys_field(key_field.clone());
builder.keys().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
let map = builder.finish();
assert_eq!(map.len(), 1);
assert_eq!(
map.data_type(),
&DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Arc::new(
Field::new("keys", DataType::Int32, false)
.with_metadata(key_metadata)
),
Arc::new(Field::new("values", DataType::Int32, true))
]
.into()
),
false,
)),
false
)
);
}
#[test]
#[should_panic(expected = "Keys field must not be nullable")]
fn test_with_nullable_keys_field() {
let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
.with_keys_field(Arc::new(Field::new("keys", DataType::Int32, true)));
builder.keys().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
builder.finish();
}
#[test]
#[should_panic(expected = "Incorrect datatype")]
fn test_keys_field_type_mismatch() {
let mut builder = MapBuilder::new(None, Int32Builder::new(), Int32Builder::new())
.with_keys_field(Arc::new(Field::new("keys", DataType::Utf8, false)));
builder.keys().append_value(1);
builder.values().append_value(2);
builder.append(true).unwrap();
builder.finish();
}
}