1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6#[duplicate_item(
9 ImplType TrA TrB ;
10 [T, DA, DB, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, DA>] [&TensorAny<Rb, T, DeviceBLAS, DB>];
11 [T, DA, DB, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
12 [T, DA, DB, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
13 [T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
14)]
15impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
16where
17 T: BlasFloat,
18 DA: DimAPI,
19 DB: DimAPI,
20 DeviceBLAS: LapackDriverAPI<T>,
21{
22 type Out = Tensor<T, DeviceBLAS, DB>;
23 fn solve_symmetric_f(self) -> Result<Self::Out> {
24 let (a, b, hermi, uplo) = self;
25 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
26 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
27 let is_b_vec = b.ndim() == 1;
28 let a_view = a.view().into_dim::<Ix2>();
29 let b_view = match is_b_vec {
30 true => b.i((.., None)).into_dim::<Ix2>(),
31 false => b.view().into_dim::<Ix2>(),
32 };
33 let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
34 let result = result.into_owned().into_dim::<IxD>();
35 match is_b_vec {
36 true => Ok(result.into_shape(-1).into_dim::<DB>()),
37 false => Ok(result.into_dim::<DB>()),
38 }
39 }
40}
41
42#[duplicate_item(
43 ImplType TrA TrB ;
44 ['b, T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
45 ['b, T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
46 [ T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB> ];
47 [ T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB> ];
48)]
49impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
50where
51 T: BlasFloat,
52 DA: DimAPI,
53 DB: DimAPI,
54 DeviceBLAS: LapackDriverAPI<T>,
55{
56 type Out = TrB;
57 fn solve_symmetric_f(self) -> Result<Self::Out> {
58 let (a, mut b, hermi, uplo) = self;
59 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
60 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
61 let is_b_vec = b.ndim() == 1;
62 let a_view = a.view().into_dim::<Ix2>();
63 let b_view = match is_b_vec {
64 true => b.i_mut((.., None)).into_dim::<Ix2>(),
65 false => b.view_mut().into_dim::<Ix2>(),
66 };
67 let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
68 result.clone_to_mut();
69 Ok(b)
70 }
71}
72
73#[duplicate_item(
74 ImplType TrA TrB ;
75 [T, DA, DB, R: DataAPI<Data = Vec<T>>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
76 [T, DA, DB, ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
77 [T, DA, DB, R: DataAPI<Data = Vec<T>>] [Tensor<T, DeviceBLAS, DA> ] [&TensorAny<R, T, DeviceBLAS, DB> ];
78 [T, DA, DB, ] [Tensor<T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
79)]
80impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
81where
82 T: BlasFloat,
83 DA: DimAPI,
84 DB: DimAPI,
85 DeviceBLAS: LapackDriverAPI<T>,
86{
87 type Out = Tensor<T, DeviceBLAS, DB>;
88 fn solve_symmetric_f(self) -> Result<Self::Out> {
89 let (mut a, b, hermi, uplo) = self;
90 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
91 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
92 let is_b_vec = b.ndim() == 1;
93 let a_view = a.view_mut().into_dim::<Ix2>();
94 let b_view = match is_b_vec {
95 true => b.i((.., None)).into_dim::<Ix2>(),
96 false => b.view().into_dim::<Ix2>(),
97 };
98 let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
99 let result = result.into_owned().into_dim::<IxD>();
100 match is_b_vec {
101 true => Ok(result.into_shape(-1).into_dim::<DB>()),
102 false => Ok(result.into_dim::<DB>()),
103 }
104 }
105}
106
107#[duplicate_item(
108 ImplType TrA TrB ;
109 ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
110 [ T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB> ];
111 ['b, T, DA, DB] [Tensor<T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
112 [ T, DA, DB] [Tensor<T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB> ];
113)]
114impl<ImplType> SolveSymmetricAPI<DeviceBLAS> for (TrA, TrB, bool, Option<FlagUpLo>)
115where
116 T: BlasFloat,
117 DA: DimAPI,
118 DB: DimAPI,
119 DeviceBLAS: LapackDriverAPI<T>,
120{
121 type Out = TrB;
122 fn solve_symmetric_f(self) -> Result<Self::Out> {
123 let (mut a, mut b, hermi, uplo) = self;
124 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
125 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
126 let is_b_vec = b.ndim() == 1;
127 let a_view = a.view_mut().into_dim::<Ix2>();
128 let b_view = match is_b_vec {
129 true => b.i_mut((.., None)).into_dim::<Ix2>(),
130 false => b.view_mut().into_dim::<Ix2>(),
131 };
132 let result = ref_impl_solve_symmetric_f(a_view.into(), b_view.into(), hermi, uplo)?;
133 result.clone_to_mut();
134 Ok(b)
135 }
136}
137
138#[duplicate_item(
143 ImplStruct args_tuple internal_tuple ;
144 [(TrA, TrB, bool, FlagUpLo)] [(a, b, hermi, uplo)] [(a, b, hermi, Some(uplo))];
145 [(TrA, TrB, bool, )] [(a, b, hermi, )] [(a, b, hermi, None )];
146 [(TrA, TrB, FlagUpLo)] [(a, b, uplo)] [(a, b, true , Some(uplo))];
147 [(TrA, TrB, )] [(a, b, )] [(a, b, true , None )];
148)]
149impl<TrA, TrB> SolveSymmetricAPI<DeviceBLAS> for ImplStruct
150where
151 (TrA, TrB, bool, Option<FlagUpLo>): SolveSymmetricAPI<DeviceBLAS>,
152{
153 type Out = <(TrA, TrB, bool, Option<FlagUpLo>) as SolveSymmetricAPI<DeviceBLAS>>::Out;
154 fn solve_symmetric_f(self) -> Result<Self::Out> {
155 let args_tuple = self;
156 SolveSymmetricAPI::<DeviceBLAS>::solve_symmetric_f(internal_tuple)
157 }
158}
159
160