ferrite/tensor/storage/
creation.rs

1use crate::*;
2
3macro_rules! match_device {
4  // The pattern: we match a single expression (`$x:expr`).
5  (storage $device:expr, $($x:tt)*) => {
6    match $device {
7      Device::Cpu => Storage::Cpu(CpuStorage::$($x)*),
8      // Add other devices here
9      _ => unimplemented!("Device not supported"),
10    }
11  };
12
13  (call $device:expr, $($x:tt)*) => {
14    match $device {
15      Device::Cpu => CpuStorage::$($x)*,
16      // Add other devices here
17      _ => unimplemented!("Device not supported"),
18    }
19  };
20}
21
22impl DeviceStorageCreation for Storage {
23  fn zeros(shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
24    let device = device.expect("Storage: device must be non-null!");
25    match_device!(storage device, zeros(shape, None, None))
26  }
27
28  fn ones(shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
29    let device = device.expect("Storage: device must be non-null!");
30    match_device!(storage device, ones(shape, None, None))
31  }
32
33  fn from_ndarray<S, D, T>(data: &ndarray::ArrayBase<S, D>, device: Option<Device>, _requires_grad: Option<bool>) -> Self
34  where 
35    S: ndarray::Data<Elem = T>,
36    T: num_traits::AsPrimitive<f32>,
37    D: ndarray::Dimension 
38  {
39    let device = device.expect("Storage: device must be non-null!");
40    match_device!(storage device, from_ndarray(data, None, None))
41  }
42
43  fn uniform(l_bound: f32, r_bound: f32, shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
44    let device = device.expect("Storage: device must be non-null!");
45    match_device!(storage device, uniform(l_bound, r_bound, shape, None, None))
46  }
47}