rstsr_openblas/linalg_auto_impl/
pinv.rs

1use crate::DeviceBLAS;
2use num::FromPrimitive;
3use rstsr_blas_traits::prelude::*;
4use rstsr_core::prelude_dev::*;
5use rstsr_linalg_traits::prelude_dev::*;
6
7impl<T, D, R> PinvAPI<DeviceBLAS> for (&TensorAny<R, T, DeviceBLAS, D>, T::Real, T::Real)
8where
9    R: DataAPI<Data = Vec<T>>,
10    T: BlasFloat,
11    T::Real: FromPrimitive,
12    D: DimAPI + DimSmallerOneAPI,
13    D::SmallerOne: DimAPI,
14    DeviceBLAS: LapackDriverAPI<T>,
15{
16    type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
17    fn pinv_f(self) -> Result<Self::Out> {
18        let (a, atol, rtol) = self;
19        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
20        let a = a.view().into_dim::<Ix2>();
21        let (b, rank) = ref_impl_pinv_f(a, Some(atol), Some(rtol))?.into();
22        let b = b.into_dim::<IxD>().into_dim::<D>();
23        return Ok(PinvResult { pinv: b, rank });
24    }
25}
26
27impl<T, D, R> PinvAPI<DeviceBLAS> for &TensorAny<R, T, DeviceBLAS, D>
28where
29    R: DataAPI<Data = Vec<T>>,
30    T: BlasFloat,
31    T::Real: FromPrimitive,
32    D: DimAPI + DimSmallerOneAPI,
33    D::SmallerOne: DimAPI,
34    DeviceBLAS: LapackDriverAPI<T>,
35{
36    type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
37    fn pinv_f(self) -> Result<Self::Out> {
38        let a = self;
39        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
40        let a = a.view().into_dim::<Ix2>();
41        let (pinv, rank) = ref_impl_pinv_f(a, None, None)?.into();
42        let pinv = pinv.into_dim::<IxD>().into_dim::<D>();
43        return Ok(PinvResult { pinv, rank });
44    }
45}
46
47#[duplicate_item(
48    Tr                               ;
49   [Tensor<T, DeviceBLAS, D>        ];
50   [TensorView<'_, T, DeviceBLAS, D>];
51)]
52impl<T, D> PinvAPI<DeviceBLAS> for (Tr, T::Real, T::Real)
53where
54    T: BlasFloat,
55    T::Real: FromPrimitive,
56    D: DimAPI + DimSmallerOneAPI,
57    D::SmallerOne: DimAPI,
58    DeviceBLAS: LapackDriverAPI<T>,
59{
60    type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
61    fn pinv_f(self) -> Result<Self::Out> {
62        let (a, atol, rtol) = self;
63        PinvAPI::<DeviceBLAS>::pinv_f((&a, atol, rtol))
64    }
65}
66
67#[duplicate_item(
68    Tr                               ;
69   [Tensor<T, DeviceBLAS, D>        ];
70   [TensorView<'_, T, DeviceBLAS, D>];
71)]
72impl<T, D> PinvAPI<DeviceBLAS> for Tr
73where
74    T: BlasFloat,
75    T::Real: FromPrimitive,
76    D: DimAPI + DimSmallerOneAPI,
77    D::SmallerOne: DimAPI,
78    DeviceBLAS: LapackDriverAPI<T>,
79{
80    type Out = PinvResult<Tensor<T, DeviceBLAS, D>>;
81    fn pinv_f(self) -> Result<Self::Out> {
82        let a = self;
83        PinvAPI::<DeviceBLAS>::pinv_f(&a)
84    }
85}