lumen-io 0.5.0

A tiny ML framework
Documentation
use std::io::{Read, Write};
use lumen_core::{DType, DynTensor, Shape, Tensor};

pub fn write_tensor<W: Write>(writer: &mut W, tensor: &DynTensor) -> lumen_core::Result<()> {
    match tensor {
        DynTensor::Bool(t) => {
            for b in t.iter()? {
                writer.write_all(&[if b {1} else {0}])?;
            }
        }
        DynTensor::U8(t) => {
            for b in t.iter()? {
                writer.write_all(&[b])?;
            }
        }
        DynTensor::F32(t) => {
            for v in t.iter()? {
                writer.write_all(&f32::to_le_bytes(v))?;
            }
        }
        DynTensor::F64(t) => {
            for v in t.iter()? {
                writer.write_all(&f64::to_le_bytes(v))?;
            }
        }
        DynTensor::U32(t) => {
            for v in t.iter()? {
                writer.write_all(&u32::to_le_bytes(v))?;
            }
        }
        DynTensor::I32(t) => {
            for v in t.iter()? {
                writer.write_all(&i32::to_le_bytes(v))?;
            }
        }
    }

    Ok(())
}

pub fn load_tensor_reader<R: Read>(dtype: DType, shape: impl Into<Shape>, reader: &mut R) -> lumen_core::Result<DynTensor> {
    let shape: Shape = shape.into();
    let element_count = shape.element_count();
    let type_size = dtype.size_of();
    let mut bytes = vec![0u8; element_count * type_size];
    reader.read_exact(&mut bytes)?;
    load_tensor(dtype, shape, &bytes)
}

pub fn load_tensor(dtype: DType, shape: impl Into<Shape>, bytes: &[u8]) -> lumen_core::Result<DynTensor> {
    let shape: Shape = shape.into();
    let element_count = shape.element_count();
    let type_size = dtype.size_of();

    if bytes.len() != element_count * type_size {
        return Err(lumen_core::Error::Msg(format!(
            "Byte length mismatch. Expected {}, got {}", 
            element_count * type_size, 
            bytes.len()
        )))?;
    }

    match dtype {
        DType::Bool => {
            let data: Vec<bool> = bytes.iter().map(|&b| b != 0).collect();
            Ok(DynTensor::Bool(Tensor::<bool>::from_vec(data, shape)?))
        }
        DType::U8 => {
            let data: Vec<u8> = bytes.to_vec();
            Ok(DynTensor::U8(Tensor::<u8>::from_vec(data, shape)?))
        }
        DType::F32 => {
            let data: Vec<f32> = bytes
                .chunks_exact(4) 
                .map(|chunk| {
                    let arr: [u8; 4] = chunk.try_into().unwrap(); 
                    f32::from_le_bytes(arr)
                })
                .collect();
            Ok(DynTensor::F32(Tensor::<f32>::from_vec(data, shape)?))
        }
        DType::F64 => {
            let data: Vec<f64> = bytes
                .chunks_exact(8)
                .map(|chunk| {
                    let arr: [u8; 8] = chunk.try_into().unwrap();
                    f64::from_le_bytes(arr)
                })
                .collect();
            Ok(DynTensor::F64(Tensor::<f64>::from_vec(data, shape)?))
        }
        DType::I32 => {
            let data: Vec<i32> = bytes
                .chunks_exact(4)
                .map(|chunk| {
                    let arr: [u8; 4] = chunk.try_into().unwrap();
                    i32::from_le_bytes(arr)
                })
                .collect();
            Ok(DynTensor::I32(Tensor::<i32>::from_vec(data, shape)?))
        }
        DType::U32 => {
            let data: Vec<u32> = bytes
                .chunks_exact(4)
                .map(|chunk| {
                    let arr: [u8; 4] = chunk.try_into().unwrap();
                    u32::from_le_bytes(arr)
                })
                .collect();
            Ok(DynTensor::U32(Tensor::<u32>::from_vec(data, shape)?))
        }
    }
}