ferrite/tensor/storage/
creation.rs1use crate::*;
2
3macro_rules! match_device {
4 (storage $device:expr, $($x:tt)*) => {
6 match $device {
7 Device::Cpu => Storage::Cpu(CpuStorage::$($x)*),
8 _ => unimplemented!("Device not supported"),
10 }
11 };
12
13 (call $device:expr, $($x:tt)*) => {
14 match $device {
15 Device::Cpu => CpuStorage::$($x)*,
16 _ => unimplemented!("Device not supported"),
18 }
19 };
20}
21
22impl DeviceStorageCreation for Storage {
23 fn zeros(shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
24 let device = device.expect("Storage: device must be non-null!");
25 match_device!(storage device, zeros(shape, None, None))
26 }
27
28 fn ones(shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
29 let device = device.expect("Storage: device must be non-null!");
30 match_device!(storage device, ones(shape, None, None))
31 }
32
33 fn from_ndarray<S, D, T>(data: &ndarray::ArrayBase<S, D>, device: Option<Device>, _requires_grad: Option<bool>) -> Self
34 where
35 S: ndarray::Data<Elem = T>,
36 T: num_traits::AsPrimitive<f32>,
37 D: ndarray::Dimension
38 {
39 let device = device.expect("Storage: device must be non-null!");
40 match_device!(storage device, from_ndarray(data, None, None))
41 }
42
43 fn uniform(l_bound: f32, r_bound: f32, shape: Vec<usize>, device: Option<Device>, _requires_grad: Option<bool>) -> Self {
44 let device = device.expect("Storage: device must be non-null!");
45 match_device!(storage device, uniform(l_bound, r_bound, shape, None, None))
46 }
47}