rstsr_core/device_faer/
conversion.rs

1//! Conversion to/from Faer
2
3use crate::prelude_dev::*;
4use core::mem::ManuallyDrop;
5use faer::prelude::*;
6use faer_ext::IntoFaer;
7
8/* #region conversion to Faer objects */
9
10impl<'a, T, B> IntoFaer for TensorView<'a, T, B, Ix2>
11where
12    B: DeviceAPI<T, Raw = Vec<T>>,
13{
14    type Faer = MatRef<'a, T>;
15
16    fn into_faer(self) -> Self::Faer {
17        let [nrows, ncols] = *self.shape();
18        let [row_stride, col_stride] = *self.stride();
19        let offset = self.offset();
20        let ptr = unsafe { self.raw().as_ptr().add(offset) };
21        unsafe { MatRef::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) }
22    }
23}
24
25impl<'a, T, B> IntoFaer for TensorViewMut<'a, T, B, Ix2>
26where
27    B: DeviceAPI<T, Raw = Vec<T>>,
28{
29    type Faer = MatMut<'a, T>;
30
31    fn into_faer(mut self) -> Self::Faer {
32        let [nrows, ncols] = *self.shape();
33        let [row_stride, col_stride] = *self.stride();
34        let offset = self.offset();
35        let ptr = unsafe { self.raw_mut().as_mut_ptr().add(offset) };
36        unsafe { MatMut::from_raw_parts_mut(ptr, nrows, ncols, row_stride, col_stride) }
37    }
38}
39
40impl<'a, T> IntoRSTSR for MatRef<'a, T> {
41    type RSTSR = TensorView<'a, T, DeviceFaer, Ix2>;
42
43    fn into_rstsr(self) -> Self::RSTSR {
44        let nrows = self.nrows();
45        let ncols = self.ncols();
46        let row_stride = self.row_stride();
47        let col_stride = self.col_stride();
48        let ptr = self.as_ptr();
49
50        let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
51        let (_, upper_bound) = layout.bounds_index().unwrap();
52        let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
53        let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
54        let storage = Storage::new(data, DeviceFaer::default());
55        let tensor = unsafe { TensorView::new_unchecked(storage, layout) };
56        return tensor;
57    }
58}
59
60impl<T> IntoRSTSR for Mat<T> {
61    type RSTSR = Tensor<T, DeviceFaer, Ix2>;
62
63    fn into_rstsr(self) -> Self::RSTSR {
64        let nrows = self.nrows();
65        let ncols = self.ncols();
66        let row_stride = self.row_stride();
67        let col_stride = self.col_stride();
68        let ptr = self.as_ptr();
69        core::mem::forget(self); // prevent double free
70
71        let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
72        let (_, upper_bound) = layout.bounds_index().unwrap();
73        let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
74        let data = DataOwned::from(raw);
75        let storage = Storage::new(data, DeviceFaer::default());
76        let tensor = unsafe { Tensor::new_unchecked(storage, layout) };
77        return tensor;
78    }
79}
80
81impl<'a, T> IntoRSTSR for ColRef<'a, T> {
82    type RSTSR = TensorView<'a, T, DeviceFaer, Ix1>;
83
84    fn into_rstsr(self) -> Self::RSTSR {
85        let nrows = self.nrows();
86        let stride = self.row_stride();
87        let ptr = self.as_ptr();
88
89        let layout = Layout::new([nrows], [stride], 0).unwrap();
90        let (_, upper_bound) = layout.bounds_index().unwrap();
91        let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
92        let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
93        let storage = Storage::new(data, DeviceFaer::default());
94        let tensor = unsafe { TensorView::new_unchecked(storage, layout) };
95        return tensor;
96    }
97}
98
99impl<'a, T> IntoRSTSR for MatMut<'a, T> {
100    type RSTSR = TensorViewMut<'a, T, DeviceFaer, Ix2>;
101
102    fn into_rstsr(self) -> Self::RSTSR {
103        let nrows = self.nrows();
104        let ncols = self.ncols();
105        let row_stride = self.row_stride();
106        let col_stride = self.col_stride();
107        let ptr = self.as_ptr();
108
109        let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
110        let (_, upper_bound) = layout.bounds_index().unwrap();
111        let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
112        let data = DataMut::from_manually_drop(ManuallyDrop::new(raw));
113        let storage = Storage::new(data, DeviceFaer::default());
114        let tensor = unsafe { TensorMut::new_unchecked(storage, layout) };
115        return tensor;
116    }
117}
118
119/* #endregion */
120
121/* #region device conversion */
122
123#[duplicate_item(
124    DevA DevB;
125   [DeviceFaer     ] [DeviceCpuSerial];
126   [DeviceCpuSerial] [DeviceFaer     ];
127   [DeviceFaer     ] [DeviceFaer     ];
128)]
129impl<'a, R, T, D> DeviceChangeAPI<'a, DevB, R, T, D> for DevA
130where
131    T: Clone + Send + Sync + 'a,
132    D: DimAPI,
133    R: DataCloneAPI<Data = Vec<T>>,
134{
135    type Repr = R;
136    type ReprTo = DataRef<'a, Vec<T>>;
137
138    fn change_device(tensor: TensorAny<R, T, DevA, D>, device: &DevB) -> Result<TensorAny<Self::Repr, T, DevB, D>> {
139        let (storage, layout) = tensor.into_raw_parts();
140        let (data, _) = storage.into_raw_parts();
141        let storage = Storage::new(data, device.clone());
142        let tensor = TensorAny::new(storage, layout);
143        Ok(tensor)
144    }
145
146    fn into_device(
147        tensor: TensorAny<R, T, DevA, D>,
148        device: &DevB,
149    ) -> Result<TensorAny<DataOwned<Vec<T>>, T, DevB, D>> {
150        let tensor = tensor.into_owned();
151        DeviceChangeAPI::change_device(tensor, device)
152    }
153
154    fn to_device(tensor: &'a TensorAny<R, T, DevA, D>, device: &DevB) -> Result<TensorView<'a, T, DevB, D>> {
155        let view = tensor.view();
156        DeviceChangeAPI::change_device(view, device)
157    }
158}
159
160/* #endregion */
161
162#[cfg(test)]
163mod test {
164    use super::*;
165
166    #[test]
167    fn test_device_conversion() {
168        let device_serial = DeviceCpuSerial::default();
169        let device_faer = DeviceFaer::new(0);
170        let a = linspace((1.0, 5.0, 5, &device_faer));
171        let b = a.to_device(&device_serial);
172        println!("{b:?}");
173        let a = linspace((1.0, 5.0, 5, &device_serial));
174        let a_view = a.view();
175        let b = a_view.to_device(&device_faer);
176        println!("{b:?}");
177    }
178
179    #[test]
180    fn test_self_conversion() {
181        let device_a = DeviceFaer::new(1);
182        let device_b = DeviceFaer::new(0);
183        let a = linspace((1.0, 5.0, 5, &device_b));
184        let b = a.to_device(&device_a);
185        println!("{b:?}");
186        let a = linspace((1.0, 5.0, 5, &device_a));
187        let a_view = a.view();
188        let b = a_view.to_device(&device_b);
189        println!("{b:?}");
190    }
191}