Skip to main content

rstsr_core/device_cpu_serial/
creation.rs

1use crate::prelude_dev::*;
2use num::{complex::ComplexFloat, Num, Zero};
3
4impl<T> DeviceCreationAnyAPI<T> for DeviceCpuSerial
5where
6    Self: DeviceRawAPI<T, Raw = Vec<T>> + DeviceRawAPI<MaybeUninit<T>, Raw = Vec<MaybeUninit<T>>>,
7{
8    unsafe fn empty_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
9        let raw = uninitialized_vec(len)?;
10        Ok(Storage::new(raw.into(), self.clone()))
11    }
12
13    fn full_impl(&self, len: usize, fill: T) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>>
14    where
15        T: Clone,
16    {
17        let raw = vec![fill; len];
18        Ok(Storage::new(raw.into(), self.clone()))
19    }
20
21    fn outof_cpu_vec(&self, vec: Vec<T>) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
22        Ok(Storage::new(vec.into(), self.clone()))
23    }
24
25    fn from_cpu_vec(&self, vec: &[T]) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>>
26    where
27        T: Clone,
28    {
29        Ok(Storage::new(vec.to_vec().into(), self.clone()))
30    }
31
32    fn uninit_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>>
33    where
34        Self: DeviceRawAPI<MaybeUninit<T>>,
35    {
36        let raw = unsafe { uninitialized_vec(len) }?;
37        Ok(Storage::new(raw.into(), self.clone()))
38    }
39
40    unsafe fn assume_init_impl(
41        storage: Storage<DataOwned<Vec<MaybeUninit<T>>>, MaybeUninit<T>, Self>,
42    ) -> Result<Storage<DataOwned<Vec<T>>, T, Self>>
43    where
44        Self: DeviceRawAPI<MaybeUninit<T>>,
45    {
46        let (data, device) = storage.into_raw_parts();
47        let vec = data.into_raw();
48        // transmute `Vec<MaybeUninit<T>>` to `Vec<T>`
49        let vec = core::mem::transmute::<Vec<MaybeUninit<T>>, Vec<T>>(vec);
50        let data = vec.into();
51        Ok(Storage::new(data, device))
52    }
53}
54
55impl<T> DeviceCreationNumAPI<T> for DeviceCpuSerial
56where
57    T: Num + Clone,
58    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
59{
60    fn zeros_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
61        let raw = vec![T::zero(); len];
62        Ok(Storage::new(raw.into(), self.clone()))
63    }
64
65    fn ones_impl(&self, len: usize) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
66        let raw = vec![T::one(); len];
67        Ok(Storage::new(raw.into(), self.clone()))
68    }
69}
70
71impl<T> DeviceCreationArangeAPI<T> for DeviceCpuSerial
72where
73    T: PartialOrd + Clone + Add<Output = T> + Zero + 'static,
74    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
75{
76    fn arange_impl(&self, start: T, end: T, step: T) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
77        rstsr_assert!(step != T::zero(), InvalidValue)?;
78        let raw = arange_cpu_serial(start, end, step);
79        Ok(Storage::new(raw.into(), self.clone()))
80    }
81}
82
83impl<T> DeviceCreationComplexFloatAPI<T> for DeviceCpuSerial
84where
85    T: ComplexFloat + Clone,
86    DeviceCpuSerial: DeviceRawAPI<T, Raw = Vec<T>>,
87{
88    fn linspace_impl(
89        &self,
90        start: T,
91        end: T,
92        n: usize,
93        endpoint: bool,
94    ) -> Result<Storage<DataOwned<Vec<T>>, T, DeviceCpuSerial>> {
95        // handle special cases
96        if n == 0 {
97            return Ok(Storage::new(vec![].into(), self.clone()));
98        } else if n == 1 {
99            return Ok(Storage::new(vec![start].into(), self.clone()));
100        }
101
102        let mut raw = Vec::with_capacity(n);
103        let step = match endpoint {
104            true => (end - start) / T::from(n - 1).unwrap(),
105            false => (end - start) / T::from(n).unwrap(),
106        };
107        let mut v = start;
108        for _ in 0..n {
109            raw.push(v);
110            v = v + step;
111        }
112        Ok(Storage::new(raw.into(), self.clone()))
113    }
114}
115
116impl<T> DeviceCreationTriAPI<T> for DeviceCpuSerial
117where
118    T: Num + Clone,
119{
120    fn tril_impl<D>(&self, raw: &mut Vec<T>, layout: &Layout<D>, k: isize) -> Result<()>
121    where
122        D: DimAPI,
123    {
124        tril_cpu_serial(raw, layout, k)
125    }
126
127    fn triu_impl<D>(&self, raw: &mut Vec<T>, layout: &Layout<D>, k: isize) -> Result<()>
128    where
129        D: DimAPI,
130    {
131        triu_cpu_serial(raw, layout, k)
132    }
133}
134
135#[cfg(test)]
136mod test {
137    #[test]
138    fn test_creation() {
139        use super::*;
140        use num::Complex;
141
142        let device = DeviceCpuSerial::default();
143        let storage: Storage<_, f64, _> = device.zeros_impl(10).unwrap();
144        println!("{storage:?}");
145        let storage: Storage<_, f64, _> = device.ones_impl(10).unwrap();
146        println!("{storage:?}");
147        let storage: Storage<_, f64, _> = unsafe { device.empty_impl(10).unwrap() };
148        println!("{storage:?}");
149        let storage = device.from_cpu_vec(&[1.0; 10]).unwrap();
150        println!("{storage:?}");
151        let storage = device.outof_cpu_vec(vec![1.0; 10]).unwrap();
152        println!("{storage:?}");
153        let storage = device.linspace_impl(0.0, 1.0, 10, true).unwrap();
154        println!("{storage:?}");
155        let storage = device.linspace_impl(Complex::new(1.0, 2.0), Complex::new(3.5, 4.7), 10, true).unwrap();
156        println!("{storage:?}");
157        let storage = device.arange_impl(0.0, 1.0, 0.1).unwrap();
158        println!("{storage:?}");
159
160        // tril/triu
161        let mut vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
162        let layout = [3, 3].c();
163        device.tril_impl(&mut vec, &layout, -1).unwrap();
164        println!("{vec:?}");
165        assert_eq!(vec, vec![0, 0, 0, 4, 0, 0, 7, 8, 0]);
166        let mut vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
167        device.triu_impl(&mut vec, &layout, -1).unwrap();
168        println!("{vec:?}");
169        assert_eq!(vec, vec![1, 2, 3, 4, 5, 6, 0, 8, 9]);
170    }
171}