mu_lib 0.2.2

XCENA mu Library
Documentation
//! Equivalent to `mu::ndarray::NDArray<T>` in C++
//! Same memory layout and API patterns

use core::mem;
use core::ptr;

pub const MAX_DIMS: usize = 10;

#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct NDArray<T> {
    data: *mut u8,
    shape: [i64; MAX_DIMS],
    strides: [i64; MAX_DIMS],
    dims: i64,
    size: i64,
    elem_size: i64,
    num_elements: i64,
    _phantom: core::marker::PhantomData<T>,
}

impl<T> Default for NDArray<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> NDArray<T> {
    /// Default constructor
    pub fn new() -> Self {
        NDArray {
            data: ptr::null_mut(),
            shape: [0; MAX_DIMS],
            strides: [0; MAX_DIMS],
            dims: 0,
            size: 0,
            elem_size: mem::size_of::<T>() as i64,
            num_elements: 0,
            _phantom: core::marker::PhantomData,
        }
    }

    /// From raw data pointer and runtime shape list
    pub fn from_raw(data: *mut T, shape_list: &[i64]) -> Self {
        let mut shape = [0i64; MAX_DIMS];
        let dims = shape_list.len().min(MAX_DIMS);
        shape[..dims].copy_from_slice(&shape_list[..dims]);

        let mut arr = NDArray {
            data: data as *mut u8,
            shape,
            strides: [0; MAX_DIMS],
            dims: dims as i64,
            size: 0,
            elem_size: mem::size_of::<T>() as i64,
            num_elements: 0,
            _phantom: core::marker::PhantomData,
        };
        arr.calculate_size_and_strides();
        arr
    }

    /// From raw bytes pointer, explicit shape array, and element size
    pub fn from_bytes(data: *mut u8, shape: [i64; MAX_DIMS], elem_size: i64) -> Self {
        let mut dims = 0;
        for &s in shape.iter() {
            if s != 0 {
                dims += 1;
            }
        }
        let mut arr = NDArray {
            data,
            shape,
            strides: [0; MAX_DIMS],
            dims,
            size: 0,
            elem_size,
            num_elements: 0,
            _phantom: core::marker::PhantomData,
        };
        arr.calculate_size_and_strides();
        arr
    }

    fn calculate_size_and_strides(&mut self) {
        // size in bytes = elem_size * product(shape[0..dims])
        let mut total = self.elem_size;
        for i in 0..self.dims as usize {
            total *= self.shape[i];
        }
        self.size = total;
        self.num_elements = total / self.elem_size;

        // row-major strides
        let mut stride = 1;
        for i in (0..self.dims as usize).rev() {
            self.strides[i] = stride;
            stride *= self.shape[i];
        }
    }

    /// Immutable data pointer as T slice
    pub fn data(&self) -> *const T {
        self.data as *const T
    }
    /// Mutable data pointer as T slice
    pub fn data_mut(&self) -> *mut T {
        self.data as *mut T
    }

    pub fn shape(&self) -> &[i64; MAX_DIMS] {
        &self.shape
    }
    pub fn strides(&self) -> &[i64; MAX_DIMS] {
        &self.strides
    }
    pub fn dims(&self) -> i64 {
        self.dims
    }
    pub fn size(&self) -> i64 {
        self.size
    }
    pub fn elem_size(&self) -> i64 {
        self.elem_size
    }
    pub fn num_elements(&self) -> i64 {
        self.num_elements
    }

    /// Compute offset in elements
    fn compute_offset(&self, idxs: &[i64]) -> i64 {
        let mut offset = 0;
        for i in 0..(self.dims as usize) {
            offset += idxs[i] * self.strides[i];
        }
        offset
    }

    /// Index without bounds checking
    pub unsafe fn get_unchecked(&self, idxs: &[i64]) -> &T {
        let offset = self.compute_offset(idxs);
        &*self
            .data
            .add(offset as usize * self.elem_size as usize)
            .cast()
    }

    pub unsafe fn get_unchecked_mut(&mut self, idxs: &[i64]) -> &mut T {
        let offset = self.compute_offset(idxs);
        &mut *self
            .data
            .add(offset as usize * self.elem_size as usize)
            .cast()
    }

    /// Index with bounds checking; returns Option<&T>
    pub fn at(&self, idxs: &[i64]) -> Option<&T> {
        if idxs.len() != self.dims as usize {
            return None;
        }
        for i in 0..idxs.len() {
            if idxs[i] < 0 || idxs[i] >= self.shape[i] {
                return None;
            }
        }
        unsafe { Some(self.get_unchecked(idxs)) }
    }

    pub fn at_mut(&mut self, idxs: &[i64]) -> Option<&mut T> {
        if idxs.len() != self.dims as usize {
            return None;
        }
        for i in 0..idxs.len() {
            if idxs[i] < 0 || idxs[i] >= self.shape[i] {
                return None;
            }
        }
        unsafe { Some(self.get_unchecked_mut(idxs)) }
    }

    /// Scalar conversion for 0-dim
    pub fn scalar(&self) -> Option<&T> {
        if self.dims == 0 {
            unsafe { Some(&*(self.data as *const T)) }
        } else {
            None
        }
    }

    pub fn scalar_mut(&mut self) -> Option<&mut T> {
        if self.dims == 0 {
            unsafe { Some(&mut *(self.data as *mut T)) }
        } else {
            None
        }
    }
}