1use crate::DeviceBLAS;
2use rstsr_blas_traits::prelude::*;
3use rstsr_core::prelude_dev::*;
4use rstsr_linalg_traits::prelude_dev::*;
5
6#[duplicate_item(
7 ImplType TrA TrB ;
8 [T, DA, DB, Ra: DataAPI<Data = Vec<T>>, Rb: DataAPI<Data = Vec<T>>] [&TensorAny<Ra, T, DeviceBLAS, DA>] [&TensorAny<Rb, T, DeviceBLAS, DB>];
9 [T, DA, DB, R: DataAPI<Data = Vec<T>> ] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
10 [T, DA, DB, R: DataAPI<Data = Vec<T>> ] [TensorView<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
11 [T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
12)]
13impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
14where
15 T: BlasFloat,
16 DA: DimAPI,
17 DB: DimAPI,
18 DeviceBLAS: LapackDriverAPI<T>,
19{
20 type Out = Tensor<T, DeviceBLAS, DB>;
21 fn solve_general_f(self) -> Result<Self::Out> {
22 let (a, b) = self;
23 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
24 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
25 let is_b_vec = b.ndim() == 1;
26 let a_view = a.view().into_dim::<Ix2>();
27 let b_view = match is_b_vec {
28 true => b.i((.., None)).into_dim::<Ix2>(),
29 false => b.view().into_dim::<Ix2>(),
30 };
31 let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
32 let result = result.into_owned().into_dim::<IxD>();
33 match is_b_vec {
34 true => Ok(result.into_shape(-1).into_dim::<DB>()),
35 false => Ok(result.into_dim::<DB>()),
36 }
37 }
38}
39
40#[duplicate_item(
41 ImplType TrA TrB ;
42 ['b, T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
43 ['b, T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
44 [ T, DA, DB, R: DataAPI<Data = Vec<T>>] [&TensorAny<R, T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB> ];
45 [ T, DA, DB, ] [TensorView<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB> ];
46)]
47impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
48where
49 T: BlasFloat,
50 DA: DimAPI,
51 DB: DimAPI,
52 DeviceBLAS: LapackDriverAPI<T>,
53{
54 type Out = TrB;
55 fn solve_general_f(self) -> Result<Self::Out> {
56 let (a, mut b) = self;
57 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
58 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
59 let is_b_vec = b.ndim() == 1;
60 let a_view = a.view().into_dim::<Ix2>();
61 let b_view = match is_b_vec {
62 true => b.i_mut((.., None)).into_dim::<Ix2>(),
63 false => b.view_mut().into_dim::<Ix2>(),
64 };
65 let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
66 result.clone_to_mut();
67 Ok(b)
68 }
69}
70
71#[duplicate_item(
72 ImplType TrA TrB ;
73 [T, DA, DB, R: DataAPI<Data = Vec<T>>] [TensorMut<'_, T, DeviceBLAS, DA>] [&TensorAny<R, T, DeviceBLAS, DB> ];
74 [T, DA, DB, ] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorView<'_, T, DeviceBLAS, DB>];
75 [T, DA, DB, R: DataAPI<Data = Vec<T>>] [Tensor<T, DeviceBLAS, DA> ] [&TensorAny<R, T, DeviceBLAS, DB> ];
76 [T, DA, DB, ] [Tensor<T, DeviceBLAS, DA> ] [TensorView<'_, T, DeviceBLAS, DB>];
77)]
78impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
79where
80 T: BlasFloat,
81 DA: DimAPI,
82 DB: DimAPI,
83 DeviceBLAS: LapackDriverAPI<T>,
84{
85 type Out = Tensor<T, DeviceBLAS, DB>;
86 fn solve_general_f(self) -> Result<Self::Out> {
87 let (mut a, b) = self;
88 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
89 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
90 let is_b_vec = b.ndim() == 1;
91 let a_view = a.view_mut().into_dim::<Ix2>();
92 let b_view = match is_b_vec {
93 true => b.i((.., None)).into_dim::<Ix2>(),
94 false => b.view().into_dim::<Ix2>(),
95 };
96 let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
97 let result = result.into_owned().into_dim::<IxD>();
98 match is_b_vec {
99 true => Ok(result.into_shape(-1).into_dim::<DB>()),
100 false => Ok(result.into_dim::<DB>()),
101 }
102 }
103}
104
105#[duplicate_item(
106 ImplType TrA TrB ;
107 ['b, T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [TensorMut<'b, T, DeviceBLAS, DB>];
108 [ T, DA, DB] [TensorMut<'_, T, DeviceBLAS, DA>] [Tensor<T, DeviceBLAS, DB> ];
109 ['b, T, DA, DB] [Tensor<T, DeviceBLAS, DA> ] [TensorMut<'b, T, DeviceBLAS, DB>];
110 [ T, DA, DB] [Tensor<T, DeviceBLAS, DA> ] [Tensor<T, DeviceBLAS, DB> ];
111)]
112impl<ImplType> SolveGeneralAPI<DeviceBLAS> for (TrA, TrB)
113where
114 T: BlasFloat,
115 DA: DimAPI,
116 DB: DimAPI,
117 DeviceBLAS: LapackDriverAPI<T>,
118{
119 type Out = TrB;
120 fn solve_general_f(self) -> Result<Self::Out> {
121 let (mut a, mut b) = self;
122 rstsr_assert_eq!(a.ndim(), 2, InvalidLayout, "Currently we can only handle 2-D matrix.")?;
123 rstsr_pattern!(b.ndim(), 1..=2, InvalidLayout, "Currently we can only handle 1/2-D matrix.")?;
124 let is_b_vec = b.ndim() == 1;
125 let a_view = a.view_mut().into_dim::<Ix2>();
126 let b_view = match is_b_vec {
127 true => b.i_mut((.., None)).into_dim::<Ix2>(),
128 false => b.view_mut().into_dim::<Ix2>(),
129 };
130 let result = ref_impl_solve_general_f(a_view.into(), b_view.into())?;
131 result.clone_to_mut();
132 Ok(b)
133 }
134}