serde_arrow 0.14.0

Convert sequences of Rust objects to Arrow arrays and back again
Documentation
use marrow::view::{BytesView, BytesViewView, PrimitiveView, View};

use crate::internal::{
    deserialization::utils::bitset_is_set,
    error::{fail, Error, ErrorKind, Result},
};

use super::Offset;

pub trait ViewExt {
    fn is_nullable(&self) -> Result<bool>;
    fn len(&self) -> Result<usize>;
}

impl ViewExt for View<'_> {
    fn is_nullable(&self) -> Result<bool> {
        use View as V;
        match self {
            V::Null(_) => Ok(true),
            V::Union(_) => Ok(false),
            V::Boolean(view) => Ok(view.validity.is_some()),
            V::Int8(view) => Ok(view.validity.is_some()),
            V::Int16(view) => Ok(view.validity.is_some()),
            V::Int32(view) => Ok(view.validity.is_some()),
            V::Int64(view) => Ok(view.validity.is_some()),
            V::UInt8(view) => Ok(view.validity.is_some()),
            V::UInt16(view) => Ok(view.validity.is_some()),
            V::UInt32(view) => Ok(view.validity.is_some()),
            V::UInt64(view) => Ok(view.validity.is_some()),
            V::Float16(view) => Ok(view.validity.is_some()),
            V::Float32(view) => Ok(view.validity.is_some()),
            V::Float64(view) => Ok(view.validity.is_some()),
            V::Date32(view) => Ok(view.validity.is_some()),
            V::Date64(view) => Ok(view.validity.is_some()),
            V::Time32(view) => Ok(view.validity.is_some()),
            V::Time64(view) => Ok(view.validity.is_some()),
            V::Timestamp(view) => Ok(view.validity.is_some()),
            V::Duration(view) => Ok(view.validity.is_some()),
            V::Decimal128(view) => Ok(view.validity.is_some()),
            V::Utf8(view) => Ok(view.validity.is_some()),
            V::Utf8View(view) => Ok(view.validity.is_some()),
            V::LargeUtf8(view) => Ok(view.validity.is_some()),
            V::Binary(view) => Ok(view.validity.is_some()),
            V::LargeBinary(view) => Ok(view.validity.is_some()),
            V::BinaryView(view) => Ok(view.validity.is_some()),
            V::FixedSizeBinary(view) => Ok(view.validity.is_some()),
            V::FixedSizeList(view) => Ok(view.validity.is_some()),
            V::List(view) => Ok(view.validity.is_some()),
            V::LargeList(view) => Ok(view.validity.is_some()),
            V::Map(view) => Ok(view.validity.is_some()),
            V::Struct(view) => Ok(view.validity.is_some()),
            V::Dictionary(view) => view.keys.is_nullable(),
            _ => fail!("Unknown view type"),
        }
    }

    fn len(&self) -> Result<usize> {
        use View as V;
        match self {
            V::Null(view) => Ok(view.len),
            V::Boolean(view) => Ok(view.len),
            V::Int8(view) => Ok(view.values.len()),
            V::Int16(view) => Ok(view.values.len()),
            V::Int32(view) => Ok(view.values.len()),
            V::Int64(view) => Ok(view.values.len()),
            V::UInt8(view) => Ok(view.values.len()),
            V::UInt16(view) => Ok(view.values.len()),
            V::UInt32(view) => Ok(view.values.len()),
            V::UInt64(view) => Ok(view.values.len()),
            V::Float16(view) => Ok(view.values.len()),
            V::Float32(view) => Ok(view.values.len()),
            V::Float64(view) => Ok(view.values.len()),
            V::Date32(view) => Ok(view.values.len()),
            V::Date64(view) => Ok(view.values.len()),
            V::Time32(view) => Ok(view.values.len()),
            V::Time64(view) => Ok(view.values.len()),
            V::Timestamp(view) => Ok(view.values.len()),
            V::Duration(view) => Ok(view.values.len()),
            V::Decimal128(view) => Ok(view.values.len()),
            V::Utf8(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::Utf8View(view) => Ok(view.data.len()),
            V::LargeUtf8(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::Binary(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::LargeBinary(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::BinaryView(view) => Ok(view.data.len()),
            V::FixedSizeBinary(view) => match usize::try_from(view.n) {
                Ok(n) if n > 0 => Ok(view.data.len() / n),
                _ => Ok(0),
            },
            V::FixedSizeList(view) => Ok(view.len),
            V::List(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::LargeList(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::Union(view) => Ok(view.types.len()),
            V::Map(view) => Ok(view.offsets.len().saturating_sub(1)),
            V::Struct(view) => Ok(view.len),
            V::Dictionary(view) => view.keys.len(),
            _ => fail!("Unknown view type"),
        }
    }
}

pub trait ViewAccess<'a, Item: ?Sized + 'a> {
    fn get(&self, idx: usize) -> Result<Option<&'a Item>>;

    fn get_required(&self, idx: usize) -> Result<&'a Item>
    where
        Self: 'a,
    {
        if let Some(val) = self.get(idx)? {
            Ok(val)
        } else {
            Err(Error::new(
                ErrorKind::NullabilityViolation { field: None },
                "Required item was not present".into(),
            ))
        }
    }

    fn is_some(&self, idx: usize) -> Result<bool>
    where
        Self: 'a,
    {
        Ok(self.get(idx)?.is_some())
    }
}

impl<'a, T> ViewAccess<'a, T> for PrimitiveView<'a, T> {
    fn get(&self, idx: usize) -> Result<Option<&'a T>> {
        if let Some(value) = self.values.get(idx) {
            if let Some(validity) = self.validity.as_ref() {
                if !bitset_is_set(validity, idx)? {
                    return Ok(None);
                }
            }
            Ok(Some(value))
        } else {
            fail!("Access beyond array length");
        }
    }
}

impl<'a, O: Offset> ViewAccess<'a, [u8]> for BytesView<'a, O> {
    fn get(&self, idx: usize) -> Result<Option<&'a [u8]>> {
        if idx + 1 > self.offsets.len() {
            fail!(
                "Invalid access: tried to get element {idx} of array with {len} elements",
                len = self.offsets.len().saturating_sub(1)
            );
        }

        if let Some(validity) = &self.validity {
            if !bitset_is_set(validity, idx)? {
                return Ok(None);
            }
        }

        let start = self.offsets[idx].try_into_usize()?;
        let end = self.offsets[idx + 1].try_into_usize()?;
        Ok(Some(&self.data[start..end]))
    }
}

impl<'a> ViewAccess<'a, [u8]> for BytesViewView<'a> {
    fn get(&self, idx: usize) -> Result<Option<&'a [u8]>> {
        let Some(desc) = self.data.get(idx) else {
            fail!(
                "Invalid access: tried to get element {idx} of array with {len} elements",
                len = self.data.len()
            );
        };

        if let Some(validity) = &self.validity {
            if !bitset_is_set(validity, idx)? {
                return Ok(None);
            }
        }

        let len = (*desc as u32) as usize;
        let res = || -> Option<&'a [u8]> {
            if len <= 12 {
                let bytes: &[u8] = bytemuck::try_cast_slice(std::slice::from_ref(desc)).ok()?;
                bytes.get(4..4 + len)
            } else {
                let buf_idx = ((*desc >> 64) as u32) as usize;
                let offset = ((*desc >> 96) as u32) as usize;
                self.buffers.get(buf_idx)?.get(offset..offset + len)
            }
        }();

        if res.is_none() {
            fail!("invalid state in bytes deserialization");
        }
        Ok(res)
    }
}

impl<'a, V> ViewAccess<'a, str> for V
where
    V: ViewAccess<'a, [u8]>,
{
    fn get(&self, idx: usize) -> Result<Option<&'a str>> {
        match ViewAccess::<[u8]>::get(self, idx) {
            Ok(Some(data)) => Ok(Some(std::str::from_utf8(data)?)),
            Ok(None) => Ok(None),
            Err(err) => Err(err),
        }
    }
}