murr 0.2.0-rc2

Columnar in-memory cache for AI/ML inference workloads
Documentation
use std::sync::Arc;

use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, PrimitiveBuilder};
use bytemuck::{NoUninit, Pod};

use crate::{
    core::MurrError,
    io::{
        column::{ColumnDecoder, ColumnEncoder, downcast},
        row::{read::ReadRow, write::WriteRow},
        schema::SegmentColumnSchema,
    },
};

pub struct PrimitiveEncoder<T: ArrowPrimitiveType>
where
    T::Native: Pod,
{
    column: SegmentColumnSchema,
    builder: PrimitiveBuilder<T>,
}

impl<T> PrimitiveEncoder<T>
where
    T: ArrowPrimitiveType,
    T::Native: Pod,
{
    pub fn new(column: SegmentColumnSchema, rows: usize) -> Self {
        Self {
            column,
            builder: PrimitiveBuilder::<T>::with_capacity(rows),
        }
    }
}

impl<T> ColumnEncoder for PrimitiveEncoder<T>
where
    T: ArrowPrimitiveType + 'static,
    T::Native: Pod,
{
    fn add_row(&mut self, row: &ReadRow) -> Result<(), MurrError> {
        if row.is_null(&self.column) {
            self.builder.append_null();
        } else {
            self.builder
                .append_value(row.read_static::<T::Native>(&self.column));
        }
        Ok(())
    }

    fn add_empty(&mut self) -> Result<(), MurrError> {
        self.builder.append_null();
        Ok(())
    }

    fn build(&mut self) -> ArrayRef {
        Arc::new(self.builder.finish())
    }
}

pub struct PrimitiveDecoder<T: ArrowPrimitiveType>
where
    T::Native: NoUninit,
{
    column: SegmentColumnSchema,
    array: PrimitiveArray<T>,
}

impl<T> PrimitiveDecoder<T>
where
    T: ArrowPrimitiveType + 'static,
    T::Native: NoUninit,
{
    pub fn new(column: SegmentColumnSchema, array: &dyn Array) -> Result<Self, MurrError> {
        let typed = downcast::<PrimitiveArray<T>>(array, &format!("{:?}", T::DATA_TYPE))?;
        Ok(Self {
            column,
            array: typed.clone(),
        })
    }
}

impl<T> ColumnDecoder for PrimitiveDecoder<T>
where
    T: ArrowPrimitiveType + 'static,
    T::Native: NoUninit,
{
    fn write_to_row(&self, index: usize, row: &mut WriteRow) {
        if !self.array.is_null(index) {
            row.write_static(&self.column, self.array.value(index));
        }
    }
}

#[cfg(test)]
mod tests {
    use arrow::array::{Float32Array, Float64Array, StringArray};
    use arrow::datatypes::{Float32Type, Float64Type};

    use super::*;
    use crate::{
        core::DType,
        io::{row::read::ReadRow, schema::SegmentSchema},
    };

    fn single(dtype: DType) -> (SegmentSchema, SegmentColumnSchema) {
        let c = SegmentColumnSchema {
            index: 0,
            dtype,
            name: "v".into(),
            offset: 0,
        };
        (SegmentSchema::new(std::slice::from_ref(&c)), c)
    }

    #[test]
    fn f32_roundtrip_with_nulls_and_nan() {
        let (schema, c) = single(DType::Float32);
        let input = Float32Array::from(vec![
            Some(1.5_f32),
            None,
            Some(-2.5),
            Some(0.0),
            Some(f32::NAN),
        ]);

        let dec = PrimitiveDecoder::<Float32Type>::new(c.clone(), &input).unwrap();
        let bufs: Vec<Vec<u8>> = (0..input.len())
            .map(|i| {
                let mut w = WriteRow::new(&schema, "");
                dec.write_to_row(i, &mut w);
                w.bytes
            })
            .collect();

        let mut enc = PrimitiveEncoder::<Float32Type>::new(c, input.len());
        for b in &bufs {
            enc.add_row(&ReadRow::new(&schema, b)).unwrap();
        }
        let out_arr = enc.build();
        let out = out_arr.as_any().downcast_ref::<Float32Array>().unwrap();

        assert_eq!(out.len(), input.len());
        for i in 0..input.len() {
            assert_eq!(out.is_null(i), input.is_null(i));
            if !input.is_null(i) {
                let (v, vb) = (input.value(i), out.value(i));
                if v.is_nan() {
                    assert!(vb.is_nan());
                } else {
                    assert_eq!(v, vb);
                }
            }
        }
    }

    #[test]
    fn f64_roundtrip() {
        let (schema, c) = single(DType::Float64);
        let input = Float64Array::from(vec![Some(1.0), Some(-1e10), None, Some(0.0)]);

        let dec = PrimitiveDecoder::<Float64Type>::new(c.clone(), &input).unwrap();
        let bufs: Vec<Vec<u8>> = (0..input.len())
            .map(|i| {
                let mut w = WriteRow::new(&schema, "");
                dec.write_to_row(i, &mut w);
                w.bytes
            })
            .collect();

        let mut enc = PrimitiveEncoder::<Float64Type>::new(c, input.len());
        for b in &bufs {
            enc.add_row(&ReadRow::new(&schema, b)).unwrap();
        }
        let out_arr = enc.build();
        assert_eq!(
            out_arr.as_any().downcast_ref::<Float64Array>().unwrap(),
            &input
        );
    }

    #[test]
    fn decoder_rejects_wrong_array_type() {
        let (_schema, c) = single(DType::Float32);
        let wrong = StringArray::from(vec!["not a float"]);
        let err = PrimitiveDecoder::<Float32Type>::new(c, &wrong);
        assert!(matches!(err, Err(MurrError::SegmentError(_))));
    }
}