serde_arrow 0.14.1

Convert sequences of Rust objects to Arrow arrays and back again
Documentation
use std::collections::{BTreeMap, HashMap};

use half::f16;
use marrow::{
    array::{Array, PrimitiveArray},
    datatypes::FieldMeta,
};
use serde::{Serialize, Serializer};

use crate::internal::{
    error::{set_default, try_, Context, ContextSupport, Result},
    serialization::utils::impl_serializer,
    utils::array_ext::{ArrayExt, ScalarArrayExt},
};

use super::array_builder::ArrayBuilder;

pub trait FloatPrimitive: Sized + Copy + Default + 'static {
    const ARRAY_BUILDER_VARIANT: fn(FloatBuilder<Self>) -> ArrayBuilder;
    const ARRAY_VARIANT: fn(PrimitiveArray<Self>) -> Array;
    const NAME: &'static str;

    fn from_i8(value: i8) -> Self;
    fn from_i16(value: i16) -> Self;
    fn from_i32(value: i32) -> Self;
    fn from_i64(value: i64) -> Self;
    fn from_u8(value: u8) -> Self;
    fn from_u16(value: u16) -> Self;
    fn from_u32(value: u32) -> Self;
    fn from_u64(value: u64) -> Self;
    fn from_f32(value: f32) -> Self;
    fn from_f64(value: f64) -> Self;
    fn from_str(value: &str) -> Result<Self>;
}

#[derive(Debug, Clone)]
pub struct FloatBuilder<I> {
    pub name: String,
    array: PrimitiveArray<I>,
    metadata: HashMap<String, String>,
}

impl<F: FloatPrimitive> FloatBuilder<F> {
    pub fn new(name: String, is_nullable: bool, metadata: HashMap<String, String>) -> Self {
        Self {
            name,
            array: PrimitiveArray::new(is_nullable),
            metadata,
        }
    }

    pub fn is_nullable(&self) -> bool {
        self.array.is_nullable()
    }

    pub fn reserve(&mut self, len: usize) {
        self.array.reserve(len);
    }

    pub fn take(&mut self) -> ArrayBuilder {
        F::ARRAY_BUILDER_VARIANT(Self {
            name: self.name.clone(),
            metadata: self.metadata.clone(),
            array: self.array.take(),
        })
    }

    pub fn into_array_and_field_meta(self) -> Result<(Array, FieldMeta)> {
        let meta = FieldMeta {
            name: self.name,
            metadata: self.metadata,
            nullable: self.array.is_nullable(),
        };
        Ok((F::ARRAY_VARIANT(self.array), meta))
    }

    pub fn serialize_default_value(&mut self) -> Result<()> {
        try_(|| self.array.push_scalar_default()).ctx(self)
    }

    pub fn serialize_value<V: Serialize>(&mut self, value: V) -> Result<()> {
        value.serialize(&mut *self).ctx(self)
    }
}

impl<F: FloatPrimitive> Context for FloatBuilder<F> {
    fn annotate(&self, annotations: &mut BTreeMap<String, String>) {
        set_default(annotations, "field", &self.name);
        set_default(annotations, "data_type", F::NAME);
    }
}

impl<'a, F: FloatPrimitive> Serializer for &'a mut FloatBuilder<F> {
    impl_serializer!(
        'a, FloatBuilder;
        override serialize_none,
        override serialize_i8,
        override serialize_i16,
        override serialize_i32,
        override serialize_i64,
        override serialize_u8,
        override serialize_u16,
        override serialize_u32,
        override serialize_u64,
        override serialize_f32,
        override serialize_f64,
        override serialize_str,
    );

    fn serialize_none(self) -> Result<()> {
        self.array.push_scalar_none()
    }

    fn serialize_i8(self, v: i8) -> Result<()> {
        self.array.push_scalar_value(F::from_i8(v))
    }

    fn serialize_i16(self, v: i16) -> Result<()> {
        self.array.push_scalar_value(F::from_i16(v))
    }

    fn serialize_i32(self, v: i32) -> Result<()> {
        self.array.push_scalar_value(F::from_i32(v))
    }

    fn serialize_i64(self, v: i64) -> Result<()> {
        self.array.push_scalar_value(F::from_i64(v))
    }

    fn serialize_u8(self, v: u8) -> Result<()> {
        self.array.push_scalar_value(F::from_u8(v))
    }

    fn serialize_u16(self, v: u16) -> Result<()> {
        self.array.push_scalar_value(F::from_u16(v))
    }

    fn serialize_u32(self, v: u32) -> Result<()> {
        self.array.push_scalar_value(F::from_u32(v))
    }

    fn serialize_u64(self, v: u64) -> Result<()> {
        self.array.push_scalar_value(F::from_u64(v))
    }

