npu_rs/
tensor.rs

1use ndarray::{ArrayD, IxDyn};
2
3/// A simple multi-dimensional tensor for our NPU framework.
4/// Internally uses `ndarray::ArrayD<f32>` for flexible dimensions.
5#[derive(Clone, Debug)]
6pub struct Tensor {
7    pub data: ArrayD<f32>,
8}
9
10impl Tensor {
11    /// Create a new tensor from a Vec and a shape.
12    /// Example: Tensor::new(vec![1.0, 2.0, 3.0], &[3])
13    pub fn new(data: Vec<f32>, shape: &[usize]) -> Self {
14        Self {
15            data: ArrayD::from_shape_vec(IxDyn(shape), data)
16                .expect("Shape does not match data length"),
17        }
18    }
19
20    /// Create a tensor filled with zeros.
21    pub fn zeros(shape: &[usize]) -> Self {
22        Self {
23            data: ArrayD::zeros(IxDyn(shape)),
24        }
25    }
26
27    /// Create a tensor filled with ones.
28    pub fn ones(shape: &[usize]) -> Self {
29        Self {
30            data: ArrayD::from_elem(IxDyn(shape), 1.0),
31        }
32    }
33
34    /// Create a tensor with random values between 0 and 1.
35    pub fn random(shape: &[usize]) -> Self {
36        use rand::distributions::Uniform;
37        use rand::Rng;
38        
39        let size: usize = shape.iter().product();
40        let mut rng = rand::thread_rng();
41        let dist = Uniform::new(0.0, 1.0);
42        let data: Vec<f32> = (0..size).map(|_| rng.sample(&dist)).collect();
43        Self::new(data, shape)
44    }
45
46    /// Create a scalar tensor (0-D tensor).
47    pub fn from_scalar(value: f32) -> Self {
48        Self {
49            data: ArrayD::from_elem(IxDyn(&[]), value),
50        }
51    }
52
53    /// Return the shape of the tensor as a slice.
54    pub fn shape(&self) -> &[usize] {
55        self.data.shape()
56    }
57
58    /// Compute the sum of all elements.
59    pub fn sum(&self) -> f32 {
60        self.data.sum()
61    }
62
63    /// Pretty-print tensor contents.
64    pub fn print(&self) {
65        println!("{:?}", self.data);
66    }
67
68    /// Element-wise addition.
69    pub fn add(&self, other: &Self) -> Self {
70        Self {
71            data: &self.data + &other.data,
72        }
73    }
74
75    /// Element-wise subtraction.
76    pub fn sub(&self, other: &Self) -> Self {
77        Self {
78            data: &self.data - &other.data,
79        }
80    }
81
82    /// Element-wise multiplication.
83    pub fn mul(&self, other: &Self) -> Self {
84        Self {
85            data: &self.data * &other.data,
86        }
87    }
88
89    /// Element-wise division.
90    pub fn div(&self, other: &Self) -> Self {
91        Self {
92            data: &self.data / &other.data,
93        }
94    }
95
96    // === Activation functions ===
97
98    /// ReLU activation function.
99    pub fn relu(&self) -> Self {
100        Self {
101            data: self.data.mapv(|x| if x > 0.0 { x } else { 0.0 }),
102        }
103    }
104
105    /// Sigmoid activation function.
106    pub fn sigmoid(&self) -> Self {
107        Self {
108            data: self.data.mapv(|x| 1.0 / (1.0 + (-x).exp())),
109        }
110    }
111}