rstsr_linalg_traits/
traits_def.rs

1use derive_builder::Builder;
2use rstsr_blas_traits::prelude::BlasFloat;
3use rstsr_core::prelude_dev::*;
4
5/* #region trait and fn definitions */
6
7#[duplicate_item(
8    LinalgAPI            func               func_f             ;
9   [CholeskyAPI       ] [cholesky        ] [cholesky_f        ];
10   [DetAPI            ] [det             ] [det_f             ];
11   [EighAPI           ] [eigh            ] [eigh_f            ];
12   [EigvalshAPI       ] [eigvalsh        ] [eigvalsh_f        ];
13   [InvAPI            ] [inv             ] [inv_f             ];
14   [PinvAPI           ] [pinv            ] [pinv_f            ];
15   [SLogDetAPI        ] [slogdet         ] [slogdet_f         ];
16   [SolveGeneralAPI   ] [solve_general   ] [solve_general_f   ];
17   [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
18   [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
19   [SVDAPI            ] [svd             ] [svd_f             ];
20   [SVDvalsAPI        ] [svdvals         ] [svdvals_f         ];
21)]
22pub trait LinalgAPI<Inp> {
23    type Out;
24    fn func_f(self) -> Result<Self::Out>;
25    fn func(self) -> Self::Out
26    where
27        Self: Sized,
28    {
29        Self::func_f(self).rstsr_unwrap()
30    }
31}
32
33#[duplicate_item(
34    LinalgAPI            func               func_f             ;
35   [CholeskyAPI       ] [cholesky        ] [cholesky_f        ];
36   [DetAPI            ] [det             ] [det_f             ];
37   [EighAPI           ] [eigh            ] [eigh_f            ];
38   [EigvalshAPI       ] [eigvalsh        ] [eigvalsh_f        ];
39   [InvAPI            ] [inv             ] [inv_f             ];
40   [PinvAPI           ] [pinv            ] [pinv_f            ];
41   [SLogDetAPI        ] [slogdet         ] [slogdet_f         ];
42   [SolveGeneralAPI   ] [solve_general   ] [solve_general_f   ];
43   [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
44   [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
45   [SVDAPI            ] [svd             ] [svd_f             ];
46   [SVDvalsAPI        ] [svdvals         ] [svdvals_f         ];
47)]
48pub fn func_f<Args, Inp>(args: Args) -> Result<<Args as LinalgAPI<Inp>>::Out>
49where
50    Args: LinalgAPI<Inp>,
51{
52    Args::func_f(args)
53}
54
55#[duplicate_item(
56    LinalgAPI            func               func_f             ;
57   [CholeskyAPI       ] [cholesky        ] [cholesky_f        ];
58   [DetAPI            ] [det             ] [det_f             ];
59   [EighAPI           ] [eigh            ] [eigh_f            ];
60   [EigvalshAPI       ] [eigvalsh        ] [eigvalsh_f        ];
61   [InvAPI            ] [inv             ] [inv_f             ];
62   [PinvAPI           ] [pinv            ] [pinv_f            ];
63   [SLogDetAPI        ] [slogdet         ] [slogdet_f         ];
64   [SolveGeneralAPI   ] [solve_general   ] [solve_general_f   ];
65   [SolveSymmetricAPI ] [solve_symmetric ] [solve_symmetric_f ];
66   [SolveTriangularAPI] [solve_triangular] [solve_triangular_f];
67   [SVDAPI            ] [svd             ] [svd_f             ];
68   [SVDvalsAPI        ] [svdvals         ] [svdvals_f         ];
69)]
70pub fn func<Args, Inp>(args: Args) -> <Args as LinalgAPI<Inp>>::Out
71where
72    Args: LinalgAPI<Inp>,
73{
74    Args::func(args)
75}
76
77/* #endregion */
78
79/* #region eigh */
80
81pub struct EighResult<W, V> {
82    pub eigenvalues: W,
83    pub eigenvectors: V,
84}
85
86impl<W, V> From<(W, V)> for EighResult<W, V> {
87    fn from((vals, vecs): (W, V)) -> Self {
88        Self { eigenvalues: vals, eigenvectors: vecs }
89    }
90}
91
92impl<W, V> From<EighResult<W, V>> for (W, V) {
93    fn from(eigh_result: EighResult<W, V>) -> Self {
94        (eigh_result.eigenvalues, eigh_result.eigenvectors)
95    }
96}
97
98#[derive(Builder)]
99#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
100pub struct EighArgs_<'a, 'b, B, T>
101where
102    T: BlasFloat,
103    B: DeviceAPI<T>,
104{
105    #[builder(setter(into))]
106    pub a: TensorReference<'a, T, B, Ix2>,
107    #[builder(setter(into, strip_option), default = "None")]
108    pub b: Option<TensorReference<'b, T, B, Ix2>>,
109
110    #[builder(setter(into), default = "None")]
111    pub uplo: Option<FlagUpLo>,
112    #[builder(setter(into), default = false)]
113    pub eigvals_only: bool,
114    #[builder(setter(into), default = 1)]
115    pub eig_type: i32,
116    #[builder(setter(into, strip_option), default = "None")]
117    pub subset_by_index: Option<(usize, usize)>,
118    #[builder(setter(into, strip_option), default = "None")]
119    pub subset_by_value: Option<(T::Real, T::Real)>,
120    #[builder(setter(into, strip_option), default = "None")]
121    pub driver: Option<&'static str>,
122}
123
124pub type EighArgs<'a, 'b, B, T> = EighArgs_Builder<'a, 'b, B, T>;
125
126/* #endregion */
127
128/* #region pinv */
129
130pub struct PinvResult<T> {
131    pub pinv: T,
132    pub rank: usize,
133}
134
135impl<T> From<(T, usize)> for PinvResult<T> {
136    fn from((pinv, rank): (T, usize)) -> Self {
137        Self { pinv, rank }
138    }
139}
140
141impl<T> From<PinvResult<T>> for (T, usize) {
142    fn from(pinv_result: PinvResult<T>) -> Self {
143        (pinv_result.pinv, pinv_result.rank)
144    }
145}
146
147/* #endregion */
148
149/* #region slogdet */
150
151pub struct SLogDetResult<T>
152where
153    T: BlasFloat,
154{
155    pub sign: T,
156    pub logabsdet: T::Real,
157}
158
159impl<T> From<(T, T::Real)> for SLogDetResult<T>
160where
161    T: BlasFloat,
162{
163    fn from((sign, logabsdet): (T, T::Real)) -> Self {
164        Self { sign, logabsdet }
165    }
166}
167
168impl<T> From<SLogDetResult<T>> for (T, T::Real)
169where
170    T: BlasFloat,
171{
172    fn from(slogdet_result: SLogDetResult<T>) -> Self {
173        (slogdet_result.sign, slogdet_result.logabsdet)
174    }
175}
176
177/* #endregion */
178
179/* #region svd */
180
181pub struct SVDResult<U, S, Vt> {
182    pub u: U,
183    pub s: S,
184    pub vt: Vt,
185}
186
187impl<U, S, Vt> From<(U, S, Vt)> for SVDResult<U, S, Vt> {
188    fn from((u, s, vt): (U, S, Vt)) -> Self {
189        Self { u, s, vt }
190    }
191}
192
193impl<U, S, Vt> From<SVDResult<U, S, Vt>> for (U, S, Vt) {
194    fn from(svd_result: SVDResult<U, S, Vt>) -> Self {
195        (svd_result.u, svd_result.s, svd_result.vt)
196    }
197}
198
199#[derive(Builder)]
200#[builder(pattern = "owned", no_std, build_fn(error = "Error"))]
201pub struct SVDArgs_<'a, B, T>
202where
203    T: BlasFloat,
204    B: DeviceAPI<T>,
205{
206    #[builder(setter(into))]
207    pub a: TensorReference<'a, T, B, Ix2>,
208    #[builder(setter(into), default = "Some(true)")]
209    pub full_matrices: Option<bool>,
210    #[builder(setter(into), default = "None")]
211    pub driver: Option<&'static str>,
212}
213
214pub type SVDArgs<'a, B, T> = SVDArgs_Builder<'a, B, T>;
215
216/* #endregion */