Skip to main content

dsfb_add/
sweep.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use dsfb::{DsfbObserver, DsfbParams, DsfbState};
5use serde::{Deserialize, Serialize};
6
7use crate::aet::{self, AetSweep};
8use crate::analysis::rlt_phase::{analyze_rlt_phase_boundary, RltPhaseBoundary};
9use crate::analysis::structural_law::{diagnostics_from_fit, fit_with_ci, LinearFit};
10use crate::config::SimulationConfig;
11use crate::iwlt::{self, IwltSweep};
12use crate::output::{
13    write_aet_csv, write_cross_layer_thresholds_csv, write_diagnostics_summary_csv, write_iwlt_csv,
14    write_rlt_csv, write_rlt_phase_boundary_csv, write_rlt_trajectory_csv,
15    write_robustness_metrics_csv, write_structural_law_summary_csv, write_tcp_csv,
16    write_tcp_phase_alignment_csv, write_tcp_points_csv, CrossLayerThresholdRow,
17    DiagnosticsSummaryRow, PhaseBoundaryRow, RobustnessMetricRow, StructuralLawSummaryRow,
18    TcpPhaseAlignmentRow,
19};
20use crate::rlt::{self, RltExampleKind, RltSweep};
21use crate::tcp::{self, TcpSweep};
22use crate::AddError;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SweepRunResult {
26    pub steps_per_run: usize,
27    pub aet: Option<AetSweep>,
28    pub tcp: Option<TcpSweep>,
29    pub rlt: Option<RltSweep>,
30    pub iwlt: Option<IwltSweep>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SweepResult {
35    pub output_dir: PathBuf,
36    pub lambda_grid: Vec<f64>,
37    pub runs: Vec<SweepRunResult>,
38    pub aet: Option<AetSweep>,
39    pub tcp: Option<TcpSweep>,
40    pub rlt: Option<RltSweep>,
41    pub iwlt: Option<IwltSweep>,
42}
43
44struct ProgressTracker {
45    total_units: usize,
46    completed_units: usize,
47    last_percent_printed: usize,
48}
49
50impl ProgressTracker {
51    fn new(total_units: usize) -> Self {
52        Self {
53            total_units,
54            completed_units: 0,
55            last_percent_printed: 0,
56        }
57    }
58
59    fn stage_start(&self, label: &str, steps_per_run: usize, stage_units: usize) {
60        println!("[dsfb-add] Starting {label} (N={steps_per_run}, {stage_units} lambda samples)");
61    }
62
63    fn report(&mut self, label: &str, steps_per_run: usize, local_done: usize, stage_units: usize) {
64        if self.total_units == 0 {
65            return;
66        }
67
68        let overall_done = self.completed_units + local_done;
69        let percent = (overall_done * 100) / self.total_units;
70
71        if percent > self.last_percent_printed {
72            self.last_percent_printed = percent;
73            println!(
74                "[dsfb-add] {percent:>3}% - {label} (N={steps_per_run}, {local_done}/{stage_units} lambda samples)"
75            );
76        }
77    }
78
79    fn finish_stage(&mut self, stage_units: usize) {
80        self.completed_units += stage_units;
81    }
82
83    fn finish_all(&mut self) {
84        if self.last_percent_printed < 100 {
85            self.last_percent_printed = 100;
86            println!("[dsfb-add] 100% - sweep complete");
87        }
88    }
89}
90
91#[derive(Debug, Clone, Copy)]
92pub(crate) struct DriveSignal {
93    pub phase_bias: f64,
94    pub trust_bias: f64,
95    pub drift_bias: f64,
96}
97
98pub(crate) fn deterministic_drive(seed: u64, lambda: f64, salt: u64) -> DriveSignal {
99    let mut observer = DsfbObserver::new(DsfbParams::new(0.35, 0.08, 0.01, 0.92, 0.15), 2);
100    observer.init(DsfbState::new(lambda * 0.25, 0.0, 0.0));
101
102    let phase = lambda * std::f64::consts::TAU + (seed ^ salt) as f64 * 1.0e-6;
103    let dt = 0.125;
104
105    for step in 0..24 {
106        let t = step as f64 * dt;
107        let quantized0 =
108            (((seed.wrapping_add(salt).wrapping_add(step as u64)) % 11) as f64 - 5.0) * 0.01;
109        let quantized1 =
110            (((seed ^ salt).wrapping_add((step * 3) as u64) % 13) as f64 - 6.0) * 0.008;
111
112        let channel0 = lambda + 0.32 * (phase + 1.7 * t).sin() + quantized0;
113        let channel1 = lambda + 0.27 * (phase * 0.8 + 2.3 * t).cos() + quantized1;
114
115        observer.step(&[channel0, channel1], dt);
116    }
117
118    let state = observer.state();
119    DriveSignal {
120        phase_bias: state.phi.tanh(),
121        trust_bias: observer.trust_weight(0) - observer.trust_weight(1),
122        drift_bias: state.omega.tanh(),
123    }
124}
125
126pub fn run_sweeps_into_dir(
127    config: &SimulationConfig,
128    output_dir: &Path,
129) -> Result<SweepResult, AddError> {
130    config.validate()?;
131    fs::create_dir_all(output_dir)?;
132
133    let lambda_grid = config.lambda_grid();
134    let sweep_steps = config.sweep_steps();
135    let use_step_suffix = sweep_steps.len() > 1;
136    let canonical_steps = canonical_steps(config, &sweep_steps);
137    let lambda_count = lambda_grid.len();
138    let mut progress = ProgressTracker::new(total_progress_units(
139        config,
140        sweep_steps.len(),
141        lambda_count,
142    ));
143
144    let mut runs = Vec::with_capacity(sweep_steps.len());
145    let mut phase_rows = Vec::new();
146    let mut law_rows = Vec::new();
147    let mut scaling_rows = Vec::new();
148    let mut diagnostics_rows = Vec::new();
149    let mut threshold_rows = Vec::new();
150    let mut tcp_alignment_rows = Vec::new();
151    let mut robustness_rows = Vec::new();
152
153    let mut canonical_aet = None;
154    let mut canonical_tcp = None;
155    let mut canonical_rlt = None;
156    let mut canonical_iwlt = None;
157
158    for steps_per_run in sweep_steps {
159        let mut run_config = config.clone();
160        run_config.steps_per_run = steps_per_run;
161
162        let is_canonical = steps_per_run == canonical_steps;
163        let suffix = if use_step_suffix {
164            format!("_N{steps_per_run}")
165        } else {
166            String::new()
167        };
168
169        let (aet, aet_perturbed) = if config.enable_aet {
170            progress.stage_start("AET baseline", steps_per_run, lambda_count);
171            let baseline =
172                aet::run_aet_sweep_with_progress(&run_config, &lambda_grid, |completed, total| {
173                    progress.report("AET baseline", steps_per_run, completed, total)
174                })?;
175            progress.finish_stage(lambda_count);
176
177            progress.stage_start("AET perturbed", steps_per_run, lambda_count);
178            let perturbed = aet::run_aet_sweep_perturbed_with_progress(
179                &run_config,
180                &lambda_grid,
181                |completed, total| {
182                    progress.report("AET perturbed", steps_per_run, completed, total)
183                },
184            )?;
185            progress.finish_stage(lambda_count);
186
187            write_aet_csv(
188                &output_dir.join(format!("aet_sweep{suffix}.csv")),
189                &lambda_grid,
190                &baseline.echo_slope,
191                &baseline.avg_increment,
192                steps_per_run,
193                false,
194            )?;
195            write_aet_csv(
196                &output_dir.join(format!("aet_sweep_perturbed{suffix}.csv")),
197                &lambda_grid,
198                &perturbed.echo_slope,
199                &perturbed.avg_increment,
200                steps_per_run,
201                true,
202            )?;
203
204            if use_step_suffix && is_canonical {
205                write_aet_csv(
206                    &output_dir.join("aet_sweep.csv"),
207                    &lambda_grid,
208                    &baseline.echo_slope,
209                    &baseline.avg_increment,
210                    steps_per_run,
211                    false,
212                )?;
213                write_aet_csv(
214                    &output_dir.join("aet_sweep_perturbed.csv"),
215                    &lambda_grid,
216                    &perturbed.echo_slope,
217                    &perturbed.avg_increment,
218                    steps_per_run,
219                    true,
220                )?;
221            }
222
223            robustness_rows.push(comparison_metric(
224                "aet_curve_l2_diff",
225                steps_per_run,
226                0.0,
227                curve_l2_diff(&baseline.echo_slope, &perturbed.echo_slope),
228            ));
229            robustness_rows.push(comparison_metric(
230                "aet_curve_max_abs_diff",
231                steps_per_run,
232                0.0,
233                curve_max_abs_diff(&baseline.echo_slope, &perturbed.echo_slope),
234            ));
235
236            if is_canonical {
237                canonical_aet = Some(baseline.clone());
238            }
239
240            (Some(baseline), Some(perturbed))
241        } else {
242            (None, None)
243        };
244
245        let tcp = if config.enable_tcp {
246            progress.stage_start("TCP baseline", steps_per_run, lambda_count);
247            let baseline =
248                tcp::run_tcp_sweep_with_progress(&run_config, &lambda_grid, |completed, total| {
249                    progress.report("TCP baseline", steps_per_run, completed, total)
250                })?;
251            progress.finish_stage(lambda_count);
252
253            write_tcp_csv(
254                &output_dir.join(format!("tcp_sweep{suffix}.csv")),
255                &lambda_grid,
256                &baseline.betti0,
257                &baseline.betti1,
258                &baseline.l_tcp,
259                &baseline.avg_radius,
260                &baseline.max_radius,
261                &baseline.variance_radius,
262                steps_per_run,
263                false,
264            )?;
265
266            if use_step_suffix && is_canonical {
267                write_tcp_csv(
268                    &output_dir.join("tcp_sweep.csv"),
269                    &lambda_grid,
270                    &baseline.betti0,
271                    &baseline.betti1,
272                    &baseline.l_tcp,
273                    &baseline.avg_radius,
274                    &baseline.max_radius,
275                    &baseline.variance_radius,
276                    steps_per_run,
277                    false,
278                )?;
279            }
280
281            for points_dir in points_dirs(output_dir, steps_per_run, use_step_suffix, is_canonical)
282            {
283                fs::create_dir_all(&points_dir)?;
284                for (idx, runs_for_lambda) in baseline.point_cloud_runs.iter().enumerate() {
285                    for (run_idx, points) in runs_for_lambda.iter().enumerate() {
286                        let filename = format!("lambda_{idx:03}_run_{run_idx:02}.csv");
287                        write_tcp_points_csv(&points_dir.join(filename), points)?;
288                    }
289                }
290            }
291
292            if is_canonical {
293                canonical_tcp = Some(baseline.clone());
294            }
295
296            Some(baseline)
297        } else {
298            None
299        };
300
301        let (rlt, rlt_perturbed, baseline_phase, perturbed_phase) = if config.enable_rlt {
302            progress.stage_start("RLT baseline", steps_per_run, lambda_count);
303            let baseline =
304                rlt::run_rlt_sweep_with_progress(&run_config, &lambda_grid, |completed, total| {
305                    progress.report("RLT baseline", steps_per_run, completed, total)
306                })?;
307            progress.finish_stage(lambda_count);
308
309            progress.stage_start("RLT perturbed", steps_per_run, lambda_count);
310            let perturbed = rlt::run_rlt_sweep_perturbed_with_progress(
311                &run_config,
312                &lambda_grid,
313                |completed, total| {
314                    progress.report("RLT perturbed", steps_per_run, completed, total)
315                },
316            )?;
317            progress.finish_stage(lambda_count);
318            let baseline_phase = analyze_rlt_phase_boundary(
319                &lambda_grid,
320                &baseline.expansion_ratio,
321                &baseline.escape_rate,
322            )?;
323            let perturbed_phase = analyze_rlt_phase_boundary(
324                &lambda_grid,
325                &perturbed.expansion_ratio,
326                &perturbed.escape_rate,
327            )?;
328
329            write_rlt_csv(
330                &output_dir.join(format!("rlt_sweep{suffix}.csv")),
331                &lambda_grid,
332                &baseline.escape_rate,
333                &baseline.expansion_ratio,
334                steps_per_run,
335                false,
336            )?;
337            write_rlt_csv(
338                &output_dir.join(format!("rlt_sweep_perturbed{suffix}.csv")),
339                &lambda_grid,
340                &perturbed.escape_rate,
341                &perturbed.expansion_ratio,
342                steps_per_run,
343                true,
344            )?;
345
346            if use_step_suffix && is_canonical {
347                write_rlt_csv(
348                    &output_dir.join("rlt_sweep.csv"),
349                    &lambda_grid,
350                    &baseline.escape_rate,
351                    &baseline.expansion_ratio,
352                    steps_per_run,
353                    false,
354                )?;
355                write_rlt_csv(
356                    &output_dir.join("rlt_sweep_perturbed.csv"),
357                    &lambda_grid,
358                    &perturbed.escape_rate,
359                    &perturbed.expansion_ratio,
360                    steps_per_run,
361                    true,
362                )?;
363            }
364
365            phase_rows.push(phase_row("baseline", false, steps_per_run, baseline_phase));
366            phase_rows.push(phase_row("perturbed", true, steps_per_run, perturbed_phase));
367
368            robustness_rows.push(comparison_metric(
369                "rlt_curve_l2_diff",
370                steps_per_run,
371                0.0,
372                curve_l2_diff(&baseline.expansion_ratio, &perturbed.expansion_ratio),
373            ));
374            robustness_rows.push(comparison_metric(
375                "rlt_curve_max_abs_diff",
376                steps_per_run,
377                0.0,
378                curve_max_abs_diff(&baseline.expansion_ratio, &perturbed.expansion_ratio),
379            ));
380            robustness_rows.push(comparison_metric_option(
381                "lambda_star",
382                steps_per_run,
383                baseline_phase.lambda_star,
384                perturbed_phase.lambda_star,
385            ));
386            robustness_rows.push(comparison_metric_option(
387                "transition_width",
388                steps_per_run,
389                baseline_phase.transition_width,
390                perturbed_phase.transition_width,
391            ));
392            robustness_rows.push(comparison_metric_option(
393                "max_derivative",
394                steps_per_run,
395                baseline_phase.max_derivative,
396                perturbed_phase.max_derivative,
397            ));
398
399            for examples_dir in
400                example_dirs(output_dir, steps_per_run, use_step_suffix, is_canonical)
401            {
402                fs::create_dir_all(&examples_dir)?;
403                let (bounded_idx, expanding_idx) =
404                    rlt::find_representative_regime_indices(&baseline.escape_rate);
405                for (kind, idx) in [
406                    (RltExampleKind::Bounded, bounded_idx),
407                    (RltExampleKind::Expanding, expanding_idx),
408                ] {
409                    let lambda = lambda_grid[idx];
410                    let trajectory = rlt::simulate_example_trajectory(
411                        &run_config,
412                        lambda,
413                        rlt::RLT_EXAMPLE_STEPS,
414                    );
415                    let filename =
416                        format!("trajectory_{}_lambda_{idx:03}.csv", kind.filename_prefix());
417                    write_rlt_trajectory_csv(&examples_dir.join(filename), &trajectory)?;
418                }
419            }
420
421            if is_canonical {
422                canonical_rlt = Some(baseline.clone());
423            }
424
425            (
426                Some(baseline),
427                Some(perturbed),
428                Some(baseline_phase),
429                Some(perturbed_phase),
430            )
431        } else {
432            (None, None, None, None)
433        };
434
435        let (iwlt, iwlt_perturbed) = if config.enable_iwlt {
436            progress.stage_start("IWLT baseline", steps_per_run, lambda_count);
437            let baseline = iwlt::run_iwlt_sweep_with_progress(
438                &run_config,
439                &lambda_grid,
440                |completed, total| {
441                    progress.report("IWLT baseline", steps_per_run, completed, total)
442                },
443            )?;
444            progress.finish_stage(lambda_count);
445
446            progress.stage_start("IWLT perturbed", steps_per_run, lambda_count);
447            let perturbed = iwlt::run_iwlt_sweep_perturbed_with_progress(
448                &run_config,
449                &lambda_grid,
450                |completed, total| {
451                    progress.report("IWLT perturbed", steps_per_run, completed, total)
452                },
453            )?;
454            progress.finish_stage(lambda_count);
455
456            write_iwlt_csv(
457                &output_dir.join(format!("iwlt_sweep{suffix}.csv")),
458                &lambda_grid,
459                &baseline.entropy_density,
460                &baseline.avg_increment,
461                steps_per_run,
462                false,
463            )?;
464            write_iwlt_csv(
465                &output_dir.join(format!("iwlt_sweep_perturbed{suffix}.csv")),
466                &lambda_grid,
467                &perturbed.entropy_density,
468                &perturbed.avg_increment,
469                steps_per_run,
470                true,
471            )?;
472
473            if use_step_suffix && is_canonical {
474                write_iwlt_csv(
475                    &output_dir.join("iwlt_sweep.csv"),
476                    &lambda_grid,
477                    &baseline.entropy_density,
478                    &baseline.avg_increment,
479                    steps_per_run,
480                    false,
481                )?;
482                write_iwlt_csv(
483                    &output_dir.join("iwlt_sweep_perturbed.csv"),
484                    &lambda_grid,
485                    &perturbed.entropy_density,
486                    &perturbed.avg_increment,
487                    steps_per_run,
488                    true,
489                )?;
490            }
491
492            robustness_rows.push(comparison_metric(
493                "iwlt_curve_l2_diff",
494                steps_per_run,
495                0.0,
496                curve_l2_diff(&baseline.entropy_density, &perturbed.entropy_density),
497            ));
498            robustness_rows.push(comparison_metric(
499                "iwlt_curve_max_abs_diff",
500                steps_per_run,
501                0.0,
502                curve_max_abs_diff(&baseline.entropy_density, &perturbed.entropy_density),
503            ));
504
505            if is_canonical {
506                canonical_iwlt = Some(baseline.clone());
507            }
508
509            (Some(baseline), Some(perturbed))
510        } else {
511            (None, None)
512        };
513
514        if let (Some(aet_baseline), Some(iwlt_baseline)) = (&aet, &iwlt) {
515            let baseline_fit =
516                fit_with_ci(&aet_baseline.echo_slope, &iwlt_baseline.entropy_density)?;
517            let baseline_diag = diagnostics_from_fit(
518                &aet_baseline.echo_slope,
519                &iwlt_baseline.entropy_density,
520                &baseline_fit,
521            )?;
522            let baseline_row = law_summary_row(steps_per_run, false, baseline_fit, baseline_diag);
523            law_rows.push(baseline_row.clone());
524            scaling_rows.push(baseline_row);
525            diagnostics_rows.push(DiagnosticsSummaryRow {
526                steps_per_run,
527                residual_mean: baseline_diag.residual_mean,
528                residual_std: baseline_diag.residual_std,
529                residual_skew_approx: baseline_diag.residual_skew_approx,
530                residual_kurtosis_approx: baseline_diag.residual_kurtosis_approx,
531                ratio_mean: baseline_diag.ratio_mean,
532                ratio_std: baseline_diag.ratio_std,
533                ratio_min: baseline_diag.ratio_min,
534                ratio_max: baseline_diag.ratio_max,
535            });
536
537            if let Some(phase) = baseline_phase {
538                if let Some(phase_index) = closest_lambda_index(&lambda_grid, phase.lambda_star) {
539                    threshold_rows.push(CrossLayerThresholdRow {
540                        steps_per_run,
541                        lambda_star: phase.lambda_star,
542                        echo_slope_star: Some(aet_baseline.echo_slope[phase_index]),
543                        entropy_density_star: Some(iwlt_baseline.entropy_density[phase_index]),
544                    });
545                }
546            }
547
548            if let (Some(aet_perturbed_sweep), Some(iwlt_perturbed_sweep)) =
549                (&aet_perturbed, &iwlt_perturbed)
550            {
551                let perturbed_fit = fit_with_ci(
552                    &aet_perturbed_sweep.echo_slope,
553                    &iwlt_perturbed_sweep.entropy_density,
554                )?;
555                let perturbed_diag = diagnostics_from_fit(
556                    &aet_perturbed_sweep.echo_slope,
557                    &iwlt_perturbed_sweep.entropy_density,
558                    &perturbed_fit,
559                )?;
560                law_rows.push(law_summary_row(
561                    steps_per_run,
562                    true,
563                    perturbed_fit,
564                    perturbed_diag,
565                ));
566
567                robustness_rows.push(comparison_metric(
568                    "structural_law_slope",
569                    steps_per_run,
570                    baseline_fit.slope,
571                    perturbed_fit.slope,
572                ));
573                robustness_rows.push(comparison_metric(
574                    "structural_law_intercept",
575                    steps_per_run,
576                    baseline_fit.intercept,
577                    perturbed_fit.intercept,
578                ));
579                robustness_rows.push(comparison_metric(
580                    "structural_law_r2",
581                    steps_per_run,
582                    baseline_fit.r2,
583                    perturbed_fit.r2,
584                ));
585                robustness_rows.push(comparison_metric(
586                    "structural_law_residual_variance",
587                    steps_per_run,
588                    baseline_fit.residual_variance,
589                    perturbed_fit.residual_variance,
590                ));
591            }
592        }
593
594        if let (Some(tcp_baseline), Some(phase)) = (&tcp, baseline_phase) {
595            tcp_alignment_rows.push(tcp_phase_alignment_row(
596                steps_per_run,
597                phase.lambda_star,
598                peak_lambda(&lambda_grid, &tcp_baseline.l_tcp),
599                peak_lambda_usize(&lambda_grid, &tcp_baseline.betti1),
600            ));
601        }
602
603        let _ = rlt_perturbed;
604        let _ = perturbed_phase;
605
606        runs.push(SweepRunResult {
607            steps_per_run,
608            aet,
609            tcp,
610            rlt,
611            iwlt,
612        });
613    }
614
615    if !phase_rows.is_empty() {
616        write_rlt_phase_boundary_csv(&output_dir.join("rlt_phase_boundary.csv"), &phase_rows)?;
617    }
618    if !law_rows.is_empty() {
619        write_structural_law_summary_csv(&output_dir.join("aet_iwlt_law_summary.csv"), &law_rows)?;
620    }
621    if !scaling_rows.is_empty() {
622        write_structural_law_summary_csv(
623            &output_dir.join("aet_iwlt_scaling_summary.csv"),
624            &scaling_rows,
625        )?;
626    }
627    if !diagnostics_rows.is_empty() {
628        write_diagnostics_summary_csv(
629            &output_dir.join("aet_iwlt_diagnostics_summary.csv"),
630            &diagnostics_rows,
631        )?;
632    }
633    if !threshold_rows.is_empty() {
634        write_cross_layer_thresholds_csv(
635            &output_dir.join("cross_layer_thresholds.csv"),
636            &threshold_rows,
637        )?;
638    }
639    if !tcp_alignment_rows.is_empty() {
640        write_tcp_phase_alignment_csv(
641            &output_dir.join("tcp_phase_alignment.csv"),
642            &tcp_alignment_rows,
643        )?;
644    }
645    if !robustness_rows.is_empty() {
646        write_robustness_metrics_csv(&output_dir.join("robustness_metrics.csv"), &robustness_rows)?;
647    }
648
649    progress.finish_all();
650
651    Ok(SweepResult {
652        output_dir: output_dir.to_path_buf(),
653        lambda_grid,
654        runs,
655        aet: canonical_aet,
656        tcp: canonical_tcp,
657        rlt: canonical_rlt,
658        iwlt: canonical_iwlt,
659    })
660}
661
662fn canonical_steps(config: &SimulationConfig, sweep_steps: &[usize]) -> usize {
663    if sweep_steps.contains(&config.steps_per_run) {
664        config.steps_per_run
665    } else {
666        sweep_steps[0]
667    }
668}
669
670fn total_progress_units(
671    config: &SimulationConfig,
672    sweep_step_count: usize,
673    lambda_count: usize,
674) -> usize {
675    let stage_count = usize::from(config.enable_aet) * 2
676        + usize::from(config.enable_tcp)
677        + usize::from(config.enable_rlt) * 2
678        + usize::from(config.enable_iwlt) * 2;
679    stage_count * sweep_step_count * lambda_count
680}
681
682fn points_dirs(
683    output_dir: &Path,
684    steps_per_run: usize,
685    use_step_suffix: bool,
686    is_canonical: bool,
687) -> Vec<PathBuf> {
688    let mut dirs = Vec::new();
689    if use_step_suffix {
690        dirs.push(output_dir.join(format!("tcp_points_N{steps_per_run}")));
691        if is_canonical {
692            dirs.push(output_dir.join("tcp_points"));
693        }
694    } else {
695        dirs.push(output_dir.join("tcp_points"));
696    }
697    dirs
698}
699
700fn example_dirs(
701    output_dir: &Path,
702    steps_per_run: usize,
703    use_step_suffix: bool,
704    is_canonical: bool,
705) -> Vec<PathBuf> {
706    let mut dirs = Vec::new();
707    if use_step_suffix {
708        dirs.push(output_dir.join(format!("rlt_examples_N{steps_per_run}")));
709        if is_canonical {
710            dirs.push(output_dir.join("rlt_examples"));
711        }
712    } else {
713        dirs.push(output_dir.join("rlt_examples"));
714    }
715    dirs
716}
717
718fn phase_row(
719    mode: &str,
720    is_perturbed: bool,
721    steps_per_run: usize,
722    summary: RltPhaseBoundary,
723) -> PhaseBoundaryRow {
724    PhaseBoundaryRow {
725        steps_per_run,
726        mode: mode.to_string(),
727        is_perturbed,
728        lambda_star: summary.lambda_star,
729        lambda_0_1: summary.lambda_0_1,
730        lambda_0_9: summary.lambda_0_9,
731        transition_width: summary.transition_width,
732        max_derivative: summary.max_derivative,
733    }
734}
735
736fn law_summary_row(
737    steps_per_run: usize,
738    is_perturbed: bool,
739    fit: LinearFit,
740    diagnostics: crate::analysis::structural_law::StructuralLawDiagnostics,
741) -> StructuralLawSummaryRow {
742    StructuralLawSummaryRow {
743        steps_per_run,
744        is_perturbed,
745        pearson_r: fit.pearson_r,
746        spearman_rho: fit.spearman_rho,
747        slope: fit.slope,
748        intercept: fit.intercept,
749        r2: fit.r2,
750        residual_variance: fit.residual_variance,
751        mse_resid: fit.mse_resid,
752        slope_ci_low: fit.slope_ci_low,
753        slope_ci_high: fit.slope_ci_high,
754        sample_count: fit.sample_count,
755        ratio_mean: diagnostics.ratio_mean,
756        ratio_std: diagnostics.ratio_std,
757    }
758}
759
760fn closest_lambda_index(lambda_grid: &[f64], target: Option<f64>) -> Option<usize> {
761    let target = target?;
762    lambda_grid
763        .iter()
764        .enumerate()
765        .min_by(|(_, left), (_, right)| {
766            let left_delta = (*left - target).abs();
767            let right_delta = (*right - target).abs();
768            left_delta
769                .partial_cmp(&right_delta)
770                .unwrap_or(std::cmp::Ordering::Equal)
771        })
772        .map(|(idx, _)| idx)
773}
774
775fn peak_lambda(lambda_grid: &[f64], values: &[f64]) -> Option<f64> {
776    lambda_grid
777        .iter()
778        .zip(values.iter())
779        .max_by(|(_, left), (_, right)| {
780            left.partial_cmp(right).unwrap_or(std::cmp::Ordering::Equal)
781        })
782        .map(|(lambda, _)| *lambda)
783}
784
785fn peak_lambda_usize(lambda_grid: &[f64], values: &[usize]) -> Option<f64> {
786    lambda_grid
787        .iter()
788        .zip(values.iter())
789        .max_by_key(|(_, value)| **value)
790        .map(|(lambda, _)| *lambda)
791}
792
793fn tcp_phase_alignment_row(
794    steps_per_run: usize,
795    lambda_star: Option<f64>,
796    lambda_tp_peak: Option<f64>,
797    lambda_b1_peak: Option<f64>,
798) -> TcpPhaseAlignmentRow {
799    TcpPhaseAlignmentRow {
800        steps_per_run,
801        lambda_star,
802        lambda_tp_peak,
803        lambda_b1_peak,
804        delta_tp: option_diff(lambda_star, lambda_tp_peak),
805        delta_b1: option_diff(lambda_star, lambda_b1_peak),
806    }
807}
808
809fn comparison_metric(
810    metric: &str,
811    steps_per_run: usize,
812    baseline: f64,
813    perturbed: f64,
814) -> RobustnessMetricRow {
815    RobustnessMetricRow {
816        metric: metric.to_string(),
817        steps_per_run,
818        baseline,
819        perturbed,
820        delta: perturbed - baseline,
821    }
822}
823
824fn comparison_metric_option(
825    metric: &str,
826    steps_per_run: usize,
827    baseline: Option<f64>,
828    perturbed: Option<f64>,
829) -> RobustnessMetricRow {
830    comparison_metric(
831        metric,
832        steps_per_run,
833        baseline.unwrap_or(f64::NAN),
834        perturbed.unwrap_or(f64::NAN),
835    )
836}
837
838fn option_diff(left: Option<f64>, right: Option<f64>) -> Option<f64> {
839    match (left, right) {
840        (Some(left), Some(right)) => Some(left - right),
841        _ => None,
842    }
843}
844
845fn curve_l2_diff(baseline: &[f64], perturbed: &[f64]) -> f64 {
846    baseline
847        .iter()
848        .zip(perturbed.iter())
849        .map(|(base, perturbed_value)| {
850            let delta = perturbed_value - base;
851            delta * delta
852        })
853        .sum::<f64>()
854        .sqrt()
855}
856
857fn curve_max_abs_diff(baseline: &[f64], perturbed: &[f64]) -> f64 {
858    baseline
859        .iter()
860        .zip(perturbed.iter())
861        .map(|(base, perturbed_value)| (perturbed_value - base).abs())
862        .fold(0.0_f64, f64::max)
863}