rstsr_openblas/
device.rs

1use crate::prelude_dev::*;
2use num::{complex::ComplexFloat, Num};
3use rstsr_dtype_traits::DTypeIntoFloatAPI;
4
5impl DeviceBLAS {
6    pub fn new(num_threads: usize) -> Self {
7        DeviceBLAS { base: DeviceCpuRayon::new(num_threads) }
8    }
9}
10
11impl DeviceRayonAPI for DeviceBLAS {
12    #[inline]
13    fn set_num_threads(&mut self, num_threads: usize) {
14        self.base.set_num_threads(num_threads);
15    }
16
17    #[inline]
18    fn get_num_threads(&self) -> usize {
19        self.base.get_num_threads()
20    }
21
22    #[inline]
23    fn get_pool(&self) -> &ThreadPool {
24        self.base.get_pool()
25    }
26
27    #[inline]
28    fn get_current_pool(&self) -> Option<&ThreadPool> {
29        self.base.get_current_pool()
30    }
31}
32
33impl Default for DeviceBLAS {
34    fn default() -> Self {
35        DeviceBLAS::new(0)
36    }
37}
38
39impl DeviceBaseAPI for DeviceBLAS {
40    fn same_device(&self, other: &Self) -> bool {
41        let same_num_threads = self.get_num_threads() == other.get_num_threads();
42        let same_default_order = self.default_order() == other.default_order();
43        same_num_threads && same_default_order
44    }
45
46    fn default_order(&self) -> FlagOrder {
47        self.base.default_order()
48    }
49
50    fn set_default_order(&mut self, order: FlagOrder) {
51        self.base.set_default_order(order);
52    }
53}
54
55impl<T> DeviceRawAPI<T> for DeviceBLAS {
56    type Raw = Vec<T>;
57}
58
59impl<T> DeviceStorageAPI<T> for DeviceBLAS {
60    fn len<R>(storage: &Storage<R, T, Self>) -> usize
61    where
62        R: DataAPI<Data = Self::Raw>,
63    {
64        storage.raw().len()
65    }
66
67    fn to_cpu_vec<R>(storage: &Storage<R, T, Self>) -> Result<Vec<T>>
68    where
69        Self::Raw: Clone,
70        R: DataAPI<Data = Self::Raw>,
71    {
72        Ok(storage.raw().clone())
73    }
74
75    fn into_cpu_vec<R>(storage: Storage<R, T, Self>) -> Result<Vec<T>>
76    where
77        Self::Raw: Clone,
78        R: DataCloneAPI<Data = Self::Raw>,
79    {
80        let (raw, _) = storage.into_raw_parts();
81        Ok(raw.into_owned().into_raw())
82    }
83
84    #[inline]
85    fn get_index<R>(storage: &Storage<R, T, Self>, index: usize) -> T
86    where
87        T: Clone,
88        R: DataAPI<Data = Self::Raw>,
89    {
90        storage.raw()[index].clone()
91    }
92
93    #[inline]
94    fn get_index_ptr<R>(storage: &Storage<R, T, Self>, index: usize) -> *const T
95    where
96        R: DataAPI<Data = Self::Raw>,
97    {
98        &storage.raw()[index] as *const T
99    }
100
101    #[inline]
102    fn get_index_mut_ptr<R>(storage: &mut Storage<R, T, Self>, index: usize) -> *mut T
103    where
104        R: DataMutAPI<Data = Self::Raw>,
105    {
106        storage.raw_mut().get_mut(index).unwrap() as *mut T
107    }
108
109    #[inline]
110    fn set_index<R>(storage: &mut Storage<R, T, Self>, index: usize, value: T)
111    where
112        R: DataMutAPI<Data = Self::Raw>,
113    {
114        storage.raw_mut()[index] = value;
115    }
116}
117
118impl<T> DeviceAPI<T> for DeviceBLAS {}
119
120impl<T, D> DeviceComplexFloatAPI<T, D> for DeviceBLAS
121where
122    T: ComplexFloat + DTypeIntoFloatAPI<FloatType = T> + Send + Sync,
123    T::Real: DTypeIntoFloatAPI<FloatType = T::Real> + Send + Sync,
124    D: DimAPI,
125{
126}
127
128impl<T, D> DeviceNumAPI<T, D> for DeviceBLAS
129where
130    T: Clone + Num + Send + Sync,
131    D: DimAPI,
132{
133}