Skip to main content

dsfb_computer_graphics/
sensitivity.rs

1use serde::Serialize;
2
3use crate::config::DemoConfig;
4use crate::dsfb::run_profiled_taa;
5use crate::error::{Error, Result};
6use crate::frame::{mean_abs_error_over_mask, ImageFrame};
7use crate::host::{
8    default_host_realistic_profile, motion_augmented_profile, HostSupervisionProfile,
9};
10use crate::scene::{
11    generate_sequence_for_definition, scenario_by_id, ScenarioExpectation, ScenarioId,
12    SceneSequence,
13};
14use crate::taa::run_fixed_alpha_baseline;
15
16const SENSITIVITY_SCENARIOS: &[ScenarioId] = &[
17    ScenarioId::ThinReveal,
18    ScenarioId::RevealBand,
19    ScenarioId::MotionBiasBand,
20    ScenarioId::ContrastPulse,
21];
22
23#[derive(Clone, Debug, Serialize)]
24pub struct ParameterSweepPoint {
25    pub parameter_id: String,
26    pub profile_mode: String,
27    pub setting_label: String,
28    pub numeric_value: f32,
29    pub benefit_scenarios_beating_fixed: usize,
30    pub benefit_scenarios_with_zero_ghost_frames: usize,
31    pub canonical_cumulative_roi_mae: f32,
32    pub region_mean_cumulative_roi_mae: f32,
33    pub motion_bias_cumulative_roi_mae: f32,
34    pub neutral_non_roi_mae: f32,
35    pub robust_corridor_member: bool,
36    pub robustness_class: String,
37}
38
39#[derive(Clone, Debug, Serialize)]
40pub struct ParameterSensitivityMetrics {
41    pub baseline_mode: String,
42    pub sweep_points: Vec<ParameterSweepPoint>,
43    pub notes: Vec<String>,
44}
45
46#[derive(Clone, Copy)]
47struct ScenarioEval {
48    expectation: ScenarioExpectation,
49    cumulative_roi_mae: f32,
50    average_non_roi_mae: f32,
51    ghost_persistence_frames: usize,
52    beat_fixed: bool,
53}
54
55pub fn run_parameter_sensitivity_study(config: &DemoConfig) -> Result<ParameterSensitivityMetrics> {
56    let baseline_profile =
57        default_host_realistic_profile(config.dsfb_alpha_range.min, config.dsfb_alpha_range.max);
58    let baseline_eval = evaluate_profile(config, &baseline_profile)?;
59
60    let mut sweep_points = Vec::new();
61    for factor in [0.5f32, 0.75, 1.0, 1.25, 1.5] {
62        let mut profile = baseline_profile.clone();
63        profile.parameters.weights.depth *= factor;
64        sweep_points.push(build_sweep_point(
65            config,
66            "depth_weight",
67            "host_realistic",
68            factor,
69            &profile,
70            &baseline_eval,
71        )?);
72
73        let mut profile = baseline_profile.clone();
74        profile.parameters.weights.thin *= factor;
75        sweep_points.push(build_sweep_point(
76            config,
77            "thin_weight",
78            "host_realistic",
79            factor,
80            &profile,
81            &baseline_eval,
82        )?);
83
84        let mut profile = baseline_profile.clone();
85        profile.parameters.weights.grammar *= factor;
86        sweep_points.push(build_sweep_point(
87            config,
88            "grammar_weight",
89            "host_realistic",
90            factor,
91            &profile,
92            &baseline_eval,
93        )?);
94
95        let mut profile = baseline_profile.clone();
96        profile.parameters.thresholds.residual.low *= factor;
97        profile.parameters.thresholds.residual.high *= factor;
98        sweep_points.push(build_sweep_point(
99            config,
100            "residual_threshold_scale",
101            "host_realistic",
102            factor,
103            &profile,
104            &baseline_eval,
105        )?);
106
107        let mut motion_profile =
108            motion_augmented_profile(config.dsfb_alpha_range.min, config.dsfb_alpha_range.max);
109        motion_profile.parameters.weights.motion *= factor;
110        sweep_points.push(build_sweep_point(
111            config,
112            "motion_weight",
113            "motion_augmented",
114            factor,
115            &motion_profile,
116            &baseline_eval,
117        )?);
118    }
119
120    for alpha_min in [0.04f32, config.dsfb_alpha_range.min, 0.12f32] {
121        let mut profile = baseline_profile.clone();
122        profile.parameters.alpha_range.min = alpha_min;
123        sweep_points.push(build_sweep_point(
124            config,
125            "alpha_min",
126            "host_realistic",
127            alpha_min,
128            &profile,
129            &baseline_eval,
130        )?);
131    }
132
133    for alpha_max in [0.84f32, config.dsfb_alpha_range.max, 0.99f32] {
134        let mut profile = baseline_profile.clone();
135        profile.parameters.alpha_range.max = alpha_max;
136        sweep_points.push(build_sweep_point(
137            config,
138            "alpha_max",
139            "host_realistic",
140            alpha_max,
141            &profile,
142            &baseline_eval,
143        )?);
144    }
145
146    Ok(ParameterSensitivityMetrics {
147        baseline_mode: baseline_profile.label,
148        sweep_points,
149        notes: vec![
150            "These sweeps are one-at-a-time sensitivity checks around the centralized hand-set parameterization. They are intended to show robustness corridors, not to overclaim a global optimum.".to_string(),
151            "The motion-weight sweep uses the optional motion-augmented profile because the minimum host-realistic path no longer includes motion disagreement by default.".to_string(),
152        ],
153    })
154}
155
156fn build_sweep_point(
157    config: &DemoConfig,
158    parameter_id: &str,
159    profile_mode: &str,
160    numeric_value: f32,
161    profile: &HostSupervisionProfile,
162    baseline_eval: &[(&'static str, ScenarioEval)],
163) -> Result<ParameterSweepPoint> {
164    let current = evaluate_profile(config, profile)?;
165    let baseline_motion = scenario_metric(baseline_eval, "motion_bias_band")?;
166    let baseline_neutral = scenario_metric(baseline_eval, "contrast_pulse")?;
167    let motion = scenario_metric(&current, "motion_bias_band")?;
168    let neutral = scenario_metric(&current, "contrast_pulse")?;
169
170    let benefit_scenarios_beating_fixed = current
171        .iter()
172        .filter(|(_, metric)| {
173            matches!(metric.expectation, ScenarioExpectation::BenefitExpected) && metric.beat_fixed
174        })
175        .count();
176    let benefit_scenarios_with_zero_ghost_frames = current
177        .iter()
178        .filter(|(_, metric)| {
179            matches!(metric.expectation, ScenarioExpectation::BenefitExpected)
180                && metric.ghost_persistence_frames == 0
181        })
182        .count();
183
184    Ok(ParameterSweepPoint {
185        parameter_id: parameter_id.to_string(),
186        profile_mode: profile_mode.to_string(),
187        setting_label: format!("{parameter_id}={numeric_value:.3}"),
188        numeric_value,
189        benefit_scenarios_beating_fixed,
190        benefit_scenarios_with_zero_ghost_frames,
191        canonical_cumulative_roi_mae: scenario_metric(&current, "thin_reveal")?.cumulative_roi_mae,
192        region_mean_cumulative_roi_mae: mean_region_roi_mae(&current),
193        motion_bias_cumulative_roi_mae: motion.cumulative_roi_mae,
194        neutral_non_roi_mae: neutral.average_non_roi_mae,
195        robust_corridor_member: benefit_scenarios_beating_fixed >= 2
196            && motion.cumulative_roi_mae <= baseline_motion.cumulative_roi_mae * 1.20
197            && neutral.average_non_roi_mae <= baseline_neutral.average_non_roi_mae * 1.25,
198        robustness_class: classify_robustness(
199            benefit_scenarios_beating_fixed,
200            motion.cumulative_roi_mae,
201            baseline_motion.cumulative_roi_mae,
202            neutral.average_non_roi_mae,
203            baseline_neutral.average_non_roi_mae,
204        )
205        .to_string(),
206    })
207}
208
209fn evaluate_profile(
210    config: &DemoConfig,
211    profile: &HostSupervisionProfile,
212) -> Result<Vec<(&'static str, ScenarioEval)>> {
213    let mut results = Vec::new();
214    for scenario_id in SENSITIVITY_SCENARIOS {
215        let definition = scenario_by_id(&config.scene, *scenario_id).ok_or_else(|| {
216            Error::Message(format!(
217                "parameter sensitivity scenario {} unavailable",
218                scenario_id.as_str()
219            ))
220        })?;
221        let sequence = generate_sequence_for_definition(&definition);
222        let fixed = run_fixed_alpha_baseline(&sequence, config.baseline.fixed_alpha);
223        let profiled = run_profiled_taa(&sequence, profile);
224        let fixed_metric = evaluate_run(&sequence, &fixed.taa.resolved_frames);
225        let profiled_metric = evaluate_run(&sequence, &profiled.resolved_frames);
226        results.push((
227            scenario_id.as_str(),
228            ScenarioEval {
229                expectation: sequence.expectation,
230                cumulative_roi_mae: profiled_metric.cumulative_roi_mae,
231                average_non_roi_mae: profiled_metric.average_non_roi_mae,
232                ghost_persistence_frames: profiled_metric.ghost_persistence_frames,
233                beat_fixed: profiled_metric.cumulative_roi_mae + 1.0e-6
234                    < fixed_metric.cumulative_roi_mae,
235            },
236        ));
237    }
238    Ok(results)
239}
240
241#[derive(Clone, Copy)]
242struct RunEval {
243    cumulative_roi_mae: f32,
244    average_non_roi_mae: f32,
245    ghost_persistence_frames: usize,
246}
247
248fn evaluate_run(sequence: &SceneSequence, resolved_frames: &[ImageFrame]) -> RunEval {
249    let target_mask = &sequence.target_mask;
250    let non_roi_mask = target_mask.iter().map(|value| !value).collect::<Vec<_>>();
251    let onset = sequence
252        .onset_frame
253        .min(sequence.frames.len().saturating_sub(1));
254    let threshold = persistence_threshold(sequence);
255    let mut cumulative_roi_mae = 0.0;
256    let mut average_non_roi_mae = 0.0;
257    let mut ghost_persistence_frames = 0usize;
258
259    for frame_index in 0..sequence.frames.len() {
260        let gt = &sequence.frames[frame_index].ground_truth;
261        let resolved = &resolved_frames[frame_index];
262        let roi_mae = mean_abs_error_over_mask(resolved, gt, target_mask);
263        let non_roi_mae = mean_abs_error_over_mask(resolved, gt, &non_roi_mask);
264        cumulative_roi_mae += roi_mae;
265        average_non_roi_mae += non_roi_mae;
266        if frame_index >= onset && roi_mae > threshold {
267            ghost_persistence_frames += 1;
268        }
269    }
270
271    RunEval {
272        cumulative_roi_mae,
273        average_non_roi_mae: average_non_roi_mae / sequence.frames.len().max(1) as f32,
274        ghost_persistence_frames,
275    }
276}
277
278fn persistence_threshold(sequence: &SceneSequence) -> f32 {
279    if sequence.onset_frame == 0 {
280        return 0.02;
281    }
282    let previous = &sequence.frames[sequence.onset_frame - 1].ground_truth;
283    let current = &sequence.frames[sequence.onset_frame].ground_truth;
284    (mean_abs_error_over_mask(previous, current, &sequence.target_mask) * 0.15).max(0.02)
285}
286
287fn scenario_metric<'a>(
288    values: &'a [(&'static str, ScenarioEval)],
289    scenario_id: &str,
290) -> Result<&'a ScenarioEval> {
291    values
292        .iter()
293        .find(|(current, _)| *current == scenario_id)
294        .map(|(_, metric)| metric)
295        .ok_or_else(|| Error::Message(format!("sensitivity metric {scenario_id} missing")))
296}
297
298fn mean_region_roi_mae(values: &[(&'static str, ScenarioEval)]) -> f32 {
299    let mut sum = 0.0;
300    let mut count = 0usize;
301    for (scenario_id, metric) in values {
302        if matches!(*scenario_id, "reveal_band" | "motion_bias_band") {
303            sum += metric.cumulative_roi_mae;
304            count += 1;
305        }
306    }
307    if count == 0 {
308        0.0
309    } else {
310        sum / count as f32
311    }
312}
313
314fn classify_robustness(
315    benefit_scenarios_beating_fixed: usize,
316    motion_roi_mae: f32,
317    baseline_motion_roi_mae: f32,
318    neutral_non_roi_mae: f32,
319    baseline_neutral_non_roi_mae: f32,
320) -> &'static str {
321    if benefit_scenarios_beating_fixed >= 2
322        && motion_roi_mae <= baseline_motion_roi_mae * 1.20
323        && neutral_non_roi_mae <= baseline_neutral_non_roi_mae * 1.25
324    {
325        "robust"
326    } else if benefit_scenarios_beating_fixed >= 2
327        && motion_roi_mae <= baseline_motion_roi_mae * 1.35
328        && neutral_non_roi_mae <= baseline_neutral_non_roi_mae * 1.40
329    {
330        "moderately_sensitive"
331    } else {
332        "fragile"
333    }
334}