Skip to main content

rnn/tensor/
tensor.rs

1use core::fmt;
2
3pub struct TensorView<'a> {
4    pub data: &'a mut [f32],
5    pub shape: [usize; 5],
6}
7
8impl<'a> TensorView<'a> {
9    pub fn is_valid_layout(&self) -> bool {
10        self.len() == self.data.len()
11    }
12
13    pub fn len(&self) -> usize {
14        let mut n = 1usize;
15        for &d in &self.shape { n = n.saturating_mul(d); }
16        n
17    }
18
19    pub fn idx_linear(&self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<usize> {
20        let [n_dim,c_dim,d_dim,h_dim,w_dim] = self.shape;
21        if n>=n_dim||c>=c_dim||d>=d_dim||h>=h_dim||w>=w_dim { return None; }
22        let idx = n
23            .checked_mul(c_dim)?
24            .checked_add(c)?
25            .checked_mul(d_dim)?
26            .checked_add(d)?
27            .checked_mul(h_dim)?
28            .checked_add(h)?
29            .checked_mul(w_dim)?
30            .checked_add(w)?;
31        if idx < self.data.len() { Some(idx) } else { None }
32    }
33
34    pub fn get(&self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<f32> {
35        let i = self.idx_linear(n,c,d,h,w)?;
36        Some(self.data[i])
37    }
38
39    pub fn get_mut(&mut self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<&mut f32> {
40        let i = self.idx_linear(n,c,d,h,w)?;
41        Some(&mut self.data[i])
42    }
43
44    pub fn as_ptr(&self) -> *const f32 { self.data.as_ptr() }
45
46    pub fn as_mut_ptr(&mut self) -> *mut f32 { self.data.as_mut_ptr() }
47}
48
49impl<'a> fmt::Debug for TensorView<'a> {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        f.debug_struct("TensorView").field("shape", &self.shape).finish()
52    }
53}