rstsr_openblas/linalg_auto_impl/
pinv.rs1use crate::DeviceBLAS;
2use num::FromPrimitive;
3use rstsr_blas_traits::prelude::*;
4use rstsr_core::prelude_dev::*;
5use rstsr_linalg_traits::prelude_dev::*;
6
7impl<T, D, R> PinvAPI<DeviceBLAS> for (&TensorAny<R, T, DeviceBLAS, D>, T::Real, T::Real)
8where
9 R: DataAPI<Data = Vec<T>>,
10 T: BlasFloat,
11 T::Real: FromPrimitive,
12 D: DimAPI + DimSmallerOneAPI,
13 D::SmallerOne: DimAPI,
14 DeviceBLAS: LapackDriverAPI<T>,
15{
16 type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
17 fn pinv_f(self) -> Result<Self::Out> {
18 let (a, atol, rtol) = self;
19 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
20 let a = a.view().into_dim::<Ix2>();
21 let (b, rank) = ref_impl_pinv_f(a, Some(atol), Some(rtol))?.into();
22 let b = b.into_dim::<IxD>().into_dim::<D>();
23 return Ok(PinvResult { pinv: b, rank });
24 }
25}
26
27impl<T, D, R> PinvAPI<DeviceBLAS> for &TensorAny<R, T, DeviceBLAS, D>
28where
29 R: DataAPI<Data = Vec<T>>,
30 T: BlasFloat,
31 T::Real: FromPrimitive,
32 D: DimAPI + DimSmallerOneAPI,
33 D::SmallerOne: DimAPI,
34 DeviceBLAS: LapackDriverAPI<T>,
35{
36 type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
37 fn pinv_f(self) -> Result<Self::Out> {
38 let a = self;
39 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
40 let a = a.view().into_dim::<Ix2>();
41 let (pinv, rank) = ref_impl_pinv_f(a, None, None)?.into();
42 let pinv = pinv.into_dim::<IxD>().into_dim::<D>();
43 return Ok(PinvResult { pinv, rank });
44 }
45}
46
47#[duplicate_item(
48 Tr ;
49 [Tensor<T, DeviceBLAS, D> ];
50 [TensorView<'_, T, DeviceBLAS, D>];
51)]
52impl<T, D> PinvAPI<DeviceBLAS> for (Tr, T::Real, T::Real)
53where
54 T: BlasFloat,
55 T::Real: FromPrimitive,
56 D: DimAPI + DimSmallerOneAPI,
57 D::SmallerOne: DimAPI,
58 DeviceBLAS: LapackDriverAPI<T>,
59{
60 type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
61 fn pinv_f(self) -> Result<Self::Out> {
62 let (a, atol, rtol) = self;
63 PinvAPI::<DeviceBLAS>::pinv_f((&a, atol, rtol))
64 }
65}
66
67#[duplicate_item(
68 Tr ;
69 [Tensor<T, DeviceBLAS, D> ];
70 [TensorView<'_, T, DeviceBLAS, D>];
71)]
72impl<T, D> PinvAPI<DeviceBLAS> for Tr
73where
74 T: BlasFloat,
75 T::Real: FromPrimitive,
76 D: DimAPI + DimSmallerOneAPI,
77 D::SmallerOne: DimAPI,
78 DeviceBLAS: LapackDriverAPI<T>,
79{
80 type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
81 fn pinv_f(self) -> Result<Self::Out> {
82 let a = self;
83 PinvAPI::<DeviceBLAS>::pinv_f(&a)
84 }
85}