1use serde::Serialize;
2
3use crate::config::DemoConfig;
4use crate::dsfb::run_profiled_taa;
5use crate::error::{Error, Result};
6use crate::frame::{mean_abs_error_over_mask, ImageFrame};
7use crate::host::{
8 default_host_realistic_profile, motion_augmented_profile, HostSupervisionProfile,
9};
10use crate::scene::{
11 generate_sequence_for_definition, scenario_by_id, ScenarioExpectation, ScenarioId,
12 SceneSequence,
13};
14use crate::taa::run_fixed_alpha_baseline;
15
16const SENSITIVITY_SCENARIOS: &[ScenarioId] = &[
17 ScenarioId::ThinReveal,
18 ScenarioId::RevealBand,
19 ScenarioId::MotionBiasBand,
20 ScenarioId::ContrastPulse,
21];
22
23#[derive(Clone, Debug, Serialize)]
24pub struct ParameterSweepPoint {
25 pub parameter_id: String,
26 pub profile_mode: String,
27 pub setting_label: String,
28 pub numeric_value: f32,
29 pub benefit_scenarios_beating_fixed: usize,
30 pub benefit_scenarios_with_zero_ghost_frames: usize,
31 pub canonical_cumulative_roi_mae: f32,
32 pub region_mean_cumulative_roi_mae: f32,
33 pub motion_bias_cumulative_roi_mae: f32,
34 pub neutral_non_roi_mae: f32,
35 pub robust_corridor_member: bool,
36 pub robustness_class: String,
37}
38
39#[derive(Clone, Debug, Serialize)]
40pub struct ParameterSensitivityMetrics {
41 pub baseline_mode: String,
42 pub sweep_points: Vec<ParameterSweepPoint>,
43 pub notes: Vec<String>,
44}
45
46#[derive(Clone, Copy)]
47struct ScenarioEval {
48 expectation: ScenarioExpectation,
49 cumulative_roi_mae: f32,
50 average_non_roi_mae: f32,
51 ghost_persistence_frames: usize,
52 beat_fixed: bool,
53}
54
55pub fn run_parameter_sensitivity_study(config: &DemoConfig) -> Result<ParameterSensitivityMetrics> {
56 let baseline_profile =
57 default_host_realistic_profile(config.dsfb_alpha_range.min, config.dsfb_alpha_range.max);
58 let baseline_eval = evaluate_profile(config, &baseline_profile)?;
59
60 let mut sweep_points = Vec::new();
61 for factor in [0.5f32, 0.75, 1.0, 1.25, 1.5] {
62 let mut profile = baseline_profile.clone();
63 profile.parameters.weights.depth *= factor;
64 sweep_points.push(build_sweep_point(
65 config,
66 "depth_weight",
67 "host_realistic",
68 factor,
69 &profile,
70 &baseline_eval,
71 )?);
72
73 let mut profile = baseline_profile.clone();
74 profile.parameters.weights.thin *= factor;
75 sweep_points.push(build_sweep_point(
76 config,
77 "thin_weight",
78 "host_realistic",
79 factor,
80 &profile,
81 &baseline_eval,
82 )?);
83
84 let mut profile = baseline_profile.clone();
85 profile.parameters.weights.grammar *= factor;
86 sweep_points.push(build_sweep_point(
87 config,
88 "grammar_weight",
89 "host_realistic",
90 factor,
91 &profile,
92 &baseline_eval,
93 )?);
94
95 let mut profile = baseline_profile.clone();
96 profile.parameters.thresholds.residual.low *= factor;
97 profile.parameters.thresholds.residual.high *= factor;
98 sweep_points.push(build_sweep_point(
99 config,
100 "residual_threshold_scale",
101 "host_realistic",
102 factor,
103 &profile,
104 &baseline_eval,
105 )?);
106
107 let mut motion_profile =
108 motion_augmented_profile(config.dsfb_alpha_range.min, config.dsfb_alpha_range.max);
109 motion_profile.parameters.weights.motion *= factor;
110 sweep_points.push(build_sweep_point(
111 config,
112 "motion_weight",
113 "motion_augmented",
114 factor,
115 &motion_profile,
116 &baseline_eval,
117 )?);
118 }
119
120 for alpha_min in [0.04f32, config.dsfb_alpha_range.min, 0.12f32] {
121 let mut profile = baseline_profile.clone();
122 profile.parameters.alpha_range.min = alpha_min;
123 sweep_points.push(build_sweep_point(
124 config,
125 "alpha_min",
126 "host_realistic",
127 alpha_min,
128 &profile,
129 &baseline_eval,
130 )?);
131 }
132
133 for alpha_max in [0.84f32, config.dsfb_alpha_range.max, 0.99f32] {
134 let mut profile = baseline_profile.clone();
135 profile.parameters.alpha_range.max = alpha_max;
136 sweep_points.push(build_sweep_point(
137 config,
138 "alpha_max",
139 "host_realistic",
140 alpha_max,
141 &profile,
142 &baseline_eval,
143 )?);
144 }
145
146 Ok(ParameterSensitivityMetrics {
147 baseline_mode: baseline_profile.label,
148 sweep_points,
149 notes: vec![
150 "These sweeps are one-at-a-time sensitivity checks around the centralized hand-set parameterization. They are intended to show robustness corridors, not to overclaim a global optimum.".to_string(),
151 "The motion-weight sweep uses the optional motion-augmented profile because the minimum host-realistic path no longer includes motion disagreement by default.".to_string(),
152 ],
153 })
154}
155
156fn build_sweep_point(
157 config: &DemoConfig,
158 parameter_id: &str,
159 profile_mode: &str,
160 numeric_value: f32,
161 profile: &HostSupervisionProfile,
162 baseline_eval: &[(&'static str, ScenarioEval)],
163) -> Result<ParameterSweepPoint> {
164 let current = evaluate_profile(config, profile)?;
165 let baseline_motion = scenario_metric(baseline_eval, "motion_bias_band")?;
166 let baseline_neutral = scenario_metric(baseline_eval, "contrast_pulse")?;
167 let motion = scenario_metric(¤t, "motion_bias_band")?;
168 let neutral = scenario_metric(¤t, "contrast_pulse")?;
169
170 let benefit_scenarios_beating_fixed = current
171 .iter()
172 .filter(|(_, metric)| {
173 matches!(metric.expectation, ScenarioExpectation::BenefitExpected) && metric.beat_fixed
174 })
175 .count();
176 let benefit_scenarios_with_zero_ghost_frames = current
177 .iter()
178 .filter(|(_, metric)| {
179 matches!(metric.expectation, ScenarioExpectation::BenefitExpected)
180 && metric.ghost_persistence_frames == 0
181 })
182 .count();
183
184 Ok(ParameterSweepPoint {
185 parameter_id: parameter_id.to_string(),
186 profile_mode: profile_mode.to_string(),
187 setting_label: format!("{parameter_id}={numeric_value:.3}"),
188 numeric_value,
189 benefit_scenarios_beating_fixed,
190 benefit_scenarios_with_zero_ghost_frames,
191 canonical_cumulative_roi_mae: scenario_metric(¤t, "thin_reveal")?.cumulative_roi_mae,
192 region_mean_cumulative_roi_mae: mean_region_roi_mae(¤t),
193 motion_bias_cumulative_roi_mae: motion.cumulative_roi_mae,
194 neutral_non_roi_mae: neutral.average_non_roi_mae,
195 robust_corridor_member: benefit_scenarios_beating_fixed >= 2
196 && motion.cumulative_roi_mae <= baseline_motion.cumulative_roi_mae * 1.20
197 && neutral.average_non_roi_mae <= baseline_neutral.average_non_roi_mae * 1.25,
198 robustness_class: classify_robustness(
199 benefit_scenarios_beating_fixed,
200 motion.cumulative_roi_mae,
201 baseline_motion.cumulative_roi_mae,
202 neutral.average_non_roi_mae,
203 baseline_neutral.average_non_roi_mae,
204 )
205 .to_string(),
206 })
207}
208
209fn evaluate_profile(
210 config: &DemoConfig,
211 profile: &HostSupervisionProfile,
212) -> Result<Vec<(&'static str, ScenarioEval)>> {
213 let mut results = Vec::new();
214 for scenario_id in SENSITIVITY_SCENARIOS {
215 let definition = scenario_by_id(&config.scene, *scenario_id).ok_or_else(|| {
216 Error::Message(format!(
217 "parameter sensitivity scenario {} unavailable",
218 scenario_id.as_str()
219 ))
220 })?;
221 let sequence = generate_sequence_for_definition(&definition);
222 let fixed = run_fixed_alpha_baseline(&sequence, config.baseline.fixed_alpha);
223 let profiled = run_profiled_taa(&sequence, profile);
224 let fixed_metric = evaluate_run(&sequence, &fixed.taa.resolved_frames);
225 let profiled_metric = evaluate_run(&sequence, &profiled.resolved_frames);
226 results.push((
227 scenario_id.as_str(),
228 ScenarioEval {
229 expectation: sequence.expectation,
230 cumulative_roi_mae: profiled_metric.cumulative_roi_mae,
231 average_non_roi_mae: profiled_metric.average_non_roi_mae,
232 ghost_persistence_frames: profiled_metric.ghost_persistence_frames,
233 beat_fixed: profiled_metric.cumulative_roi_mae + 1.0e-6
234 < fixed_metric.cumulative_roi_mae,
235 },
236 ));
237 }
238 Ok(results)
239}
240
241#[derive(Clone, Copy)]
242struct RunEval {
243 cumulative_roi_mae: f32,
244 average_non_roi_mae: f32,
245 ghost_persistence_frames: usize,
246}
247
248fn evaluate_run(sequence: &SceneSequence, resolved_frames: &[ImageFrame]) -> RunEval {
249 let target_mask = &sequence.target_mask;
250 let non_roi_mask = target_mask.iter().map(|value| !value).collect::<Vec<_>>();
251 let onset = sequence
252 .onset_frame
253 .min(sequence.frames.len().saturating_sub(1));
254 let threshold = persistence_threshold(sequence);
255 let mut cumulative_roi_mae = 0.0;
256 let mut average_non_roi_mae = 0.0;
257 let mut ghost_persistence_frames = 0usize;
258
259 for frame_index in 0..sequence.frames.len() {
260 let gt = &sequence.frames[frame_index].ground_truth;
261 let resolved = &resolved_frames[frame_index];
262 let roi_mae = mean_abs_error_over_mask(resolved, gt, target_mask);
263 let non_roi_mae = mean_abs_error_over_mask(resolved, gt, &non_roi_mask);
264 cumulative_roi_mae += roi_mae;
265 average_non_roi_mae += non_roi_mae;
266 if frame_index >= onset && roi_mae > threshold {
267 ghost_persistence_frames += 1;
268 }
269 }
270
271 RunEval {
272 cumulative_roi_mae,
273 average_non_roi_mae: average_non_roi_mae / sequence.frames.len().max(1) as f32,
274 ghost_persistence_frames,
275 }
276}
277
278fn persistence_threshold(sequence: &SceneSequence) -> f32 {
279 if sequence.onset_frame == 0 {
280 return 0.02;
281 }
282 let previous = &sequence.frames[sequence.onset_frame - 1].ground_truth;
283 let current = &sequence.frames[sequence.onset_frame].ground_truth;
284 (mean_abs_error_over_mask(previous, current, &sequence.target_mask) * 0.15).max(0.02)
285}
286
287fn scenario_metric<'a>(
288 values: &'a [(&'static str, ScenarioEval)],
289 scenario_id: &str,
290) -> Result<&'a ScenarioEval> {
291 values
292 .iter()
293 .find(|(current, _)| *current == scenario_id)
294 .map(|(_, metric)| metric)
295 .ok_or_else(|| Error::Message(format!("sensitivity metric {scenario_id} missing")))
296}
297
298fn mean_region_roi_mae(values: &[(&'static str, ScenarioEval)]) -> f32 {
299 let mut sum = 0.0;
300 let mut count = 0usize;
301 for (scenario_id, metric) in values {
302 if matches!(*scenario_id, "reveal_band" | "motion_bias_band") {
303 sum += metric.cumulative_roi_mae;
304 count += 1;
305 }
306 }
307 if count == 0 {
308 0.0
309 } else {
310 sum / count as f32
311 }
312}
313
314fn classify_robustness(
315 benefit_scenarios_beating_fixed: usize,
316 motion_roi_mae: f32,
317 baseline_motion_roi_mae: f32,
318 neutral_non_roi_mae: f32,
319 baseline_neutral_non_roi_mae: f32,
320) -> &'static str {
321 if benefit_scenarios_beating_fixed >= 2
322 && motion_roi_mae <= baseline_motion_roi_mae * 1.20
323 && neutral_non_roi_mae <= baseline_neutral_non_roi_mae * 1.25
324 {
325 "robust"
326 } else if benefit_scenarios_beating_fixed >= 2
327 && motion_roi_mae <= baseline_motion_roi_mae * 1.35
328 && neutral_non_roi_mae <= baseline_neutral_non_roi_mae * 1.40
329 {
330 "moderately_sensitive"
331 } else {
332 "fragile"
333 }
334}