use std::marker::PhantomData;
use arrow::array::{Array, ArrayRef, MapArray, StructArray};
use arrow::datatypes::ArrowNativeType as _;
use arrow::datatypes::{DataType, Field, Fields};
use crate::datatype::{ColumnError, InfallibleBuild, LogicalType, downcast_array};
pub struct Map<K, V> {
_marker: PhantomData<fn() -> (K, V)>,
}
pub struct TypedMap<K: LogicalType, V: LogicalType> {
map: MapArray,
keys: K::Typed,
values: V::Typed,
}
impl<K: LogicalType, V: LogicalType> Clone for TypedMap<K, V> {
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
keys: self.keys.clone(),
values: self.values.clone(),
}
}
}
fn entry_fields<K: crate::ConcreteType, V: crate::ConcreteType>() -> Fields {
Fields::from(vec![
Field::new("keys", K::datatype(), false),
Field::new("values", V::datatype(), V::NULLABLE),
])
}
impl<K: LogicalType + 'static, V: LogicalType + 'static> LogicalType for Map<K, V> {
type Typed = TypedMap<K, V>;
type Value<'a>
= MapValue<'a, K, V>
where
Self: 'a;
type Owned = Vec<(K::Owned, V::Owned)>;
fn downcast(array: &dyn Array) -> Result<Self::Typed, ColumnError> {
let map = downcast_array::<MapArray>(array, || "Map(…)".to_owned())?;
if !K::NULLABLE {
let null_count = logical_child_null_count(&map, map.keys());
if 0 < null_count {
return Err(ColumnError::UnexpectedNulls { null_count });
}
}
if !V::NULLABLE {
let null_count = logical_child_null_count(&map, map.values());
if 0 < null_count {
return Err(ColumnError::UnexpectedNulls { null_count });
}
}
let keys = K::downcast(&**map.keys())?;
let values = V::downcast(&**map.values())?;
Ok(TypedMap { map, keys, values })
}
#[inline]
fn is_null(typed: &Self::Typed, index: usize) -> bool {
typed.map.is_null(index)
}
#[inline]
unsafe fn is_null_unchecked(typed: &Self::Typed, index: usize) -> bool {
unsafe { crate::datatype::leaf_is_null_unchecked(&typed.map, index) }
}
#[inline]
fn value(typed: &Self::Typed, index: usize) -> Self::Value<'_> {
let offsets = typed.map.value_offsets();
MapValue {
keys: &typed.keys,
values: &typed.values,
index: offsets[index].as_usize(),
end: offsets[index + 1].as_usize(),
}
}
fn to_owned_value(value: Self::Value<'_>) -> Self::Owned {
value
.map(|(key, val)| (K::to_owned_value(key), V::to_owned_value(val)))
.collect()
}
}
impl<K: crate::ConcreteType + 'static, V: crate::ConcreteType + 'static> crate::ConcreteType
for Map<K, V>
{
fn datatype() -> DataType {
DataType::Map(
std::sync::Arc::new(Field::new(
"entries",
DataType::Struct(entry_fields::<K, V>()),
false,
)),
false,
)
}
fn build(values: impl Iterator<Item = Option<Self::Owned>>) -> Result<ArrayRef, ColumnError> {
let mut lengths = Vec::new();
let mut validity = Vec::new();
let mut flat_keys = Vec::new();
let mut flat_values = Vec::new();
for entries in values {
if let Some(pairs) = entries {
lengths.push(pairs.len());
validity.push(true);
for (key, value) in pairs {
flat_keys.push(key);
flat_values.push(value);
}
} else {
lengths.push(0);
validity.push(false);
}
}
let fields = entry_fields::<K, V>();
let keys_array = K::build(flat_keys.into_iter().map(Some))?;
let values_array = V::build(flat_values.into_iter().map(Some))?;
let entries = StructArray::try_new(fields.clone(), vec![keys_array, values_array], None)
.map_err(ColumnError::Build)?;
let offsets = arrow::buffer::OffsetBuffer::<i32>::from_lengths(lengths);
let nulls = validity
.contains(&false)
.then(|| arrow::buffer::NullBuffer::from(validity));
let entries_field =
std::sync::Arc::new(Field::new("entries", DataType::Struct(fields), false));
let map = MapArray::try_new(entries_field, offsets, entries, nulls, false)
.map_err(ColumnError::Build)?;
Ok(std::sync::Arc::new(map))
}
}
impl<K: InfallibleBuild + 'static, V: InfallibleBuild + 'static> InfallibleBuild for Map<K, V> {}
pub struct MapValue<'a, K: LogicalType, V: LogicalType> {
keys: &'a K::Typed,
values: &'a V::Typed,
index: usize,
end: usize,
}
impl<'a, K: LogicalType + 'a, V: LogicalType + 'a> Iterator for MapValue<'a, K, V> {
type Item = (K::Value<'a>, V::Value<'a>);
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.end {
let pair = (
K::value(self.keys, self.index),
V::value(self.values, self.index),
);
self.index += 1;
Some(pair)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.end - self.index;
(remaining, Some(remaining))
}
}
impl<'a, K: LogicalType + 'a, V: LogicalType + 'a> ExactSizeIterator for MapValue<'a, K, V> {}
fn logical_child_null_count(map: &MapArray, child: &dyn Array) -> usize {
let Some(child_nulls) = child.nulls() else {
return 0;
};
let offsets = map.value_offsets();
let window_start = offsets[0].as_usize();
let window_end = offsets[map.len()].as_usize();
if child_nulls
.slice(window_start, window_end - window_start)
.null_count()
== 0
{
return 0; }
match map.nulls() {
None => child_nulls
.slice(window_start, window_end - window_start)
.null_count(),
Some(row_validity) => (0..map.len())
.filter(|&row| row_validity.is_valid(row))
.map(|row| {
let start = offsets[row].as_usize();
let end = offsets[row + 1].as_usize();
if start == end {
0
} else {
child_nulls.slice(start, end - start).null_count()
}
})
.sum(),
}
}