    fn serialize_f32(self, v: f32) -> Result<()> {
        self.array.push_scalar_value(F::from_f32(v))
    }

    fn serialize_f64(self, v: f64) -> Result<()> {
        self.array.push_scalar_value(F::from_f64(v))
    }

    fn serialize_str(self, v: &str) -> Result<()> {
        self.array.push_scalar_value(F::from_str(v)?)
    }
}

impl FloatPrimitive for f16 {
    const ARRAY_BUILDER_VARIANT: fn(FloatBuilder<Self>) -> ArrayBuilder = ArrayBuilder::F16;
    const ARRAY_VARIANT: fn(PrimitiveArray<Self>) -> Array = Array::Float16;
    const NAME: &'static str = "Float16";

    fn from_i8(value: i8) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_i16(value: i16) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_i32(value: i32) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_i64(value: i64) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_u8(value: u8) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_u16(value: u16) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_u32(value: u32) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_u64(value: u64) -> Self {
        f16::from_f64(value as f64)
    }

    fn from_f32(value: f32) -> Self {
        f16::from_f32(value)
    }

    fn from_f64(value: f64) -> Self {
        f16::from_f64(value)
    }

    fn from_str(value: &str) -> Result<Self> {
        Ok(f16::from_f64(parse_float_with_underscores::<f64>(value)?))
    }
}

impl FloatPrimitive for f32 {
    const ARRAY_BUILDER_VARIANT: fn(FloatBuilder<Self>) -> ArrayBuilder = ArrayBuilder::F32;
    const ARRAY_VARIANT: fn(PrimitiveArray<Self>) -> Array = Array::Float32;
    const NAME: &'static str = "Float32";

    fn from_i8(value: i8) -> Self {
        value as f32
    }

    fn from_i16(value: i16) -> Self {
        value as f32
    }

    fn from_i32(value: i32) -> Self {
        value as f32
    }

    fn from_i64(value: i64) -> Self {
        value as f32
    }

    fn from_u8(value: u8) -> Self {
        value as f32
    }

    fn from_u16(value: u16) -> Self {
        value as f32
    }

    fn from_u32(value: u32) -> Self {
        value as f32
    }

    fn from_u64(value: u64) -> Self {
        value as f32
    }

    fn from_f32(value: f32) -> Self {
        value
    }

    fn from_f64(value: f64) -> Self {
        value as f32
    }

    fn from_str(value: &str) -> Result<Self> {
        parse_float_with_underscores(value)
    }
}

impl FloatPrimitive for f64 {
    const ARRAY_BUILDER_VARIANT: fn(FloatBuilder<Self>) -> ArrayBuilder = ArrayBuilder::F64;
    const ARRAY_VARIANT: fn(PrimitiveArray<Self>) -> Array = Array::Float64;
    const NAME: &'static str = "Float64";

    fn from_i8(value: i8) -> Self {
        value as f64
    }

    fn from_i16(value: i16) -> Self {
        value as f64
    }

    fn from_i32(value: i32) -> Self {
        value as f64
    }

    fn from_i64(value: i64) -> Self {
        value as f64
    }

    fn from_u8(value: u8) -> Self {
        value as f64
    }

    fn from_u16(value: u16) -> Self {
        value as f64
    }

    fn from_u32(value: u32) -> Self {
        value as f64
    }

    fn from_u64(value: u64) -> Self {
        value as f64
    }

    fn from_f32(value: f32) -> Self {
        value as f64
    }

    fn from_f64(value: f64) -> Self {
        value
    }

    fn from_str(value: &str) -> Result<Self> {
        parse_float_with_underscores(value)
    }
}

/// Copy the string into a temporary buffer if it contains underscores
fn parse_float_with_underscores<T>(value: &str) -> Result<T>
where
    T: std::str::FromStr<Err = std::num::ParseFloatError>,
{
    if !value.contains('_') {
        return Ok(value.parse()?);
    }

    const STACK_BUFFER_LEN: usize = 64;

    let mut stack_buffer;
    let mut heap_buffer;

    let buffer = if value.len() <= STACK_BUFFER_LEN {
        stack_buffer = [0_u8; STACK_BUFFER_LEN];
        &mut stack_buffer
    } else {
        heap_buffer = vec![0_u8; value.len()];
        heap_buffer.as_mut_slice()
    };

    let mut len = 0;
    for &byte in value.as_bytes() {
        if byte != b'_' {
            buffer[len] = byte;
            len += 1;
        }
    }

    let Ok(sanitized) = std::str::from_utf8(&buffer[..len]) else {
        unreachable!("removing _ does not make a string invalid utf8");
    };
    Ok(sanitized.parse()?)
}