use core::fmt;
pub const MAX_DIMS: usize = 4;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct Shape {
dims: [usize; MAX_DIMS],
ndim: u8,
}
impl Shape {
pub fn new(dims: &[usize]) -> Self {
assert!(
dims.len() <= MAX_DIMS,
"wasmicro::Shape supports up to {} dimensions, got {}",
MAX_DIMS,
dims.len()
);
let mut out = [0usize; MAX_DIMS];
out[..dims.len()].copy_from_slice(dims);
Self {
dims: out,
ndim: dims.len() as u8,
}
}
#[inline]
pub fn ndim(&self) -> usize {
self.ndim as usize
}
#[inline]
pub fn as_slice(&self) -> &[usize] {
&self.dims[..self.ndim()]
}
#[inline]
pub fn numel(&self) -> usize {
self.as_slice().iter().product()
}
}
impl fmt::Debug for Shape {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.as_slice())
}
}
#[derive(Clone)]
pub struct Tensor {
data: Vec<f32>,
shape: Shape,
}
impl Tensor {
pub fn from_vec(data: Vec<f32>, shape: &[usize]) -> Self {
let shape = Shape::new(shape);
assert_eq!(
data.len(),
shape.numel(),
"data length {} does not match shape {:?} (expected {} elements)",
data.len(),
shape.as_slice(),
shape.numel()
);
Self { data, shape }
}
pub fn zeros(shape: &[usize]) -> Self {
let shape = Shape::new(shape);
Self {
data: vec![0.0; shape.numel()],
shape,
}
}
#[inline]
pub fn data(&self) -> &[f32] {
&self.data
}
#[inline]
pub fn data_mut(&mut self) -> &mut [f32] {
&mut self.data
}
#[inline]
pub fn shape(&self) -> &Shape {
&self.shape
}
#[inline]
pub fn numel(&self) -> usize {
self.shape.numel()
}
}
impl fmt::Debug for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tensor")
.field("shape", &self.shape)
.field("numel", &self.numel())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shape_basic() {
let s = Shape::new(&[2, 3, 4]);
assert_eq!(s.ndim(), 3);
assert_eq!(s.as_slice(), &[2, 3, 4]);
assert_eq!(s.numel(), 24);
}
#[test]
fn tensor_zeros() {
let t = Tensor::zeros(&[2, 3]);
assert_eq!(t.numel(), 6);
assert_eq!(t.data(), &[0.0; 6]);
}
#[test]
fn tensor_from_vec() {
let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
assert_eq!(t.shape().as_slice(), &[2, 2]);
assert_eq!(t.data(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
#[should_panic]
fn tensor_shape_mismatch_panics() {
let _ = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[2, 2]);
}
}