1use crate::prelude_dev::*;
2use num::{complex::ComplexFloat, Num};
3
4impl<T> DeviceCreationAnyAPI<T> for DeviceFaer
6where
7 Self: DeviceRawAPI<T, Raw = Vec<T>> + DeviceRawAPI<MaybeUninit<T>, Raw = Vec<MaybeUninit<T>>>,
8{
9 unsafe fn empty_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
10 let storage = DeviceCpuSerial::default().empty_impl(len)?;
11 let (data, _) = storage.into_raw_parts();
12 Ok(Storage::new(data, self.clone()))
13 }
14
15 fn full_impl(&self, len: usize, fill: T) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
16 where
17 T: Clone,
18 {
19 let storage = DeviceCpuSerial::default().full_impl(len, fill)?;
20 let (data, _) = storage.into_raw_parts();
21 Ok(Storage::new(data, self.clone()))
22 }
23
24 fn outof_cpu_vec(&self, vec: Vec<T>) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
25 Ok(Storage::new(DataOwned::from(vec), self.clone()))
26 }
27
28 fn from_cpu_vec(&self, vec: &[T]) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
29 where
30 T: Clone,
31 {
32 let raw = vec.to_vec();
33 Ok(Storage::new(DataOwned::from(raw), self.clone()))
34 }
35
36 fn uninit_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>> {
37 let raw = unsafe { uninitialized_vec(len) }?;
38 Ok(Storage::new(raw.into(), self.clone()))
39 }
40
41 unsafe fn assume_init_impl(
42 storage: Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>,
43 ) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
44 where
45 Self: DeviceRawAPI<MaybeUninit<T>>,
46 {
47 let (data, device) = storage.into_raw_parts();
48 let vec = data.into_raw();
49 let vec = core::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(vec);
51 let data = vec.into();
52 Ok(Storage::new(data, device))
53 }
54}
55
56impl<T> DeviceCreationNumAPI<T> for DeviceFaer
57where
58 T: Num + Clone,
59 Self: DeviceRawAPI<T, Raw = Vec<T>>,
60{
61 fn zeros_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
62 let storage = DeviceCpuSerial::default().zeros_impl(len)?;
63 let (data, _) = storage.into_raw_parts();
64 Ok(Storage::new(data, self.clone()))
65 }
66
67 fn ones_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
68 let storage = DeviceCpuSerial::default().ones_impl(len)?;
69 let (data, _) = storage.into_raw_parts();
70 Ok(Storage::new(data, self.clone()))
71 }
72
73 fn arange_int_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
74 let storage = DeviceCpuSerial::default().arange_int_impl(len)?;
75 let (data, _) = storage.into_raw_parts();
76 Ok(Storage::new(data, self.clone()))
77 }
78}
79
80impl<T> DeviceCreationPartialOrdNumAPI<T> for DeviceFaer
81where
82 T: Num + PartialOrd + Clone,
83 Self: DeviceRawAPI<T, Raw = Vec<T>>,
84{
85 fn arange_impl(&self, start: T, end: T, step: T) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
86 let storage = DeviceCpuSerial::default().arange_impl(start, end, step)?;
87 let (data, _) = storage.into_raw_parts();
88 Ok(Storage::new(data, self.clone()))
89 }
90}
91
92impl<T> DeviceCreationComplexFloatAPI<T> for DeviceFaer
93where
94 T: ComplexFloat + Clone + Send + Sync,
95 Self: DeviceRawAPI<T, Raw = Vec<T>>,
96{
97 fn linspace_impl(&self, start: T, end: T, n: usize, endpoint: bool) -> Result<Storage<DataOwned<Vec<T>>, T, Self>> {
98 let storage = DeviceCpuSerial::default().linspace_impl(start, end, n, endpoint)?;
99 let (data, _) = storage.into_raw_parts();
100 Ok(Storage::new(data, self.clone()))
101 }
102}
103
104impl<T> DeviceCreationTriAPI<T> for DeviceFaer
105where
106 T: Num + Clone,
107 Self: DeviceRawAPI<T, Raw = Vec<T>>,
108{
109 fn tril_impl<D>(&self, raw: &mut Self::Raw, layout: &Layout<D>, k: isize) -> Result<()>
110 where
111 D: DimAPI,
112 {
113 DeviceCpuSerial::default().tril_impl(raw, layout, k)
114 }
115
116 fn triu_impl<D>(&self, raw: &mut Self::Raw, layout: &Layout<D>, k: isize) -> Result<()>
117 where
118 D: DimAPI,
119 {
120 DeviceCpuSerial::default().triu_impl(raw, layout, k)
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 #[test]
129 fn test_linspace() {
130 let device = DeviceFaer::default();
131 let a = linspace((1.0, 5.0, 5, &device));
132 assert_eq!(a.raw(), &vec![1., 2., 3., 4., 5.]);
133 }
134}