faer/linalg/matmul/
triangular.rs

1use super::*;
2use crate::linalg::temp_mat_uninit;
3use crate::linalg::zip::Diag;
4use crate::mat::{AsMatMut, MatMut, MatRef};
5use crate::utils::thread::join_raw;
6use crate::{assert, debug_assert, unzip, zip};
7
8#[repr(u8)]
9#[derive(Copy, Clone, Debug)]
10pub(crate) enum DiagonalKind {
11	Zero,
12	Unit,
13	Generic,
14}
15
16#[inline]
17fn pointer_offset<T>(ptr: *const T) -> usize {
18	if try_const! {core::mem::size_of::<T>().is_power_of_two() &&core::mem::size_of::<T>() <= 64 } {
19		ptr.align_offset(64).wrapping_neg() % 16
20	} else {
21		0
22	}
23}
24
25#[faer_macros::math]
26fn copy_lower<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, src_diag: DiagonalKind) {
27	let N = dst.nrows();
28	let mut dst = dst;
29	match src_diag {
30		DiagonalKind::Zero => {
31			dst.copy_from_strict_triangular_lower(src);
32			for j in N.indices() {
33				let zero = zero();
34				dst[(j, j)] = zero;
35			}
36		},
37		DiagonalKind::Unit => {
38			dst.copy_from_strict_triangular_lower(src);
39			for j in N.indices() {
40				let one = one();
41				dst[(j, j)] = one;
42			}
43		},
44		DiagonalKind::Generic => dst.copy_from_triangular_lower(src),
45	}
46
47	zip!(dst).for_each_triangular_upper(Diag::Skip, |unzip!(dst)| *dst = zero());
48}
49
50#[faer_macros::math]
51fn accum_lower<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, skip_diag: bool, beta: Accum) {
52	let N = dst.nrows();
53	debug_assert!(N == dst.nrows());
54	debug_assert!(N == dst.ncols());
55	debug_assert!(N == src.nrows());
56	debug_assert!(N == src.ncols());
57
58	match beta {
59		Accum::Add => {
60			zip!(dst, src).for_each_triangular_lower(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(dst, src)| *dst = *dst + *src);
61		},
62		Accum::Replace => {
63			zip!(dst, src).for_each_triangular_lower(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(dst, src)| *dst = copy(*src));
64		},
65	}
66}
67
68#[faer_macros::math]
69fn copy_upper<'N, T: ComplexField>(dst: MatMut<'_, T, Dim<'N>, Dim<'N>>, src: MatRef<'_, T, Dim<'N>, Dim<'N>>, src_diag: DiagonalKind) {
70	copy_lower(dst.transpose_mut(), src.transpose(), src_diag)
71}
72
73#[repr(align(64))]
74struct Storage<T>([T; 32 * 16]);
75
76macro_rules! stack_mat_16x16 {
77	($name: ident, $n: expr, $offset: expr, $rs: expr, $cs: expr,  $T: ty $(,)?) => {
78		let mut __tmp = core::mem::MaybeUninit::<Storage<$T>>::uninit();
79		let __stack = MemStack::new_any(core::slice::from_mut(&mut __tmp));
80		let mut $name = unsafe { temp_mat_uninit(32, $n, __stack) }.0;
81		let mut $name = $name.as_mat_mut().subrows_mut($offset, $n);
82		if $cs.unsigned_abs() == 1 {
83			$name = $name.transpose_mut();
84			if $cs == 1 {
85				$name = $name.transpose_mut().reverse_cols_mut();
86			}
87		} else if $rs == -1 {
88			$name = $name.reverse_rows_mut();
89		}
90	};
91}
92
93#[faer_macros::math]
94fn mat_x_lower_impl_unchecked<'M, 'N, T: ComplexField>(
95	dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
96	beta: Accum,
97	lhs: MatRef<'_, T, Dim<'M>, Dim<'N>>,
98	rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
99	rhs_diag: DiagonalKind,
100	alpha: &T,
101	conj_lhs: Conj,
102	conj_rhs: Conj,
103	par: Par,
104) {
105	let N = rhs.nrows();
106	let M = lhs.nrows();
107	let n = N.unbound();
108	let m = M.unbound();
109	debug_assert!(M == lhs.nrows());
110	debug_assert!(N == lhs.ncols());
111	debug_assert!(N == rhs.nrows());
112	debug_assert!(N == rhs.ncols());
113	debug_assert!(M == dst.nrows());
114	debug_assert!(N == dst.ncols());
115
116	let join_parallelism = if n * n * m < 128usize * 128usize * 64usize { Par::Seq } else { par };
117
118	if n <= 16 {
119		let op = {
120			#[inline(never)]
121			|| {
122				stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
123
124				copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
125
126				let mut dst = dst;
127				super::matmul_with_conj(dst.rb_mut(), beta, lhs, conj_lhs, temp_rhs.rb(), conj_rhs, alpha.clone(), par);
128			}
129		};
130		op();
131	} else {
132		// split rhs into 3 sections
133		// split lhs and dst into 2 sections
134
135		make_guard!(HEAD);
136		make_guard!(TAIL);
137		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
138
139		let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
140		let (lhs_left, lhs_right) = lhs.split_cols_with(bs);
141		let (mut dst_left, mut dst_right) = dst.split_cols_with_mut(bs);
142
143		{
144			join_raw(
145				|par| mat_x_lower_impl_unchecked(dst_left.rb_mut(), beta, lhs_left, rhs_top_left, rhs_diag, alpha, conj_lhs, conj_rhs, par),
146				|par| {
147					mat_x_lower_impl_unchecked(
148						dst_right.rb_mut(),
149						beta,
150						lhs_right,
151						rhs_bot_right,
152						rhs_diag,
153						alpha,
154						conj_lhs,
155						conj_rhs,
156						par,
157					)
158				},
159				join_parallelism,
160			)
161		};
162
163		super::matmul_with_conj(dst_left, Accum::Add, lhs_right, conj_lhs, rhs_bot_left, conj_rhs, alpha.clone(), par);
164	}
165}
166
167#[faer_macros::math]
168fn lower_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
169	dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
170	beta: Accum,
171	skip_diag: bool,
172	lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
173	lhs_diag: DiagonalKind,
174	rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
175	rhs_diag: DiagonalKind,
176	alpha: &T,
177	conj_lhs: Conj,
178	conj_rhs: Conj,
179	par: Par,
180) {
181	let N = dst.nrows();
182	let n = N.unbound();
183	debug_assert!(N == lhs.nrows());
184	debug_assert!(N == lhs.ncols());
185	debug_assert!(N == rhs.nrows());
186	debug_assert!(N == rhs.ncols());
187	debug_assert!(N == dst.nrows());
188	debug_assert!(N == dst.ncols());
189
190	if n <= 16 {
191		let op = {
192			#[inline(never)]
193			|| {
194				stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
195				stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
196				stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
197
198				copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
199				copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
200
201				super::matmul_with_conj(
202					temp_dst.rb_mut(),
203					Accum::Replace,
204					temp_lhs.rb(),
205					conj_lhs,
206					temp_rhs.rb(),
207					conj_rhs,
208					alpha.clone(),
209					par,
210				);
211				accum_lower(dst, temp_dst.rb(), skip_diag, beta);
212			}
213		};
214		op();
215	} else {
216		make_guard!(HEAD);
217		make_guard!(TAIL);
218		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
219
220		let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
221		let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_with(bs, bs);
222		let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
223
224		// lhs_top_left  × rhs_top_left  => dst_top_left  | low × low => low |   X
225		// lhs_bot_left  × rhs_top_left  => dst_bot_left  | mat × low => mat | 1/2
226		// lhs_bot_right × rhs_bot_left  => dst_bot_left  | low × mat => mat | 1/2
227		// lhs_bot_right × rhs_bot_right => dst_bot_right | low × low => low |   X
228
229		lower_x_lower_into_lower_impl_unchecked(
230			dst_top_left,
231			beta,
232			skip_diag,
233			lhs_top_left,
234			lhs_diag,
235			rhs_top_left,
236			rhs_diag,
237			alpha,
238			conj_lhs,
239			conj_rhs,
240			par,
241		);
242		mat_x_lower_impl_unchecked(
243			dst_bot_left.rb_mut(),
244			beta,
245			lhs_bot_left,
246			rhs_top_left,
247			rhs_diag,
248			alpha,
249			conj_lhs,
250			conj_rhs,
251			par,
252		);
253		mat_x_lower_impl_unchecked(
254			dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(),
255			Accum::Add,
256			rhs_bot_left.reverse_rows_and_cols().transpose(),
257			lhs_bot_right.reverse_rows_and_cols().transpose(),
258			lhs_diag,
259			alpha,
260			conj_rhs,
261			conj_lhs,
262			par,
263		);
264		lower_x_lower_into_lower_impl_unchecked(
265			dst_bot_right,
266			beta,
267			skip_diag,
268			lhs_bot_right,
269			lhs_diag,
270			rhs_bot_right,
271			rhs_diag,
272			alpha,
273			conj_lhs,
274			conj_rhs,
275			par,
276		)
277	}
278}
279
280#[math]
281fn upper_x_lower_impl_unchecked<'N, T: ComplexField>(
282	dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
283	beta: Accum,
284	lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
285	lhs_diag: DiagonalKind,
286	rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
287	rhs_diag: DiagonalKind,
288	alpha: &T,
289	conj_lhs: Conj,
290	conj_rhs: Conj,
291	par: Par,
292) {
293	let N = dst.nrows();
294	let n = N.unbound();
295	debug_assert!(N == lhs.nrows());
296	debug_assert!(N == lhs.ncols());
297	debug_assert!(N == rhs.nrows());
298	debug_assert!(N == rhs.ncols());
299	debug_assert!(N == dst.nrows());
300	debug_assert!(N == dst.ncols());
301
302	if n <= 16 {
303		let op = {
304			#[inline(never)]
305			|| {
306				stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
307				stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
308
309				copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
310				copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
311
312				super::matmul_with_conj(dst, beta, temp_lhs.rb(), conj_lhs, temp_rhs.rb(), conj_rhs, alpha.clone(), par);
313			}
314		};
315		op();
316	} else {
317		make_guard!(HEAD);
318		make_guard!(TAIL);
319		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
320
321		let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
322		let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_with(bs, bs);
323		let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
324
325		// lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => mat |   1
326		// lhs_top_left  × rhs_top_left  => dst_top_left  | upp × low => mat |   X
327		//
328		// lhs_top_right × rhs_bot_right => dst_top_right | mat × low => mat | 1/2
329		// lhs_bot_right × rhs_bot_left  => dst_bot_left  | upp × mat => mat | 1/2
330		// lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => mat |   X
331
332		join_raw(
333			|par| {
334				super::matmul_with_conj(
335					dst_top_left.rb_mut(),
336					beta,
337					lhs_top_right,
338					conj_lhs,
339					rhs_bot_left,
340					conj_rhs,
341					alpha.clone(),
342					par,
343				);
344				upper_x_lower_impl_unchecked(
345					dst_top_left,
346					Accum::Add,
347					lhs_top_left,
348					lhs_diag,
349					rhs_top_left,
350					rhs_diag,
351					alpha,
352					conj_lhs,
353					conj_rhs,
354					par,
355				)
356			},
357			|par| {
358				join_raw(
359					|par| {
360						mat_x_lower_impl_unchecked(
361							dst_top_right,
362							beta,
363							lhs_top_right,
364							rhs_bot_right,
365							rhs_diag,
366							alpha,
367							conj_lhs,
368							conj_rhs,
369							par,
370						)
371					},
372					|par| {
373						mat_x_lower_impl_unchecked(
374							dst_bot_left.transpose_mut(),
375							beta,
376							rhs_bot_left.transpose(),
377							lhs_bot_right.transpose(),
378							lhs_diag,
379							alpha,
380							conj_rhs,
381							conj_lhs,
382							par,
383						)
384					},
385					par,
386				);
387
388				upper_x_lower_impl_unchecked(
389					dst_bot_right,
390					beta,
391					lhs_bot_right,
392					lhs_diag,
393					rhs_bot_right,
394					rhs_diag,
395					alpha,
396					conj_lhs,
397					conj_rhs,
398					par,
399				)
400			},
401			par,
402		);
403	}
404}
405
406#[math]
407fn upper_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
408	dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
409	beta: Accum,
410	skip_diag: bool,
411	lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
412	lhs_diag: DiagonalKind,
413	rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
414	rhs_diag: DiagonalKind,
415	alpha: &T,
416	conj_lhs: Conj,
417	conj_rhs: Conj,
418	par: Par,
419) {
420	let N = dst.nrows();
421	let n = N.unbound();
422	debug_assert!(N == lhs.nrows());
423	debug_assert!(N == lhs.ncols());
424	debug_assert!(N == rhs.nrows());
425	debug_assert!(N == rhs.ncols());
426	debug_assert!(N == dst.nrows());
427	debug_assert!(N == dst.ncols());
428
429	if n <= 16 {
430		let op = {
431			#[inline(never)]
432			|| {
433				stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
434				stack_mat_16x16!(temp_lhs, N, pointer_offset(lhs.as_ptr()), lhs.row_stride(), lhs.col_stride(), T);
435				stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
436
437				copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
438				copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
439
440				super::matmul_with_conj(
441					temp_dst.rb_mut(),
442					Accum::Replace,
443					temp_lhs.rb(),
444					conj_lhs,
445					temp_rhs.rb(),
446					conj_rhs,
447					alpha.clone(),
448					par,
449				);
450
451				accum_lower(dst, temp_dst.rb(), skip_diag, beta);
452			}
453		};
454		op();
455	} else {
456		make_guard!(HEAD);
457		make_guard!(TAIL);
458		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
459
460		let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
461		let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_with(bs, bs);
462		let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
463
464		// lhs_top_left  × rhs_top_left  => dst_top_left  | upp × low => low |   X
465		// lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => low | 1/2
466		//
467		// lhs_bot_right × rhs_bot_left  => dst_bot_left  | upp × mat => mat | 1/2
468		// lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => low |   X
469
470		join_raw(
471			|par| {
472				mat_x_mat_into_lower_impl_unchecked(
473					dst_top_left.rb_mut(),
474					beta,
475					skip_diag,
476					lhs_top_right,
477					rhs_bot_left,
478					alpha,
479					conj_lhs,
480					conj_rhs,
481					par,
482				);
483				upper_x_lower_into_lower_impl_unchecked(
484					dst_top_left,
485					Accum::Add,
486					skip_diag,
487					lhs_top_left,
488					lhs_diag,
489					rhs_top_left,
490					rhs_diag,
491					alpha,
492					conj_lhs,
493					conj_rhs,
494					par,
495				)
496			},
497			|par| {
498				mat_x_lower_impl_unchecked(
499					dst_bot_left.transpose_mut(),
500					beta,
501					rhs_bot_left.transpose(),
502					lhs_bot_right.transpose(),
503					lhs_diag,
504					alpha,
505					conj_rhs,
506					conj_lhs,
507					par,
508				);
509				upper_x_lower_into_lower_impl_unchecked(
510					dst_bot_right,
511					beta,
512					skip_diag,
513					lhs_bot_right,
514					lhs_diag,
515					rhs_bot_right,
516					rhs_diag,
517					alpha,
518					conj_lhs,
519					conj_rhs,
520					par,
521				)
522			},
523			par,
524		);
525	}
526}
527
528#[math]
529fn mat_x_mat_into_lower_impl_unchecked<'N, 'K, T: ComplexField>(
530	dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
531	beta: Accum,
532	skip_diag: bool,
533	lhs: MatRef<'_, T, Dim<'N>, Dim<'K>>,
534	rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
535	alpha: &T,
536	conj_lhs: Conj,
537	conj_rhs: Conj,
538	par: Par,
539) {
540	#[cfg(all(target_arch = "x86_64", feature = "std"))]
541	if const { T::IS_NATIVE_F64 || T::IS_NATIVE_F32 || T::IS_NATIVE_C64 || T::IS_NATIVE_C32 } {
542		use private_gemm_x86::*;
543
544		let feat = if std::arch::is_x86_feature_detected!("avx512f") {
545			Some(InstrSet::Avx512)
546		} else if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
547			Some(InstrSet::Avx256)
548		} else {
549			None
550		};
551
552		if *dst.nrows() > 0 && *dst.ncols() > 0 && *lhs.ncols() > 0 {
553			if let Some(feat) = feat {
554				unsafe {
555					let (dst, lhs) = if skip_diag {
556						(dst.as_dyn_mut().get_mut(1.., ..), lhs.as_dyn().get(1.., ..))
557					} else {
558						(dst.as_dyn_mut(), lhs.as_dyn())
559					};
560
561					private_gemm_x86::gemm(
562						const {
563							if T::IS_NATIVE_F64 {
564								DType::F64
565							} else if T::IS_NATIVE_F32 {
566								DType::F32
567							} else if T::IS_NATIVE_C64 {
568								DType::C64
569							} else {
570								DType::C32
571							}
572						},
573						const { IType::U32 },
574						feat,
575						dst.nrows(),
576						dst.ncols(),
577						lhs.ncols(),
578						dst.as_ptr_mut() as _,
579						dst.row_stride(),
580						dst.col_stride(),
581						core::ptr::null(),
582						core::ptr::null(),
583						DstKind::Lower,
584						match beta {
585							crate::Accum::Replace => Accum::Replace,
586							crate::Accum::Add => Accum::Add,
587						},
588						lhs.as_ptr() as _,
589						lhs.row_stride(),
590						lhs.col_stride(),
591						conj_lhs == Conj::Yes,
592						core::ptr::null(),
593						0,
594						rhs.as_ptr() as _,
595						rhs.row_stride(),
596						rhs.col_stride(),
597						conj_rhs == Conj::Yes,
598						alpha as *const T as *const (),
599						par.degree(),
600					);
601
602					return;
603				}
604			}
605		}
606	}
607
608	let N = dst.nrows();
609	let K = lhs.ncols();
610	let n = N.unbound();
611	let k = K.unbound();
612	debug_assert!(dst.nrows() == dst.ncols());
613	debug_assert!(dst.nrows() == lhs.nrows());
614	debug_assert!(dst.ncols() == rhs.ncols());
615	debug_assert!(lhs.ncols() == rhs.nrows());
616
617	let par = if n * n * k < 128usize * 128usize * 128usize { Par::Seq } else { par };
618
619	if n <= 16 {
620		let op = {
621			#[inline(never)]
622			|| {
623				stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
624
625				super::matmul_with_conj(temp_dst.rb_mut(), Accum::Replace, lhs, conj_lhs, rhs, conj_rhs, alpha.clone(), par);
626				accum_lower(dst, temp_dst.rb(), skip_diag, beta);
627			}
628		};
629		op();
630	} else {
631		make_guard!(HEAD);
632		make_guard!(TAIL);
633		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
634
635		let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
636		let (lhs_top, lhs_bot) = lhs.split_rows_with(bs);
637		let (rhs_left, rhs_right) = rhs.split_cols_with(bs);
638
639		join_raw(
640			|par| super::matmul_with_conj(dst_bot_left, beta, lhs_bot, conj_lhs, rhs_left, conj_rhs, alpha.clone(), par),
641			|par| {
642				join_raw(
643					|par| mat_x_mat_into_lower_impl_unchecked(dst_top_left, beta, skip_diag, lhs_top, rhs_left, alpha, conj_lhs, conj_rhs, par),
644					|par| mat_x_mat_into_lower_impl_unchecked(dst_bot_right, beta, skip_diag, lhs_bot, rhs_right, alpha, conj_lhs, conj_rhs, par),
645					par,
646				)
647			},
648			par,
649		);
650	}
651}
652
653#[math]
654fn mat_x_lower_into_lower_impl_unchecked<'N, T: ComplexField>(
655	dst: MatMut<'_, T, Dim<'N>, Dim<'N>>,
656	beta: Accum,
657	skip_diag: bool,
658	lhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
659	rhs: MatRef<'_, T, Dim<'N>, Dim<'N>>,
660	rhs_diag: DiagonalKind,
661	alpha: &T,
662	conj_lhs: Conj,
663	conj_rhs: Conj,
664	par: Par,
665) {
666	let N = dst.nrows();
667	let n = N.unbound();
668	debug_assert!(N == dst.nrows());
669	debug_assert!(N == dst.ncols());
670	debug_assert!(N == lhs.nrows());
671	debug_assert!(N == lhs.ncols());
672	debug_assert!(N == rhs.nrows());
673	debug_assert!(N == rhs.ncols());
674
675	if n <= 16 {
676		let op = {
677			#[inline(never)]
678			|| {
679				stack_mat_16x16!(temp_dst, N, pointer_offset(dst.as_ptr()), dst.row_stride(), dst.col_stride(), T);
680				stack_mat_16x16!(temp_rhs, N, pointer_offset(rhs.as_ptr()), rhs.row_stride(), rhs.col_stride(), T);
681
682				copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
683				super::matmul_with_conj(
684					temp_dst.rb_mut(),
685					Accum::Replace,
686					lhs,
687					conj_lhs,
688					temp_rhs.rb(),
689					conj_rhs,
690					alpha.clone(),
691					par,
692				);
693				accum_lower(dst, temp_dst.rb(), skip_diag, beta);
694			}
695		};
696		op();
697	} else {
698		make_guard!(HEAD);
699		make_guard!(TAIL);
700		let bs = N.partition(N.checked_idx_inc(N.unbound() / 2), HEAD, TAIL);
701
702		let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_with_mut(bs, bs);
703		let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_with(bs, bs);
704		let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_with(bs, bs);
705
706		// lhs_bot_right × rhs_bot_left  => dst_bot_left  | mat × mat => mat |   1
707		// lhs_bot_right × rhs_bot_right => dst_bot_right | mat × low => low |   X
708		//
709		// lhs_top_left  × rhs_top_left  => dst_top_left  | mat × low => low |   X
710		// lhs_top_right × rhs_bot_left  => dst_top_left  | mat × mat => low | 1/2
711		// lhs_bot_left  × rhs_top_left  => dst_bot_left  | mat × low => mat | 1/2
712
713		super::matmul_with_conj(
714			dst_bot_left.rb_mut(),
715			beta,
716			lhs_bot_right,
717			conj_lhs,
718			rhs_bot_left,
719			conj_rhs,
720			alpha.clone(),
721			par,
722		);
723		mat_x_lower_into_lower_impl_unchecked(
724			dst_bot_right,
725			beta,
726			skip_diag,
727			lhs_bot_right,
728			rhs_bot_right,
729			rhs_diag,
730			alpha,
731			conj_lhs,
732			conj_rhs,
733			par,
734		);
735
736		mat_x_lower_into_lower_impl_unchecked(
737			dst_top_left.rb_mut(),
738			beta,
739			skip_diag,
740			lhs_top_left,
741			rhs_top_left,
742			rhs_diag,
743			alpha,
744			conj_lhs,
745			conj_rhs,
746			par,
747		);
748		mat_x_mat_into_lower_impl_unchecked(
749			dst_top_left,
750			Accum::Add,
751			skip_diag,
752			lhs_top_right,
753			rhs_bot_left,
754			alpha,
755			conj_lhs,
756			conj_rhs,
757			par,
758		);
759		mat_x_lower_impl_unchecked(
760			dst_bot_left,
761			Accum::Add,
762			lhs_bot_left,
763			rhs_top_left,
764			rhs_diag,
765			alpha,
766			conj_lhs,
767			conj_rhs,
768			par,
769		);
770	}
771}
772
773/// describes the parts of the matrix that must be accessed.
774#[derive(Debug, Clone, Copy, PartialEq, Eq)]
775pub enum BlockStructure {
776	/// the full matrix is accessed.
777	Rectangular,
778	/// the lower triangular half (including the diagonal) is accessed.
779	TriangularLower,
780	/// the lower triangular half (excluding the diagonal) is accessed.
781	StrictTriangularLower,
782	/// the lower triangular half (excluding the diagonal, which is assumed to be equal to
783	/// `1.0`) is accessed.
784	UnitTriangularLower,
785	/// the upper triangular half (including the diagonal) is accessed.
786	TriangularUpper,
787	/// the upper triangular half (excluding the diagonal) is accessed.
788	StrictTriangularUpper,
789	/// the upper triangular half (excluding the diagonal, which is assumed to be equal to
790	/// `1.0`) is accessed.
791	UnitTriangularUpper,
792}
793
794impl BlockStructure {
795	/// checks if `self` is full.
796	#[inline]
797	pub fn is_dense(self) -> bool {
798		matches!(self, BlockStructure::Rectangular)
799	}
800
801	/// checks if `self` is triangular lower (either inclusive or exclusive).
802	#[inline]
803	pub fn is_lower(self) -> bool {
804		use BlockStructure::*;
805		matches!(self, TriangularLower | StrictTriangularLower | UnitTriangularLower)
806	}
807
808	/// checks if `self` is triangular upper (either inclusive or exclusive).
809	#[inline]
810	pub fn is_upper(self) -> bool {
811		use BlockStructure::*;
812		matches!(self, TriangularUpper | StrictTriangularUpper | UnitTriangularUpper)
813	}
814
815	/// returns the block structure corresponding to the transposed matrix.
816	#[inline]
817	pub fn transpose(self) -> Self {
818		use BlockStructure::*;
819		match self {
820			Rectangular => Rectangular,
821			TriangularLower => TriangularUpper,
822			StrictTriangularLower => StrictTriangularUpper,
823			UnitTriangularLower => UnitTriangularUpper,
824			TriangularUpper => TriangularLower,
825			StrictTriangularUpper => StrictTriangularLower,
826			UnitTriangularUpper => UnitTriangularLower,
827		}
828	}
829
830	#[inline]
831	pub(crate) fn diag_kind(self) -> DiagonalKind {
832		use BlockStructure::*;
833		match self {
834			Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic,
835			StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero,
836			UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit,
837		}
838	}
839}
840
841#[track_caller]
842fn precondition<M: Shape, N: Shape, K: Shape>(
843	dst_nrows: M,
844	dst_ncols: N,
845	dst_structure: BlockStructure,
846	lhs_nrows: M,
847	lhs_ncols: K,
848	lhs_structure: BlockStructure,
849	rhs_nrows: K,
850	rhs_ncols: N,
851	rhs_structure: BlockStructure,
852) {
853	assert!(all(dst_nrows == lhs_nrows, dst_ncols == rhs_ncols, lhs_ncols == rhs_nrows,));
854
855	let dst_nrows = dst_nrows.unbound();
856	let dst_ncols = dst_ncols.unbound();
857	let lhs_nrows = lhs_nrows.unbound();
858	let lhs_ncols = lhs_ncols.unbound();
859	let rhs_nrows = rhs_nrows.unbound();
860	let rhs_ncols = rhs_ncols.unbound();
861
862	if !dst_structure.is_dense() {
863		assert!(dst_nrows == dst_ncols);
864	}
865	if !lhs_structure.is_dense() {
866		assert!(lhs_nrows == lhs_ncols);
867	}
868	if !rhs_structure.is_dense() {
869		assert!(rhs_nrows == rhs_ncols);
870	}
871}
872
873/// computes the matrix product `[beta * acc] + alpha * lhs * rhs` (implicitly conjugating the
874/// operands if needed) and stores the result in `acc`
875///
876/// performs the operation:
877/// - `acc = alpha * lhs * rhs` if `beta` is `accum::replace` (in this case, the preexisting
878/// values in `acc` are not read)
879/// - `acc = acc + alpha * lhs * rhs` if `beta` is `accum::add`
880///
881/// the left hand side and right hand side may be interpreted as triangular depending on the
882/// given corresponding matrix structure.
883///
884/// for the destination matrix, the result is:
885/// - fully computed if the structure is rectangular,
886/// - only the triangular half (including the diagonal) is computed if the structure is
887/// triangular
888/// - only the strict triangular half (excluding the diagonal) is computed if the structure is
889/// strictly triangular or unit triangular
890///
891/// # panics
892///
893/// panics if the matrix dimensions are not compatible for matrix multiplication.
894/// i.e.  
895///  - `acc.nrows() == lhs.nrows()`
896///  - `acc.ncols() == rhs.ncols()`
897///  - `lhs.ncols() == rhs.nrows()`
898///
899/// additionally, matrices that are marked as triangular must be square, i.e., they must have
900/// the same number of rows and columns.
901///
902/// # example
903///
904/// ```
905/// use faer::linalg::matmul::triangular::{BlockStructure, matmul_with_conj};
906/// use faer::{Accum, Conj, Mat, Par, mat, unzip, zip};
907///
908/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
909/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
910///
911/// let mut acc = Mat::<f64>::zeros(2, 2);
912/// let target = mat![
913/// 	[
914/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 0)] + lhs[(0, 1)] * rhs[(1, 0)]),
915/// 		0.0,
916/// 	],
917/// 	[
918/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 0)] + lhs[(1, 1)] * rhs[(1, 0)]),
919/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 1)] + lhs[(1, 1)] * rhs[(1, 1)]),
920/// 	],
921/// ];
922///
923/// matmul_with_conj(
924/// 	&mut acc,
925/// 	BlockStructure::TriangularLower,
926/// 	Accum::Replace,
927/// 	&lhs,
928/// 	BlockStructure::Rectangular,
929/// 	Conj::No,
930/// 	&rhs,
931/// 	BlockStructure::Rectangular,
932/// 	Conj::No,
933/// 	2.5,
934/// 	Par::Seq,
935/// );
936///
937/// zip!(&acc, &target).for_each(|unzip!(acc, target)| assert!((acc - target).abs() < 1e-10));
938/// ```
939#[track_caller]
940#[inline]
941pub fn matmul_with_conj<T: ComplexField, M: Shape, N: Shape, K: Shape>(
942	dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
943	dst_structure: BlockStructure,
944	beta: Accum,
945	lhs: impl AsMatRef<T = T, Rows = M, Cols = K>,
946	lhs_structure: BlockStructure,
947	conj_lhs: Conj,
948	rhs: impl AsMatRef<T = T, Rows = K, Cols = N>,
949	rhs_structure: BlockStructure,
950	conj_rhs: Conj,
951	alpha: T,
952	par: Par,
953) {
954	let mut dst = dst;
955	let dst = dst.as_mat_mut();
956	let lhs = lhs.as_mat_ref();
957	let rhs = rhs.as_mat_ref();
958
959	precondition(
960		dst.nrows(),
961		dst.ncols(),
962		dst_structure,
963		lhs.nrows(),
964		lhs.ncols(),
965		lhs_structure,
966		rhs.nrows(),
967		rhs.ncols(),
968		rhs_structure,
969	);
970
971	make_guard!(M);
972	make_guard!(N);
973	make_guard!(K);
974	let M = dst.nrows().bind(M);
975	let N = dst.ncols().bind(N);
976	let K = lhs.ncols().bind(K);
977
978	matmul_imp(
979		dst.as_dyn_stride_mut().as_shape_mut(M, N),
980		dst_structure,
981		beta,
982		lhs.as_dyn_stride().canonical().as_shape(M, K),
983		lhs_structure,
984		conj_lhs,
985		rhs.as_dyn_stride().canonical().as_shape(K, N),
986		rhs_structure,
987		conj_rhs,
988		&alpha,
989		par,
990	);
991}
992
993/// computes the matrix product `[beta * acc] + alpha * lhs * rhs` (implicitly conjugating the
994/// operands if needed) and stores the result in `acc`
995///
996/// performs the operation:
997/// - `acc = alpha * lhs * rhs` if `beta` is `accum::replace` (in this case, the preexisting
998/// values in `acc` are not read)
999/// - `acc = acc + alpha * lhs * rhs` if `beta` is `accum::add`
1000///
1001/// the left hand side and right hand side may be interpreted as triangular depending on the
1002/// given corresponding matrix structure.
1003///
1004/// for the destination matrix, the result is:
1005/// - fully computed if the structure is rectangular,
1006/// - only the triangular half (including the diagonal) is computed if the structure is
1007/// triangular
1008/// - only the strict triangular half (excluding the diagonal) is computed if the structure is
1009/// strictly triangular or unit triangular
1010///
1011/// # panics
1012///
1013/// panics if the matrix dimensions are not compatible for matrix multiplication.
1014/// i.e.  
1015///  - `acc.nrows() == lhs.nrows()`
1016///  - `acc.ncols() == rhs.ncols()`
1017///  - `lhs.ncols() == rhs.nrows()`
1018///
1019/// additionally, matrices that are marked as triangular must be square, i.e., they must have
1020/// the same number of rows and columns.
1021///
1022/// # example
1023///
1024/// ```
1025/// use faer::linalg::matmul::triangular::{BlockStructure, matmul};
1026/// use faer::{Accum, Conj, Mat, Par, mat, unzip, zip};
1027///
1028/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]];
1029/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]];
1030///
1031/// let mut acc = Mat::<f64>::zeros(2, 2);
1032/// let target = mat![
1033/// 	[
1034/// 		2.5 * (lhs[(0, 0)] * rhs[(0, 0)] + lhs[(0, 1)] * rhs[(1, 0)]),
1035/// 		0.0,
1036/// 	],
1037/// 	[
1038/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 0)] + lhs[(1, 1)] * rhs[(1, 0)]),
1039/// 		2.5 * (lhs[(1, 0)] * rhs[(0, 1)] + lhs[(1, 1)] * rhs[(1, 1)]),
1040/// 	],
1041/// ];
1042///
1043/// matmul(
1044/// 	&mut acc,
1045/// 	BlockStructure::TriangularLower,
1046/// 	Accum::Replace,
1047/// 	&lhs,
1048/// 	BlockStructure::Rectangular,
1049/// 	&rhs,
1050/// 	BlockStructure::Rectangular,
1051/// 	2.5,
1052/// 	Par::Seq,
1053/// );
1054///
1055/// zip!(&acc, &target).for_each(|unzip!(acc, target)| assert!((acc - target).abs() < 1e-10));
1056/// ```
1057#[track_caller]
1058#[inline]
1059pub fn matmul<T: ComplexField, LhsT: Conjugate<Canonical = T>, RhsT: Conjugate<Canonical = T>, M: Shape, N: Shape, K: Shape>(
1060	dst: impl AsMatMut<T = T, Rows = M, Cols = N>,
1061	dst_structure: BlockStructure,
1062	beta: Accum,
1063	lhs: impl AsMatRef<T = LhsT, Rows = M, Cols = K>,
1064	lhs_structure: BlockStructure,
1065	rhs: impl AsMatRef<T = RhsT, Rows = K, Cols = N>,
1066	rhs_structure: BlockStructure,
1067	alpha: T,
1068	par: Par,
1069) {
1070	let mut dst = dst;
1071	let dst = dst.as_mat_mut();
1072	let lhs = lhs.as_mat_ref();
1073	let rhs = rhs.as_mat_ref();
1074
1075	precondition(
1076		dst.nrows(),
1077		dst.ncols(),
1078		dst_structure,
1079		lhs.nrows(),
1080		lhs.ncols(),
1081		lhs_structure,
1082		rhs.nrows(),
1083		rhs.ncols(),
1084		rhs_structure,
1085	);
1086
1087	make_guard!(M);
1088	make_guard!(N);
1089	make_guard!(K);
1090	let M = dst.nrows().bind(M);
1091	let N = dst.ncols().bind(N);
1092	let K = lhs.ncols().bind(K);
1093
1094	matmul_imp(
1095		dst.as_dyn_stride_mut().as_shape_mut(M, N),
1096		dst_structure,
1097		beta,
1098		lhs.as_dyn_stride().canonical().as_shape(M, K),
1099		lhs_structure,
1100		try_const! { Conj::get::<LhsT>() },
1101		rhs.as_dyn_stride().canonical().as_shape(K, N),
1102		rhs_structure,
1103		try_const! { Conj::get::<RhsT>() },
1104		alpha.by_ref(),
1105		par,
1106	);
1107}
1108
1109#[math]
1110fn matmul_imp<'M, 'N, 'K, T: ComplexField>(
1111	dst: MatMut<'_, T, Dim<'M>, Dim<'N>>,
1112	dst_structure: BlockStructure,
1113	beta: Accum,
1114	lhs: MatRef<'_, T, Dim<'M>, Dim<'K>>,
1115	lhs_structure: BlockStructure,
1116	conj_lhs: Conj,
1117	rhs: MatRef<'_, T, Dim<'K>, Dim<'N>>,
1118	rhs_structure: BlockStructure,
1119	conj_rhs: Conj,
1120	alpha: &T,
1121	par: Par,
1122) {
1123	let mut acc = dst.as_dyn_mut();
1124	let mut lhs = lhs.as_dyn();
1125	let mut rhs = rhs.as_dyn();
1126
1127	let mut acc_structure = dst_structure;
1128	let mut lhs_structure = lhs_structure;
1129	let mut rhs_structure = rhs_structure;
1130
1131	let mut conj_lhs = conj_lhs;
1132	let mut conj_rhs = conj_rhs;
1133
1134	// if either the lhs or the rhs is triangular
1135	if rhs_structure.is_lower() {
1136		// do nothing
1137		false
1138	} else if rhs_structure.is_upper() {
1139		// invert acc, lhs and rhs
1140		acc = acc.reverse_rows_and_cols_mut();
1141		lhs = lhs.reverse_rows_and_cols();
1142		rhs = rhs.reverse_rows_and_cols();
1143		acc_structure = acc_structure.transpose();
1144		lhs_structure = lhs_structure.transpose();
1145		rhs_structure = rhs_structure.transpose();
1146		false
1147	} else if lhs_structure.is_lower() {
1148		// invert and transpose
1149		acc = acc.reverse_rows_and_cols_mut().transpose_mut();
1150		(lhs, rhs) = (rhs.reverse_rows_and_cols().transpose(), lhs.reverse_rows_and_cols().transpose());
1151		(conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
1152		(lhs_structure, rhs_structure) = (rhs_structure, lhs_structure);
1153		true
1154	} else if lhs_structure.is_upper() {
1155		// transpose
1156		acc_structure = acc_structure.transpose();
1157		acc = acc.transpose_mut();
1158		(lhs, rhs) = (rhs.transpose(), lhs.transpose());
1159		(conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
1160		(lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose());
1161		true
1162	} else {
1163		// do nothing
1164		false
1165	};
1166
1167	make_guard!(M);
1168	make_guard!(N);
1169	make_guard!(K);
1170	let M = acc.nrows().bind(M);
1171	let N = acc.ncols().bind(N);
1172	let K = lhs.ncols().bind(K);
1173
1174	let clear_upper = |acc: MatMut<'_, T>, skip_diag: bool| match &beta {
1175		Accum::Add => {},
1176
1177		Accum::Replace => zip!(acc).for_each_triangular_upper(if skip_diag { Diag::Skip } else { Diag::Include }, |unzip!(acc)| *acc = zero()),
1178	};
1179
1180	let skip_diag = matches!(
1181		acc_structure,
1182		BlockStructure::StrictTriangularLower
1183			| BlockStructure::StrictTriangularUpper
1184			| BlockStructure::UnitTriangularLower
1185			| BlockStructure::UnitTriangularUpper
1186	);
1187	let lhs_diag = lhs_structure.diag_kind();
1188	let rhs_diag = rhs_structure.diag_kind();
1189
1190	if acc_structure.is_dense() {
1191		if lhs_structure.is_dense() && rhs_structure.is_dense() {
1192			super::matmul_with_conj(acc, beta, lhs, conj_lhs, rhs, conj_rhs, alpha.clone(), par);
1193		} else {
1194			debug_assert!(rhs_structure.is_lower());
1195
1196			if lhs_structure.is_dense() {
1197				mat_x_lower_impl_unchecked(
1198					acc.as_shape_mut(M, N),
1199					beta,
1200					lhs.as_shape(M, N),
1201					rhs.as_shape(N, N),
1202					rhs_diag,
1203					alpha,
1204					conj_lhs,
1205					conj_rhs,
1206					par,
1207				)
1208			} else if lhs_structure.is_lower() {
1209				clear_upper(acc.rb_mut(), true);
1210				lower_x_lower_into_lower_impl_unchecked(
1211					acc.as_shape_mut(N, N),
1212					beta,
1213					false,
1214					lhs.as_shape(N, N),
1215					lhs_diag,
1216					rhs.as_shape(N, N),
1217					rhs_diag,
1218					alpha,
1219					conj_lhs,
1220					conj_rhs,
1221					par,
1222				);
1223			} else {
1224				debug_assert!(lhs_structure.is_upper());
1225				upper_x_lower_impl_unchecked(
1226					acc.as_shape_mut(N, N),
1227					beta,
1228					lhs.as_shape(N, N),
1229					lhs_diag,
1230					rhs.as_shape(N, N),
1231					rhs_diag,
1232					alpha,
1233					conj_lhs,
1234					conj_rhs,
1235					par,
1236				)
1237			}
1238		}
1239	} else if acc_structure.is_lower() {
1240		if lhs_structure.is_dense() && rhs_structure.is_dense() {
1241			mat_x_mat_into_lower_impl_unchecked(
1242				acc.as_shape_mut(N, N),
1243				beta,
1244				skip_diag,
1245				lhs.as_shape(N, K),
1246				rhs.as_shape(K, N),
1247				alpha,
1248				conj_lhs,
1249				conj_rhs,
1250				par,
1251			)
1252		} else {
1253			debug_assert!(rhs_structure.is_lower());
1254			if lhs_structure.is_dense() {
1255				mat_x_lower_into_lower_impl_unchecked(
1256					acc.as_shape_mut(N, N),
1257					beta,
1258					skip_diag,
1259					lhs.as_shape(N, N),
1260					rhs.as_shape(N, N),
1261					rhs_diag,
1262					alpha,
1263					conj_lhs,
1264					conj_rhs,
1265					par,
1266				);
1267			} else if lhs_structure.is_lower() {
1268				lower_x_lower_into_lower_impl_unchecked(
1269					acc.as_shape_mut(N, N),
1270					beta,
1271					skip_diag,
1272					lhs.as_shape(N, N),
1273					lhs_diag,
1274					rhs.as_shape(N, N),
1275					rhs_diag,
1276					alpha,
1277					conj_lhs,
1278					conj_rhs,
1279					par,
1280				)
1281			} else {
1282				upper_x_lower_into_lower_impl_unchecked(
1283					acc.as_shape_mut(N, N),
1284					beta,
1285					skip_diag,
1286					lhs.as_shape(N, N),
1287					lhs_diag,
1288					rhs.as_shape(N, N),
1289					rhs_diag,
1290					alpha,
1291					conj_lhs,
1292					conj_rhs,
1293					par,
1294				)
1295			}
1296		}
1297	} else if lhs_structure.is_dense() && rhs_structure.is_dense() {
1298		mat_x_mat_into_lower_impl_unchecked(
1299			acc.as_shape_mut(N, N).transpose_mut(),
1300			beta,
1301			skip_diag,
1302			rhs.transpose().as_shape(N, K),
1303			lhs.transpose().as_shape(K, N),
1304			alpha,
1305			conj_rhs,
1306			conj_lhs,
1307			par,
1308		)
1309	} else {
1310		debug_assert!(rhs_structure.is_lower());
1311		if lhs_structure.is_dense() {
1312			// lower part of lhs does not contribute to result
1313			upper_x_lower_into_lower_impl_unchecked(
1314				acc.as_shape_mut(N, N).transpose_mut(),
1315				beta,
1316				skip_diag,
1317				rhs.transpose().as_shape(N, N),
1318				rhs_diag,
1319				lhs.transpose().as_shape(N, N),
1320				lhs_diag,
1321				alpha,
1322				conj_rhs,
1323				conj_lhs,
1324				par,
1325			)
1326		} else if lhs_structure.is_lower() {
1327			if !skip_diag {
1328				match beta {
1329					Accum::Add => {
1330						for j in 0..N.unbound() {
1331							acc[(j, j)] = acc[(j, j)] + *alpha * lhs[(j, j)] * rhs[(j, j)];
1332						}
1333					},
1334					Accum::Replace => {
1335						for j in 0..N.unbound() {
1336							acc[(j, j)] = *alpha * lhs[(j, j)] * rhs[(j, j)];
1337						}
1338					},
1339				}
1340			}
1341			clear_upper(acc.rb_mut(), true);
1342		} else {
1343			debug_assert!(lhs_structure.is_upper());
1344			upper_x_lower_into_lower_impl_unchecked(
1345				acc.as_shape_mut(N, N).transpose_mut(),
1346				beta,
1347				skip_diag,
1348				rhs.transpose().as_shape(N, N),
1349				rhs_diag,
1350				lhs.transpose().as_shape(N, N),
1351				lhs_diag,
1352				alpha,
1353				conj_rhs,
1354				conj_lhs,
1355				par,
1356			)
1357		}
1358	}
1359}