rstsr_openblas/linalg_auto_impl/
solve_general.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6#[duplicate_item(
7    ImplType                                                            TrA                                 TrB                               ;
8   [T, DA, DB, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, DA>] [&TensorAny<Rb, T, DeviceBLAS, DB>];
9   [T, DA, DB, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
10   [T, DA, DB, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
11   [T, DA, DB,                                                       ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
12)]
13impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
14where
15    T: BlasFloat,
16    DA: DimAPI,
17    DB: DimAPI,
18    DeviceBLAS: LapackDriverAPI<T>,
19{
20    type Out = Tensor<T, DeviceBLAS, DB>;
21    fn solve_general_f(self) -> Result<Self::Out> {
22        let (a, b) = self;
23        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
25        let is_b_vec = b.ndim() == 1;
26        let a_view = a.view().into_dim::<Ix2>();
27        let b_view = match is_b_vec {
28            true => b.i((.., None)).into_dim::<Ix2>(),
29            false => b.view().into_dim::<Ix2>(),
30        };
31        let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
32        let result = result.into_owned().into_dim::<IxD>();
33        match is_b_vec {
34            true => Ok(result.into_shape(-1).into_dim::<DB>()),
35            false => Ok(result.into_dim::<DB>()),
36        }
37    }
38}
39
40#[duplicate_item(
41    ImplType                                   TrA                                 TrB                              ;
42   ['b, T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
43   ['b, T, DA, DB,                          ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
44   [    T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB>       ];
45   [    T, DA, DB,                          ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB>       ];
46)]
47impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
48where
49    T: BlasFloat,
50    DA: DimAPI,
51    DB: DimAPI,
52    DeviceBLAS: LapackDriverAPI<T>,
53{
54    type Out = TrB;
55    fn solve_general_f(self) -> Result<Self::Out> {
56        let (a, mut b) = self;
57        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
58        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
59        let is_b_vec = b.ndim() == 1;
60        let a_view = a.view().into_dim::<Ix2>();
61        let b_view = match is_b_vec {
62            true => b.i_mut((.., None)).into_dim::<Ix2>(),
63            false => b.view_mut().into_dim::<Ix2>(),
64        };
65        let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
66        result.clone_to_mut();
67        Ok(b)
68    }
69}
70
71#[duplicate_item(
72    ImplType                               TrA                                TrB                               ;
73   [T, DA, DB, R: DataAPI<Data = Vec<T>>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
74   [T, DA, DB,                          ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
75   [T, DA, DB, R: DataAPI<Data = Vec<T>>] [Tensor<T, DeviceBLAS, DA>       ] [&TensorAny<R, T, DeviceBLAS, DB> ];
76   [T, DA, DB,                          ] [Tensor<T, DeviceBLAS, DA>       ] [TensorView<'_, T, DeviceBLAS, DB>];
77)]
78impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
79where
80    T: BlasFloat,
81    DA: DimAPI,
82    DB: DimAPI,
83    DeviceBLAS: LapackDriverAPI<T>,
84{
85    type Out = Tensor<T, DeviceBLAS, DB>;
86    fn solve_general_f(self) -> Result<Self::Out> {
87        let (mut a, b) = self;
88        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
89        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
90        let is_b_vec = b.ndim() == 1;
91        let a_view = a.view_mut().into_dim::<Ix2>();
92        let b_view = match is_b_vec {
93            true => b.i((.., None)).into_dim::<Ix2>(),
94            false => b.view().into_dim::<Ix2>(),
95        };
96        let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
97        let result = result.into_owned().into_dim::<IxD>();
98        match is_b_vec {
99            true => Ok(result.into_shape(-1).into_dim::<DB>()),
100            false => Ok(result.into_dim::<DB>()),
101        }
102    }
103}
104
105#[duplicate_item(
106    ImplType        TrA                                TrB                              ;
107   ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
108   [    T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB>       ];
109   ['b, T, DA, DB] [Tensor<T, DeviceBLAS, DA>       ] [TensorMut<'b, T, DeviceBLAS, DB>];
110   [    T, DA, DB] [Tensor<T, DeviceBLAS, DA>       ] [Tensor<T, DeviceBLAS, DB>       ];
111)]
112impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
113where
114    T: BlasFloat,
115    DA: DimAPI,
116    DB: DimAPI,
117    DeviceBLAS: LapackDriverAPI<T>,
118{
119    type Out = TrB;
120    fn solve_general_f(self) -> Result<Self::Out> {
121        let (mut a, mut b) = self;
122        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
123        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
124        let is_b_vec = b.ndim() == 1;
125        let a_view = a.view_mut().into_dim::<Ix2>();
126        let b_view = match is_b_vec {
127            true => b.i_mut((.., None)).into_dim::<Ix2>(),
128            false => b.view_mut().into_dim::<Ix2>(),
129        };
130        let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
131        result.clone_to_mut();
132        Ok(b)
133    }
134}