serde_arrow 0.14.0

Convert sequences of Rust objects to Arrow arrays and back again
Documentation
use marrow::view::{BitsWithOffset, StructView};
use serde::de::{value::StrDeserializer, DeserializeSeed, MapAccess, SeqAccess, Visitor};

use crate::internal::{
    error::{fail, set_default, Context, ContextSupport, Error, Result},
    schema::get_strategy_from_metadata,
    utils::ChildName,
};

use super::{
    array_deserializer::ArrayDeserializer, random_access_deserializer::RandomAccessDeserializer,
    utils::bitset_is_set,
};

pub struct StructDeserializer<'a> {
    pub path: String,
    pub fields: Vec<(String, ArrayDeserializer<'a>)>,
    pub validity: Option<BitsWithOffset<'a>>,
    pub len: usize,
}

impl<'a> StructDeserializer<'a> {
    pub fn new(path: String, view: StructView<'a>) -> Result<Self> {
        let mut fields = Vec::new();
        for (field_meta, field_view) in view.fields {
            let child_path = format!("{path}.{child}", child = ChildName(&field_meta.name));
            let field_deserializer = ArrayDeserializer::new(
                child_path,
                get_strategy_from_metadata(&field_meta.metadata)?.as_ref(),
                field_view,
            )?;
            let field_name = field_meta.name;

            fields.push((field_name, field_deserializer));
        }

        Ok(Self::from_parts(path, fields, view.validity, view.len))
    }

    pub fn from_parts(
        path: String,
        fields: Vec<(String, ArrayDeserializer<'a>)>,
        validity: Option<BitsWithOffset<'a>>,
        len: usize,
    ) -> Self {
        Self {
            path,
            fields,
            validity,
            len,
        }
    }
}

impl Context for StructDeserializer<'_> {
    fn annotate(&self, annotations: &mut std::collections::BTreeMap<String, String>) {
        set_default(annotations, "field", &self.path);
        set_default(annotations, "data_type", "Struct");
    }
}

impl<'de> RandomAccessDeserializer<'de> for StructDeserializer<'de> {
    fn is_some(&self, idx: usize) -> Result<bool> {
        if idx >= self.len {
            fail!("Out of bounds access");
        }
        if let Some(validity) = self.validity.as_ref() {
            Ok(bitset_is_set(validity, idx)?)
        } else {
            Ok(true)
        }
    }

    fn deserialize_any_some<V: Visitor<'de>>(&self, visitor: V, idx: usize) -> Result<V::Value> {
        if idx >= self.len {
            fail!("Exhausted deserializer");
        }
        visitor.visit_map(StructItemDeserializer::new(self, idx))
    }

    fn deserialize_map<V: Visitor<'de>>(&self, visitor: V, idx: usize) -> Result<V::Value> {
        visitor
            .visit_map(StructItemDeserializer::new(self, idx))
            .ctx(self)
    }

    fn deserialize_struct<V: Visitor<'de>>(
        &self,
        _: &'static str,
        _: &'static [&'static str],
        visitor: V,
        idx: usize,
    ) -> Result<V::Value> {
        visitor
            .visit_map(StructItemDeserializer::new(self, idx))
            .ctx(self)
    }

    fn deserialize_tuple<V: Visitor<'de>>(
        &self,
        _: usize,
        visitor: V,
        idx: usize,
    ) -> Result<V::Value> {
        visitor
            .visit_seq(StructItemDeserializer::new(self, idx))
            .ctx(self)
    }

    fn deserialize_tuple_struct<V: Visitor<'de>>(
        &self,
        _: &'static str,
        _: usize,
        visitor: V,
        idx: usize,
    ) -> Result<V::Value> {
        visitor
            .visit_seq(StructItemDeserializer::new(self, idx))
            .ctx(self)
    }
}

struct StructItemDeserializer<'a, 'de> {
    deserializer: &'a StructDeserializer<'de>,
    item: usize,
    field: usize,
}

impl<'a, 'de> StructItemDeserializer<'a, 'de> {
    pub fn new(deserializer: &'a StructDeserializer<'de>, item: usize) -> Self {
        Self {
            deserializer,
            item,
            field: 0,
        }
    }
}

impl<'de> MapAccess<'de> for StructItemDeserializer<'_, 'de> {
    type Error = Error;

    fn next_key_seed<K: DeserializeSeed<'de>>(&mut self, seed: K) -> Result<Option<K::Value>> {
        let Some((field_name, _)) = self.deserializer.fields.get(self.field) else {
            return Ok(None);
        };

        let key = seed.deserialize(StrDeserializer::<Error>::new(field_name))?;
        Ok(Some(key))
    }

    fn next_value_seed<V: DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value> {
        let Some((_, field_deserializer)) = self.deserializer.fields.get(self.field) else {
            fail!("Invalid state in struct deserializer");
        };

        let res = seed.deserialize(field_deserializer.at(self.item))?;
        self.field += 1;

        Ok(res)
    }
}

impl<'de> SeqAccess<'de> for StructItemDeserializer<'_, 'de> {
    type Error = Error;

    fn next_element_seed<T: DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Value>> {
        let Some((_, field_deserializer)) = self.deserializer.fields.get(self.field) else {
            return Ok(None);
        };

        let res = seed.deserialize(field_deserializer.at(self.item))?;
        self.field += 1;

        Ok(Some(res))
    }
}