1use std::sync::Arc;
37
38use faer::Mat;
39
40use super::super::{
41 advecator::Advector,
42 algos::ht::HtTensor,
43 algos::lagrangian::sl_shift_1d_into,
44 integrator::{StepProducts, StepTimings, TimeIntegrator},
45 phasespace::PhaseSpaceRepr,
46 progress::{StepPhase, StepProgress},
47 solver::PoissonSolver,
48 types::*,
49};
50use super::helpers;
51use crate::CausticError;
52
53pub struct BugConfig {
55 pub tolerance: f64,
57 pub max_rank: usize,
59 pub midpoint: bool,
61 pub conservative: bool,
63 pub rank_increase: usize,
65}
66
67impl Default for BugConfig {
68 fn default() -> Self {
69 Self {
70 tolerance: 1e-8,
71 max_rank: 50,
72 midpoint: false,
73 conservative: false,
74 rank_increase: 2,
75 }
76 }
77}
78
79pub(crate) const LEAF_PARENT: [(usize, bool); 6] = [
81 (8, true), (6, true), (6, false), (9, true), (7, true), (7, false), ];
88
89pub(crate) fn k_step_leaf(
98 ht: &HtTensor,
99 leaf_dim: usize,
100 displacement: f64,
101 aug_displacements: &[f64],
102 max_rank: usize,
103 tolerance: f64,
104) -> (Mat<f64>, Mat<f64>) {
105 let frame = ht.leaf_frame(leaf_dim);
106 let (n, k) = (frame.nrows(), frame.ncols());
107 let is_spatial = leaf_dim < 3;
108 let dim_idx = if is_spatial { leaf_dim } else { leaf_dim - 3 };
109
110 let (cell_size, half_extent, periodic) = if is_spatial {
111 let dx = ht.domain.dx();
112 let lx = ht.domain.lx();
113 let per = matches!(
114 ht.domain.spatial_bc,
115 super::super::init::domain::SpatialBoundType::Periodic
116 );
117 (dx[dim_idx], lx[dim_idx], per)
118 } else {
119 let dv = ht.domain.dv();
120 let lv = ht.domain.lv();
121 let per = matches!(
122 ht.domain.velocity_bc,
123 super::super::init::domain::VelocityBoundType::Truncated
124 );
125 (dv[dim_idx], lv[dim_idx], per)
126 };
127
128 let mut shifted = Mat::<f64>::zeros(n, k);
130 let mut col_buf = vec![0.0f64; n];
131 let mut out_buf = vec![0.0f64; n];
132
133 for j in 0..k {
134 for i in 0..n {
135 col_buf[i] = frame[(i, j)];
136 }
137 sl_shift_1d_into(
138 &col_buf,
139 displacement,
140 cell_size,
141 n,
142 half_extent,
143 periodic,
144 &mut out_buf,
145 );
146 for i in 0..n {
147 shifted[(i, j)] = out_buf[i];
148 }
149 }
150
151 let n_aug = aug_displacements.len();
152 if n_aug == 0 {
153 let (q, r) = qr_thin(&shifted);
155 return (q, r);
156 }
157
158 let total_cols = k + n_aug * k;
160 let mut augmented = Mat::<f64>::zeros(n, total_cols);
161 for j in 0..k {
162 for i in 0..n {
163 augmented[(i, j)] = shifted[(i, j)];
164 }
165 }
166 for (s, &disp) in aug_displacements.iter().enumerate() {
167 for j in 0..k {
168 for i in 0..n {
169 col_buf[i] = frame[(i, j)];
170 }
171 sl_shift_1d_into(
172 &col_buf,
173 disp,
174 cell_size,
175 n,
176 half_extent,
177 periodic,
178 &mut out_buf,
179 );
180 for i in 0..n {
181 augmented[(i, k + s * k + j)] = out_buf[i];
182 }
183 }
184 }
185
186 let (q_aug, r_aug) = qr_thin(&augmented);
188
189 let target_rank = (k + n_aug).min(max_rank).min(q_aug.ncols());
191 let (u, sv, _vt) = svd_thin(&r_aug);
192 if u.ncols() == 0 {
193 return (q_aug, r_aug.subcols(0, k).to_owned());
194 }
195 let rank = truncation_rank(&sv, tolerance)
196 .max(1)
197 .min(target_rank)
198 .min(u.ncols());
199
200 let u_trunc = u.subcols(0, rank);
202 let q_trunc = &q_aug * u_trunc;
203
204 let r_aug_left = r_aug.subcols(0, k);
208 let r_trunc = u_trunc.transpose() * r_aug_left;
209
210 (q_trunc.to_owned(), r_trunc.to_owned())
211}
212
213pub(crate) fn update_transfer(ht: &mut HtTensor, leaf_dim: usize, r_matrix: &Mat<f64>) {
218 let (parent_idx, is_left) = LEAF_PARENT[leaf_dim];
219 let (transfer, ranks) = ht.transfer_tensor(parent_idx);
220 let [kp, kl, kr] = ranks;
221 let k_new = r_matrix.nrows();
222 let k_old = r_matrix.ncols();
223
224 if is_left {
225 assert_eq!(k_old, kl, "R cols must match old left rank");
226 let mut new_data = vec![0.0f64; kp * k_new * kr];
227 for p in 0..kp {
228 for l_new in 0..k_new {
229 for r in 0..kr {
230 let mut sum = 0.0;
231 for l in 0..kl {
232 sum += r_matrix[(l_new, l)] * transfer[p * kl * kr + l * kr + r];
233 }
234 new_data[p * k_new * kr + l_new * kr + r] = sum;
235 }
236 }
237 }
238 ht.set_transfer_tensor(parent_idx, new_data, [kp, k_new, kr]);
239 } else {
240 assert_eq!(k_old, kr, "R cols must match old right rank");
241 let mut new_data = vec![0.0f64; kp * kl * k_new];
242 for p in 0..kp {
243 for l in 0..kl {
244 for r_new in 0..k_new {
245 let mut sum = 0.0;
246 for r in 0..kr {
247 sum += r_matrix[(r_new, r)] * transfer[p * kl * kr + l * kr + r];
248 }
249 new_data[p * kl * k_new + l * k_new + r_new] = sum;
250 }
251 }
252 }
253 ht.set_transfer_tensor(parent_idx, new_data, [kp, kl, k_new]);
254 }
255}
256
257pub(crate) fn representative_velocities(ht: &HtTensor, vel_dim: usize) -> Vec<f64> {
259 let v_frame = ht.leaf_frame(vel_dim);
260 let (nv, kv) = (v_frame.nrows(), v_frame.ncols());
261 let dim_idx = vel_dim - 3;
262 let dv = ht.domain.dv();
263 let lv = ht.domain.lv();
264
265 (0..kv)
266 .map(|l| {
267 let mut wt_sum = 0.0f64;
268 let mut v_sum = 0.0f64;
269 for i in 0..nv {
270 let v = -lv[dim_idx] + (i as f64 + 0.5) * dv[dim_idx];
271 let w = v_frame[(i, l)] * v_frame[(i, l)];
272 v_sum += w * v;
273 wt_sum += w;
274 }
275 if wt_sum > 1e-30 { v_sum / wt_sum } else { 0.0 }
276 })
277 .collect()
278}
279
280pub(crate) fn representative_accelerations(
282 ht: &HtTensor,
283 spatial_dim: usize,
284 accel: &AccelerationField,
285) -> Vec<f64> {
286 let x_frame = ht.leaf_frame(spatial_dim);
287 let (nx_dim, kx) = (x_frame.nrows(), x_frame.ncols());
288 let [nx1, nx2, nx3, _, _, _] = ht.shape;
289
290 let accel_data = match spatial_dim {
291 0 => &accel.gx,
292 1 => &accel.gy,
293 2 => &accel.gz,
294 _ => unreachable!(),
295 };
296
297 (0..kx)
298 .map(|j| {
299 let mut wt_sum = 0.0f64;
300 let mut a_sum = 0.0f64;
301 for i in 0..nx_dim {
302 let w = x_frame[(i, j)] * x_frame[(i, j)];
303 let mut a_avg = 0.0f64;
305 let n_other: usize = match spatial_dim {
306 0 => {
307 for ix2 in 0..nx2 {
308 for ix3 in 0..nx3 {
309 a_avg += accel_data[i * nx2 * nx3 + ix2 * nx3 + ix3];
310 }
311 }
312 nx2 * nx3
313 }
314 1 => {
315 for ix1 in 0..nx1 {
316 for ix3 in 0..nx3 {
317 a_avg += accel_data[ix1 * nx2 * nx3 + i * nx3 + ix3];
318 }
319 }
320 nx1 * nx3
321 }
322 2 => {
323 for ix1 in 0..nx1 {
324 for ix2 in 0..nx2 {
325 a_avg += accel_data[ix1 * nx2 * nx3 + ix2 * nx3 + i];
326 }
327 }
328 nx1 * nx2
329 }
330 _ => unreachable!(),
331 };
332 a_avg /= n_other as f64;
333 a_sum += w * a_avg;
334 wt_sum += w;
335 }
336 if wt_sum > 1e-30 { a_sum / wt_sum } else { 0.0 }
337 })
338 .collect()
339}
340
341pub(crate) fn sample_aug_displacements(
343 representatives: &[f64],
344 dt: f64,
345 rank_increase: usize,
346) -> Vec<f64> {
347 if rank_increase == 0 || representatives.is_empty() {
348 return vec![];
349 }
350 let mut reps: Vec<f64> = representatives.iter().map(|&v| v * dt).collect();
351 reps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
352 let n = reps.len();
353 let mut aug = Vec::with_capacity(rank_increase);
354 if n >= 1 {
355 aug.push(reps[n - 1]); }
357 if rank_increase >= 2 && n >= 2 {
358 aug.push(reps[0]); }
360 for s in 2..rank_increase {
361 let frac = s as f64 / (rank_increase - 1) as f64;
362 let idx = ((n as f64 - 1.0) * frac) as usize;
363 aug.push(reps[idx.min(n - 1)]);
364 }
365 aug
366}
367
368pub(crate) fn bug_drift_substep(ht: &mut HtTensor, dt: f64, config: &BugConfig) {
370 for d in 0..3 {
371 let reps = representative_velocities(ht, d + 3);
372 let primary = if reps.is_empty() {
373 0.0
374 } else {
375 reps.iter().sum::<f64>() / reps.len() as f64 * dt
376 };
377 let aug = sample_aug_displacements(&reps, dt, config.rank_increase);
378 let (new_frame, r_mat) =
379 k_step_leaf(ht, d, primary, &aug, config.max_rank, config.tolerance);
380 *ht.leaf_frame_mut(d) = new_frame;
381 update_transfer(ht, d, &r_mat);
382 }
383}
384
385pub(crate) fn bug_kick_substep(
387 ht: &mut HtTensor,
388 accel: &AccelerationField,
389 dt: f64,
390 config: &BugConfig,
391) {
392 for d in 3..6 {
393 let reps = representative_accelerations(ht, d - 3, accel);
394 let primary = if reps.is_empty() {
395 0.0
396 } else {
397 reps.iter().sum::<f64>() / reps.len() as f64 * dt
398 };
399 let aug = sample_aug_displacements(&reps, dt, config.rank_increase);
400 let (new_frame, r_mat) =
401 k_step_leaf(ht, d, primary, &aug, config.max_rank, config.tolerance);
402 *ht.leaf_frame_mut(d) = new_frame;
403 update_transfer(ht, d, &r_mat);
404 }
405}
406
407pub(crate) fn conservative_correction(ht: &mut HtTensor, density_before: &DensityField) {
409 let density_after = ht.compute_density();
410 let mass_before: f64 = density_before.data.iter().sum();
411 let mass_after: f64 = density_after.data.iter().sum();
412 if mass_before.abs() < 1e-30 || (mass_after - mass_before).abs() < 1e-14 * mass_before.abs() {
413 return;
414 }
415 let scale = mass_before / mass_after;
416 let (transfer, ranks) = ht.transfer_tensor(10); let new_data: Vec<f64> = transfer.iter().map(|&v| v * scale).collect();
418 ht.set_transfer_tensor(10, new_data, ranks);
419}
420
421pub(crate) fn qr_thin(mat: &Mat<f64>) -> (Mat<f64>, Mat<f64>) {
424 let m = mat.nrows();
425 let n = mat.ncols();
426 if m.min(n) == 0 {
427 return (Mat::zeros(m, 0), Mat::zeros(0, n));
428 }
429 let qr = mat.as_ref().qr();
430 (qr.compute_thin_Q(), qr.thin_R().to_owned())
431}
432
433pub(crate) fn svd_thin(mat: &Mat<f64>) -> (Mat<f64>, Vec<f64>, Mat<f64>) {
434 let m = mat.nrows();
435 let n = mat.ncols();
436 let k = m.min(n);
437 if k == 0 {
438 return (Mat::zeros(m, 0), vec![], Mat::zeros(0, n));
439 }
440 let svd = match mat.as_ref().thin_svd() {
441 Ok(s) => s,
442 Err(_) => return (Mat::zeros(m, 0), vec![], Mat::zeros(0, n)),
443 };
444 let u = svd.U().to_owned();
445 let vt = svd.V().transpose().to_owned();
446 let s_diag = svd.S().column_vector();
447 let s: Vec<f64> = (0..k).map(|i| s_diag[i]).collect();
448 (u, s, vt)
449}
450
451pub(crate) fn truncation_rank(sv: &[f64], eps: f64) -> usize {
452 let eps2 = eps * eps;
453 let mut tail_sq = 0.0;
454 for k in (0..sv.len()).rev() {
455 tail_sq += sv[k] * sv[k];
456 if tail_sq > eps2 {
457 return k + 1;
458 }
459 }
460 1
461}
462
463pub struct BugIntegrator {
471 pub config: BugConfig,
473 pub g: f64,
475 last_timings: StepTimings,
476 progress: Option<Arc<StepProgress>>,
477}
478
479impl BugIntegrator {
480 pub fn new(g: f64, config: BugConfig) -> Self {
482 Self {
483 config,
484 g,
485 last_timings: StepTimings::default(),
486 progress: None,
487 }
488 }
489
490 fn strang_fallback(
492 &self,
493 repr: &mut dyn PhaseSpaceRepr,
494 solver: &dyn PoissonSolver,
495 advector: &dyn Advector,
496 dt: f64,
497 timings: &mut StepTimings,
498 ) {
499 helpers::time_ms!(timings, drift_ms, advector.drift(repr, dt / 2.0));
500
501 let (_, _, accel) = helpers::time_ms!(
502 timings,
503 poisson_ms,
504 helpers::solve_poisson(repr, solver, self.g)
505 );
506
507 helpers::time_ms!(timings, kick_ms, advector.kick(repr, &accel, dt));
508
509 helpers::time_ms!(timings, drift_ms, advector.drift(repr, dt / 2.0));
510 }
511
512 fn bug_step_ht(
514 &self,
515 repr: &mut dyn PhaseSpaceRepr,
516 solver: &dyn PoissonSolver,
517 dt: f64,
518 timings: &mut StepTimings,
519 ) {
520 let Some(ht) = repr.as_any_mut().downcast_mut::<HtTensor>() else {
521 debug_assert!(false, "BUG step requires HtTensor");
522 return;
523 };
524
525 let density_before = if self.config.conservative {
526 Some(ht.compute_density())
527 } else {
528 None
529 };
530
531 helpers::report_phase!(self.progress, StepPhase::BugKStep, 0, 4);
532 helpers::time_ms!(
533 timings,
534 drift_ms,
535 bug_drift_substep(ht, dt / 2.0, &self.config)
536 );
537
538 helpers::report_phase!(self.progress, StepPhase::BugLStep, 1, 4);
539 let (_, _, accel) = helpers::time_ms!(
540 timings,
541 poisson_ms,
542 helpers::solve_poisson(ht, solver, self.g)
543 );
544
545 helpers::time_ms!(
546 timings,
547 kick_ms,
548 bug_kick_substep(ht, &accel, dt, &self.config)
549 );
550
551 helpers::time_ms!(
552 timings,
553 drift_ms,
554 bug_drift_substep(ht, dt / 2.0, &self.config)
555 );
556
557 helpers::report_phase!(self.progress, StepPhase::BugSStep, 2, 4);
558 if let Some(ref dens) = density_before {
559 conservative_correction(ht, dens);
560 }
561 }
562
563 fn midpoint_bug_step(
565 &self,
566 repr: &mut dyn PhaseSpaceRepr,
567 solver: &dyn PoissonSolver,
568 dt: f64,
569 timings: &mut StepTimings,
570 ) {
571 let Some(ht) = repr.as_any_mut().downcast_mut::<HtTensor>() else {
572 debug_assert!(false, "midpoint BUG requires HtTensor");
573 return;
574 };
575
576 let density_before = if self.config.conservative {
577 Some(ht.compute_density())
578 } else {
579 None
580 };
581
582 helpers::report_phase!(self.progress, StepPhase::BugKStep, 0, 4);
583
584 let saved = ht.clone();
586 helpers::time_ms!(
587 timings,
588 drift_ms,
589 bug_drift_substep(ht, dt / 4.0, &self.config)
590 );
591
592 let (_, _, accel) = helpers::time_ms!(
593 timings,
594 poisson_ms,
595 helpers::solve_poisson(ht, solver, self.g)
596 );
597
598 helpers::time_ms!(
599 timings,
600 kick_ms,
601 bug_kick_substep(ht, &accel, dt / 2.0, &self.config)
602 );
603
604 helpers::time_ms!(
605 timings,
606 drift_ms,
607 bug_drift_substep(ht, dt / 4.0, &self.config)
608 );
609
610 helpers::report_phase!(self.progress, StepPhase::BugLStep, 1, 4);
612 *ht = saved;
613
614 let aug_config = BugConfig {
615 rank_increase: self.config.rank_increase.max(1),
616 ..BugConfig {
617 tolerance: self.config.tolerance,
618 max_rank: self.config.max_rank,
619 midpoint: false,
620 conservative: false,
621 rank_increase: self.config.rank_increase.max(1),
622 }
623 };
624
625 helpers::time_ms!(
626 timings,
627 drift_ms,
628 bug_drift_substep(ht, dt / 2.0, &aug_config)
629 );
630
631 let (_, _, accel) = helpers::time_ms!(
632 timings,
633 poisson_ms,
634 helpers::solve_poisson(ht, solver, self.g)
635 );
636
637 helpers::time_ms!(
638 timings,
639 kick_ms,
640 bug_kick_substep(ht, &accel, dt, &aug_config)
641 );
642
643 helpers::time_ms!(
644 timings,
645 drift_ms,
646 bug_drift_substep(ht, dt / 2.0, &aug_config)
647 );
648
649 helpers::report_phase!(self.progress, StepPhase::BugSStep, 2, 4);
650 if let Some(ref dens) = density_before {
651 conservative_correction(ht, dens);
652 }
653 }
654}
655
656impl TimeIntegrator for BugIntegrator {
657 fn advance(
662 &mut self,
663 repr: &mut dyn PhaseSpaceRepr,
664 solver: &dyn PoissonSolver,
665 advector: &dyn Advector,
666 dt: f64,
667 ) -> Result<StepProducts, CausticError> {
668 let _span = tracing::info_span!("bug_advance").entered();
669 let mut timings = StepTimings::default();
670
671 if let Some(ref p) = self.progress {
672 p.start_step();
673 }
674
675 let is_ht = repr.as_any().downcast_ref::<HtTensor>().is_some();
676
677 if is_ht {
678 if self.config.midpoint {
679 self.midpoint_bug_step(repr, solver, dt, &mut timings);
680 } else {
681 self.bug_step_ht(repr, solver, dt, &mut timings);
682 }
683 } else {
684 self.strang_fallback(repr, solver, advector, dt, &mut timings);
685 }
686
687 helpers::report_phase!(self.progress, StepPhase::StepComplete, 3, 4);
688
689 let (density, potential, acceleration) = helpers::time_ms!(
691 timings,
692 density_ms,
693 helpers::solve_poisson(repr, solver, self.g)
694 );
695
696 self.last_timings = timings;
697
698 Ok(StepProducts {
699 density,
700 potential,
701 acceleration,
702 })
703 }
704
705 fn max_dt(&self, repr: &dyn PhaseSpaceRepr, cfl_factor: f64) -> f64 {
707 helpers::dynamical_timestep(repr, self.g, cfl_factor)
708 }
709
710 fn last_step_timings(&self) -> Option<&StepTimings> {
712 Some(&self.last_timings)
713 }
714
715 fn set_progress(&mut self, progress: Arc<StepProgress>) {
717 self.progress = Some(progress);
718 }
719}