rstsr_openblas/linalg_auto_impl/
svd.rs1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6impl<T, D, R> SVDAPI<DeviceBLAS> for (&TensorAny<R, T, DeviceBLAS, D>, bool)
9where
10 R: DataAPI<Data = Vec<T>>,
11 T: BlasFloat,
12 D: DimAPI + DimSmallerOneAPI,
13 D::SmallerOne: DimAPI,
14 DeviceBLAS: LapackDriverAPI<T>,
15{
16 type Out =
17 SVDResult<Tensor<T, DeviceBLAS, D>, Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
18 fn svd_f(self) -> Result<Self::Out> {
19 let (a, full_matrices) = self;
20 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
21 let a = a.view().into_dim::<Ix2>();
22 let svd_args = SVDArgs::default().a(a).full_matrices(full_matrices).build()?;
23 let (u, s, vt) = ref_impl_svd_simple_f(svd_args)?;
24 let u = u.unwrap().into_dim::<IxD>().into_dim::<D>();
26 let vt = vt.unwrap().into_dim::<IxD>().into_dim::<D>();
27 let s = s.into_dim::<IxD>().into_dim::<D::SmallerOne>();
28 Ok(SVDResult { u, s, vt })
29 }
30}
31
32#[duplicate_item(
33 Tr; [Tensor<T, DeviceBLAS, D>]; [TensorView<'_, T, DeviceBLAS, D>];
34)]
35impl<T, D> SVDAPI<DeviceBLAS> for (Tr, bool)
36where
37 T: BlasFloat,
38 D: DimAPI + DimSmallerOneAPI,
39 D::SmallerOne: DimAPI,
40 DeviceBLAS: LapackDriverAPI<T>,
41{
42 type Out =
43 SVDResult<Tensor<T, DeviceBLAS, D>, Tensor<T::Real, DeviceBLAS, D::SmallerOne>, Tensor<T, DeviceBLAS, D>>;
44 fn svd_f(self) -> Result<Self::Out> {
45 let (a, full_matrices) = self;
46 SVDAPI::<DeviceBLAS>::svd_f((&a, full_matrices))
47 }
48}
49
50#[duplicate_item(
55 ImplType Tr;
56 ['a, T, D, R: DataAPI<Data = Vec<T>>] [&'a TensorAny<R, T, DeviceBLAS, D>];
57 ['a, T, D, ] [TensorView<'a, T, DeviceBLAS, D> ];
58 [ T, D ] [Tensor<T, DeviceBLAS, D> ];
59)]
60impl<ImplType> SVDAPI<DeviceBLAS> for Tr
61where
62 T: BlasFloat,
63 D: DimAPI,
64 (Tr, bool): SVDAPI<DeviceBLAS>,
65{
66 type Out = <(Tr, bool) as SVDAPI<DeviceBLAS>>::Out;
67 fn svd_f(self) -> Result<Self::Out> {
68 let a = self;
69 SVDAPI::<DeviceBLAS>::svd_f((a, true))
70 }
71}
72
73impl<'a, T> SVDAPI<DeviceBLAS> for SVDArgs<'a, DeviceBLAS, T>
78where
79 T: BlasFloat,
80 DeviceBLAS: LapackDriverAPI<T>,
81{
82 type Out = SVDResult<Tensor<T, DeviceBLAS, Ix2>, Tensor<T::Real, DeviceBLAS, Ix1>, Tensor<T, DeviceBLAS, Ix2>>;
83 fn svd_f(self) -> Result<Self::Out> {
84 SVDAPI::<DeviceBLAS>::svd_f(self.build()?)
85 }
86}
87
88impl<'a, T> SVDAPI<DeviceBLAS> for SVDArgs_<'a, DeviceBLAS, T>
89where
90 T: BlasFloat,
91 DeviceBLAS: LapackDriverAPI<T>,
92{
93 type Out = SVDResult<Tensor<T, DeviceBLAS, Ix2>, Tensor<T::Real, DeviceBLAS, Ix1>, Tensor<T, DeviceBLAS, Ix2>>;
94 fn svd_f(self) -> Result<Self::Out> {
95 let args = self;
96 rstsr_assert!(
97 args.full_matrices.is_some(),
98 InvalidValue,
99 "`svd` must compute UV. Refer to `svdvals` if UV is not required."
100 )?;
101 let (u, s, vt) = ref_impl_svd_simple_f(args)?;
102 let u = u.unwrap().into_dim::<IxD>().into_dim::<Ix2>();
103 let vt = vt.unwrap().into_dim::<IxD>().into_dim::<Ix2>();
104 let s = s.into_dim::<IxD>().into_dim::<Ix1>();
105 Ok(SVDResult { u, s, vt })
106 }
107}
108
109