rstsr_openblas/linalg_auto_impl/
svd.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region full-args */
7
8impl<T, D, R> SVDAPI<DeviceBLAS> for (&TensorAny<R, T, DeviceBLAS, D>, bool)
9where
10    R: DataAPI<Data = Vec<T>>,
11    T: BlasFloat,
12    D: DimAPI + DimSmallerOneAPI,
13    D::SmallerOne: DimAPI,
14    DeviceBLAS: LapackDriverAPI<T>,
15{
16    type Out =
17        SVDResult<Tensor<T, DeviceBLAS, D>, Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
18    fn svd_f(self) -> Result<Self::Out> {
19        let (a, full_matrices) = self;
20        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
21        let a = a.view().into_dim::<Ix2>();
22        let svd_args = SVDArgs::default().a(a).full_matrices(full_matrices).build()?;
23        let (u, s, vt) = ref_impl_svd_simple_f(svd_args)?;
24        // convert dimensions
25        let u = u.unwrap().into_dim::<IxD>().into_dim::<D>();
26        let vt = vt.unwrap().into_dim::<IxD>().into_dim::<D>();
27        let s = s.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28        Ok(SVDResult { u, s, vt })
29    }
30}
31
32#[duplicate_item(
33    Tr; [Tensor<T, DeviceBLAS, D>]; [TensorView<'_, T, DeviceBLAS, D>];
34)]
35impl<T, D> SVDAPI<DeviceBLAS> for (Tr, bool)
36where
37    T: BlasFloat,
38    D: DimAPI + DimSmallerOneAPI,
39    D::SmallerOne: DimAPI,
40    DeviceBLAS: LapackDriverAPI<T>,
41{
42    type Out =
43        SVDResult<Tensor<T, DeviceBLAS, D>, Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
44    fn svd_f(self) -> Result<Self::Out> {
45        let (a, full_matrices) = self;
46        SVDAPI::<DeviceBLAS>::svd_f((&a, full_matrices))
47    }
48}
49
50/* #endregion */
51
52/* #region sub-args */
53
54#[duplicate_item(
55    ImplType                              Tr;
56   ['a, T, D, R: DataAPI<Data = Vec<T>>] [&'a TensorAny<R, T, DeviceBLAS, D>];
57   ['a, T, D,                          ] [TensorView<'a, T, DeviceBLAS, D>  ];
58   [    T, D                           ] [Tensor<T, DeviceBLAS, D>          ];
59)]
60impl<ImplType> SVDAPI<DeviceBLAS> for Tr
61where
62    T: BlasFloat,
63    D: DimAPI,
64    (Tr, bool): SVDAPI<DeviceBLAS>,
65{
66    type Out = <(Tr, bool) as SVDAPI<DeviceBLAS>>::Out;
67    fn svd_f(self) -> Result<Self::Out> {
68        let a = self;
69        SVDAPI::<DeviceBLAS>::svd_f((a, true))
70    }
71}
72
73/* #endregion */
74
75/* #region SVDArgs implementation */
76
77impl<'a, T> SVDAPI<DeviceBLAS> for SVDArgs<'a, DeviceBLAS, T>
78where
79    T: BlasFloat,
80    DeviceBLAS: LapackDriverAPI<T>,
81{
82    type Out = SVDResult<Tensor<T, DeviceBLAS, Ix2>, Tensor<T::Real, DeviceBLAS, Ix1>, Tensor<T, DeviceBLAS, Ix2>>;
83    fn svd_f(self) -> Result<Self::Out> {
84        SVDAPI::<DeviceBLAS>::svd_f(self.build()?)
85    }
86}
87
88impl<'a, T> SVDAPI<DeviceBLAS> for SVDArgs_<'a, DeviceBLAS, T>
89where
90    T: BlasFloat,
91    DeviceBLAS: LapackDriverAPI<T>,
92{
93    type Out = SVDResult<Tensor<T, DeviceBLAS, Ix2>, Tensor<T::Real, DeviceBLAS, Ix1>, Tensor<T, DeviceBLAS, Ix2>>;
94    fn svd_f(self) -> Result<Self::Out> {
95        let args = self;
96        rstsr_assert!(
97            args.full_matrices.is_some(),
98            InvalidValue,
99            "`svd` must compute UV. Refer to `svdvals` if UV is not required."
100        )?;
101        let (u, s, vt) = ref_impl_svd_simple_f(args)?;
102        let u = u.unwrap().into_dim::<IxD>().into_dim::<Ix2>();
103        let vt = vt.unwrap().into_dim::<IxD>().into_dim::<Ix2>();
104        let s = s.into_dim::<IxD>().into_dim::<Ix1>();
105        Ok(SVDResult { u, s, vt })
106    }
107}
108
109/* #endregion */