rstsr_core/device_cpu_serial/
creation.rs

1use crate::prelude_dev::*;
2use num::{complex::ComplexFloat, Num};
3
4impl<T> DeviceCreationAnyAPI<T> for DeviceCpuSerial
5where
6    Self: DeviceRawAPI<T, Raw = Vec<T>> + DeviceRawAPI<MaybeUninit<T>, Raw = Vec<MaybeUninit<T>>>,
7{
8    unsafe fn empty_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
9        let raw = uninitialized_vec(len)?;
10        Ok(Storage::new(raw.into(), self.clone()))
11    }
12
13    fn full_impl(&self, len: usize, fill: T) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>>
14    where
15        T: Clone,
16    {
17        let raw = vec![fill; len];
18        Ok(Storage::new(raw.into(), self.clone()))
19    }
20
21    fn outof_cpu_vec(&self, vec: Vec<T>) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
22        Ok(Storage::new(vec.into(), self.clone()))
23    }
24
25    fn from_cpu_vec(&self, vec: &[T]) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>>
26    where
27        T: Clone,
28    {
29        Ok(Storage::new(vec.to_vec().into(), self.clone()))
30    }
31
32    fn uninit_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>>
33    where
34        Self: DeviceRawAPI<MaybeUninit<T>>,
35    {
36        let raw = unsafe { uninitialized_vec(len) }?;
37        Ok(Storage::new(raw.into(), self.clone()))
38    }
39
40    unsafe fn assume_init_impl(
41        storage: Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>,
42    ) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
43    where
44        Self: DeviceRawAPI<MaybeUninit<T>>,
45    {
46        let (data, device) = storage.into_raw_parts();
47        let vec = data.into_raw();
48        // transmute `Vec<MaybeUninit<T>>` to `Vec<T>`
49        let vec = core::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(vec);
50        let data = vec.into();
51        Ok(Storage::new(data, device))
52    }
53}
54
55impl<T> DeviceCreationNumAPI<T> for DeviceCpuSerial
56where
57    T: Num + Clone,
58    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
59{
60    fn zeros_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
61        let raw = vec![T::zero(); len];
62        Ok(Storage::new(raw.into(), self.clone()))
63    }
64
65    fn ones_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
66        let raw = vec![T::one(); len];
67        Ok(Storage::new(raw.into(), self.clone()))
68    }
69
70    fn arange_int_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
71        let mut raw = Vec::with_capacity(len);
72        let mut v = T::zero();
73        for _ in 0..len {
74            raw.push(v.clone());
75            v = v + T::one();
76        }
77        Ok(Storage::new(raw.into(), self.clone()))
78    }
79}
80
81impl<T> DeviceCreationPartialOrdNumAPI<T> for DeviceCpuSerial
82where
83    T: Num + PartialOrd + Clone,
84    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
85{
86    fn arange_impl(&self, start: T, end: T, step: T) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
87        rstsr_assert!(step != T::zero(), InvalidValue)?;
88        let mut raw = Vec::new();
89        let mut current = start.clone();
90        while current < end {
91            raw.push(current.clone());
92            current = current + step.clone();
93        }
94        Ok(Storage::new(raw.into(), self.clone()))
95    }
96}
97
98impl<T> DeviceCreationComplexFloatAPI<T> for DeviceCpuSerial
99where
100    T: ComplexFloat + Clone,
101    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
102{
103    fn linspace_impl(
104        &self,
105        start: T,
106        end: T,
107        n: usize,
108        endpoint: bool,
109    ) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
110        // handle special cases
111        if n == 0 {
112            return Ok(Storage::new(vec![].into(), self.clone()));
113        } else if n == 1 {
114            return Ok(Storage::new(vec![start].into(), self.clone()));
115        }
116
117        let mut raw = Vec::with_capacity(n);
118        let step = match endpoint {
119            true => (end - start) / T::from(n - 1).unwrap(),
120            false => (end - start) / T::from(n).unwrap(),
121        };
122        let mut v = start;
123        for _ in 0..n {
124            raw.push(v);
125            v = v + step;
126        }
127        Ok(Storage::new(raw.into(), self.clone()))
128    }
129}
130
131impl<T> DeviceCreationTriAPI<T> for DeviceCpuSerial
132where
133    T: Num + Clone,
134{
135    fn tril_impl<D>(&self, raw: &mut Vec<T>, layout: &Layout<D>, k: isize) -> Result<()>
136    where
137        D: DimAPI,
138    {
139        tril_cpu_serial(raw, layout, k)
140    }
141
142    fn triu_impl<D>(&self, raw: &mut Vec<T>, layout: &Layout<D>, k: isize) -> Result<()>
143    where
144        D: DimAPI,
145    {
146        triu_cpu_serial(raw, layout, k)
147    }
148}
149
150#[cfg(test)]
151mod test {
152    #[test]
153    fn test_creation() {
154        use super::*;
155        use num::Complex;
156
157        let device = DeviceCpuSerial::default();
158        let storage: Storage<_, f64, _> = device.zeros_impl(10).unwrap();
159        println!("{storage:?}");
160        let storage: Storage<_, f64, _> = device.ones_impl(10).unwrap();
161        println!("{storage:?}");
162        let storage: Storage<_, f64, _> = device.arange_int_impl(10).unwrap();
163        println!("{storage:?}");
164        let storage: Storage<_, f64, _> = unsafe { device.empty_impl(10).unwrap() };
165        println!("{storage:?}");
166        let storage = device.from_cpu_vec(&[1.0; 10]).unwrap();
167        println!("{storage:?}");
168        let storage = device.outof_cpu_vec(vec![1.0; 10]).unwrap();
169        println!("{storage:?}");
170        let storage = device.linspace_impl(0.0, 1.0, 10, true).unwrap();
171        println!("{storage:?}");
172        let storage = device.linspace_impl(Complex::new(1.0, 2.0), Complex::new(3.5, 4.7), 10, true).unwrap();
173        println!("{storage:?}");
174        let storage = device.arange_impl(0.0, 1.0, 0.1).unwrap();
175        println!("{storage:?}");
176
177        // tril/triu
178        let mut vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
179        let layout = [3, 3].c();
180        device.tril_impl(&mut vec, &layout, -1).unwrap();
181        println!("{vec:?}");
182        assert_eq!(vec, vec![0, 0, 0, 4, 0, 0, 7, 8, 0]);
183        let mut vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
184        device.triu_impl(&mut vec, &layout, -1).unwrap();
185        println!("{vec:?}");
186        assert_eq!(vec, vec![1, 2, 3, 4, 5, 6, 0, 8, 9]);
187    }
188}