1use crate::assert;
2use crate::internal_prelude::*;
3use linalg::householder;
4use linalg::matmul::triangular::BlockStructure;
5use linalg::matmul::{self, dot};
6
7#[derive(Copy, Clone, Debug)]
9pub struct TridiagParams {
10 pub par_threshold: usize,
12
13 #[doc(hidden)]
14 pub non_exhaustive: NonExhaustive,
15}
16
17impl<T: ComplexField> Auto<T> for TridiagParams {
18 fn auto() -> Self {
19 Self {
20 par_threshold: 192 * 256,
21 non_exhaustive: NonExhaustive(()),
22 }
23 }
24}
25
26pub fn tridiag_in_place_scratch<T: ComplexField>(dim: usize, par: Par, params: Spec<TridiagParams, T>) -> StackReq {
29 _ = par;
30 _ = params;
31 StackReq::all_of(&[temp_mat_scratch::<T>(dim, 1).array(2), temp_mat_scratch::<T>(dim, par.degree())])
32}
33
34#[math]
35fn tridiag_fused_op_simd<T: ComplexField>(
36 A: MatMut<'_, T, usize, usize, ContiguousFwd>,
37 y2: ColMut<'_, T, usize>,
38 z2: ColMut<'_, T, usize, ContiguousFwd>,
39
40 ry2: ColRef<'_, T, usize>,
41 rz2: ColRef<'_, T, usize, ContiguousFwd>,
42
43 u0: ColRef<'_, T, usize, ContiguousFwd>,
44 u1: ColRef<'_, T, usize>,
45 u2: ColRef<'_, T, usize>,
46 v2: ColRef<'_, T, usize, ContiguousFwd>,
47
48 f: T,
49 align: usize,
50) {
51 struct Impl<'a, 'M, 'N, T: ComplexField> {
52 A: MatMut<'a, T, Dim<'M>, Dim<'N>, ContiguousFwd>,
53 y2: ColMut<'a, T, Dim<'N>>,
54 z2: ColMut<'a, T, Dim<'M>, ContiguousFwd>,
55
56 ry2: ColRef<'a, T, Dim<'N>>,
57 rz2: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
58
59 u0: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
60 u1: ColRef<'a, T, Dim<'N>>,
61 u2: ColRef<'a, T, Dim<'N>>,
62 v2: ColRef<'a, T, Dim<'M>, ContiguousFwd>,
63
64 f: T,
65 align: usize,
66 }
67
68 impl<'a, 'M, 'N, T: ComplexField> pulp::WithSimd for Impl<'a, 'M, 'N, T> {
69 type Output = ();
70
71 #[inline(always)]
72 fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
73 let Self {
74 mut A,
75 mut y2,
76 mut z2,
77 ry2,
78 rz2,
79 u0,
80 u1,
81 u2,
82 v2,
83 f,
84 mut align,
85 } = self;
86
87 let simd = T::simd_ctx(simd);
88 let (m, n) = A.shape();
89 {
90 let simd = SimdCtx::<T, S>::new_align(simd, m, align);
91 let (head, body, tail) = simd.indices();
92
93 if let Some(i0) = head {
94 simd.write(z2.rb_mut(), i0, simd.zero());
95 }
96 for i0 in body {
97 simd.write(z2.rb_mut(), i0, simd.zero());
98 }
99 if let Some(i0) = tail {
100 simd.write(z2.rb_mut(), i0, simd.zero());
101 }
102 }
103
104 for j in n.indices() {
105 let i = m.idx_inc(*j);
106 with_dim!(m, *m - *j);
107
108 let simd = SimdCtx::<T, S>::new_align(simd, m, align);
109 align -= 1;
110
111 let mut A = A.rb_mut().col_mut(j).subrows_mut(i, m);
112
113 let mut z = z2.rb_mut().subrows_mut(i, m);
114 let rz = rz2.subrows(i, m);
115 let ua = u0.subrows(i, m);
116 let v = v2.subrows(i, m);
117
118 let y = y2.rb_mut().at_mut(j);
119 let ry = simd.splat(&(-ry2[j]));
120 let ub = simd.splat(&(-u1[j]));
121 let uc = simd.splat(&(f * u2[j]));
122
123 let mut acc0 = simd.zero();
124 let mut acc1 = simd.zero();
125 let mut acc2 = simd.zero();
126 let mut acc3 = simd.zero();
127
128 let (head, body4, body1, tail) = simd.batch_indices::<4>();
129 if let Some(i0) = head {
130 let mut a = simd.read(A.rb(), i0);
131 a = simd.conj_mul_add(ry, simd.read(ua, i0), a);
132 a = simd.conj_mul_add(ub, simd.read(rz, i0), a);
133 simd.write(A.rb_mut(), i0, a);
134
135 let tmp = simd.read(z.rb(), i0);
136 simd.write(z.rb_mut(), i0, simd.mul_add(a, uc, tmp));
137
138 acc0 = simd.conj_mul_add(a, simd.read(v, i0), acc0);
139 }
140
141 for [i0, i1, i2, i3] in body4 {
142 {
143 let mut a = simd.read(A.rb(), i0);
144 a = simd.conj_mul_add(ry, simd.read(ua, i0), a);
145 a = simd.conj_mul_add(ub, simd.read(rz, i0), a);
146 simd.write(A.rb_mut(), i0, a);
147
148 let tmp = simd.read(z.rb(), i0);
149 simd.write(z.rb_mut(), i0, simd.mul_add(a, uc, tmp));
150
151 acc0 = simd.conj_mul_add(a, simd.read(v, i0), acc0);
152 }
153 {
154 let mut a = simd.read(A.rb(), i1);
155 a = simd.conj_mul_add(ry, simd.read(ua, i1), a);
156 a = simd.conj_mul_add(ub, simd.read(rz, i1), a);
157 simd.write(A.rb_mut(), i1, a);
158
159 let tmp = simd.read(z.rb(), i1);
160 simd.write(z.rb_mut(), i1, simd.mul_add(a, uc, tmp));
161
162 acc1 = simd.conj_mul_add(a, simd.read(v, i1), acc1);
163 }
164 {
165 let mut a = simd.read(A.rb(), i2);
166 a = simd.conj_mul_add(ry, simd.read(ua, i2), a);
167 a = simd.conj_mul_add(ub, simd.read(rz, i2), a);
168 simd.write(A.rb_mut(), i2, a);
169
170 let tmp = simd.read(z.rb(), i2);
171 simd.write(z.rb_mut(), i2, simd.mul_add(a, uc, tmp));
172
173 acc2 = simd.conj_mul_add(a, simd.read(v, i2), acc2);
174 }
175 {
176 let mut a = simd.read(A.rb(), i3);
177 a = simd.conj_mul_add(ry, simd.read(ua, i3), a);
178 a = simd.conj_mul_add(ub, simd.read(rz, i3), a);
179 simd.write(A.rb_mut(), i3, a);
180
181 let tmp = simd.read(z.rb(), i3);
182 simd.write(z.rb_mut(), i3, simd.mul_add(a, uc, tmp));
183
184 acc3 = simd.conj_mul_add(a, simd.read(v, i3), acc3);
185 }
186 }
187 for i0 in body1 {
188 let mut a = simd.read(A.rb(), i0);
189 a = simd.conj_mul_add(ry, simd.read(ua, i0), a);
190 a = simd.conj_mul_add(ub, simd.read(rz, i0), a);
191 simd.write(A.rb_mut(), i0, a);
192
193 let tmp = simd.read(z.rb(), i0);
194 simd.write(z.rb_mut(), i0, simd.mul_add(a, uc, tmp));
195
196 acc0 = simd.conj_mul_add(a, simd.read(v, i0), acc0);
197 }
198 if let Some(i0) = tail {
199 let mut a = simd.read(A.rb(), i0);
200 a = simd.conj_mul_add(ry, simd.read(ua, i0), a);
201 a = simd.conj_mul_add(ub, simd.read(rz, i0), a);
202 simd.write(A.rb_mut(), i0, a);
203
204 let tmp = simd.read(z.rb(), i0);
205 simd.write(z.rb_mut(), i0, simd.mul_add(a, uc, tmp));
206
207 acc0 = simd.conj_mul_add(a, simd.read(v, i0), acc0);
208 }
209
210 acc0 = simd.add(acc0, acc1);
211 acc2 = simd.add(acc2, acc3);
212 acc0 = simd.add(acc0, acc2);
213
214 let acc0 = simd.reduce_sum(acc0);
215 let i0 = m.idx(0);
216 *y = f * (acc0 - A[i0] * v[i0]);
217 }
218 }
219 }
220
221 with_dim!(M, A.nrows());
222 with_dim!(N, A.ncols());
223
224 dispatch!(
225 Impl {
226 A: A.as_shape_mut(M, N),
227 y2: y2.as_row_shape_mut(N),
228 z2: z2.as_row_shape_mut(M),
229 ry2: ry2.as_row_shape(N),
230 rz2: rz2.as_row_shape(M),
231 u0: u0.as_row_shape(M),
232 u1: u1.as_row_shape(N),
233 u2: u2.as_row_shape(N),
234 v2: v2.as_row_shape(M),
235 f,
236 align,
237 },
238 Impl,
239 T
240 )
241}
242
243#[math]
244fn tridiag_fused_op<T: ComplexField>(
245 A: MatMut<'_, T>,
246 y2: ColMut<'_, T>,
247 z2: ColMut<'_, T>,
248
249 ry2: ColRef<'_, T>,
250 rz2: ColRef<'_, T>,
251
252 u0: ColRef<'_, T>,
253 u1: ColRef<'_, T>,
254 u2: ColRef<'_, T>,
255 v2: ColRef<'_, T>,
256
257 f: T,
258 align: usize,
259) {
260 let mut A = A;
261 let mut z2 = z2;
262
263 if try_const! { T::SIMD_CAPABILITIES.is_simd() } {
264 if let (Some(A), Some(z2), Some(rz2), Some(u0), Some(v2)) = (
265 A.rb_mut().try_as_col_major_mut(),
266 z2.rb_mut().try_as_col_major_mut(),
267 rz2.try_as_col_major(),
268 u0.try_as_col_major(),
269 v2.try_as_col_major(),
270 ) {
271 tridiag_fused_op_simd(A, y2, z2, ry2, rz2, u0, u1, u2, v2, f, align);
272 } else {
273 tridiag_fused_op_fallback(A, y2, z2, ry2, rz2, u0, u1, u2, v2, f);
274 }
275 } else {
276 tridiag_fused_op_fallback(A, y2, z2, ry2, rz2, u0, u1, u2, v2, f);
277 }
278}
279
280#[math]
281fn tridiag_fused_op_fallback<T: ComplexField>(
282 A: MatMut<'_, T>,
283 y2: ColMut<'_, T>,
284 z2: ColMut<'_, T>,
285
286 ry2: ColRef<'_, T>,
287 rz2: ColRef<'_, T>,
288
289 u0: ColRef<'_, T>,
290 u1: ColRef<'_, T>,
291 u2: ColRef<'_, T>,
292 v2: ColRef<'_, T>,
293
294 f: T,
295) {
296 let par = Par::Seq;
297
298 let mut A = A;
299 let mut y2 = y2;
300
301 let n = A.ncols();
302
303 let (mut A0, mut A1) = A.rb_mut().split_at_row_mut(n);
304 let (u00, u01) = u0.split_at_row(n);
305 let (v20, v21) = v2.split_at_row(n);
306 let (mut z20, mut z21) = z2.split_at_row_mut(n);
307
308 let (rz20, rz21) = rz2.split_at_row(n);
309
310 matmul::triangular::matmul(
311 A0.rb_mut(),
312 BlockStructure::TriangularLower,
313 Accum::Add,
314 u00,
315 BlockStructure::Rectangular,
316 ry2.adjoint(),
317 BlockStructure::Rectangular,
318 -one::<T>(),
319 par,
320 );
321 matmul::triangular::matmul(
322 A0.rb_mut(),
323 BlockStructure::TriangularLower,
324 Accum::Add,
325 rz20,
326 BlockStructure::Rectangular,
327 u1.adjoint(),
328 BlockStructure::Rectangular,
329 -one::<T>(),
330 par,
331 );
332 matmul::matmul(A1.rb_mut(), Accum::Add, u01, ry2.adjoint(), -one::<T>(), par);
333 matmul::matmul(A1.rb_mut(), Accum::Add, rz21, u1.adjoint(), -one::<T>(), par);
334
335 matmul::triangular::matmul(
336 z20.rb_mut(),
337 BlockStructure::Rectangular,
338 Accum::Replace,
339 A0.rb(),
340 BlockStructure::TriangularLower,
341 u2,
342 BlockStructure::Rectangular,
343 f.clone(),
344 par,
345 );
346 matmul::triangular::matmul(
347 y2.rb_mut(),
348 BlockStructure::Rectangular,
349 Accum::Replace,
350 A0.rb().adjoint(),
351 BlockStructure::StrictTriangularUpper,
352 v20,
353 BlockStructure::Rectangular,
354 f.clone(),
355 par,
356 );
357
358 matmul::matmul(z21.rb_mut(), Accum::Replace, A1.rb(), u2, f.clone(), par);
359 matmul::matmul(y2.rb_mut(), Accum::Add, A1.rb().adjoint(), v21, f.clone(), par);
360}
361
362#[math]
369pub fn tridiag_in_place<T: ComplexField>(
370 A: MatMut<'_, T>,
371 householder: MatMut<'_, T>,
372 par: Par,
373 stack: &mut MemStack,
374 params: Spec<TridiagParams, T>,
375) {
376 let params = params.config;
377 let mut A = A;
378 let mut H = householder;
379 let mut par = par;
380 let n = A.nrows();
381 let b = H.nrows();
382
383 assert!(H.ncols() == n.saturating_sub(1));
384
385 if n == 0 {
386 return;
387 }
388
389 let (mut y, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
390 let (mut w, stack) = unsafe { temp_mat_uninit(n, 1, stack) };
391 let (mut z, _) = unsafe { temp_mat_uninit(n, par.degree(), stack) };
392 let mut y = y.as_mat_mut().col_mut(0);
393 let mut w = w.as_mat_mut().col_mut(0);
394 let mut z = z.as_mat_mut();
395
396 {
397 let mut H = H.rb_mut().row_mut(0);
398 for k in 0..n {
399 let (_, A01, A10, A11) = A.rb_mut().split_at_mut(k, k);
400
401 let (_, _) = A01.split_first_col().unwrap();
402 let (_, A20) = A10.split_first_row_mut().unwrap();
403 let (mut A11, _, A21, mut A22) = A11.split_at_mut(1, 1);
404
405 let mut A21 = A21.col_mut(0);
406
407 let a11 = &mut A11[(0, 0)];
408
409 let (y1, mut y2) = y.rb_mut().split_at_row_mut(k).1.split_at_row_mut(1);
410 let y1 = copy(y1[0]);
411
412 if k > 0 {
413 let p = k - 1;
414
415 let u2 = (A20.rb()).col(p);
416
417 *a11 = *a11 - y1 - conj(y1);
418
419 z!(A21.rb_mut(), u2, y2.rb()).for_each(|uz!(a, u, y)| {
420 *a = *a - conj(y1) * *u - *y;
421 });
422 }
423
424 if k + 1 == n {
425 break;
426 }
427
428 let rem = n - k - 1;
429 if rem * rem / 2 < params.par_threshold {
430 par = Par::Seq;
431 }
432
433 let k1 = k + 1;
434
435 let tau_inv;
436 {
437 let (mut a11, mut x2) = A21.rb_mut().split_at_row_mut(1);
438 let a11 = &mut a11[0];
439
440 let householder::HouseholderInfo { tau, .. } = householder::make_householder_in_place(a11, x2.rb_mut());
441
442 tau_inv = recip(real(tau));
443 H[k] = from_real(tau);
444
445 let mut z2 = z.rb_mut().split_at_row_mut(k + 2).1;
446 let mut w2 = w.rb_mut().split_at_row_mut(k + 2).1;
447
448 let (mut y1, mut y2) = y2.rb_mut().split_at_row_mut(1);
449 let y1 = &mut y1[0];
450
451 let (A1, A2) = A22.rb_mut().split_at_row_mut(1);
452 let A1 = A1.row_mut(0);
453
454 let (mut a11, _) = A1.split_at_col_mut(1);
455 let a11 = &mut a11[0];
456
457 let (A21, mut A22) = A2.split_at_col_mut(1);
458 let mut A21 = A21.col_mut(0);
459
460 if k > 0 {
461 let p = k - 1;
462
463 let (u1, u2) = (A20.rb()).col(p).split_at_row(1);
464 let u1 = copy(u1[0]);
465
466 *a11 = *a11 - u1 * conj(y1) - *y1 * conj(u1);
467
468 z!(A21.rb_mut(), u2.rb(), y2.rb()).for_each(|uz!(a, u, y)| {
469 *a = *a - *u * conj(y1) - *y * conj(u1);
470 });
471
472 w2.copy_from(y2.rb());
473
474 match par {
475 Par::Seq => {
476 let mut z2 = z2.rb_mut().col_mut(0);
477 tridiag_fused_op(
478 A22.rb_mut(),
479 y2.rb_mut(),
480 z2.rb_mut(),
481 w2.rb(),
482 w2.rb(),
483 u2.rb(),
484 u2.rb(),
485 x2.rb(),
486 x2.rb(),
487 from_real(tau_inv),
488 simd_align(k1 + 1),
489 );
490 z!(y2.rb_mut(), z2.rb_mut()).for_each(|uz!(y, z)| *y = *y + *z);
491 },
492 #[cfg(feature = "rayon")]
493 Par::Rayon(nthreads) => {
494 use rayon::prelude::*;
495 let nthreads = nthreads.get();
496 let mut z2 = z2.rb_mut().subcols_mut(0, nthreads);
497
498 let n2 = A22.ncols();
499 assert!((n2 as u64) < (1u64 << 50)); let idx_to_col_start = |idx: usize| {
503 let idx_as_percent = idx as f64 / nthreads as f64;
504 let col_start_percent = 1.0f64 - libm::sqrt(1.0f64 - idx_as_percent);
505 (col_start_percent * n2 as f64) as usize
506 };
507
508 {
509 let A22 = A22.rb();
510 let y2 = y2.rb();
511
512 let f = from_real(tau_inv);
513 z2.rb_mut().par_col_iter_mut().enumerate().for_each(|(idx, mut z2)| {
514 let first = idx_to_col_start(idx);
515 let last_col = idx_to_col_start(idx + 1);
516 let nrows = n2 - first;
517 let ncols = last_col - first;
518
519 let mut A = unsafe { A22.rb().subcols(first, ncols).subrows(first, nrows).const_cast() };
520
521 {
522 let y2 = unsafe { y2.subrows(first, ncols).const_cast() };
523 let mut z2 = z2.rb_mut().subrows_mut(first, nrows);
524
525 let ry2 = w2.rb().subrows(first, ncols);
526 let rz2 = w2.rb().subrows(first, nrows);
527
528 let u0 = u2.subrows(first, nrows);
529 let u1 = u2.subrows(first, ncols);
530 let u2 = x2.rb().subrows(first, ncols);
531 let v2 = x2.rb().subrows(first, nrows);
532
533 tridiag_fused_op(
534 A.rb_mut(),
535 y2,
536 z2.rb_mut(),
537 ry2,
538 rz2,
539 u0,
540 u1,
541 u2,
542 v2,
543 copy(f),
544 n.next_power_of_two() - (k1 + 1) - first,
545 );
546 }
547
548 z2.rb_mut().subrows_mut(0, first).fill(zero());
549 });
550 }
551
552 for mut z2 in z2.rb_mut().col_iter_mut() {
553 z!(y2.rb_mut(), z2.rb_mut()).for_each(|uz!(y, z)| *y = *y + *z);
554 }
555 },
556 }
557 } else {
558 matmul::triangular::matmul(
559 y2.rb_mut(),
560 BlockStructure::Rectangular,
561 Accum::Replace,
562 A22.rb(),
563 BlockStructure::TriangularLower,
564 x2.rb(),
565 BlockStructure::Rectangular,
566 from_real(tau_inv),
567 par,
568 );
569 matmul::triangular::matmul(
570 y2.rb_mut(),
571 BlockStructure::Rectangular,
572 Accum::Add,
573 A22.rb().adjoint(),
574 BlockStructure::StrictTriangularUpper,
575 x2.rb(),
576 BlockStructure::Rectangular,
577 from_real(tau_inv),
578 par,
579 );
580 }
581
582 z!(y2.rb_mut(), A21.rb()).for_each(|uz!(y, a)| *y = *y + mul_real(*a, tau_inv));
583
584 *y1 = mul_real(*a11 + dot::inner_prod(A21.rb().transpose(), Conj::Yes, x2.rb(), Conj::No), tau_inv);
585
586 let b = mul_real(
587 mul_pow2(*y1 + dot::inner_prod(x2.rb().transpose(), Conj::Yes, y2.rb(), Conj::No), from_f64(0.5)),
588 tau_inv,
589 );
590 *y1 = *y1 - b;
591 z!(y2.rb_mut(), x2.rb()).for_each(|uz!(y, u)| {
592 *y = *y - b * *u;
593 });
594 }
595 }
596 }
597
598 if n > 0 {
599 let n = n - 1;
600 let A = A.rb().submatrix(1, 0, n, n);
601 let mut H = H.rb_mut().subcols_mut(0, n);
602
603 let mut j = 0;
604 while j < n {
605 let b = Ord::min(b, n - j);
606
607 let mut H = H.rb_mut().submatrix_mut(0, j, b, b);
608
609 for k in 0..b {
610 H[(k, k)] = copy(H[(0, k)]);
611 }
612
613 householder::upgrade_householder_factor(H.rb_mut(), A.submatrix(j, j, n - j, b), b, 1, par);
614 j += b;
615 }
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622 use crate::stats::prelude::*;
623 use crate::utils::approx::*;
624 use crate::{Mat, assert, c64};
625 use dyn_stack::MemBuffer;
626
627 #[test]
628 fn test_tridiag_real() {
629 let rng = &mut StdRng::seed_from_u64(0);
630
631 for n in [2, 3, 4, 8, 16] {
632 let A = CwiseMatDistribution {
633 nrows: n,
634 ncols: n,
635 dist: StandardNormal,
636 }
637 .rand::<Mat<f64>>(rng);
638
639 let A = A.rb() + A.adjoint();
640
641 let b = 3;
642 let mut H = Mat::zeros(b, n - 1);
643
644 let mut V = A.clone();
645 let mut V = V.as_mut();
646 tridiag_in_place(
647 V.rb_mut(),
648 H.rb_mut(),
649 Par::Seq,
650 MemStack::new(&mut MemBuffer::new(StackReq::all_of(&[
651 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n - 1, b, n),
652 tridiag_in_place_scratch::<f64>(n, Par::Seq, default()),
653 ]))),
654 default(),
655 );
656
657 let mut A = A.clone();
658 let mut A = A.as_mut();
659
660 for iter in 0..2 {
661 let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
662
663 let n = n - 1;
664
665 let V = V.rb().submatrix(1, 0, n, n);
666 let mut A = A.rb_mut().subrows_mut(1, n);
667 let H = H.as_ref();
668
669 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
670 V,
671 H.as_ref(),
672 if iter == 0 { Conj::Yes } else { Conj::No },
673 A.rb_mut(),
674 Par::Seq,
675 MemStack::new(&mut MemBuffer::new(
676 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<f64>(n, b, n + 1),
677 )),
678 );
679 }
680
681 let approx_eq = CwiseMat(ApproxEq::<f64>::eps());
682 for j in 0..n {
683 for i in 0..n {
684 if i > j + 1 || j > i + 1 {
685 V[(i, j)] = 0.0;
686 }
687 }
688 }
689 for i in 0..n {
690 if i + 1 < n {
691 V[(i, i + 1)] = V[(i + 1, i)];
692 }
693 }
694
695 assert!(V ~ A);
696 }
697 }
698
699 #[test]
700 fn test_tridiag_cplx() {
701 let rng = &mut StdRng::seed_from_u64(0);
702
703 for n in [2, 3, 4, 8, 16] {
704 let A = CwiseMatDistribution {
705 nrows: n,
706 ncols: n,
707 dist: ComplexDistribution::new(StandardNormal, StandardNormal),
708 }
709 .rand::<Mat<c64>>(rng);
710
711 let A = A.rb() + A.adjoint();
712
713 let b = 3;
714 let mut H = Mat::zeros(b, n - 1);
715
716 let mut V = A.clone();
717 let mut V = V.as_mut();
718 tridiag_in_place(
719 V.rb_mut(),
720 H.as_mut(),
721 Par::Seq,
722 MemStack::new(&mut MemBuffer::new(tridiag_in_place_scratch::<c64>(n, Par::Seq, default()))),
723 default(),
724 );
725
726 let mut A = A.clone();
727 let mut A = A.as_mut();
728
729 for iter in 0..2 {
730 let mut A = if iter == 0 { A.rb_mut() } else { A.rb_mut().transpose_mut() };
731
732 let n = n - 1;
733
734 let V = V.rb().submatrix(1, 0, n, n);
735 let mut A = A.rb_mut().subrows_mut(1, n);
736 let H = H.as_ref();
737
738 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj(
739 V,
740 H.as_ref(),
741 if iter == 0 { Conj::Yes } else { Conj::No },
742 A.rb_mut(),
743 Par::Seq,
744 MemStack::new(&mut MemBuffer::new(
745 householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<c64>(n, b, n + 1),
746 )),
747 );
748 }
749
750 let approx_eq = CwiseMat(ApproxEq::eps());
751 for j in 0..n {
752 for i in 0..n {
753 if i > j + 1 || j > i + 1 {
754 V[(i, j)] = c64::ZERO;
755 }
756 }
757 }
758 for i in 0..n {
759 if i + 1 < n {
760 V[(i, i + 1)] = V[(i + 1, i)].conj();
761 }
762 }
763
764 assert!(V ~ A);
765 }
766 }
767}