1use core::fmt;
2
3pub struct TensorView<'a> {
4 pub data: &'a mut [f32],
5 pub shape: [usize; 5],
6}
7
8impl<'a> TensorView<'a> {
9 pub fn is_valid_layout(&self) -> bool {
10 self.len() == self.data.len()
11 }
12
13 pub fn len(&self) -> usize {
14 let mut n = 1usize;
15 for &d in &self.shape { n = n.saturating_mul(d); }
16 n
17 }
18
19 pub fn idx_linear(&self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<usize> {
20 let [n_dim,c_dim,d_dim,h_dim,w_dim] = self.shape;
21 if n>=n_dim||c>=c_dim||d>=d_dim||h>=h_dim||w>=w_dim { return None; }
22 let idx = n
23 .checked_mul(c_dim)?
24 .checked_add(c)?
25 .checked_mul(d_dim)?
26 .checked_add(d)?
27 .checked_mul(h_dim)?
28 .checked_add(h)?
29 .checked_mul(w_dim)?
30 .checked_add(w)?;
31 if idx < self.data.len() { Some(idx) } else { None }
32 }
33
34 pub fn get(&self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<f32> {
35 let i = self.idx_linear(n,c,d,h,w)?;
36 Some(self.data[i])
37 }
38
39 pub fn get_mut(&mut self, n:usize,c:usize,d:usize,h:usize,w:usize) -> Option<&mut f32> {
40 let i = self.idx_linear(n,c,d,h,w)?;
41 Some(&mut self.data[i])
42 }
43
44 pub fn as_ptr(&self) -> *const f32 { self.data.as_ptr() }
45
46 pub fn as_mut_ptr(&mut self) -> *mut f32 { self.data.as_mut_ptr() }
47}
48
49impl<'a> fmt::Debug for TensorView<'a> {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 f.debug_struct("TensorView").field("shape", &self.shape).finish()
52 }
53}