1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6#[duplicate_item(
9 ImplType Tr ;
10 [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
11 [T, D ] [TensorView<'_, T, DeviceBLAS, D>];
12)]
13impl<ImplType> EigvalshAPI<DeviceBLAS> for (Tr, FlagUpLo)
14where
15 T: BlasFloat,
16 D: DimAPI + DimSmallerOneAPI,
17 D::SmallerOne: DimAPI,
18 DeviceBLAS: LapackDriverAPI<T>,
19{
20 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
21 fn eigvalsh_f(self) -> Result<Self::Out> {
22 let (a, uplo) = self;
23 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24 let a_view = a.view().into_dim::<Ix2>();
25 let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?;
26 let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
27 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28 return Ok(vals);
29 }
30}
31
32#[duplicate_item(
33 ImplType Tr ;
34 [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
35 [T, D ] [TensorView<'_, T, DeviceBLAS, D>];
36)]
37impl<ImplType> EigvalshAPI<DeviceBLAS> for Tr
38where
39 T: BlasFloat,
40 D: DimAPI + DimSmallerOneAPI,
41 D::SmallerOne: DimAPI,
42 DeviceBLAS: LapackDriverAPI<T>,
43{
44 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
45 fn eigvalsh_f(self) -> Result<Self::Out> {
46 let a = self;
47 let uplo = match a.device().default_order() {
48 RowMajor => Lower,
49 ColMajor => Upper,
50 };
51 EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, uplo))
52 }
53}
54
55#[duplicate_item(
56 ImplType Tr ;
57 ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
58 [ T, D] [Tensor<T, DeviceBLAS, D> ];
59)]
60impl<ImplType> EigvalshAPI<DeviceBLAS> for (Tr, FlagUpLo)
61where
62 T: BlasFloat,
63 D: DimAPI + DimSmallerOneAPI,
64 D::SmallerOne: DimAPI,
65 DeviceBLAS: LapackDriverAPI<T>,
66{
67 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
68 fn eigvalsh_f(self) -> Result<Self::Out> {
69 let (mut a, uplo) = self;
70 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
71 let a_view = a.view_mut().into_dim::<Ix2>();
72 let eigh_args = EighArgs::default().a(a_view).uplo(uplo).eigvals_only(true).build()?;
73 let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
74 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
75 return Ok(vals);
76 }
77}
78
79#[duplicate_item(
80 ImplType Tr ;
81 ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
82 [ T, D] [Tensor<T, DeviceBLAS, D> ];
83)]
84impl<ImplType> EigvalshAPI<DeviceBLAS> for Tr
85where
86 T: BlasFloat,
87 D: DimAPI + DimSmallerOneAPI,
88 D::SmallerOne: DimAPI,
89 DeviceBLAS: LapackDriverAPI<T>,
90{
91 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
92 fn eigvalsh_f(self) -> Result<Self::Out> {
93 let a = self;
94 let uplo = match a.device().default_order() {
95 RowMajor => Lower,
96 ColMajor => Upper,
97 };
98 EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, uplo))
99 }
100}
101
102#[duplicate_item(
107 ImplType TrA TrB ;
108 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
109 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
110 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
111 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
112)]
113impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo, i32)
114where
115 T: BlasFloat,
116 D: DimAPI + DimSmallerOneAPI,
117 D::SmallerOne: DimAPI,
118 DeviceBLAS: LapackDriverAPI<T>,
119{
120 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
121 fn eigvalsh_f(self) -> Result<Self::Out> {
122 let (a, b, uplo, eig_type) = self;
123 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
124 rstsr_assert_eq!(b.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
125 rstsr_pattern!(eig_type, 1..=3, InvalidLayout, "Only eig_type = 1, 2, or 3 allowed.")?;
126 let a_view = a.view().into_dim::<Ix2>();
127 let b_view = b.view().into_dim::<Ix2>();
128 let eigh_args =
129 EighArgs::default().a(a_view).b(b_view).uplo(uplo).eig_type(eig_type).eigvals_only(true).build()?;
130 let (vals, _) = ref_impl_eigh_simple_f(eigh_args)?;
131 let vals = vals.into_dim::<IxD>().into_dim::<D::SmallerOne>();
132 return Ok(vals);
133 }
134}
135
136#[duplicate_item(
137 ImplType TrA TrB ;
138 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
139 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
140 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
141 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
142)]
143impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB, FlagUpLo)
144where
145 T: BlasFloat,
146 D: DimAPI + DimSmallerOneAPI,
147 D::SmallerOne: DimAPI,
148 DeviceBLAS: LapackDriverAPI<T>,
149{
150 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
151 fn eigvalsh_f(self) -> Result<Self::Out> {
152 let (a, b, uplo) = self;
153 EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, b, uplo, 1))
154 }
155}
156
157#[duplicate_item(
158 ImplType TrA TrB ;
159 [T, D, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, D>] [&TensorAny<Rb, T, DeviceBLAS, D>];
160 [T, D, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, D> ] [TensorView<'_, T, DeviceBLAS, D>];
161 [T, D, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, D>] [&TensorAny<R, T, DeviceBLAS, D> ];
162 [T, D, ] [TensorView<'_, T, DeviceBLAS, D>] [TensorView<'_, T, DeviceBLAS, D>];
163)]
164impl<ImplType> EigvalshAPI<DeviceBLAS> for (TrA, TrB)
165where
166 T: BlasFloat,
167 D: DimAPI + DimSmallerOneAPI,
168 D::SmallerOne: DimAPI,
169 DeviceBLAS: LapackDriverAPI<T>,
170{
171 type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
172 fn eigvalsh_f(self) -> Result<Self::Out> {
173 let (a, b) = self;
174 let uplo = match a.device().default_order() {
175 RowMajor => Lower,
176 ColMajor => Upper,
177 };
178 EigvalshAPI::<DeviceBLAS>::eigvalsh_f((a, b, uplo, 1))
179 }
180}
181
182impl<'a, 'b, T> EigvalshAPI<DeviceBLAS> for EighArgs<'a, 'b, DeviceBLAS, T>
187where
188 T: BlasFloat,
189 DeviceBLAS: LapackDriverAPI<T>,
190{
191 type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
192 fn eigvalsh_f(self) -> Result<Self::Out> {
193 let args = self.build()?;
194 rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?;
195 let (vals, _) = ref_impl_eigh_simple_f(args)?;
196 Ok(vals)
197 }
198}
199
200impl<'a, 'b, T> EigvalshAPI<DeviceBLAS> for EighArgs_<'a, 'b, DeviceBLAS, T>
201where
202 T: BlasFloat,
203 DeviceBLAS: LapackDriverAPI<T>,
204{
205 type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
206 fn eigvalsh_f(self) -> Result<Self::Out> {
207 let args = self;
208 rstsr_assert!(args.eigvals_only, InvalidValue, "Eigvalsh only supports eigvals_only = true.")?;
209 let (vals, _) = ref_impl_eigh_simple_f(args)?;
210 Ok(vals)
211 }
212}
213
214