Skip to main content

caustic/tooling/core/time/
bug.rs

1//! Basis Update & Galerkin (BUG) integrator for Hierarchical Tucker format.
2//!
3//! BUG is a dynamical low-rank integrator that updates the HT tensor factors
4//! (leaf frames and transfer tensors) directly without ever materializing the
5//! full 6D grid. Each timestep consists of three conceptual stages:
6//!
7//! - **K-step:** Update spatial (drift) or velocity (kick) leaf bases by
8//!   semi-Lagrangian shifting, then QR-decompose to maintain orthonormality.
9//! - **L-step:** Recompute the gravitational acceleration from the updated
10//!   density projection and apply the velocity kick.
11//! - **S-step:** Update the transfer tensors to absorb the basis change
12//!   coefficients (the R matrices from QR), optionally augmenting rank.
13//!
14//! This avoids the step-and-truncate (SAT) approach where the full tensor is
15//! advanced and then re-compressed, providing controlled memory usage,
16//! automatic rank adaptation, and robust stability on the low-rank manifold.
17//!
18//! Algorithm (rank-adaptive BUG, one step):
19//! 1. **K-step:** For each active leaf, shift basis by semi-Lagrangian,
20//!    QR-decompose to get new orthonormal basis and coefficient matrix R.
21//! 2. **Transfer update:** Contract R into the parent's transfer tensor
22//!    along the appropriate child axis.
23//! 3. **S-step (optional augmentation):** When `rank_increase > 0`, augment
24//!    bases with shifted columns at additional velocity/acceleration samples,
25//!    then SVD-truncate transfer tensors for rank adaptation.
26//!
27//! Midpoint variant: half-step K-update to predict midpoint bases, then
28//! full-step Galerkin projection using the midpoint bases for 2nd-order accuracy.
29//!
30//! Conservative variant: after truncation, correct the distribution to
31//! restore mass conservation via root transfer tensor scaling.
32//!
33//! Reference: Ceruti, Lubich & Walach, "An unconventional robust integrator
34//! for dynamical low-rank approximation", BIT Numer. Math. (2022).
35
36use std::sync::Arc;
37
38use faer::Mat;
39
40use super::super::{
41    advecator::Advector,
42    algos::ht::HtTensor,
43    algos::lagrangian::sl_shift_1d_into,
44    integrator::{StepProducts, StepTimings, TimeIntegrator},
45    phasespace::PhaseSpaceRepr,
46    progress::{StepPhase, StepProgress},
47    solver::PoissonSolver,
48    types::*,
49};
50use super::helpers;
51use crate::CausticError;
52
53/// Configuration for the BUG integrator.
54pub struct BugConfig {
55    /// Truncation tolerance for rank adaptation.
56    pub tolerance: f64,
57    /// Maximum rank per node.
58    pub max_rank: usize,
59    /// Use 2nd-order midpoint variant (otherwise 1st-order).
60    pub midpoint: bool,
61    /// Apply conservative moment correction after truncation.
62    pub conservative: bool,
63    /// Number of extra basis columns per K-step (0 = rank-preserving).
64    pub rank_increase: usize,
65}
66
67impl Default for BugConfig {
68    fn default() -> Self {
69        Self {
70            tolerance: 1e-8,
71            max_rank: 50,
72            midpoint: false,
73            conservative: false,
74            rank_increase: 2,
75        }
76    }
77}
78
79/// Parent node index and whether the leaf is the left child.
80pub(crate) const LEAF_PARENT: [(usize, bool); 6] = [
81    (8, true),  // leaf 0 → node 8, left
82    (6, true),  // leaf 1 → node 6, left
83    (6, false), // leaf 2 → node 6, right
84    (9, true),  // leaf 3 → node 9, left
85    (7, true),  // leaf 4 → node 7, left
86    (7, false), // leaf 5 → node 7, right
87];
88
89// ─── Shared BUG helpers (used by BugIntegrator, ParallelBugIntegrator, RkBugIntegrator) ──
90
91/// Shift all columns of a leaf frame by `displacement`, QR decompose,
92/// and optionally augment with extra shifted columns.
93///
94/// Returns `(new_frame, R_trunc)` where:
95/// - `new_frame`: orthonormal basis (n × k_new)
96/// - `R_trunc`: coefficient matrix (k_new × k_old) for transfer tensor update
97pub(crate) fn k_step_leaf(
98    ht: &HtTensor,
99    leaf_dim: usize,
100    displacement: f64,
101    aug_displacements: &[f64],
102    max_rank: usize,
103    tolerance: f64,
104) -> (Mat<f64>, Mat<f64>) {
105    let frame = ht.leaf_frame(leaf_dim);
106    let (n, k) = (frame.nrows(), frame.ncols());
107    let is_spatial = leaf_dim < 3;
108    let dim_idx = if is_spatial { leaf_dim } else { leaf_dim - 3 };
109
110    let (cell_size, half_extent, periodic) = if is_spatial {
111        let dx = ht.domain.dx();
112        let lx = ht.domain.lx();
113        let per = matches!(
114            ht.domain.spatial_bc,
115            super::super::init::domain::SpatialBoundType::Periodic
116        );
117        (dx[dim_idx], lx[dim_idx], per)
118    } else {
119        let dv = ht.domain.dv();
120        let lv = ht.domain.lv();
121        let per = matches!(
122            ht.domain.velocity_bc,
123            super::super::init::domain::VelocityBoundType::Truncated
124        );
125        (dv[dim_idx], lv[dim_idx], per)
126    };
127
128    // Shift each column by the primary displacement
129    let mut shifted = Mat::<f64>::zeros(n, k);
130    let mut col_buf = vec![0.0f64; n];
131    let mut out_buf = vec![0.0f64; n];
132
133    for j in 0..k {
134        for i in 0..n {
135            col_buf[i] = frame[(i, j)];
136        }
137        sl_shift_1d_into(
138            &col_buf,
139            displacement,
140            cell_size,
141            n,
142            half_extent,
143            periodic,
144            &mut out_buf,
145        );
146        for i in 0..n {
147            shifted[(i, j)] = out_buf[i];
148        }
149    }
150
151    let n_aug = aug_displacements.len();
152    if n_aug == 0 {
153        // Rank-preserving: QR of shifted frame only
154        let (q, r) = qr_thin(&shifted);
155        return (q, r);
156    }
157
158    // Augmented: shift by additional displacements, collect extra columns
159    let total_cols = k + n_aug * k;
160    let mut augmented = Mat::<f64>::zeros(n, total_cols);
161    for j in 0..k {
162        for i in 0..n {
163            augmented[(i, j)] = shifted[(i, j)];
164        }
165    }
166    for (s, &disp) in aug_displacements.iter().enumerate() {
167        for j in 0..k {
168            for i in 0..n {
169                col_buf[i] = frame[(i, j)];
170            }
171            sl_shift_1d_into(
172                &col_buf,
173                disp,
174                cell_size,
175                n,
176                half_extent,
177                periodic,
178                &mut out_buf,
179            );
180            for i in 0..n {
181                augmented[(i, k + s * k + j)] = out_buf[i];
182            }
183        }
184    }
185
186    // QR decompose augmented matrix
187    let (q_aug, r_aug) = qr_thin(&augmented);
188
189    // SVD truncate to target rank
190    let target_rank = (k + n_aug).min(max_rank).min(q_aug.ncols());
191    let (u, sv, _vt) = svd_thin(&r_aug);
192    if u.ncols() == 0 {
193        return (q_aug, r_aug.subcols(0, k).to_owned());
194    }
195    let rank = truncation_rank(&sv, tolerance)
196        .max(1)
197        .min(target_rank)
198        .min(u.ncols());
199
200    // Truncated basis: Q_trunc = Q_aug @ U[:, :rank]
201    let u_trunc = u.subcols(0, rank);
202    let q_trunc = &q_aug * u_trunc;
203
204    // Coefficient matrix: how the shifted primary columns decompose in the new basis
205    // shifted = Q_aug @ R_aug[:, :k], and Q_trunc = Q_aug @ U[:,:rank]
206    // So R_trunc = U[:,:rank]^T @ R_aug[:, :k]  (k_new × k_old)
207    let r_aug_left = r_aug.subcols(0, k);
208    let r_trunc = u_trunc.transpose() * r_aug_left;
209
210    (q_trunc.to_owned(), r_trunc.to_owned())
211}
212
213/// Update the parent's transfer tensor after replacing a child's leaf frame.
214///
215/// If left child:  B_new[p, l_new, r] = Σ_l R[l_new, l] * B[p, l, r]
216/// If right child: B_new[p, l, r_new] = Σ_r R[r_new, r] * B[p, l, r]
217pub(crate) fn update_transfer(ht: &mut HtTensor, leaf_dim: usize, r_matrix: &Mat<f64>) {
218    let (parent_idx, is_left) = LEAF_PARENT[leaf_dim];
219    let (transfer, ranks) = ht.transfer_tensor(parent_idx);
220    let [kp, kl, kr] = ranks;
221    let k_new = r_matrix.nrows();
222    let k_old = r_matrix.ncols();
223
224    if is_left {
225        assert_eq!(k_old, kl, "R cols must match old left rank");
226        let mut new_data = vec![0.0f64; kp * k_new * kr];
227        for p in 0..kp {
228            for l_new in 0..k_new {
229                for r in 0..kr {
230                    let mut sum = 0.0;
231                    for l in 0..kl {
232                        sum += r_matrix[(l_new, l)] * transfer[p * kl * kr + l * kr + r];
233                    }
234                    new_data[p * k_new * kr + l_new * kr + r] = sum;
235                }
236            }
237        }
238        ht.set_transfer_tensor(parent_idx, new_data, [kp, k_new, kr]);
239    } else {
240        assert_eq!(k_old, kr, "R cols must match old right rank");
241        let mut new_data = vec![0.0f64; kp * kl * k_new];
242        for p in 0..kp {
243            for l in 0..kl {
244                for r_new in 0..k_new {
245                    let mut sum = 0.0;
246                    for r in 0..kr {
247                        sum += r_matrix[(r_new, r)] * transfer[p * kl * kr + l * kr + r];
248                    }
249                    new_data[p * kl * k_new + l * k_new + r_new] = sum;
250                }
251            }
252        }
253        ht.set_transfer_tensor(parent_idx, new_data, [kp, kl, k_new]);
254    }
255}
256
257/// Compute representative velocities from a velocity leaf frame.
258pub(crate) fn representative_velocities(ht: &HtTensor, vel_dim: usize) -> Vec<f64> {
259    let v_frame = ht.leaf_frame(vel_dim);
260    let (nv, kv) = (v_frame.nrows(), v_frame.ncols());
261    let dim_idx = vel_dim - 3;
262    let dv = ht.domain.dv();
263    let lv = ht.domain.lv();
264
265    (0..kv)
266        .map(|l| {
267            let mut wt_sum = 0.0f64;
268            let mut v_sum = 0.0f64;
269            for i in 0..nv {
270                let v = -lv[dim_idx] + (i as f64 + 0.5) * dv[dim_idx];
271                let w = v_frame[(i, l)] * v_frame[(i, l)];
272                v_sum += w * v;
273                wt_sum += w;
274            }
275            if wt_sum > 1e-30 { v_sum / wt_sum } else { 0.0 }
276        })
277        .collect()
278}
279
280/// Compute representative accelerations for a spatial leaf.
281pub(crate) fn representative_accelerations(
282    ht: &HtTensor,
283    spatial_dim: usize,
284    accel: &AccelerationField,
285) -> Vec<f64> {
286    let x_frame = ht.leaf_frame(spatial_dim);
287    let (nx_dim, kx) = (x_frame.nrows(), x_frame.ncols());
288    let [nx1, nx2, nx3, _, _, _] = ht.shape;
289
290    let accel_data = match spatial_dim {
291        0 => &accel.gx,
292        1 => &accel.gy,
293        2 => &accel.gz,
294        _ => unreachable!(),
295    };
296
297    (0..kx)
298        .map(|j| {
299            let mut wt_sum = 0.0f64;
300            let mut a_sum = 0.0f64;
301            for i in 0..nx_dim {
302                let w = x_frame[(i, j)] * x_frame[(i, j)];
303                // Average acceleration over the other two spatial dimensions
304                let mut a_avg = 0.0f64;
305                let n_other: usize = match spatial_dim {
306                    0 => {
307                        for ix2 in 0..nx2 {
308                            for ix3 in 0..nx3 {
309                                a_avg += accel_data[i * nx2 * nx3 + ix2 * nx3 + ix3];
310                            }
311                        }
312                        nx2 * nx3
313                    }
314                    1 => {
315                        for ix1 in 0..nx1 {
316                            for ix3 in 0..nx3 {
317                                a_avg += accel_data[ix1 * nx2 * nx3 + i * nx3 + ix3];
318                            }
319                        }
320                        nx1 * nx3
321                    }
322                    2 => {
323                        for ix1 in 0..nx1 {
324                            for ix2 in 0..nx2 {
325                                a_avg += accel_data[ix1 * nx2 * nx3 + ix2 * nx3 + i];
326                            }
327                        }
328                        nx1 * nx2
329                    }
330                    _ => unreachable!(),
331                };
332                a_avg /= n_other as f64;
333                a_sum += w * a_avg;
334                wt_sum += w;
335            }
336            if wt_sum > 1e-30 { a_sum / wt_sum } else { 0.0 }
337        })
338        .collect()
339}
340
341/// Sample augmentation displacements from representative values.
342pub(crate) fn sample_aug_displacements(
343    representatives: &[f64],
344    dt: f64,
345    rank_increase: usize,
346) -> Vec<f64> {
347    if rank_increase == 0 || representatives.is_empty() {
348        return vec![];
349    }
350    let mut reps: Vec<f64> = representatives.iter().map(|&v| v * dt).collect();
351    reps.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
352    let n = reps.len();
353    let mut aug = Vec::with_capacity(rank_increase);
354    if n >= 1 {
355        aug.push(reps[n - 1]); // max displacement
356    }
357    if rank_increase >= 2 && n >= 2 {
358        aug.push(reps[0]); // min displacement
359    }
360    for s in 2..rank_increase {
361        let frac = s as f64 / (rank_increase - 1) as f64;
362        let idx = ((n as f64 - 1.0) * frac) as usize;
363        aug.push(reps[idx.min(n - 1)]);
364    }
365    aug
366}
367
368/// BUG drift substep: K-step for spatial leaves 0, 1, 2.
369pub(crate) fn bug_drift_substep(ht: &mut HtTensor, dt: f64, config: &BugConfig) {
370    for d in 0..3 {
371        let reps = representative_velocities(ht, d + 3);
372        let primary = if reps.is_empty() {
373            0.0
374        } else {
375            reps.iter().sum::<f64>() / reps.len() as f64 * dt
376        };
377        let aug = sample_aug_displacements(&reps, dt, config.rank_increase);
378        let (new_frame, r_mat) =
379            k_step_leaf(ht, d, primary, &aug, config.max_rank, config.tolerance);
380        *ht.leaf_frame_mut(d) = new_frame;
381        update_transfer(ht, d, &r_mat);
382    }
383}
384
385/// BUG kick substep: K-step for velocity leaves 3, 4, 5.
386pub(crate) fn bug_kick_substep(
387    ht: &mut HtTensor,
388    accel: &AccelerationField,
389    dt: f64,
390    config: &BugConfig,
391) {
392    for d in 3..6 {
393        let reps = representative_accelerations(ht, d - 3, accel);
394        let primary = if reps.is_empty() {
395            0.0
396        } else {
397            reps.iter().sum::<f64>() / reps.len() as f64 * dt
398        };
399        let aug = sample_aug_displacements(&reps, dt, config.rank_increase);
400        let (new_frame, r_mat) =
401            k_step_leaf(ht, d, primary, &aug, config.max_rank, config.tolerance);
402        *ht.leaf_frame_mut(d) = new_frame;
403        update_transfer(ht, d, &r_mat);
404    }
405}
406
407/// Conservative correction: scale root transfer tensor to restore mass.
408pub(crate) fn conservative_correction(ht: &mut HtTensor, density_before: &DensityField) {
409    let density_after = ht.compute_density();
410    let mass_before: f64 = density_before.data.iter().sum();
411    let mass_after: f64 = density_after.data.iter().sum();
412    if mass_before.abs() < 1e-30 || (mass_after - mass_before).abs() < 1e-14 * mass_before.abs() {
413        return;
414    }
415    let scale = mass_before / mass_after;
416    let (transfer, ranks) = ht.transfer_tensor(10); // root
417    let new_data: Vec<f64> = transfer.iter().map(|&v| v * scale).collect();
418    ht.set_transfer_tensor(10, new_data, ranks);
419}
420
421// ─── Linear algebra helpers ─────────────────────────────────────────────
422
423pub(crate) fn qr_thin(mat: &Mat<f64>) -> (Mat<f64>, Mat<f64>) {
424    let m = mat.nrows();
425    let n = mat.ncols();
426    if m.min(n) == 0 {
427        return (Mat::zeros(m, 0), Mat::zeros(0, n));
428    }
429    let qr = mat.as_ref().qr();
430    (qr.compute_thin_Q(), qr.thin_R().to_owned())
431}
432
433pub(crate) fn svd_thin(mat: &Mat<f64>) -> (Mat<f64>, Vec<f64>, Mat<f64>) {
434    let m = mat.nrows();
435    let n = mat.ncols();
436    let k = m.min(n);
437    if k == 0 {
438        return (Mat::zeros(m, 0), vec![], Mat::zeros(0, n));
439    }
440    let svd = match mat.as_ref().thin_svd() {
441        Ok(s) => s,
442        Err(_) => return (Mat::zeros(m, 0), vec![], Mat::zeros(0, n)),
443    };
444    let u = svd.U().to_owned();
445    let vt = svd.V().transpose().to_owned();
446    let s_diag = svd.S().column_vector();
447    let s: Vec<f64> = (0..k).map(|i| s_diag[i]).collect();
448    (u, s, vt)
449}
450
451pub(crate) fn truncation_rank(sv: &[f64], eps: f64) -> usize {
452    let eps2 = eps * eps;
453    let mut tail_sq = 0.0;
454    for k in (0..sv.len()).rev() {
455        tail_sq += sv[k] * sv[k];
456        if tail_sq > eps2 {
457            return k + 1;
458        }
459    }
460    1
461}
462
463// ─── BugIntegrator ──────────────────────────────────────────────────────
464
465/// BUG (Basis Update & Galerkin) integrator for low-rank tensor formats.
466///
467/// When the representation is an `HtTensor`, this integrator evolves the
468/// solution directly on the low-rank manifold via K/L/S-step updates.
469/// For other representations, it falls back to standard Strang splitting.
470pub struct BugIntegrator {
471    /// BUG algorithm parameters (tolerance, max rank, midpoint, conservative).
472    pub config: BugConfig,
473    /// Gravitational constant G used for the Poisson solve.
474    pub g: f64,
475    last_timings: StepTimings,
476    progress: Option<Arc<StepProgress>>,
477}
478
479impl BugIntegrator {
480    /// Create a new BUG integrator with the given gravitational constant and configuration.
481    pub fn new(g: f64, config: BugConfig) -> Self {
482        Self {
483            config,
484            g,
485            last_timings: StepTimings::default(),
486            progress: None,
487        }
488    }
489
490    /// Fallback: standard Strang splitting for non-HT representations.
491    fn strang_fallback(
492        &self,
493        repr: &mut dyn PhaseSpaceRepr,
494        solver: &dyn PoissonSolver,
495        advector: &dyn Advector,
496        dt: f64,
497        timings: &mut StepTimings,
498    ) {
499        helpers::time_ms!(timings, drift_ms, advector.drift(repr, dt / 2.0));
500
501        let (_, _, accel) = helpers::time_ms!(
502            timings,
503            poisson_ms,
504            helpers::solve_poisson(repr, solver, self.g)
505        );
506
507        helpers::time_ms!(timings, kick_ms, advector.kick(repr, &accel, dt));
508
509        helpers::time_ms!(timings, drift_ms, advector.drift(repr, dt / 2.0));
510    }
511
512    /// Standard BUG step: Strang-split drift-kick-drift on HT leaves.
513    fn bug_step_ht(
514        &self,
515        repr: &mut dyn PhaseSpaceRepr,
516        solver: &dyn PoissonSolver,
517        dt: f64,
518        timings: &mut StepTimings,
519    ) {
520        let Some(ht) = repr.as_any_mut().downcast_mut::<HtTensor>() else {
521            debug_assert!(false, "BUG step requires HtTensor");
522            return;
523        };
524
525        let density_before = if self.config.conservative {
526            Some(ht.compute_density())
527        } else {
528            None
529        };
530
531        helpers::report_phase!(self.progress, StepPhase::BugKStep, 0, 4);
532        helpers::time_ms!(
533            timings,
534            drift_ms,
535            bug_drift_substep(ht, dt / 2.0, &self.config)
536        );
537
538        helpers::report_phase!(self.progress, StepPhase::BugLStep, 1, 4);
539        let (_, _, accel) = helpers::time_ms!(
540            timings,
541            poisson_ms,
542            helpers::solve_poisson(ht, solver, self.g)
543        );
544
545        helpers::time_ms!(
546            timings,
547            kick_ms,
548            bug_kick_substep(ht, &accel, dt, &self.config)
549        );
550
551        helpers::time_ms!(
552            timings,
553            drift_ms,
554            bug_drift_substep(ht, dt / 2.0, &self.config)
555        );
556
557        helpers::report_phase!(self.progress, StepPhase::BugSStep, 2, 4);
558        if let Some(ref dens) = density_before {
559            conservative_correction(ht, dens);
560        }
561    }
562
563    /// Midpoint BUG step: half-step predict, full-step with augmented bases.
564    fn midpoint_bug_step(
565        &self,
566        repr: &mut dyn PhaseSpaceRepr,
567        solver: &dyn PoissonSolver,
568        dt: f64,
569        timings: &mut StepTimings,
570    ) {
571        let Some(ht) = repr.as_any_mut().downcast_mut::<HtTensor>() else {
572            debug_assert!(false, "midpoint BUG requires HtTensor");
573            return;
574        };
575
576        let density_before = if self.config.conservative {
577            Some(ht.compute_density())
578        } else {
579            None
580        };
581
582        helpers::report_phase!(self.progress, StepPhase::BugKStep, 0, 4);
583
584        // Predict midpoint with half-step
585        let saved = ht.clone();
586        helpers::time_ms!(
587            timings,
588            drift_ms,
589            bug_drift_substep(ht, dt / 4.0, &self.config)
590        );
591
592        let (_, _, accel) = helpers::time_ms!(
593            timings,
594            poisson_ms,
595            helpers::solve_poisson(ht, solver, self.g)
596        );
597
598        helpers::time_ms!(
599            timings,
600            kick_ms,
601            bug_kick_substep(ht, &accel, dt / 2.0, &self.config)
602        );
603
604        helpers::time_ms!(
605            timings,
606            drift_ms,
607            bug_drift_substep(ht, dt / 4.0, &self.config)
608        );
609
610        // ht is now at midpoint — restore and do full step
611        helpers::report_phase!(self.progress, StepPhase::BugLStep, 1, 4);
612        *ht = saved;
613
614        let aug_config = BugConfig {
615            rank_increase: self.config.rank_increase.max(1),
616            ..BugConfig {
617                tolerance: self.config.tolerance,
618                max_rank: self.config.max_rank,
619                midpoint: false,
620                conservative: false,
621                rank_increase: self.config.rank_increase.max(1),
622            }
623        };
624
625        helpers::time_ms!(
626            timings,
627            drift_ms,
628            bug_drift_substep(ht, dt / 2.0, &aug_config)
629        );
630
631        let (_, _, accel) = helpers::time_ms!(
632            timings,
633            poisson_ms,
634            helpers::solve_poisson(ht, solver, self.g)
635        );
636
637        helpers::time_ms!(
638            timings,
639            kick_ms,
640            bug_kick_substep(ht, &accel, dt, &aug_config)
641        );
642
643        helpers::time_ms!(
644            timings,
645            drift_ms,
646            bug_drift_substep(ht, dt / 2.0, &aug_config)
647        );
648
649        helpers::report_phase!(self.progress, StepPhase::BugSStep, 2, 4);
650        if let Some(ref dens) = density_before {
651            conservative_correction(ht, dens);
652        }
653    }
654}
655
656impl TimeIntegrator for BugIntegrator {
657    /// Advance the distribution by one timestep `dt`.
658    ///
659    /// If the representation is `HtTensor`, performs a BUG step (standard or midpoint
660    /// depending on config). Otherwise falls back to Strang splitting.
661    fn advance(
662        &mut self,
663        repr: &mut dyn PhaseSpaceRepr,
664        solver: &dyn PoissonSolver,
665        advector: &dyn Advector,
666        dt: f64,
667    ) -> Result<StepProducts, CausticError> {
668        let _span = tracing::info_span!("bug_advance").entered();
669        let mut timings = StepTimings::default();
670
671        if let Some(ref p) = self.progress {
672            p.start_step();
673        }
674
675        let is_ht = repr.as_any().downcast_ref::<HtTensor>().is_some();
676
677        if is_ht {
678            if self.config.midpoint {
679                self.midpoint_bug_step(repr, solver, dt, &mut timings);
680            } else {
681                self.bug_step_ht(repr, solver, dt, &mut timings);
682            }
683        } else {
684            self.strang_fallback(repr, solver, advector, dt, &mut timings);
685        }
686
687        helpers::report_phase!(self.progress, StepPhase::StepComplete, 3, 4);
688
689        // Compute end-of-step products for caller reuse
690        let (density, potential, acceleration) = helpers::time_ms!(
691            timings,
692            density_ms,
693            helpers::solve_poisson(repr, solver, self.g)
694        );
695
696        self.last_timings = timings;
697
698        Ok(StepProducts {
699            density,
700            potential,
701            acceleration,
702        })
703    }
704
705    /// Dynamical-time CFL: dt <= cfl_factor / sqrt(G * rho_max).
706    fn max_dt(&self, repr: &dyn PhaseSpaceRepr, cfl_factor: f64) -> f64 {
707        helpers::dynamical_timestep(repr, self.g, cfl_factor)
708    }
709
710    /// Return timing breakdown from the most recent step.
711    fn last_step_timings(&self) -> Option<&StepTimings> {
712        Some(&self.last_timings)
713    }
714
715    /// Attach a progress reporter for intra-step TUI updates.
716    fn set_progress(&mut self, progress: Arc<StepProgress>) {
717        self.progress = Some(progress);
718    }
719}