rstsr_core/device_faer/
creation.rs

1use crate::prelude_dev::*;
2use num::{complex::ComplexFloat, Num};
3
4// for creation, we use most of the functions from DeviceCpuSerial
5impl<T> DeviceCreationAnyAPI<T> for DeviceFaer
6where
7    Self: DeviceRawAPI<T, Raw = Vec<T>> + DeviceRawAPI<MaybeUninit<T>, Raw = Vec<MaybeUninit<T>>>,
8{
9    unsafe fn empty_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
10        let storage = DeviceCpuSerial::default().empty_impl(len)?;
11        let (data, _) = storage.into_raw_parts();
12        Ok(Storage::new(data, self.clone()))
13    }
14
15    fn full_impl(&self, len: usize, fill: T) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
16    where
17        T: Clone,
18    {
19        let storage = DeviceCpuSerial::default().full_impl(len, fill)?;
20        let (data, _) = storage.into_raw_parts();
21        Ok(Storage::new(data, self.clone()))
22    }
23
24    fn outof_cpu_vec(&self, vec: Vec<T>) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
25        Ok(Storage::new(DataOwned::from(vec), self.clone()))
26    }
27
28    fn from_cpu_vec(&self, vec: &[T]) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
29    where
30        T: Clone,
31    {
32        let raw = vec.to_vec();
33        Ok(Storage::new(DataOwned::from(raw), self.clone()))
34    }
35
36    fn uninit_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>> {
37        let raw = unsafe { uninitialized_vec(len) }?;
38        Ok(Storage::new(raw.into(), self.clone()))
39    }
40
41    unsafe fn assume_init_impl(
42        storage: Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>,
43    ) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
44    where
45        Self: DeviceRawAPI<MaybeUninit<T>>,
46    {
47        let (data, device) = storage.into_raw_parts();
48        let vec = data.into_raw();
49        // transmute `Vec<MaybeUninit<T>>` to `Vec<T>`
50        let vec = core::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(vec);
51        let data = vec.into();
52        Ok(Storage::new(data, device))
53    }
54}
55
56impl<T> DeviceCreationNumAPI<T> for DeviceFaer
57where
58    T: Num + Clone,
59    Self: DeviceRawAPI<T, Raw = Vec<T>>,
60{
61    fn zeros_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
62        let storage = DeviceCpuSerial::default().zeros_impl(len)?;
63        let (data, _) = storage.into_raw_parts();
64        Ok(Storage::new(data, self.clone()))
65    }
66
67    fn ones_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
68        let storage = DeviceCpuSerial::default().ones_impl(len)?;
69        let (data, _) = storage.into_raw_parts();
70        Ok(Storage::new(data, self.clone()))
71    }
72
73    fn arange_int_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
74        let storage = DeviceCpuSerial::default().arange_int_impl(len)?;
75        let (data, _) = storage.into_raw_parts();
76        Ok(Storage::new(data, self.clone()))
77    }
78}
79
80impl<T> DeviceCreationPartialOrdNumAPI<T> for DeviceFaer
81where
82    T: Num + PartialOrd + Clone,
83    Self: DeviceRawAPI<T, Raw = Vec<T>>,
84{
85    fn arange_impl(&self, start: T, end: T, step: T) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
86        let storage = DeviceCpuSerial::default().arange_impl(start, end, step)?;
87        let (data, _) = storage.into_raw_parts();
88        Ok(Storage::new(data, self.clone()))
89    }
90}
91
92impl<T> DeviceCreationComplexFloatAPI<T> for DeviceFaer
93where
94    T: ComplexFloat + Clone + Send + Sync,
95    Self: DeviceRawAPI<T, Raw = Vec<T>>,
96{
97    fn linspace_impl(&self, start: T, end: T, n: usize, endpoint: bool) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
98        let storage = DeviceCpuSerial::default().linspace_impl(start, end, n, endpoint)?;
99        let (data, _) = storage.into_raw_parts();
100        Ok(Storage::new(data, self.clone()))
101    }
102}
103
104impl<T> DeviceCreationTriAPI<T> for DeviceFaer
105where
106    T: Num + Clone,
107    Self: DeviceRawAPI<T, Raw = Vec<T>>,
108{
109    fn tril_impl<D>(&self, raw: &mut Self::Raw, layout: &Layout<D>, k: isize) -> Result<()>
110    where
111        D: DimAPI,
112    {
113        DeviceCpuSerial::default().tril_impl(raw, layout, k)
114    }
115
116    fn triu_impl<D>(&self, raw: &mut Self::Raw, layout: &Layout<D>, k: isize) -> Result<()>
117    where
118        D: DimAPI,
119    {
120        DeviceCpuSerial::default().triu_impl(raw, layout, k)
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use super::*;
127
128    #[test]
129    fn test_linspace() {
130        let device = DeviceFaer::default();
131        let a = linspace((1.0, 5.0, 5, &device));
132        assert_eq!(a.raw(), &vec![1., 2., 3., 4., 5.]);
133    }
134}