rstsr_openblas/
conversion.rs

1use crate::prelude_dev::*;
2
3macro_rules! impl_change_device {
4    ($DevA: ty, $DevB: ty) => {
5        impl<'a, R, T, D> DeviceChangeAPI<'a, $DevB, R, T, D> for $DevA
6        where
7            T: Clone + Send + Sync + 'a,
8            D: DimAPI,
9            R: DataCloneAPI<Data = Vec<T>>,
10        {
11            type Repr = R;
12            type ReprTo = DataRef<'a, Vec<T>>;
13
14            fn change_device(
15                tensor: TensorAny<R, T, $DevA, D>,
16                device: &$DevB,
17            ) -> Result<TensorAny<Self::Repr, T, $DevB, D>> {
18                let (storage, layout) = tensor.into_raw_parts();
19                let (data, _) = storage.into_raw_parts();
20                let storage = Storage::new(data, device.clone());
21                let tensor = TensorAny::new(storage, layout);
22                Ok(tensor)
23            }
24
25            fn into_device(
26                tensor: TensorAny<R, T, $DevA, D>,
27                device: &$DevB,
28            ) -> Result<TensorAny<DataOwned<Vec<T>>, T, $DevB, D>> {
29                let tensor = tensor.into_owned();
30                DeviceChangeAPI::change_device(tensor, device)
31            }
32
33            fn to_device(tensor: &'a TensorAny<R, T, $DevA, D>, device: &$DevB) -> Result<TensorView<'a, T, $DevB, D>> {
34                let view = tensor.view();
35                DeviceChangeAPI::change_device(view, device)
36            }
37        }
38    };
39}
40
41impl_change_device!(DeviceCpuSerial, DeviceBLAS);
42impl_change_device!(DeviceBLAS, DeviceCpuSerial);
43impl_change_device!(DeviceBLAS, DeviceBLAS);
44#[cfg(feature = "faer")]
45impl_change_device!(DeviceFaer, DeviceBLAS);
46#[cfg(feature = "faer")]
47impl_change_device!(DeviceBLAS, DeviceFaer);
48
49#[cfg(test)]
50mod test {
51    use super::*;
52
53    #[test]
54    fn test_device_conversion_cpu_serial() {
55        let device_serial = DeviceCpuSerial::default();
56        let device = DeviceBLAS::new(0);
57        let a = linspace((1.0, 5.0, 5, &device));
58        let b = a.to_device(&device_serial);
59        println!("{b:?}");
60        let a = linspace((1.0, 5.0, 5, &device_serial));
61        let a_view = a.view();
62        let b = a_view.to_device(&device);
63        println!("{b:?}");
64    }
65
66    #[test]
67    #[cfg(feature = "faer")]
68    fn test_device_conversion_faer() {
69        let device_faer = DeviceFaer::new(0);
70        let device = DeviceBLAS::new(0);
71        let a = linspace((1.0, 5.0, 5, &device));
72        let b = a.to_device(&device_faer);
73        println!("{b:?}");
74        let a = linspace((1.0, 5.0, 5, &device_faer));
75        let a_view = a.view();
76        let b = a_view.to_device(&device);
77        println!("{b:?}");
78    }
79}