1use crate::prelude_dev::*;
4use core::mem::ManuallyDrop;
5use faer::prelude::*;
6use faer_ext::IntoFaer;
7
8impl<'a, T, B> IntoFaer for TensorView<'a, T, B, Ix2>
11where
12 B: DeviceAPI<T, Raw = Vec<T>>,
13{
14 type Faer = MatRef<'a, T>;
15
16 fn into_faer(self) -> Self::Faer {
17 let [nrows, ncols] = *self.shape();
18 let [row_stride, col_stride] = *self.stride();
19 let offset = self.offset();
20 let ptr = unsafe { self.raw().as_ptr().add(offset) };
21 unsafe { MatRef::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) }
22 }
23}
24
25impl<'a, T, B> IntoFaer for TensorViewMut<'a, T, B, Ix2>
26where
27 B: DeviceAPI<T, Raw = Vec<T>>,
28{
29 type Faer = MatMut<'a, T>;
30
31 fn into_faer(mut self) -> Self::Faer {
32 let [nrows, ncols] = *self.shape();
33 let [row_stride, col_stride] = *self.stride();
34 let offset = self.offset();
35 let ptr = unsafe { self.raw_mut().as_mut_ptr().add(offset) };
36 unsafe { MatMut::from_raw_parts_mut(ptr, nrows, ncols, row_stride, col_stride) }
37 }
38}
39
40impl<'a, T> IntoRSTSR for MatRef<'a, T> {
41 type RSTSR = TensorView<'a, T, DeviceFaer, Ix2>;
42
43 fn into_rstsr(self) -> Self::RSTSR {
44 let nrows = self.nrows();
45 let ncols = self.ncols();
46 let row_stride = self.row_stride();
47 let col_stride = self.col_stride();
48 let ptr = self.as_ptr();
49
50 let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
51 let (_, upper_bound) = layout.bounds_index().unwrap();
52 let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
53 let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
54 let storage = Storage::new(data, DeviceFaer::default());
55 let tensor = unsafe { TensorView::new_unchecked(storage, layout) };
56 return tensor;
57 }
58}
59
60impl<T> IntoRSTSR for Mat<T> {
61 type RSTSR = Tensor<T, DeviceFaer, Ix2>;
62
63 fn into_rstsr(self) -> Self::RSTSR {
64 let nrows = self.nrows();
65 let ncols = self.ncols();
66 let row_stride = self.row_stride();
67 let col_stride = self.col_stride();
68 let ptr = self.as_ptr();
69 core::mem::forget(self); let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
72 let (_, upper_bound) = layout.bounds_index().unwrap();
73 let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
74 let data = DataOwned::from(raw);
75 let storage = Storage::new(data, DeviceFaer::default());
76 let tensor = unsafe { Tensor::new_unchecked(storage, layout) };
77 return tensor;
78 }
79}
80
81impl<'a, T> IntoRSTSR for ColRef<'a, T> {
82 type RSTSR = TensorView<'a, T, DeviceFaer, Ix1>;
83
84 fn into_rstsr(self) -> Self::RSTSR {
85 let nrows = self.nrows();
86 let stride = self.row_stride();
87 let ptr = self.as_ptr();
88
89 let layout = Layout::new([nrows], [stride], 0).unwrap();
90 let (_, upper_bound) = layout.bounds_index().unwrap();
91 let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
92 let data = DataRef::from_manually_drop(ManuallyDrop::new(raw));
93 let storage = Storage::new(data, DeviceFaer::default());
94 let tensor = unsafe { TensorView::new_unchecked(storage, layout) };
95 return tensor;
96 }
97}
98
99impl<'a, T> IntoRSTSR for MatMut<'a, T> {
100 type RSTSR = TensorViewMut<'a, T, DeviceFaer, Ix2>;
101
102 fn into_rstsr(self) -> Self::RSTSR {
103 let nrows = self.nrows();
104 let ncols = self.ncols();
105 let row_stride = self.row_stride();
106 let col_stride = self.col_stride();
107 let ptr = self.as_ptr();
108
109 let layout = Layout::new([nrows, ncols], [row_stride, col_stride], 0).unwrap();
110 let (_, upper_bound) = layout.bounds_index().unwrap();
111 let raw = unsafe { Vec::from_raw_parts(ptr as *mut T, upper_bound, upper_bound) };
112 let data = DataMut::from_manually_drop(ManuallyDrop::new(raw));
113 let storage = Storage::new(data, DeviceFaer::default());
114 let tensor = unsafe { TensorMut::new_unchecked(storage, layout) };
115 return tensor;
116 }
117}
118
119#[duplicate_item(
124 DevA DevB;
125 [DeviceFaer ] [DeviceCpuSerial];
126 [DeviceCpuSerial] [DeviceFaer ];
127 [DeviceFaer ] [DeviceFaer ];
128)]
129impl<'a, R, T, D> DeviceChangeAPI<'a, DevB, R, T, D> for DevA
130where
131 T: Clone + Send + Sync + 'a,
132 D: DimAPI,
133 R: DataCloneAPI<Data = Vec<T>>,
134{
135 type Repr = R;
136 type ReprTo = DataRef<'a, Vec<T>>;
137
138 fn change_device(tensor: TensorAny<R, T, DevA, D>, device: &DevB) -> Result<TensorAny<Self::Repr, T, DevB, D>> {
139 let (storage, layout) = tensor.into_raw_parts();
140 let (data, _) = storage.into_raw_parts();
141 let storage = Storage::new(data, device.clone());
142 let tensor = TensorAny::new(storage, layout);
143 Ok(tensor)
144 }
145
146 fn into_device(
147 tensor: TensorAny<R, T, DevA, D>,
148 device: &DevB,
149 ) -> Result<TensorAny<DataOwned<Vec<T>>, T, DevB, D>> {
150 let tensor = tensor.into_owned();
151 DeviceChangeAPI::change_device(tensor, device)
152 }
153
154 fn to_device(tensor: &'a TensorAny<R, T, DevA, D>, device: &DevB) -> Result<TensorView<'a, T, DevB, D>> {
155 let view = tensor.view();
156 DeviceChangeAPI::change_device(view, device)
157 }
158}
159
160#[cfg(test)]
163mod test {
164 use super::*;
165
166 #[test]
167 fn test_device_conversion() {
168 let device_serial = DeviceCpuSerial::default();
169 let device_faer = DeviceFaer::new(0);
170 let a = linspace((1.0, 5.0, 5, &device_faer));
171 let b = a.to_device(&device_serial);
172 println!("{b:?}");
173 let a = linspace((1.0, 5.0, 5, &device_serial));
174 let a_view = a.view();
175 let b = a_view.to_device(&device_faer);
176 println!("{b:?}");
177 }
178
179 #[test]
180 fn test_self_conversion() {
181 let device_a = DeviceFaer::new(1);
182 let device_b = DeviceFaer::new(0);
183 let a = linspace((1.0, 5.0, 5, &device_b));
184 let b = a.to_device(&device_a);
185 println!("{b:?}");
186 let a = linspace((1.0, 5.0, 5, &device_a));
187 let a_view = a.view();
188 let b = a_view.to_device(&device_b);
189 println!("{b:?}");
190 }
191}