1use derive_builder::Builder;
2use rstsr_blas_traits::prelude::BlasFloat;
3use rstsr_core::prelude_dev::*;
4
5#[duplicate_item(
8 LinalgAPI func func_f ;
9 [CholeskyAPI ] [cholesky ] [cholesky_f ];
10 [DetAPI ] [det ] [det_f ];
11 [EighAPI ] [eigh ] [eigh_f ];
12 [EigvalshAPI ] [eigvalsh ] [eigvalsh_f ];
13 [InvAPI ] [inv ] [inv_f ];
14 [PinvAPI ] [pinv ] [pinv_f ];
15 [SLogDetAPI ] [slogdet ] [slogdet_f ];
16 [SolveGeneralAPI ] [solve_general ] [solve_general_f ];
17 [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
18 [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
19 [SVDAPI ] [svd ] [svd_f ];
20 [SVDvalsAPI ] [svdvals ] [svdvals_f ];
21)]
22pub trait LinalgAPI<Inp> {
23 type Out;
24 fn func_f(self) -> Result<Self::Out>;
25 fn func(self) -> Self::Out
26 where
27 Self: Sized,
28 {
29 Self::func_f(self).rstsr_unwrap()
30 }
31}
32
33#[duplicate_item(
34 LinalgAPI func func_f ;
35 [CholeskyAPI ] [cholesky ] [cholesky_f ];
36 [DetAPI ] [det ] [det_f ];
37 [EighAPI ] [eigh ] [eigh_f ];
38 [EigvalshAPI ] [eigvalsh ] [eigvalsh_f ];
39 [InvAPI ] [inv ] [inv_f ];
40 [PinvAPI ] [pinv ] [pinv_f ];
41 [SLogDetAPI ] [slogdet ] [slogdet_f ];
42 [SolveGeneralAPI ] [solve_general ] [solve_general_f ];
43 [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
44 [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
45 [SVDAPI ] [svd ] [svd_f ];
46 [SVDvalsAPI ] [svdvals ] [svdvals_f ];
47)]
48pub fn func_f<Args, Inp>(args: Args) -> Result<<Args as LinalgAPI<Inp>>::Out>
49where
50 Args: LinalgAPI<Inp>,
51{
52 Args::func_f(args)
53}
54
55#[duplicate_item(
56 LinalgAPI func func_f ;
57 [CholeskyAPI ] [cholesky ] [cholesky_f ];
58 [DetAPI ] [det ] [det_f ];
59 [EighAPI ] [eigh ] [eigh_f ];
60 [EigvalshAPI ] [eigvalsh ] [eigvalsh_f ];
61 [InvAPI ] [inv ] [inv_f ];
62 [PinvAPI ] [pinv ] [pinv_f ];
63 [SLogDetAPI ] [slogdet ] [slogdet_f ];
64 [SolveGeneralAPI ] [solve_general ] [solve_general_f ];
65 [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
66 [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
67 [SVDAPI ] [svd ] [svd_f ];
68 [SVDvalsAPI ] [svdvals ] [svdvals_f ];
69)]
70pub fn func<Args, Inp>(args: Args) -> <Args as LinalgAPI<Inp>>::Out
71where
72 Args: LinalgAPI<Inp>,
73{
74 Args::func(args)
75}
76
77pub struct EighResult<W, V> {
82 pub eigenvalues: W,
83 pub eigenvectors: V,
84}
85
86impl<W, V> From<(W, V)> for EighResult<W, V> {
87 fn from((vals, vecs): (W, V)) -> Self {
88 Self { eigenvalues: vals, eigenvectors: vecs }
89 }
90}
91
92impl<W, V> From<EighResult<W, V>> for (W, V) {
93 fn from(eigh_result: EighResult<W, V>) -> Self {
94 (eigh_result.eigenvalues, eigh_result.eigenvectors)
95 }
96}
97
98#[derive(Builder)]
99#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
100pub struct EighArgs_<'a, 'b, B, T>
101where
102 T: BlasFloat,
103 B: DeviceAPI<T>,
104{
105 #[builder(setter(into))]
106 pub a: TensorReference<'a, T, B, Ix2>,
107 #[builder(setter(into, strip_option), default = "None")]
108 pub b: Option<TensorReference<'b, T, B, Ix2>>,
109
110 #[builder(setter(into), default = "None")]
111 pub uplo: Option<FlagUpLo>,
112 #[builder(setter(into), default = false)]
113 pub eigvals_only: bool,
114 #[builder(setter(into), default = 1)]
115 pub eig_type: i32,
116 #[builder(setter(into, strip_option), default = "None")]
117 pub subset_by_index: Option<(usize, usize)>,
118 #[builder(setter(into, strip_option), default = "None")]
119 pub subset_by_value: Option<(T::Real, T::Real)>,
120 #[builder(setter(into, strip_option), default = "None")]
121 pub driver: Option<&'static str>,
122}
123
124pub type EighArgs<'a, 'b, B, T> = EighArgs_Builder<'a, 'b, B, T>;
125
126pub struct PinvResult<T> {
131 pub pinv: T,
132 pub rank: usize,
133}
134
135impl<T> From<(T, usize)> for PinvResult<T> {
136 fn from((pinv, rank): (T, usize)) -> Self {
137 Self { pinv, rank }
138 }
139}
140
141impl<T> From<PinvResult<T>> for (T, usize) {
142 fn from(pinv_result: PinvResult<T>) -> Self {
143 (pinv_result.pinv, pinv_result.rank)
144 }
145}
146
147pub struct SLogDetResult<T>
152where
153 T: BlasFloat,
154{
155 pub sign: T,
156 pub logabsdet: T::Real,
157}
158
159impl<T> From<(T, T::Real)> for SLogDetResult<T>
160where
161 T: BlasFloat,
162{
163 fn from((sign, logabsdet): (T, T::Real)) -> Self {
164 Self { sign, logabsdet }
165 }
166}
167
168impl<T> From<SLogDetResult<T>> for (T, T::Real)
169where
170 T: BlasFloat,
171{
172 fn from(slogdet_result: SLogDetResult<T>) -> Self {
173 (slogdet_result.sign, slogdet_result.logabsdet)
174 }
175}
176
177pub struct SVDResult<U, S, Vt> {
182 pub u: U,
183 pub s: S,
184 pub vt: Vt,
185}
186
187impl<U, S, Vt> From<(U, S, Vt)> for SVDResult<U, S, Vt> {
188 fn from((u, s, vt): (U, S, Vt)) -> Self {
189 Self { u, s, vt }
190 }
191}
192
193impl<U, S, Vt> From<SVDResult<U, S, Vt>> for (U, S, Vt) {
194 fn from(svd_result: SVDResult<U, S, Vt>) -> Self {
195 (svd_result.u, svd_result.s, svd_result.vt)
196 }
197}
198
199#[derive(Builder)]
200#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
201pub struct SVDArgs_<'a, B, T>
202where
203 T: BlasFloat,
204 B: DeviceAPI<T>,
205{
206 #[builder(setter(into))]
207 pub a: TensorReference<'a, T, B, Ix2>,
208 #[builder(setter(into), default = "Some(true)")]
209 pub full_matrices: Option<bool>,
210 #[builder(setter(into), default = "None")]
211 pub driver: Option<&'static str>,
212}
213
214pub type SVDArgs<'a, B, T> = SVDArgs_Builder<'a, B, T>;
215
216