1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6#[duplicate_item(
9 ImplType Tr ;
10 [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
11 [T, D ] [TensorView<'_, T, DeviceBLAS, D>];
12)]
13impl<ImplType> EighAPI<DeviceBLAS> for (Tr, FlagUpLo)
14where
15 T: BlasFloat,
16 D: DimAPI + DimSmallerOneAPI,
17 D::SmallerOne: DimAPI,
18 DeviceBLAS: LapackDriverAPI<T>,
19{
20 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
21 fn eigh_f(self) -> Result<Self::Out> {
22 let (a, uplo) = self;
23 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24 let a_view = a.view().into_dim::<Ix2>();
25 let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?;
26 let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
27 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28 let vecs = vecs.unwrap().into_owned().into_dim::<IxD>().into_dim::<D>();
29 return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs });
30 }
31}
32
33#[duplicate_item(
34 ImplType Tr ;
35 [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
36 [T, D ] [TensorView<'_, T, DeviceBLAS, D>];
37)]
38impl<ImplType> EighAPI<DeviceBLAS> for Tr
39where
40 T: BlasFloat,
41 D: DimAPI + DimSmallerOneAPI,
42 D::SmallerOne: DimAPI,
43 DeviceBLAS: LapackDriverAPI<T>,
44{
45 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
46 fn eigh_f(self) -> Result<Self::Out> {
47 let a = self;
48 let uplo = match a.device().default_order() {
49 RowMajor => Lower,
50 ColMajor => Upper,
51 };
52 EighAPI::<DeviceBLAS>::eigh_f((a, uplo))
53 }
54}
55
56#[duplicate_item(
57 ImplType Tr ;
58 ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
59 [ T, D] [Tensor<T, DeviceBLAS, D> ];
60)]
61impl<ImplType> EighAPI<DeviceBLAS> for (Tr, FlagUpLo)
62where
63 T: BlasFloat,
64 D: DimAPI + DimSmallerOneAPI,
65 D::SmallerOne: DimAPI,
66 DeviceBLAS: LapackDriverAPI<T>,
67{
68 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tr>;
69 fn eigh_f(self) -> Result<Self::Out> {
70 let (mut a, uplo) = self;
71 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
72 let a_view = a.view_mut().into_dim::<Ix2>();
73 let eigh_args = EighArgs::default().a(a_view).uplo(uplo).build()?;
74 let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
75 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
76 vecs.unwrap().clone_to_mut();
77 return Ok(EighResult { eigenvalues: vals, eigenvectors: a });
78 }
79}
80
81#[duplicate_item(
82 ImplType Tr ;
83 ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
84 [ T, D] [Tensor<T, DeviceBLAS, D> ];
85)]
86impl<ImplType> EighAPI<DeviceBLAS> for Tr
87where
88 T: BlasFloat,
89 D: DimAPI + DimSmallerOneAPI,
90 D::SmallerOne: DimAPI,
91 DeviceBLAS: LapackDriverAPI<T>,
92{
93 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tr>;
94 fn eigh_f(self) -> Result<Self::Out> {
95 let a = self;
96 let uplo = match a.device().default_order() {
97 RowMajor => Lower,
98 ColMajor => Upper,
99 };
100 EighAPI::<DeviceBLAS>::eigh_f((a, uplo))
101 }
102}
103
104#[duplicate_item(
109 ImplType TrA TrB ;
110 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
111 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
112 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
113 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
114)]
115impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo, i32)
116where
117 T: BlasFloat,
118 D: DimAPI + DimSmallerOneAPI,
119 D::SmallerOne: DimAPI,
120 DeviceBLAS: LapackDriverAPI<T>,
121{
122 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
123 fn eigh_f(self) -> Result<Self::Out> {
124 let (a, b, uplo, eig_type) = self;
125 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
126 rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
127 rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?;
128 let a_view = a.view().into_dim::<Ix2>();
129 let b_view = b.view().into_dim::<Ix2>();
130 let eigh_args = EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).build()?;
131 let (vals, vecs) = ref_impl_eigh_simple_f(eigh_args)?;
132 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
133 let vecs = vecs.unwrap().into_owned().into_dim::<IxD>().into_dim::<D>();
134 return Ok(EighResult { eigenvalues: vals, eigenvectors: vecs });
135 }
136}
137
138#[duplicate_item(
139 ImplType TrA TrB ;
140 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
141 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
142 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
143 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
144)]
145impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo)
146where
147 T: BlasFloat,
148 D: DimAPI + DimSmallerOneAPI,
149 D::SmallerOne: DimAPI,
150 DeviceBLAS: LapackDriverAPI<T>,
151{
152 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
153 fn eigh_f(self) -> Result<Self::Out> {
154 let (a, b, uplo) = self;
155 EighAPI::<DeviceBLAS>::eigh_f((a, b, uplo, 1))
156 }
157}
158
159#[duplicate_item(
160 ImplType TrA TrB ;
161 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
162 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
163 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
164 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
165)]
166impl<ImplType> EighAPI<DeviceBLAS> for (TrA, TrB)
167where
168 T: BlasFloat,
169 D: DimAPI + DimSmallerOneAPI,
170 D::SmallerOne: DimAPI,
171 DeviceBLAS: LapackDriverAPI<T>,
172{
173 type Out = EighResult<Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
174 fn eigh_f(self) -> Result<Self::Out> {
175 let (a, b) = self;
176 let uplo = match a.device().default_order() {
177 RowMajor => Lower,
178 ColMajor => Upper,
179 };
180 EighAPI::<DeviceBLAS>::eigh_f((a, b, uplo, 1))
181 }
182}
183
184impl<'a, 'b, T> EighAPI<DeviceBLAS> for EighArgs<'a, 'b, DeviceBLAS, T>
189where
190 T: BlasFloat,
191 DeviceBLAS: LapackDriverAPI<T>,
192{
193 type Out = EighResult<Tensor<T::Real, DeviceBLAS, Ix1>, TensorMutable<'a, T, DeviceBLAS, Ix2>>;
194 fn eigh_f(self) -> Result<Self::Out> {
195 let args = self.build()?;
196 rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?;
197 let (vals, vecs) = ref_impl_eigh_simple_f(args)?;
198 Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() })
199 }
200}
201
202impl<'a, 'b, T> EighAPI<DeviceBLAS> for EighArgs_<'a, 'b, DeviceBLAS, T>
203where
204 T: BlasFloat,
205 DeviceBLAS: LapackDriverAPI<T>,
206{
207 type Out = EighResult<Tensor<T::Real, DeviceBLAS, Ix1>, TensorMutable<'a, T, DeviceBLAS, Ix2>>;
208 fn eigh_f(self) -> Result<Self::Out> {
209 let args = self;
210 rstsr_assert!(!args.eigvals_only, InvalidValue, "Eigh only supports eigvals_only = false.")?;
211 let (vals, vecs) = ref_impl_eigh_simple_f(args)?;
212 Ok(EighResult { eigenvalues: vals, eigenvectors: vecs.unwrap() })
213 }
214}
215
216