rstsr_openblas/linalg_auto_impl/
cholesky.rs

1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6/* #region full-args */
7
8#[duplicate_item(
9    ImplType                          Tr                               ;
10   [T, D, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, D> ];
11   [T, D                           ] [TensorView<'_, T, DeviceBLAS, D>];
12)]
13impl<ImplType> CholeskyAPI<DeviceBLAS> for (Tr, Option<FlagUpLo>)
14where
15    T: BlasFloat,
16    D: DimAPI,
17    DeviceBLAS: LapackDriverAPI<T>,
18{
19    type Out = Tensor<T, DeviceBLAS, D>;
20    fn cholesky_f(self) -> Result<Self::Out> {
21        let (a, uplo) = self;
22        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
23        let a = a.view().into_dim::<Ix2>();
24        let result = ref_impl_cholesky_f(a.view().into(), uplo)?.into_owned();
25        Ok(result.into_dim::<IxD>().into_dim::<D>())
26    }
27}
28
29#[duplicate_item(
30    ImplType   Tr                              ;
31   ['a, T, D] [TensorMut<'a, T, DeviceBLAS, D>];
32   [    T, D] [Tensor<T, DeviceBLAS, D>       ];
33)]
34impl<ImplType> CholeskyAPI<DeviceBLAS> for (Tr, Option<FlagUpLo>)
35where
36    T: BlasFloat,
37    D: DimAPI,
38    DeviceBLAS: LapackDriverAPI<T>,
39{
40    type Out = Tr;
41    fn cholesky_f(self) -> Result<Self::Out> {
42        let (mut a, uplo) = self;
43        rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
44        let a_ix2 = a.view_mut().into_dim::<Ix2>();
45        let result = ref_impl_cholesky_f(a_ix2.into(), uplo)?;
46        result.clone_to_mut();
47        Ok(a)
48    }
49}
50
51/* #endregion */
52
53/* #region sub-args */
54
55#[duplicate_item(
56    ImplStruct        args_tuple  internal_tuple  ;
57   [(Tr, FlagUpLo)] [(a, uplo)] [(a, Some(uplo))];
58)]
59impl<Tr> CholeskyAPI<DeviceBLAS> for ImplStruct
60where
61    (Tr, Option<FlagUpLo>): CholeskyAPI<DeviceBLAS>,
62{
63    type Out = <(Tr, Option<FlagUpLo>) as CholeskyAPI<DeviceBLAS>>::Out;
64    fn cholesky_f(self) -> Result<Self::Out> {
65        let args_tuple = self;
66        CholeskyAPI::<DeviceBLAS>::cholesky_f(internal_tuple)
67    }
68}
69
70#[duplicate_item(
71    ImplType                              Tr;
72   ['a, T, D, R: DataAPI<Data = Vec<T>>] [&'a TensorAny<R, T, DeviceBLAS, D>];
73   ['a, T, D,                          ] [TensorView<'a, T, DeviceBLAS, D>  ];
74   [    T, D                           ] [Tensor<T, DeviceBLAS, D>          ];
75   ['a, T, D                           ] [TensorMut<'a, T, DeviceBLAS, D>   ];
76)]
77impl<ImplType> CholeskyAPI<DeviceBLAS> for Tr
78where
79    T: BlasFloat,
80    D: DimAPI,
81    (Tr, Option<FlagUpLo>): CholeskyAPI<DeviceBLAS>,
82{
83    type Out = <(Tr, Option<FlagUpLo>) as CholeskyAPI<DeviceBLAS>>::Out;
84    fn cholesky_f(self) -> Result<Self::Out> {
85        let a = self;
86        CholeskyAPI::<DeviceBLAS>::cholesky_f((a, None))
87    }
88}
89
90/* #endregion */