rstsr_openblas/linalg_auto_impl/
solve_symmetric.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region full-args */
7
8#[duplicate_item(
9    ImplType                                                            TrA                                 TrB                               ;
10   [T, DA, DB, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, DA>] [&TensorAny<Rb, T, DeviceBLAS, DB>];
11   [T, DA, DB, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
12   [T, DA, DB, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
13   [T, DA, DB,                                                       ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
14)]
15impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
16where
17    T: BlasFloat,
18    DA: DimAPI,
19    DB: DimAPI,
20    DeviceBLAS: LapackDriverAPI<T>,
21{
22    type Out = Tensor<T, DeviceBLAS, DB>;
23    fn solve_symmetric_f(self) -> Result<Self::Out> {
24        let (a, b, hermi, uplo) = self;
25        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
26        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
27        let is_b_vec = b.ndim() == 1;
28        let a_view = a.view().into_dim::<Ix2>();
29        let b_view = match is_b_vec {
30            true => b.i((.., None)).into_dim::<Ix2>(),
31            false => b.view().into_dim::<Ix2>(),
32        };
33        let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
34        let result = result.into_owned().into_dim::<IxD>();
35        match is_b_vec {
36            true => Ok(result.into_shape(-1).into_dim::<DB>()),
37            false => Ok(result.into_dim::<DB>()),
38        }
39    }
40}
41
42#[duplicate_item(
43    ImplType                                   TrA                                 TrB                              ;
44   ['b, T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
45   ['b, T, DA, DB,                          ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
46   [    T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB>       ];
47   [    T, DA, DB,                          ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB>       ];
48)]
49impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
50where
51    T: BlasFloat,
52    DA: DimAPI,
53    DB: DimAPI,
54    DeviceBLAS: LapackDriverAPI<T>,
55{
56    type Out = TrB;
57    fn solve_symmetric_f(self) -> Result<Self::Out> {
58        let (a, mut b, hermi, uplo) = self;
59        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
60        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
61        let is_b_vec = b.ndim() == 1;
62        let a_view = a.view().into_dim::<Ix2>();
63        let b_view = match is_b_vec {
64            true => b.i_mut((.., None)).into_dim::<Ix2>(),
65            false => b.view_mut().into_dim::<Ix2>(),
66        };
67        let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
68        result.clone_to_mut();
69        Ok(b)
70    }
71}
72
73#[duplicate_item(
74    ImplType                               TrA                                TrB                               ;
75   [T, DA, DB, R: DataAPI<Data = Vec<T>>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
76   [T, DA, DB,                          ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
77   [T, DA, DB, R: DataAPI<Data = Vec<T>>] [Tensor<T, DeviceBLAS, DA>       ] [&TensorAny<R, T, DeviceBLAS, DB> ];
78   [T, DA, DB,                          ] [Tensor<T, DeviceBLAS, DA>       ] [TensorView<'_, T, DeviceBLAS, DB>];
79)]
80impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
81where
82    T: BlasFloat,
83    DA: DimAPI,
84    DB: DimAPI,
85    DeviceBLAS: LapackDriverAPI<T>,
86{
87    type Out = Tensor<T, DeviceBLAS, DB>;
88    fn solve_symmetric_f(self) -> Result<Self::Out> {
89        let (mut a, b, hermi, uplo) = self;
90        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
91        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
92        let is_b_vec = b.ndim() == 1;
93        let a_view = a.view_mut().into_dim::<Ix2>();
94        let b_view = match is_b_vec {
95            true => b.i((.., None)).into_dim::<Ix2>(),
96            false => b.view().into_dim::<Ix2>(),
97        };
98        let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
99        let result = result.into_owned().into_dim::<IxD>();
100        match is_b_vec {
101            true => Ok(result.into_shape(-1).into_dim::<DB>()),
102            false => Ok(result.into_dim::<DB>()),
103        }
104    }
105}
106
107#[duplicate_item(
108    ImplType        TrA                                TrB                              ;
109   ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
110   [    T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB>       ];
111   ['b, T, DA, DB] [Tensor<T, DeviceBLAS, DA>       ] [TensorMut<'b, T, DeviceBLAS, DB>];
112   [    T, DA, DB] [Tensor<T, DeviceBLAS, DA>       ] [Tensor<T, DeviceBLAS, DB>       ];
113)]
114impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
115where
116    T: BlasFloat,
117    DA: DimAPI,
118    DB: DimAPI,
119    DeviceBLAS: LapackDriverAPI<T>,
120{
121    type Out = TrB;
122    fn solve_symmetric_f(self) -> Result<Self::Out> {
123        let (mut a, mut b, hermi, uplo) = self;
124        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
125        rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
126        let is_b_vec = b.ndim() == 1;
127        let a_view = a.view_mut().into_dim::<Ix2>();
128        let b_view = match is_b_vec {
129            true => b.i_mut((.., None)).into_dim::<Ix2>(),
130            false => b.view_mut().into_dim::<Ix2>(),
131        };
132        let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
133        result.clone_to_mut();
134        Ok(b)
135    }
136}
137
138/* #endregion */
139
140/* #region sub-args */
141
142#[duplicate_item(
143    ImplStruct                   args_tuple            internal_tuple            ;
144   [(TrA, TrB, bool, FlagUpLo)] [(a, b, hermi, uplo)] [(a, b, hermi, Some(uplo))];
145   [(TrA, TrB, bool,         )] [(a, b, hermi,     )] [(a, b, hermi, None      )];
146   [(TrA, TrB,       FlagUpLo)] [(a, b,        uplo)] [(a, b, true , Some(uplo))];
147   [(TrA, TrB,               )] [(a, b,            )] [(a, b, true , None      )];
148)]
149impl<TrA, TrB> SolveSymmetricAPI<DeviceBLAS> for ImplStruct
150where
151    (TrA, TrB, bool, Option<FlagUpLo>): SolveSymmetricAPI<DeviceBLAS>,
152{
153    type Out = <(TrA, TrB, bool, Option<FlagUpLo>) as SolveSymmetricAPI<DeviceBLAS>>::Out;
154    fn solve_symmetric_f(self) -> Result<Self::Out> {
155        let args_tuple = self;
156        SolveSymmetricAPI::<DeviceBLAS>::solve_symmetric_f(internal_tuple)
157    }
158}
159
160/* #endregion */