use core::mem;
use core::ptr;
pub const MAX_DIMS: usize = 10;
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct NDArray<T> {
data: *mut u8,
shape: [i64; MAX_DIMS],
strides: [i64; MAX_DIMS],
dims: i64,
size: i64,
elem_size: i64,
num_elements: i64,
_phantom: core::marker::PhantomData<T>,
}
impl<T> Default for NDArray<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> NDArray<T> {
pub fn new() -> Self {
NDArray {
data: ptr::null_mut(),
shape: [0; MAX_DIMS],
strides: [0; MAX_DIMS],
dims: 0,
size: 0,
elem_size: mem::size_of::<T>() as i64,
num_elements: 0,
_phantom: core::marker::PhantomData,
}
}
pub fn from_raw(data: *mut T, shape_list: &[i64]) -> Self {
let mut shape = [0i64; MAX_DIMS];
let dims = shape_list.len().min(MAX_DIMS);
shape[..dims].copy_from_slice(&shape_list[..dims]);
let mut arr = NDArray {
data: data as *mut u8,
shape,
strides: [0; MAX_DIMS],
dims: dims as i64,
size: 0,
elem_size: mem::size_of::<T>() as i64,
num_elements: 0,
_phantom: core::marker::PhantomData,
};
arr.calculate_size_and_strides();
arr
}
pub fn from_bytes(data: *mut u8, shape: [i64; MAX_DIMS], elem_size: i64) -> Self {
let mut dims = 0;
for &s in shape.iter() {
if s != 0 {
dims += 1;
}
}
let mut arr = NDArray {
data,
shape,
strides: [0; MAX_DIMS],
dims,
size: 0,
elem_size,
num_elements: 0,
_phantom: core::marker::PhantomData,
};
arr.calculate_size_and_strides();
arr
}
fn calculate_size_and_strides(&mut self) {
let mut total = self.elem_size;
for i in 0..self.dims as usize {
total *= self.shape[i];
}
self.size = total;
self.num_elements = total / self.elem_size;
let mut stride = 1;
for i in (0..self.dims as usize).rev() {
self.strides[i] = stride;
stride *= self.shape[i];
}
}
pub fn data(&self) -> *const T {
self.data as *const T
}
pub fn data_mut(&self) -> *mut T {
self.data as *mut T
}
pub fn shape(&self) -> &[i64; MAX_DIMS] {
&self.shape
}
pub fn strides(&self) -> &[i64; MAX_DIMS] {
&self.strides
}
pub fn dims(&self) -> i64 {
self.dims
}
pub fn size(&self) -> i64 {
self.size
}
pub fn elem_size(&self) -> i64 {
self.elem_size
}
pub fn num_elements(&self) -> i64 {
self.num_elements
}
fn compute_offset(&self, idxs: &[i64]) -> i64 {
let mut offset = 0;
for i in 0..(self.dims as usize) {
offset += idxs[i] * self.strides[i];
}
offset
}
pub unsafe fn get_unchecked(&self, idxs: &[i64]) -> &T {
let offset = self.compute_offset(idxs);
&*self
.data
.add(offset as usize * self.elem_size as usize)
.cast()
}
pub unsafe fn get_unchecked_mut(&mut self, idxs: &[i64]) -> &mut T {
let offset = self.compute_offset(idxs);
&mut *self
.data
.add(offset as usize * self.elem_size as usize)
.cast()
}
pub fn at(&self, idxs: &[i64]) -> Option<&T> {
if idxs.len() != self.dims as usize {
return None;
}
for i in 0..idxs.len() {
if idxs[i] < 0 || idxs[i] >= self.shape[i] {
return None;
}
}
unsafe { Some(self.get_unchecked(idxs)) }
}
pub fn at_mut(&mut self, idxs: &[i64]) -> Option<&mut T> {
if idxs.len() != self.dims as usize {
return None;
}
for i in 0..idxs.len() {
if idxs[i] < 0 || idxs[i] >= self.shape[i] {
return None;
}
}
unsafe { Some(self.get_unchecked_mut(idxs)) }
}
pub fn scalar(&self) -> Option<&T> {
if self.dims == 0 {
unsafe { Some(&*(self.data as *const T)) }
} else {
None
}
}
pub fn scalar_mut(&mut self) -> Option<&mut T> {
if self.dims == 0 {
unsafe { Some(&mut *(self.data as *mut T)) }
} else {
None
}
}
}