ferrite/tensor/storage/
base.rs

1use std::sync::{Arc, RwLock};
2
3use crate::*;
4
5
6macro_rules! match_self{
7  // The pattern: we match a single expression (`$x:expr`).
8  (storage $self:expr, $($x:tt)*) => {
9    match $self {
10      Storage::Cpu(cpu) => Storage::Cpu(cpu.$($x)*),
11      // Add other devices here
12      _ => unimplemented!("Device not supported"),
13    }
14  };
15
16  (call $self:expr, $($x:tt)*) => {
17    match $self {
18      Storage::Cpu(cpu) => cpu.$($x)*,
19      // Add other devices here
20      _ => unimplemented!("Device not supported"),
21    }
22  }
23}
24
25impl DeviceStorage for Storage {
26  fn view(&self, new_shape: Vec<usize>) -> Self where Self: Sized {
27    match_self!(storage self, view(new_shape))
28  }
29
30  fn data(&self) -> Arc<RwLock<Vec<f32>>> {
31    match_self!(call self, data())
32  }
33
34  fn data_mut(&self) -> std::sync::RwLockWriteGuard<Vec<f32>> {
35    match_self!(call self, data_mut())
36  }
37
38  fn set_data(&mut self, data: Vec<f32>) {
39    match_self!(call self, set_data(data));
40  }
41
42  fn shape(&self) -> &Vec<usize> {
43    match_self!(call self, shape())
44  }
45
46  fn set_shape(&mut self, shape: Vec<usize>) {
47    match_self!(call self, set_shape(shape));
48  }
49
50  fn stride(&self) -> &Vec<usize> {
51    match_self!(call self, stride())
52  }
53
54  fn set_stride(&mut self, stride: Vec<usize>) {
55    match_self!(call self, set_stride(stride));
56  }
57
58  fn offset(&self) -> usize {
59    match_self!(call self, offset())
60  }
61
62  fn get(&self, indices: &[usize]) -> f32 {
63    match_self!(call self, get(indices))
64  }
65
66  fn set(&mut self, indices: &[usize], value: f32) {
67    match_self!(call self, set(indices, value));
68  }
69
70  fn make_contiguous(&self) -> (Vec<f32>, i32) {
71    match_self!(call self, make_contiguous())
72  }
73
74  fn is_contiguous(&self) -> bool {
75    match_self!(call self, is_contiguous())
76  }
77}