1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait HEMMNum: BLASFloat {
9 unsafe fn hemm(
10 side: *const c_char,
11 uplo: *const c_char,
12 m: *const blas_int,
13 n: *const blas_int,
14 alpha: *const Self,
15 a: *const Self,
16 lda: *const blas_int,
17 b: *const Self,
18 ldb: *const blas_int,
19 beta: *const Self,
20 c: *mut Self,
21 ldc: *const blas_int,
22 );
23}
24
25macro_rules! impl_func {
26 ($type: ty, $func: ident) => {
27 impl HEMMNum for $type {
28 unsafe fn hemm(
29 side: *const c_char,
30 uplo: *const c_char,
31 m: *const blas_int,
32 n: *const blas_int,
33 alpha: *const Self,
34 a: *const Self,
35 lda: *const blas_int,
36 b: *const Self,
37 ldb: *const blas_int,
38 beta: *const Self,
39 c: *mut Self,
40 ldc: *const blas_int,
41 ) {
42 ffi::$func(side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc);
43 }
44 }
45 };
46}
47
48impl_func!(c32, chemm_);
49impl_func!(c64, zhemm_);
50
51pub struct HEMM_Driver<'a, 'b, 'c, F>
56where
57 F: HEMMNum,
58{
59 side: c_char,
60 uplo: c_char,
61 m: blas_int,
62 n: blas_int,
63 alpha: F,
64 a: ArrayView2<'a, F>,
65 lda: blas_int,
66 b: ArrayView2<'b, F>,
67 ldb: blas_int,
68 beta: F,
69 c: ArrayOut2<'c, F>,
70 ldc: blas_int,
71}
72
73impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for HEMM_Driver<'a, 'b, 'c, F>
74where
75 F: HEMMNum,
76{
77 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
78 let Self { side, uplo, m, n, alpha, a, lda, b, ldb, beta, mut c, ldc, .. } = self;
79 let a_ptr = a.as_ptr();
80 let b_ptr = b.as_ptr();
81 let c_ptr = c.get_data_mut_ptr();
82
83 if m == 0 || n == 0 {
86 return Ok(c.clone_to_view_mut());
87 }
88
89 unsafe {
90 F::hemm(&side, &uplo, &m, &n, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
91 }
92 return Ok(c.clone_to_view_mut());
93 }
94}
95
96#[derive(Builder)]
101#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
102pub struct HEMM_<'a, 'b, 'c, F>
103where
104 F: HEMMNum,
105{
106 pub a: ArrayView2<'a, F>,
107 pub b: ArrayView2<'b, F>,
108
109 #[builder(setter(into, strip_option), default = "None")]
110 pub c: Option<ArrayViewMut2<'c, F>>,
111 #[builder(setter(into), default = "F::one()")]
112 pub alpha: F,
113 #[builder(setter(into), default = "F::zero()")]
114 pub beta: F,
115 #[builder(setter(into), default = "BLASLeft")]
116 pub side: BLASSide,
117 #[builder(setter(into), default = "BLASLower")]
118 pub uplo: BLASUpLo,
119 #[builder(setter(into, strip_option), default = "None")]
120 pub layout: Option<BLASLayout>,
121}
122
123impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for HEMM_<'a, 'b, 'c, F>
124where
125 F: HEMMNum,
126{
127 fn driver(self) -> Result<HEMM_Driver<'a, 'b, 'c, F>, BLASError> {
128 let Self { a, b, c, alpha, beta, side, uplo, layout, .. } = self;
129
130 assert_eq!(layout, Some(BLASColMajor));
132 assert!(a.is_fpref() && a.is_fpref());
133
134 let m = b.len_of(Axis(0));
136 let n = b.len_of(Axis(1));
137 let lda = a.stride_of(Axis(1));
138 let ldb = b.stride_of(Axis(1));
139
140 match side {
142 BLASLeft => blas_assert_eq!(a.dim(), (m, m), InvalidDim)?,
143 BLASRight => blas_assert_eq!(a.dim(), (n, n), InvalidDim)?,
144 _ => blas_invalid!(side)?,
145 }
146
147 let c = match c {
149 Some(c) => {
150 blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
151 if c.view().is_fpref() {
152 ArrayOut2::ViewMut(c)
153 } else {
154 let c_buffer = c.view().to_col_layout()?.into_owned();
155 ArrayOut2::ToBeCloned(c, c_buffer)
156 }
157 },
158 None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
159 };
160 let ldc = c.view().stride_of(Axis(1));
161
162 let driver = HEMM_Driver::<'a, 'b, 'c, F> {
164 side: side.try_into()?,
165 uplo: uplo.try_into()?,
166 m: m.try_into()?,
167 n: n.try_into()?,
168 alpha,
169 a,
170 lda: lda.try_into()?,
171 b,
172 ldb: ldb.try_into()?,
173 beta,
174 c,
175 ldc: ldc.try_into()?,
176 };
177 return Ok(driver);
178 }
179}
180
181pub type HEMM<'a, 'b, 'c, F> = HEMM_Builder<'a, 'b, 'c, F>;
186pub type CHEMM<'a, 'b, 'c> = HEMM<'a, 'b, 'c, c32>;
187pub type ZHEMM<'a, 'b, 'c> = HEMM<'a, 'b, 'c, c64>;
188
189impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for HEMM_Builder<'a, 'b, 'c, F>
190where
191 F: HEMMNum,
192{
193 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
194 let HEMM_ { a, b, c, alpha, beta, side, uplo, layout, .. } = self.build()?;
196
197 let layout_a = get_layout_array2(&a);
198 let layout_b = get_layout_array2(&b);
199 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
200
201 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
202 if layout == BLASColMajor {
203 let a_cow = a.to_col_layout()?;
205 let b_cow = b.to_col_layout()?;
206 let obj = HEMM_ {
207 a: a_cow.view(),
208 b: b_cow.view(),
209 c,
210 alpha,
211 beta,
212 side,
213 uplo,
214 layout: Some(BLASColMajor),
215 };
216 return obj.driver()?.run_blas();
217 } else {
218 let a_cow = a.to_row_layout()?;
220 let b_cow = b.to_row_layout()?;
221 let obj = HEMM_ {
222 a: a_cow.t(),
223 b: b_cow.t(),
224 c: c.map(|c| c.reversed_axes()),
225 alpha,
226 beta,
227 side: side.flip()?,
228 uplo: uplo.flip()?,
229 layout: Some(BLASColMajor),
230 };
231 let c = obj.driver()?.run_blas()?.reversed_axes();
232 return Ok(c);
233 }
234 }
235}
236
237