rstsr_openblas/
conversion.rs1use crate::prelude_dev::*;
2
3macro_rules! impl_change_device {
4 ($DevA: ty, $DevB: ty) => {
5 impl<'a, R, T, D> DeviceChangeAPI<'a, $DevB, R, T, D> for $DevA
6 where
7 T: Clone + Send + Sync + 'a,
8 D: DimAPI,
9 R: DataCloneAPI<Data = Vec<T>>,
10 {
11 type Repr = R;
12 type ReprTo = DataRef<'a, Vec<T>>;
13
14 fn change_device(
15 tensor: TensorAny<R, T, $DevA, D>,
16 device: &$DevB,
17 ) -> Result<TensorAny<Self::Repr, T, $DevB, D>> {
18 let (storage, layout) = tensor.into_raw_parts();
19 let (data, _) = storage.into_raw_parts();
20 let storage = Storage::new(data, device.clone());
21 let tensor = TensorAny::new(storage, layout);
22 Ok(tensor)
23 }
24
25 fn into_device(
26 tensor: TensorAny<R, T, $DevA, D>,
27 device: &$DevB,
28 ) -> Result<TensorAny<DataOwned<Vec<T>>, T, $DevB, D>> {
29 let tensor = tensor.into_owned();
30 DeviceChangeAPI::change_device(tensor, device)
31 }
32
33 fn to_device(tensor: &'a TensorAny<R, T, $DevA, D>, device: &$DevB) -> Result<TensorView<'a, T, $DevB, D>> {
34 let view = tensor.view();
35 DeviceChangeAPI::change_device(view, device)
36 }
37 }
38 };
39}
40
41impl_change_device!(DeviceCpuSerial, DeviceBLAS);
42impl_change_device!(DeviceBLAS, DeviceCpuSerial);
43impl_change_device!(DeviceBLAS, DeviceBLAS);
44#[cfg(feature = "faer")]
45impl_change_device!(DeviceFaer, DeviceBLAS);
46#[cfg(feature = "faer")]
47impl_change_device!(DeviceBLAS, DeviceFaer);
48
49#[cfg(test)]
50mod test {
51 use super::*;
52
53 #[test]
54 fn test_device_conversion_cpu_serial() {
55 let device_serial = DeviceCpuSerial::default();
56 let device = DeviceBLAS::new(0);
57 let a = linspace((1.0, 5.0, 5, &device));
58 let b = a.to_device(&device_serial);
59 println!("{b:?}");
60 let a = linspace((1.0, 5.0, 5, &device_serial));
61 let a_view = a.view();
62 let b = a_view.to_device(&device);
63 println!("{b:?}");
64 }
65
66 #[test]
67 #[cfg(feature = "faer")]
68 fn test_device_conversion_faer() {
69 let device_faer = DeviceFaer::new(0);
70 let device = DeviceBLAS::new(0);
71 let a = linspace((1.0, 5.0, 5, &device));
72 let b = a.to_device(&device_faer);
73 println!("{b:?}");
74 let a = linspace((1.0, 5.0, 5, &device_faer));
75 let a_view = a.view();
76 let b = a_view.to_device(&device);
77 println!("{b:?}");
78 }
79}