rstsr_openblas/linalg_auto_impl/
eigvalsh.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region simple eigh */
7
8#[duplicate_item(
9    ImplType                          Tr                               ;
10   [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
11   [T, D                           ] [TensorView<'_, T, DeviceBLAS, D>];
12)]
13impl<ImplType> EigvalshAPI<DeviceBLAS> for (Tr, FlagUpLo)
14where
15    T: BlasFloat,
16    D: DimAPI + DimSmallerOneAPI,
17    D::SmallerOne: DimAPI,
18    DeviceBLAS: LapackDriverAPI<T>,
19{
20    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
21    fn eigvalsh_f(self) -> Result<Self::Out> {
22        let (a, uplo) = self;
23        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24        let a_view = a.view().into_dim::<Ix2>();
25        let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?;
26        let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
27        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28        return Ok(vals);
29    }
30}
31
32#[duplicate_item(
33    ImplType                          Tr                               ;
34   [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
35   [T, D                           ] [TensorView<'_, T, DeviceBLAS, D>];
36)]
37impl<ImplType> EigvalshAPI<DeviceBLAS> for Tr
38where
39    T: BlasFloat,
40    D: DimAPI + DimSmallerOneAPI,
41    D::SmallerOne: DimAPI,
42    DeviceBLAS: LapackDriverAPI<T>,
43{
44    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
45    fn eigvalsh_f(self) -> Result<Self::Out> {
46        let a = self;
47        let uplo = match a.device().default_order() {
48            RowMajor => Lower,
49            ColMajor => Upper,
50        };
51        EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, uplo))
52    }
53}
54
55#[duplicate_item(
56    ImplType   Tr                              ;
57   ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
58   [    T, D] [Tensor<T, DeviceBLAS, D>       ];
59)]
60impl<ImplType> EigvalshAPI<DeviceBLAS> for (Tr, FlagUpLo)
61where
62    T: BlasFloat,
63    D: DimAPI + DimSmallerOneAPI,
64    D::SmallerOne: DimAPI,
65    DeviceBLAS: LapackDriverAPI<T>,
66{
67    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
68    fn eigvalsh_f(self) -> Result<Self::Out> {
69        let (mut a, uplo) = self;
70        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
71        let a_view = a.view_mut().into_dim::<Ix2>();
72        let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?;
73        let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
74        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
75        return Ok(vals);
76    }
77}
78
79#[duplicate_item(
80    ImplType   Tr                              ;
81   ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
82   [    T, D] [Tensor<T, DeviceBLAS, D>       ];
83)]
84impl<ImplType> EigvalshAPI<DeviceBLAS> for Tr
85where
86    T: BlasFloat,
87    D: DimAPI + DimSmallerOneAPI,
88    D::SmallerOne: DimAPI,
89    DeviceBLAS: LapackDriverAPI<T>,
90{
91    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
92    fn eigvalsh_f(self) -> Result<Self::Out> {
93        let a = self;
94        let uplo = match a.device().default_order() {
95            RowMajor => Lower,
96            ColMajor => Upper,
97        };
98        EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, uplo))
99    }
100}
101
102/* #endregion */
103
104/* #region general eigh */
105
106#[duplicate_item(
107    ImplType                                                       TrA                                TrB                              ;
108   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
109   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
110   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
111   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
112)]
113impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo, i32)
114where
115    T: BlasFloat,
116    D: DimAPI + DimSmallerOneAPI,
117    D::SmallerOne: DimAPI,
118    DeviceBLAS: LapackDriverAPI<T>,
119{
120    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
121    fn eigvalsh_f(self) -> Result<Self::Out> {
122        let (a, b, uplo, eig_type) = self;
123        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
124        rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
125        rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?;
126        let a_view = a.view().into_dim::<Ix2>();
127        let b_view = b.view().into_dim::<Ix2>();
128        let eigh_args =
129            EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).eigvals_only(true).build()?;
130        let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
131        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
132        return Ok(vals);
133    }
134}
135
136#[duplicate_item(
137    ImplType                                                       TrA                                TrB                              ;
138   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
139   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
140   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
141   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
142)]
143impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo)
144where
145    T: BlasFloat,
146    D: DimAPI + DimSmallerOneAPI,
147    D::SmallerOne: DimAPI,
148    DeviceBLAS: LapackDriverAPI<T>,
149{
150    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
151    fn eigvalsh_f(self) -> Result<Self::Out> {
152        let (a, b, uplo) = self;
153        EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, b, uplo, 1))
154    }
155}
156
157#[duplicate_item(
158    ImplType                                                       TrA                                TrB                              ;
159   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
160   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
161   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
162   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
163)]
164impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB)
165where
166    T: BlasFloat,
167    D: DimAPI + DimSmallerOneAPI,
168    D::SmallerOne: DimAPI,
169    DeviceBLAS: LapackDriverAPI<T>,
170{
171    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
172    fn eigvalsh_f(self) -> Result<Self::Out> {
173        let (a, b) = self;
174        let uplo = match a.device().default_order() {
175            RowMajor => Lower,
176            ColMajor => Upper,
177        };
178        EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, b, uplo, 1))
179    }
180}
181
182/* #endregion */
183
184/* #region EighArgs implementation */
185
186impl<'a, 'b, T> EigvalshAPI<DeviceBLAS> for EighArgs<'a, 'b, DeviceBLAS, T>
187where
188    T: BlasFloat,
189    DeviceBLAS: LapackDriverAPI<T>,
190{
191    type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
192    fn eigvalsh_f(self) -> Result<Self::Out> {
193        let args = self.build()?;
194        rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?;
195        let (vals, _) = ref_impl_eigh_simple_f(args)?;
196        Ok(vals)
197    }
198}
199
200impl<'a, 'b, T> EigvalshAPI<DeviceBLAS> for EighArgs_<'a, 'b, DeviceBLAS, T>
201where
202    T: BlasFloat,
203    DeviceBLAS: LapackDriverAPI<T>,
204{
205    type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
206    fn eigvalsh_f(self) -> Result<Self::Out> {
207        let args = self;
208        rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?;
209        let (vals, _) = ref_impl_eigh_simple_f(args)?;
210        Ok(vals)
211    }
212}
213
214/* #endregion */