rstsr_openblas/linalg_auto_impl/
svdvals.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region full-args */
7
8impl<T, D, R> SVDvalsAPI<DeviceBLAS> for &TensorAny<R, T, DeviceBLAS, D>
9where
10    R: DataAPI<Data = Vec<T>>,
11    T: BlasFloat,
12    D: DimAPI + DimSmallerOneAPI,
13    D::SmallerOne: DimAPI,
14    DeviceBLAS: LapackDriverAPI<T>,
15{
16    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
17    fn svdvals_f(self) -> Result<Self::Out> {
18        let a = self;
19        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
20        let a = a.view().into_dim::<Ix2>();
21        let svd_args = SVDArgs::default().a(a).full_matrices(None).build()?;
22        let (_, s, _) = ref_impl_svd_simple_f(svd_args)?;
23        // convert dimensions
24        let s = s.into_dim::<IxD>().into_dim::<D::SmallerOne>();
25        Ok(s)
26    }
27}
28
29#[duplicate_item(
30    Tr; [Tensor<T, DeviceBLAS, D>]; [TensorView<'_, T, DeviceBLAS, D>];
31)]
32impl<T, D> SVDvalsAPI<DeviceBLAS> for Tr
33where
34    T: BlasFloat,
35    D: DimAPI + DimSmallerOneAPI,
36    D::SmallerOne: DimAPI,
37    DeviceBLAS: LapackDriverAPI<T>,
38{
39    type Out = Tensor<T::Real, DeviceBLAS, D::SmallerOne>;
40    fn svdvals_f(self) -> Result<Self::Out> {
41        let a = self;
42        SVDvalsAPI::<DeviceBLAS>::svdvals_f(&a)
43    }
44}
45
46/* #endregion */
47
48/* #region SVDArgs implementation */
49
50impl<'a, T> SVDvalsAPI<DeviceBLAS> for SVDArgs<'a, DeviceBLAS, T>
51where
52    T: BlasFloat,
53    DeviceBLAS: LapackDriverAPI<T>,
54{
55    type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
56    fn svdvals_f(self) -> Result<Self::Out> {
57        SVDvalsAPI::<DeviceBLAS>::svdvals_f(self.build()?)
58    }
59}
60
61impl<'a, T> SVDvalsAPI<DeviceBLAS> for SVDArgs_<'a, DeviceBLAS, T>
62where
63    T: BlasFloat,
64    DeviceBLAS: LapackDriverAPI<T>,
65{
66    type Out = Tensor<T::Real, DeviceBLAS, Ix1>;
67    fn svdvals_f(self) -> Result<Self::Out> {
68        let args = self;
69        rstsr_assert!(
70            args.full_matrices.is_none(),
71            InvalidValue,
72            "`svdvals` must not compute UV. Refer to `svd` if UV is required."
73        )?;
74        let (_, s, _) = ref_impl_svd_simple_f(args)?;
75        let s = s.into_dim::<IxD>().into_dim::<Ix1>();
76        Ok(s)
77    }
78}
79
80/* #endregion */