rstsr_openblas/linalg_auto_impl/
eigh.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region simple eigh */
7
8#[duplicate_item(
9    ImplType                          Tr                               ;
10   [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
11   [T, D                           ] [TensorView<'_, T, DeviceBLAS, D>];
12)]
13impl<ImplType> EighAPI<DeviceBLAS> for (Tr, FlagUpLo)
14where
15    T: BlasFloat,
16    D: DimAPI + DimSmallerOneAPI,
17    D::SmallerOne: DimAPI,
18    DeviceBLAS: LapackDriverAPI<T>,
19{
20    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
21    fn eigh_f(self) -> Result<Self::Out> {
22        let (a, uplo) = self;
23        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24        let a_view = a.view().into_dim::<Ix2>();
25        let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?;
26        let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
27        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28        let vecs = vecs.unwrap().into_owned().into_dim::<IxD>().into_dim::<D>();
29        return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs });
30    }
31}
32
33#[duplicate_item(
34    ImplType                          Tr                               ;
35   [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
36   [T, D                           ] [TensorView<'_, T, DeviceBLAS, D>];
37)]
38impl<ImplType> EighAPI<DeviceBLAS> for Tr
39where
40    T: BlasFloat,
41    D: DimAPI + DimSmallerOneAPI,
42    D::SmallerOne: DimAPI,
43    DeviceBLAS: LapackDriverAPI<T>,
44{
45    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
46    fn eigh_f(self) -> Result<Self::Out> {
47        let a = self;
48        let uplo = match a.device().default_order() {
49            RowMajor => Lower,
50            ColMajor => Upper,
51        };
52        EighAPI::<DeviceBLAS>::eigh_f((a, uplo))
53    }
54}
55
56#[duplicate_item(
57    ImplType   Tr                              ;
58   ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
59   [    T, D] [Tensor<T, DeviceBLAS, D>       ];
60)]
61impl<ImplType> EighAPI<DeviceBLAS> for (Tr, FlagUpLo)
62where
63    T: BlasFloat,
64    D: DimAPI + DimSmallerOneAPI,
65    D::SmallerOne: DimAPI,
66    DeviceBLAS: LapackDriverAPI<T>,
67{
68    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tr>;
69    fn eigh_f(self) -> Result<Self::Out> {
70        let (mut a, uplo) = self;
71        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
72        let a_view = a.view_mut().into_dim::<Ix2>();
73        let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?;
74        let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
75        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
76        vecs.unwrap().clone_to_mut();
77        return Ok(EighResult { eigenvalues: vals, eigenvectors: a });
78    }
79}
80
81#[duplicate_item(
82    ImplType   Tr                              ;
83   ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
84   [    T, D] [Tensor<T, DeviceBLAS, D>       ];
85)]
86impl<ImplType> EighAPI<DeviceBLAS> for Tr
87where
88    T: BlasFloat,
89    D: DimAPI + DimSmallerOneAPI,
90    D::SmallerOne: DimAPI,
91    DeviceBLAS: LapackDriverAPI<T>,
92{
93    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tr>;
94    fn eigh_f(self) -> Result<Self::Out> {
95        let a = self;
96        let uplo = match a.device().default_order() {
97            RowMajor => Lower,
98            ColMajor => Upper,
99        };
100        EighAPI::<DeviceBLAS>::eigh_f((a, uplo))
101    }
102}
103
104/* #endregion */
105
106/* #region general eigh */
107
108#[duplicate_item(
109    ImplType                                                       TrA                                TrB                              ;
110   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
111   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
112   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
113   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
114)]
115impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo, i32)
116where
117    T: BlasFloat,
118    D: DimAPI + DimSmallerOneAPI,
119    D::SmallerOne: DimAPI,
120    DeviceBLAS: LapackDriverAPI<T>,
121{
122    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
123    fn eigh_f(self) -> Result<Self::Out> {
124        let (a, b, uplo, eig_type) = self;
125        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
126        rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
127        rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?;
128        let a_view = a.view().into_dim::<Ix2>();
129        let b_view = b.view().into_dim::<Ix2>();
130        let eigh_args = EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).build()?;
131        let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
132        let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
133        let vecs = vecs.unwrap().into_owned().into_dim::<IxD>().into_dim::<D>();
134        return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs });
135    }
136}
137
138#[duplicate_item(
139    ImplType                                                       TrA                                TrB                              ;
140   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
141   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
142   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
143   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
144)]
145impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo)
146where
147    T: BlasFloat,
148    D: DimAPI + DimSmallerOneAPI,
149    D::SmallerOne: DimAPI,
150    DeviceBLAS: LapackDriverAPI<T>,
151{
152    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
153    fn eigh_f(self) -> Result<Self::Out> {
154        let (a, b, uplo) = self;
155        EighAPI::<DeviceBLAS>::eigh_f((a, b, uplo, 1))
156    }
157}
158
159#[duplicate_item(
160    ImplType                                                       TrA                                TrB                              ;
161   [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
162   [T, D, R: DataAPI<Data = Vec<T>>                             ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
163   [T, D, R: DataAPI<Data = Vec<T>>                             ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
164   [T, D,                                                       ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
165)]
166impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB)
167where
168    T: BlasFloat,
169    D: DimAPI + DimSmallerOneAPI,
170    D::SmallerOne: DimAPI,
171    DeviceBLAS: LapackDriverAPI<T>,
172{
173    type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
174    fn eigh_f(self) -> Result<Self::Out> {
175        let (a, b) = self;
176        let uplo = match a.device().default_order() {
177            RowMajor => Lower,
178            ColMajor => Upper,
179        };
180        EighAPI::<DeviceBLAS>::eigh_f((a, b, uplo, 1))
181    }
182}
183
184/* #endregion */
185
186/* #region EighArgs implementation */
187
188impl<'a, 'b, T> EighAPI<DeviceBLAS> for EighArgs<'a, 'b, DeviceBLAS, T>
189where
190    T: BlasFloat,
191    DeviceBLAS: LapackDriverAPI<T>,
192{
193    type Out = EighResult<Tensor<T::Real, DeviceBLAS, Ix1>, TensorMutable<'a, T, DeviceBLAS, Ix2>>;
194    fn eigh_f(self) -> Result<Self::Out> {
195        let args = self.build()?;
196        rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?;
197        let (vals, vecs) = ref_impl_eigh_simple_f(args)?;
198        Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() })
199    }
200}
201
202impl<'a, 'b, T> EighAPI<DeviceBLAS> for EighArgs_<'a, 'b, DeviceBLAS, T>
203where
204    T: BlasFloat,
205    DeviceBLAS: LapackDriverAPI<T>,
206{
207    type Out = EighResult<Tensor<T::Real, DeviceBLAS, Ix1>, TensorMutable<'a, T, DeviceBLAS, Ix2>>;
208    fn eigh_f(self) -> Result<Self::Out> {
209        let args = self;
210        rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?;
211        let (vals, vecs) = ref_impl_eigh_simple_f(args)?;
212        Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() })
213    }
214}
215
216/* #endregion */