ferrite/tensor/device/cpu/storage/
utils.rs

1use crate::*;  // Import from parent module's base.rs
2
3use std::fmt;
4
5pub trait Display {
6  fn print(&self);
7  fn print_data_recursive<'a>(data: &'a [f32], shape: &'a [usize], stride: &'a [usize]) -> String;
8  fn print_data(&self);
9}
10
11impl fmt::Display for CpuStorage {
12  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13    write!(f, "{}", Self::print_data_recursive(&self.data().read().unwrap(), self.shape(), self.stride()))
14  }
15}
16
17impl fmt::Debug for CpuStorage {
18  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19    write!(f, "{}", Self::print_data_recursive(&self.data().read().unwrap(), self.shape(), self.stride()))
20  }
21}
22
23impl Display for CpuStorage {
24  fn print(&self) {
25    println!("Data: {:?}", self.data());
26    println!("Shape: {:?}", self.shape());
27    println!("Strides: {:?}", self.stride());
28  }
29
30  fn print_data_recursive<'a>(data: &'a [f32], shape: &'a [usize], stride: &'a [usize]) -> String {
31    let mut res = String::new();
32    res += "[";
33    if shape.len() == 1 {
34      for i in 0..shape[0] {
35        res += &format!("{}", data[i*stride[0]]);
36
37        if i < shape[0] - 1 {
38          res += ", ";
39        }
40        
41      }
42    } else {
43      for i in 0..shape[0] {
44        let start = i*stride[0];
45        let sub_res = Self::print_data_recursive(&data[start..], &shape[1..], &stride[1..]);
46        res += &sub_res;
47        if i < shape[0] - 1 {
48          res += ", ";
49        }
50      }
51    }
52
53    res += "]";
54    res
55  }
56
57  fn print_data(&self) {
58    let res = Self::print_data_recursive(&self.data().read().unwrap(), self.shape(), self.stride());
59    println!("{}", res);
60  }
61}