Skip to main content

dsfb_computer_graphics/
sampling.rs

1use serde::{Deserialize, Serialize};
2
3use crate::config::DemoConfig;
4use crate::dsfb::DsfbRun;
5use crate::error::{Error, Result};
6use crate::frame::{mean_abs_error, Color, ImageFrame, ScalarField};
7use crate::scene::{
8    Rect, ScenarioExpectation, ScenarioId, ScenarioSupportCategory, SceneFrame, SceneSequence,
9};
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize)]
12pub enum AllocationPolicyId {
13    Uniform,
14    EdgeGuided,
15    ResidualGuided,
16    ContrastGuided,
17    VarianceGuided,
18    CombinedHeuristic,
19    NativeTrust,
20    ImportedTrust,
21    HybridTrustVariance,
22}
23
24impl AllocationPolicyId {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::Uniform => "uniform",
28            Self::EdgeGuided => "edge_guided",
29            Self::ResidualGuided => "residual_guided",
30            Self::ContrastGuided => "contrast_guided",
31            Self::VarianceGuided => "variance_guided",
32            Self::CombinedHeuristic => "combined_heuristic",
33            Self::NativeTrust => "native_trust",
34            Self::ImportedTrust => "imported_trust",
35            Self::HybridTrustVariance => "hybrid_trust_variance",
36        }
37    }
38
39    pub fn label(self) -> &'static str {
40        match self {
41            Self::Uniform => "Uniform",
42            Self::EdgeGuided => "Edge-guided",
43            Self::ResidualGuided => "Residual-guided",
44            Self::ContrastGuided => "Contrast-guided",
45            Self::VarianceGuided => "Variance-guided",
46            Self::CombinedHeuristic => "Combined heuristic",
47            Self::NativeTrust => "Native trust",
48            Self::ImportedTrust => "Imported trust",
49            Self::HybridTrustVariance => "Hybrid trust + variance",
50        }
51    }
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct BudgetCurvePoint {
56    pub average_spp: f32,
57    pub roi_mae: f32,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct BudgetCurve {
62    pub scenario_id: String,
63    pub policy_id: String,
64    pub points: Vec<BudgetCurvePoint>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct DemoBPolicyMetrics {
69    pub policy_id: String,
70    pub label: String,
71    pub total_samples: usize,
72    pub overall_mae: f32,
73    pub overall_rmse: f32,
74    pub roi_mae: f32,
75    pub roi_rmse: f32,
76    pub non_roi_mae: f32,
77    pub non_roi_rmse: f32,
78    pub roi_mean_spp: f32,
79    pub non_roi_mean_spp: f32,
80    pub max_spp: usize,
81    pub allocation_concentration: f32,
82    pub extra_roi_samples_vs_uniform: f32,
83    pub roi_error_reduction_per_extra_roi_sample: f32,
84}
85
86#[derive(Clone, Debug, Serialize)]
87pub struct DemoBScenarioReport {
88    pub scenario_id: String,
89    pub scenario_title: String,
90    pub expectation: ScenarioExpectation,
91    pub support_category: ScenarioSupportCategory,
92    pub sampling_taxonomy: String,
93    pub demo_b_taxonomy: String,
94    pub onset_frame: usize,
95    pub target_label: String,
96    pub target_pixels: usize,
97    pub policies: Vec<DemoBPolicyMetrics>,
98    pub headline: String,
99    pub bounded_note: String,
100}
101
102#[derive(Clone, Debug, Serialize)]
103pub struct DemoBSummary {
104    pub scenario_ids: Vec<String>,
105    pub policy_ids: Vec<String>,
106    pub primary_behavioral_result: String,
107    pub imported_trust_beats_uniform_scenarios: usize,
108    pub imported_trust_beats_combined_heuristic_scenarios: usize,
109    pub neutral_or_mixed_scenarios: Vec<String>,
110}
111
112#[derive(Clone, Debug, Serialize)]
113pub struct DemoBSuiteMetrics {
114    pub summary: DemoBSummary,
115    pub scenarios: Vec<DemoBScenarioReport>,
116    pub budget_efficiency_curves: Vec<BudgetCurve>,
117}
118
119#[derive(Clone, Debug)]
120pub struct DemoBPolicyRun {
121    pub policy_id: AllocationPolicyId,
122    pub frame: ImageFrame,
123    pub error: ScalarField,
124    pub spp: ScalarField,
125    pub metrics: DemoBPolicyMetrics,
126}
127
128#[derive(Clone, Debug)]
129pub struct DemoBScenarioRun {
130    pub reference_frame: ImageFrame,
131    pub policy_runs: Vec<DemoBPolicyRun>,
132    pub target_bbox: crate::frame::BoundingBox,
133}
134
135struct PolicyMetricsContext<'a> {
136    counts: &'a [usize],
137    frame: &'a ImageFrame,
138    reference_frame: &'a ImageFrame,
139    _error: &'a ScalarField,
140    target_mask: &'a [bool],
141    extra_roi_samples_vs_uniform: f32,
142    uniform_roi_mae: f32,
143}
144
145pub fn run_demo_b_suite(
146    config: &DemoConfig,
147    scenarios: &[(SceneSequence, DsfbRun)],
148) -> Result<(DemoBSuiteMetrics, Vec<(String, DemoBScenarioRun)>)> {
149    if scenarios.is_empty() {
150        return Err(Error::Message(
151            "Demo B suite requires at least one scenario".to_string(),
152        ));
153    }
154
155    let mut reports = Vec::with_capacity(scenarios.len());
156    let mut runs = Vec::with_capacity(scenarios.len());
157    let mut curves = Vec::new();
158
159    for (sequence, dsfb_host_realistic) in scenarios {
160        let (report, scenario_run, scenario_curves) =
161            run_demo_b_scenario(config, sequence, dsfb_host_realistic)?;
162        reports.push(report);
163        runs.push((sequence.scenario_id.as_str().to_string(), scenario_run));
164        curves.extend(scenario_curves);
165    }
166
167    let canonical = &reports[0];
168    let canonical_uniform = find_policy(canonical, AllocationPolicyId::Uniform)?;
169    let canonical_imported = find_policy(canonical, AllocationPolicyId::ImportedTrust)?;
170
171    let imported_trust_beats_uniform_scenarios = reports
172        .iter()
173        .filter(|report| {
174            let uniform = report
175                .policies
176                .iter()
177                .find(|policy| policy.policy_id == AllocationPolicyId::Uniform.as_str());
178            let trust = report
179                .policies
180                .iter()
181                .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str());
182            match (uniform, trust) {
183                (Some(uniform), Some(trust)) => trust.roi_mae + 1e-6 < uniform.roi_mae,
184                _ => false,
185            }
186        })
187        .count();
188    let imported_trust_beats_combined_heuristic_scenarios = reports
189        .iter()
190        .filter(|report| {
191            let combined = report
192                .policies
193                .iter()
194                .find(|policy| policy.policy_id == AllocationPolicyId::CombinedHeuristic.as_str());
195            let trust = report
196                .policies
197                .iter()
198                .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str());
199            match (combined, trust) {
200                (Some(combined), Some(trust)) => trust.roi_mae + 1e-6 < combined.roi_mae,
201                _ => false,
202            }
203        })
204        .count();
205
206    let neutral_or_mixed_scenarios = reports
207        .iter()
208        .filter(|report| {
209            matches!(report.expectation, ScenarioExpectation::NeutralExpected)
210                || report
211                    .policies
212                    .iter()
213                    .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str())
214                    .zip(report.policies.iter().find(|policy| {
215                        policy.policy_id == AllocationPolicyId::CombinedHeuristic.as_str()
216                    }))
217                    .map(|(trust, combined)| trust.roi_mae > combined.roi_mae)
218                    .unwrap_or(false)
219        })
220        .map(|report| report.scenario_id.clone())
221        .collect::<Vec<_>>();
222
223    Ok((
224        DemoBSuiteMetrics {
225            summary: DemoBSummary {
226                scenario_ids: reports
227                    .iter()
228                    .map(|report| report.scenario_id.clone())
229                    .collect(),
230                policy_ids: [
231                    AllocationPolicyId::Uniform,
232                    AllocationPolicyId::EdgeGuided,
233                    AllocationPolicyId::ResidualGuided,
234                    AllocationPolicyId::ContrastGuided,
235                    AllocationPolicyId::VarianceGuided,
236                    AllocationPolicyId::CombinedHeuristic,
237                    AllocationPolicyId::NativeTrust,
238                    AllocationPolicyId::ImportedTrust,
239                    AllocationPolicyId::HybridTrustVariance,
240                ]
241                .iter()
242                .map(|policy| policy.as_str().to_string())
243                .collect(),
244                primary_behavioral_result: format!(
245                    "On the canonical sampling scenario, imported trust reduced ROI MAE from {:.5} for uniform allocation to {:.5} under the same total budget.",
246                    canonical_uniform.roi_mae, canonical_imported.roi_mae
247                ),
248                imported_trust_beats_uniform_scenarios,
249                imported_trust_beats_combined_heuristic_scenarios,
250                neutral_or_mixed_scenarios,
251            },
252            scenarios: reports,
253            budget_efficiency_curves: curves,
254        },
255        runs,
256    ))
257}
258
259fn run_demo_b_scenario(
260    config: &DemoConfig,
261    sequence: &SceneSequence,
262    dsfb_host_realistic: &DsfbRun,
263) -> Result<(DemoBScenarioReport, DemoBScenarioRun, Vec<BudgetCurve>)> {
264    let onset = sequence
265        .onset_frame
266        .min(sequence.frames.len().saturating_sub(1));
267    let scene_frame = &sequence.frames[onset];
268    let width = sequence.config.width;
269    let height = sequence.config.height;
270    let total_pixels = width * height;
271    let uniform_total_samples = config.demo_b_uniform_spp * total_pixels;
272    let min_total = config.demo_b_min_spp * total_pixels;
273    let max_total = config.demo_b_max_spp * total_pixels;
274    if uniform_total_samples < min_total || uniform_total_samples > max_total {
275        return Err(Error::Message(
276            "Demo B uniform total sample budget is incompatible with the min/max spp bounds"
277                .to_string(),
278        ));
279    }
280
281    let reference_counts = vec![config.demo_b_reference_spp; total_pixels];
282    let uniform_counts = vec![config.demo_b_uniform_spp; total_pixels];
283    let reference_frame = render_with_counts(sequence, scene_frame, &reference_counts);
284    let pilot_a = render_with_counts(sequence, scene_frame, &vec![1usize; total_pixels]);
285    let pilot_b = render_with_offset_counts(sequence, scene_frame, &vec![1usize; total_pixels], 17);
286    let target_bbox = crate::frame::bounding_box_from_mask(&sequence.target_mask, width, height)
287        .ok_or_else(|| Error::Message("Demo B target mask was empty".to_string()))?;
288
289    let imported_trust = invert_trust(&dsfb_host_realistic.supervision_frames[onset].trust);
290    let edge_difficulty = gradient_field(&pilot_a);
291    let residual_difficulty = residual_proxy_field(&pilot_a);
292    let contrast_difficulty = local_contrast_field(&pilot_a);
293    let variance_difficulty = pilot_variance_field(&pilot_a, &pilot_b);
294    let combined_difficulty = combine_fields(
295        &[
296            (&edge_difficulty, 0.35),
297            (&residual_difficulty, 0.25),
298            (&contrast_difficulty, 0.25),
299            (&variance_difficulty, 0.15),
300        ],
301        width,
302        height,
303    );
304    let native_trust_difficulty = combine_fields(
305        &[
306            (&edge_difficulty, 0.18),
307            (&residual_difficulty, 0.28),
308            (&contrast_difficulty, 0.24),
309            (&variance_difficulty, 0.30),
310        ],
311        width,
312        height,
313    );
314    let hybrid_difficulty = combine_fields(
315        &[(&imported_trust, 0.55), (&variance_difficulty, 0.45)],
316        width,
317        height,
318    );
319
320    let policies = [
321        (AllocationPolicyId::Uniform, None),
322        (AllocationPolicyId::EdgeGuided, Some(&edge_difficulty)),
323        (
324            AllocationPolicyId::ResidualGuided,
325            Some(&residual_difficulty),
326        ),
327        (
328            AllocationPolicyId::ContrastGuided,
329            Some(&contrast_difficulty),
330        ),
331        (
332            AllocationPolicyId::VarianceGuided,
333            Some(&variance_difficulty),
334        ),
335        (
336            AllocationPolicyId::CombinedHeuristic,
337            Some(&combined_difficulty),
338        ),
339        (
340            AllocationPolicyId::NativeTrust,
341            Some(&native_trust_difficulty),
342        ),
343        (AllocationPolicyId::ImportedTrust, Some(&imported_trust)),
344        (
345            AllocationPolicyId::HybridTrustVariance,
346            Some(&hybrid_difficulty),
347        ),
348    ];
349
350    let uniform_frame = render_with_counts(sequence, scene_frame, &uniform_counts);
351    let uniform_error = build_error_field(&uniform_frame, &reference_frame);
352    let uniform_metrics = policy_metrics(
353        AllocationPolicyId::Uniform,
354        PolicyMetricsContext {
355            counts: &uniform_counts,
356            frame: &uniform_frame,
357            reference_frame: &reference_frame,
358            _error: &uniform_error,
359            target_mask: &sequence.target_mask,
360            extra_roi_samples_vs_uniform: 0.0,
361            uniform_roi_mae: 0.0,
362        },
363    );
364    let mut policy_runs = vec![DemoBPolicyRun {
365        policy_id: AllocationPolicyId::Uniform,
366        frame: uniform_frame,
367        error: uniform_error,
368        spp: build_count_field(&uniform_counts, width, height),
369        metrics: uniform_metrics.clone(),
370    }];
371
372    for (policy_id, field) in policies.iter().skip(1) {
373        let difficulty = field.expect("guided policies require a difficulty field");
374        let counts = guided_allocation(
375            difficulty,
376            uniform_total_samples,
377            config.demo_b_min_spp,
378            config.demo_b_max_spp,
379        )?;
380        let frame = render_with_counts(sequence, scene_frame, &counts);
381        let error = build_error_field(&frame, &reference_frame);
382        let roi_mean_spp = mean_count_over_mask(&counts, &sequence.target_mask);
383        let metrics = policy_metrics(
384            *policy_id,
385            PolicyMetricsContext {
386                counts: &counts,
387                frame: &frame,
388                reference_frame: &reference_frame,
389                _error: &error,
390                target_mask: &sequence.target_mask,
391                extra_roi_samples_vs_uniform: roi_mean_spp - uniform_metrics.roi_mean_spp,
392                uniform_roi_mae: uniform_metrics.roi_mae,
393            },
394        );
395        policy_runs.push(DemoBPolicyRun {
396            policy_id: *policy_id,
397            frame,
398            error,
399            spp: build_count_field(&counts, width, height),
400            metrics,
401        });
402    }
403
404    for run in &policy_runs {
405        let total = run.metrics.total_samples;
406        if total != uniform_total_samples {
407            return Err(Error::Message(format!(
408                "policy {} violated fixed-budget fairness: expected {}, got {}",
409                run.policy_id.as_str(),
410                uniform_total_samples,
411                total
412            )));
413        }
414    }
415
416    let budget_levels = [1.0f32, config.demo_b_uniform_spp as f32, 4.0f32, 8.0f32];
417    let mut curves = Vec::new();
418    for policy_id in [
419        AllocationPolicyId::Uniform,
420        AllocationPolicyId::CombinedHeuristic,
421        AllocationPolicyId::NativeTrust,
422        AllocationPolicyId::ImportedTrust,
423        AllocationPolicyId::HybridTrustVariance,
424    ] {
425        let mut points = Vec::new();
426        for average_spp in budget_levels {
427            let total_samples = (average_spp * total_pixels as f32).round() as usize;
428            let counts = match policy_id {
429                AllocationPolicyId::Uniform => {
430                    vec![average_spp.round().max(1.0) as usize; total_pixels]
431                }
432                AllocationPolicyId::CombinedHeuristic => guided_allocation(
433                    &combined_difficulty,
434                    total_samples,
435                    1,
436                    config.demo_b_max_spp,
437                )?,
438                AllocationPolicyId::NativeTrust => guided_allocation(
439                    &native_trust_difficulty,
440                    total_samples,
441                    1,
442                    config.demo_b_max_spp,
443                )?,
444                AllocationPolicyId::ImportedTrust => {
445                    guided_allocation(&imported_trust, total_samples, 1, config.demo_b_max_spp)?
446                }
447                AllocationPolicyId::HybridTrustVariance => {
448                    guided_allocation(&hybrid_difficulty, total_samples, 1, config.demo_b_max_spp)?
449                }
450                _ => unreachable!(),
451            };
452            let frame = render_with_counts(sequence, scene_frame, &counts);
453            points.push(BudgetCurvePoint {
454                average_spp,
455                roi_mae: mean_abs_error_over_mask(&frame, &reference_frame, &sequence.target_mask),
456            });
457        }
458        curves.push(BudgetCurve {
459            scenario_id: sequence.scenario_id.as_str().to_string(),
460            policy_id: policy_id.as_str().to_string(),
461            points,
462        });
463    }
464
465    let imported_trust_metrics = policy_runs
466        .iter()
467        .find(|run| run.policy_id == AllocationPolicyId::ImportedTrust)
468        .map(|run| &run.metrics)
469        .ok_or_else(|| Error::Message("imported trust policy missing".to_string()))?;
470    let combined_metrics = policy_runs
471        .iter()
472        .find(|run| run.policy_id == AllocationPolicyId::CombinedHeuristic)
473        .map(|run| &run.metrics)
474        .ok_or_else(|| Error::Message("combined heuristic policy missing".to_string()))?;
475
476    let headline = format!(
477        "{}: imported-trust ROI MAE {:.5}, combined-heuristic ROI MAE {:.5}, uniform ROI MAE {:.5}.",
478        sequence.scenario_title,
479        imported_trust_metrics.roi_mae,
480        combined_metrics.roi_mae,
481        uniform_metrics.roi_mae
482    );
483    let bounded_note = match sequence.expectation {
484        ScenarioExpectation::BenefitExpected => {
485            if imported_trust_metrics.roi_mae > combined_metrics.roi_mae {
486                "Combined heuristic remains stronger on this scenario, which is surfaced explicitly in the decision report.".to_string()
487            } else {
488                "Imported trust remains competitive under equal budget on this scenario.".to_string()
489            }
490        }
491        ScenarioExpectation::NeutralExpected => {
492            "Neutral case: guidance is not expected to produce a large win, so non-ROI penalties and concentration behavior matter more than raw ROI gain.".to_string()
493        }
494    };
495
496    Ok((
497        DemoBScenarioReport {
498            scenario_id: sequence.scenario_id.as_str().to_string(),
499            scenario_title: sequence.scenario_title.clone(),
500            expectation: sequence.expectation,
501            support_category: sequence.support_category,
502            sampling_taxonomy: sequence.sampling_taxonomy.clone(),
503            demo_b_taxonomy: sequence.demo_b_taxonomy.clone(),
504            onset_frame: onset,
505            target_label: sequence.target_label.clone(),
506            target_pixels: sequence.target_mask.iter().filter(|value| **value).count(),
507            policies: policy_runs.iter().map(|run| run.metrics.clone()).collect(),
508            headline,
509            bounded_note,
510        },
511        DemoBScenarioRun {
512            reference_frame,
513            policy_runs,
514            target_bbox,
515        },
516        curves,
517    ))
518}
519
520fn policy_metrics(
521    policy_id: AllocationPolicyId,
522    context: PolicyMetricsContext<'_>,
523) -> DemoBPolicyMetrics {
524    let non_roi_mask = context
525        .target_mask
526        .iter()
527        .map(|value| !value)
528        .collect::<Vec<_>>();
529    let total_samples = context.counts.iter().sum::<usize>();
530    let roi_mean_spp = mean_count_over_mask(context.counts, context.target_mask);
531    let non_roi_mean_spp = mean_count_over_mask(context.counts, &non_roi_mask);
532    let roi_mae =
533        mean_abs_error_over_mask(context.frame, context.reference_frame, context.target_mask);
534    let policy_label = policy_id.label().to_string();
535    let allocation_concentration = if non_roi_mean_spp <= f32::EPSILON {
536        0.0
537    } else {
538        roi_mean_spp / non_roi_mean_spp
539    };
540
541    DemoBPolicyMetrics {
542        policy_id: policy_id.as_str().to_string(),
543        label: policy_label,
544        total_samples,
545        overall_mae: mean_abs_error(context.frame, context.reference_frame),
546        overall_rmse: rmse(context.frame, context.reference_frame, None),
547        roi_mae,
548        roi_rmse: rmse(
549            context.frame,
550            context.reference_frame,
551            Some(context.target_mask),
552        ),
553        non_roi_mae: mean_abs_error_over_mask(
554            context.frame,
555            context.reference_frame,
556            &non_roi_mask,
557        ),
558        non_roi_rmse: rmse(context.frame, context.reference_frame, Some(&non_roi_mask)),
559        roi_mean_spp,
560        non_roi_mean_spp,
561        max_spp: context.counts.iter().copied().max().unwrap_or(0),
562        allocation_concentration,
563        extra_roi_samples_vs_uniform: context.extra_roi_samples_vs_uniform,
564        roi_error_reduction_per_extra_roi_sample: if context.extra_roi_samples_vs_uniform
565            <= f32::EPSILON
566        {
567            0.0
568        } else {
569            (context.uniform_roi_mae - roi_mae) / context.extra_roi_samples_vs_uniform
570        },
571    }
572}
573
574fn find_policy(
575    report: &DemoBScenarioReport,
576    policy_id: AllocationPolicyId,
577) -> Result<&DemoBPolicyMetrics> {
578    report
579        .policies
580        .iter()
581        .find(|policy| policy.policy_id == policy_id.as_str())
582        .ok_or_else(|| Error::Message(format!("Demo B policy {} missing", policy_id.as_str())))
583}
584
585fn render_with_counts(
586    sequence: &SceneSequence,
587    scene_frame: &SceneFrame,
588    counts: &[usize],
589) -> ImageFrame {
590    render_with_offset_counts(sequence, scene_frame, counts, 0)
591}
592
593fn render_with_offset_counts(
594    sequence: &SceneSequence,
595    scene_frame: &SceneFrame,
596    counts: &[usize],
597    seed_offset: u32,
598) -> ImageFrame {
599    let mut frame = ImageFrame::new(sequence.config.width, sequence.config.height);
600    for y in 0..sequence.config.height {
601        for x in 0..sequence.config.width {
602            let pixel_index = y * sequence.config.width + x;
603            let sample_count = counts[pixel_index].max(1);
604            let mut accum = Color::rgb(0.0, 0.0, 0.0);
605
606            for sample_index in 0..sample_count {
607                let (offset_x, offset_y) =
608                    sample_offset(pixel_index as u32 ^ seed_offset, sample_index as u32);
609                let sample = sample_scene_continuous(
610                    sequence,
611                    scene_frame,
612                    x as f32 + offset_x,
613                    y as f32 + offset_y,
614                );
615                accum = Color::rgb(accum.r + sample.r, accum.g + sample.g, accum.b + sample.b);
616            }
617
618            let inv = 1.0 / sample_count as f32;
619            frame.set(
620                x,
621                y,
622                Color::rgb(accum.r * inv, accum.g * inv, accum.b * inv).clamp01(),
623            );
624        }
625    }
626    frame
627}
628
629fn sample_scene_continuous(
630    sequence: &SceneSequence,
631    scene_frame: &SceneFrame,
632    sample_x: f32,
633    sample_y: f32,
634) -> Color {
635    let mut color =
636        background_color_continuous(sample_x, sample_y, &sequence.config, sequence.scenario_id);
637    if is_thin_structure_continuous(sample_x, sample_y, &sequence.config, sequence.scenario_id) {
638        color = thin_structure_color_continuous(
639            sample_x,
640            sample_y,
641            &sequence.config,
642            sequence.scenario_id,
643        );
644    }
645    if rect_contains_continuous(scene_frame.object_rect, sample_x, sample_y)
646        && !matches!(
647            sequence.scenario_id,
648            ScenarioId::ContrastPulse | ScenarioId::StabilityHoldout
649        )
650    {
651        color = object_color_continuous(sample_x, sample_y, scene_frame.object_rect);
652    }
653    if matches!(sequence.scenario_id, ScenarioId::ContrastPulse)
654        && scene_frame.index >= sequence.onset_frame
655        && contrast_pulse_rect(&sequence.config).contains(sample_x as i32, sample_y as i32)
656    {
657        color = Color::rgb(color.r * 1.22, color.g * 1.22, color.b * 1.22).clamp01();
658    }
659    color
660}
661
662fn background_color_continuous(
663    sample_x: f32,
664    sample_y: f32,
665    config: &crate::config::SceneConfig,
666    scenario_id: ScenarioId,
667) -> Color {
668    let xf = sample_x / config.width.max(1) as f32;
669    let yf = sample_y / config.height.max(1) as f32;
670    let checker = if ((sample_x / 12.0).floor() + (sample_y / 12.0).floor()) as i32 % 2 == 0 {
671        1.0
672    } else {
673        0.0
674    };
675    let diagonal = if (sample_x + 2.0 * sample_y).rem_euclid(22.0) < 6.0 {
676        1.0
677    } else {
678        0.0
679    };
680    let stripes = if (3.0 * sample_x + sample_y).rem_euclid(17.0) < 5.0 {
681        1.0
682    } else {
683        0.0
684    };
685    let vignette_x = (xf - 0.5).abs();
686    let vignette_y = (yf - 0.5).abs();
687    let vignette = 1.0 - (vignette_x * 0.35 + vignette_y * 0.4);
688
689    match scenario_id {
690        ScenarioId::ThinReveal | ScenarioId::StabilityHoldout => Color::rgb(
691            (0.12 + 0.16 * xf + 0.05 * checker + 0.03 * diagonal) * vignette,
692            (0.15 + 0.11 * yf + 0.04 * diagonal) * vignette,
693            (0.22 + 0.18 * (1.0 - xf) + 0.03 * checker) * vignette,
694        ),
695        ScenarioId::FastPan => Color::rgb(
696            (0.10 + 0.18 * xf + 0.08 * checker + 0.05 * stripes) * vignette,
697            (0.11 + 0.15 * yf + 0.10 * diagonal + 0.04 * stripes) * vignette,
698            (0.18 + 0.20 * (1.0 - xf) + 0.06 * checker) * vignette,
699        ),
700        ScenarioId::DiagonalReveal => Color::rgb(
701            (0.08 + 0.24 * checker + 0.20 * diagonal + 0.05 * xf) * vignette,
702            (0.08 + 0.18 * stripes + 0.07 * yf) * vignette,
703            (0.12 + 0.25 * (1.0 - checker) + 0.04 * xf) * vignette,
704        ),
705        ScenarioId::RevealBand
706        | ScenarioId::MotionBiasBand
707        | ScenarioId::LayeredSlats
708        | ScenarioId::NoisyReprojection
709        | ScenarioId::HeuristicFriendlyPan => {
710            let micro = ((sample_x * 0.83 + sample_y * 1.91).sin() * 0.5 + 0.5)
711                * ((sample_x * 1.37 - sample_y * 0.71).cos() * 0.5 + 0.5);
712            let band = if (18.0..=(config.height as f32 - 18.0)).contains(&sample_y)
713                && (26.0..=(config.width as f32 - 24.0)).contains(&sample_x)
714            {
715                1.0
716            } else {
717                0.0
718            };
719            let micro_gain = match scenario_id {
720                ScenarioId::LayeredSlats => 0.13,
721                ScenarioId::NoisyReprojection => 0.16,
722                ScenarioId::HeuristicFriendlyPan => 0.06,
723                _ => 0.10,
724            };
725            Color::rgb(
726                (0.10 + 0.14 * xf + 0.05 * checker + micro_gain * micro * band) * vignette,
727                (0.12 + 0.13 * yf + 0.06 * diagonal + 0.08 * micro * band) * vignette,
728                (0.16 + 0.18 * (1.0 - xf) + 0.07 * stripes + 0.12 * micro * band) * vignette,
729            )
730        }
731        ScenarioId::ContrastPulse => {
732            Color::rgb(0.18 + 0.06 * xf, 0.18 + 0.05 * yf, 0.24 + 0.06 * (1.0 - xf))
733        }
734    }
735}
736
737fn is_thin_structure_continuous(
738    sample_x: f32,
739    sample_y: f32,
740    config: &crate::config::SceneConfig,
741    scenario_id: ScenarioId,
742) -> bool {
743    if matches!(scenario_id, ScenarioId::ContrastPulse) {
744        return false;
745    }
746    let vertical_center = config.thin_vertical_x as f32 + 0.5;
747    let vertical = (sample_x - vertical_center).abs() <= 0.18
748        && sample_y >= 14.0
749        && sample_y <= config.height as f32 - 14.0;
750    let diagonal_line =
751        (sample_y - (0.58 * sample_x + 10.5)).abs() <= 0.20 && (28.0..=118.0).contains(&sample_x);
752    let mixed_width_band = {
753        let in_band = (18.0..=(config.height as f32 - 18.0)).contains(&sample_y)
754            && (26.0..=(config.width as f32 - 24.0)).contains(&sample_x);
755        let thin_slats = (sample_x - 28.0).rem_euclid(11.0) < 0.18;
756        let medium_slats = (sample_x - 34.0).rem_euclid(19.0) < 1.10;
757        let wide_slats = (sample_x - 48.0).rem_euclid(29.0) < 2.15;
758        let diagonal = (sample_y - (0.44 * sample_x + 12.0)).abs() <= 1.15
759            && (38.0..=(config.width as f32 - 32.0)).contains(&sample_x);
760        in_band && (thin_slats || medium_slats || wide_slats || diagonal)
761    };
762    match scenario_id {
763        ScenarioId::DiagonalReveal => diagonal_line,
764        ScenarioId::RevealBand
765        | ScenarioId::MotionBiasBand
766        | ScenarioId::LayeredSlats
767        | ScenarioId::NoisyReprojection
768        | ScenarioId::HeuristicFriendlyPan => mixed_width_band,
769        _ => vertical || diagonal_line,
770    }
771}
772
773fn thin_structure_color_continuous(
774    sample_x: f32,
775    sample_y: f32,
776    config: &crate::config::SceneConfig,
777    scenario_id: ScenarioId,
778) -> Color {
779    let vertical_center = config.thin_vertical_x as f32 + 0.5;
780    if !matches!(scenario_id, ScenarioId::DiagonalReveal)
781        && (sample_x - vertical_center).abs() <= 0.18
782    {
783        let pulse = if (sample_y / 3.0).floor() as i32 % 2 == 0 {
784            1.0
785        } else {
786            0.84
787        };
788        return Color::rgb(0.95 * pulse, 0.96 * pulse, 0.98);
789    }
790    if matches!(scenario_id, ScenarioId::DiagonalReveal) {
791        Color::rgb(0.24, 0.29, 0.35)
792    } else if matches!(
793        scenario_id,
794        ScenarioId::RevealBand | ScenarioId::LayeredSlats
795    ) {
796        let phase = ((sample_x + 2.0 * sample_y).rem_euclid(9.0)) / 8.0;
797        Color::rgb(
798            0.22 + 0.48 * phase,
799            0.58 + 0.26 * phase,
800            0.84 + 0.12 * (1.0 - phase),
801        )
802    } else if matches!(
803        scenario_id,
804        ScenarioId::MotionBiasBand | ScenarioId::NoisyReprojection
805    ) {
806        let phase = ((2.0 * sample_x + sample_y).rem_euclid(13.0)) / 12.0;
807        Color::rgb(
808            0.78 + 0.16 * phase,
809            0.74 + 0.10 * (1.0 - phase),
810            0.26 + 0.18 * phase,
811        )
812    } else if matches!(scenario_id, ScenarioId::HeuristicFriendlyPan) {
813        let phase = ((sample_x - 1.5 * sample_y).rem_euclid(15.0)) / 14.0;
814        Color::rgb(
815            0.18 + 0.14 * phase,
816            0.86 - 0.18 * phase,
817            0.28 + 0.10 * (1.0 - phase),
818        )
819    } else {
820        Color::rgb(0.64, 0.90, 0.96)
821    }
822}
823
824fn contrast_pulse_rect(config: &crate::config::SceneConfig) -> Rect {
825    Rect {
826        x: (config.width as i32 / 2) - 18,
827        y: (config.height as i32 / 2) - 18,
828        width: 52,
829        height: 36,
830    }
831}
832
833fn rect_contains_continuous(rect: Rect, sample_x: f32, sample_y: f32) -> bool {
834    sample_x >= rect.x as f32
835        && sample_x < (rect.x + rect.width) as f32
836        && sample_y >= rect.y as f32
837        && sample_y < (rect.y + rect.height) as f32
838}
839
840fn object_color_continuous(sample_x: f32, sample_y: f32, rect: Rect) -> Color {
841    let local_x = (sample_x - rect.x as f32) / rect.width.max(1) as f32;
842    let local_y = (sample_y - rect.y as f32) / rect.height.max(1) as f32;
843    let stripe = if (0.36..0.46).contains(&local_x) {
844        0.55
845    } else {
846        1.0
847    };
848    let rim = if !(0.05..=0.95).contains(&local_x) || !(0.05..=0.95).contains(&local_y) {
849        1.12
850    } else {
851        1.0
852    };
853    Color::rgb(
854        (0.82 + 0.10 * local_y) * stripe * rim,
855        (0.35 + 0.12 * (1.0 - local_y)) * stripe * rim,
856        (0.20 + 0.08 * local_x) * stripe * rim,
857    )
858    .clamp01()
859}
860
861pub(crate) fn guided_allocation(
862    difficulty: &ScalarField,
863    total_samples: usize,
864    min_spp: usize,
865    max_spp: usize,
866) -> Result<Vec<usize>> {
867    let total_pixels = difficulty.width() * difficulty.height();
868    let min_total = min_spp * total_pixels;
869    let max_total = max_spp * total_pixels;
870    if total_samples < min_total || total_samples > max_total {
871        return Err(Error::Message(
872            "guided allocation budget is incompatible with the min/max spp bounds".to_string(),
873        ));
874    }
875
876    let mut counts = vec![min_spp; total_pixels];
877    let mut remaining = total_samples - min_total;
878    if remaining == 0 {
879        return Ok(counts);
880    }
881
882    let weights = difficulty
883        .values()
884        .iter()
885        .map(|value| 0.05 + 0.95 * value.clamp(0.0, 1.0).powf(2.4))
886        .collect::<Vec<_>>();
887
888    while remaining > 0 {
889        let available_weight: f32 = counts
890            .iter()
891            .zip(weights.iter())
892            .filter_map(|(count, weight)| (*count < max_spp).then_some(*weight))
893            .sum();
894        if available_weight <= f32::EPSILON {
895            break;
896        }
897
898        let round_budget = remaining;
899        let mut floor_assignments = vec![0usize; total_pixels];
900        let mut fractional_parts = Vec::new();
901        for (index, (count, weight)) in counts
902            .iter()
903            .copied()
904            .zip(weights.iter().copied())
905            .enumerate()
906        {
907            if count >= max_spp {
908                continue;
909            }
910            let capacity = max_spp - count;
911            let target = round_budget as f32 * weight / available_weight;
912            let whole = target.floor() as usize;
913            let assigned = whole.min(capacity);
914            floor_assignments[index] = assigned;
915            if assigned < capacity {
916                fractional_parts.push((target - assigned as f32, index));
917            }
918        }
919
920        let mut assigned_this_round = 0usize;
921        for (count, extra) in counts.iter_mut().zip(floor_assignments.iter().copied()) {
922            *count += extra;
923            assigned_this_round += extra;
924        }
925        remaining -= assigned_this_round.min(remaining);
926        if remaining == 0 {
927            break;
928        }
929
930        fractional_parts.sort_by(|left, right| right.0.total_cmp(&left.0));
931        let mut assigned_fractional = 0usize;
932        for (_, index) in fractional_parts {
933            if remaining == 0 {
934                break;
935            }
936            if counts[index] < max_spp {
937                counts[index] += 1;
938                remaining -= 1;
939                assigned_fractional += 1;
940            }
941        }
942
943        if assigned_this_round == 0 && assigned_fractional == 0 {
944            let mut fallback = weights
945                .iter()
946                .copied()
947                .enumerate()
948                .filter(|(index, _)| counts[*index] < max_spp)
949                .map(|(index, weight)| (weight, index))
950                .collect::<Vec<_>>();
951            fallback.sort_by(|left, right| right.0.total_cmp(&left.0));
952            for (_, index) in fallback {
953                if remaining == 0 {
954                    break;
955                }
956                counts[index] += 1;
957                remaining -= 1;
958            }
959        }
960    }
961
962    Ok(counts)
963}
964
965fn build_error_field(frame: &ImageFrame, reference: &ImageFrame) -> ScalarField {
966    let mut field = ScalarField::new(frame.width(), frame.height());
967    for y in 0..frame.height() {
968        for x in 0..frame.width() {
969            field.set(x, y, frame.get(x, y).abs_diff(reference.get(x, y)));
970        }
971    }
972    field
973}
974
975pub(crate) fn build_count_field(counts: &[usize], width: usize, height: usize) -> ScalarField {
976    let mut field = ScalarField::new(width, height);
977    for y in 0..height {
978        for x in 0..width {
979            field.set(x, y, counts[y * width + x] as f32);
980        }
981    }
982    field
983}
984
985pub(crate) fn mean_count_over_mask(counts: &[usize], mask: &[bool]) -> f32 {
986    let mut sum = 0usize;
987    let mut count = 0usize;
988    for (spp, include) in counts.iter().copied().zip(mask.iter().copied()) {
989        if include {
990            sum += spp;
991            count += 1;
992        }
993    }
994    if count == 0 {
995        0.0
996    } else {
997        sum as f32 / count as f32
998    }
999}
1000
1001fn mean_abs_error_over_mask(frame_a: &ImageFrame, frame_b: &ImageFrame, mask: &[bool]) -> f32 {
1002    crate::frame::mean_abs_error_over_mask(frame_a, frame_b, mask)
1003}
1004
1005fn rmse(frame: &ImageFrame, reference: &ImageFrame, mask: Option<&[bool]>) -> f32 {
1006    let mut sum = 0.0;
1007    let mut count = 0usize;
1008
1009    for y in 0..frame.height() {
1010        for x in 0..frame.width() {
1011            let index = y * frame.width() + x;
1012            if mask.map(|values| values[index]).unwrap_or(true) {
1013                let diff = frame.get(x, y).abs_diff(reference.get(x, y));
1014                sum += diff * diff;
1015                count += 1;
1016            }
1017        }
1018    }
1019
1020    if count == 0 {
1021        0.0
1022    } else {
1023        (sum / count as f32).sqrt()
1024    }
1025}
1026
1027pub(crate) fn invert_trust(trust: &ScalarField) -> ScalarField {
1028    let mut field = ScalarField::new(trust.width(), trust.height());
1029    for y in 0..trust.height() {
1030        for x in 0..trust.width() {
1031            field.set(x, y, (1.0 - trust.get(x, y)).clamp(0.0, 1.0));
1032        }
1033    }
1034    field
1035}
1036
1037pub(crate) fn gradient_field(frame: &ImageFrame) -> ScalarField {
1038    let mut field = ScalarField::new(frame.width(), frame.height());
1039    for y in 0..frame.height() {
1040        for x in 0..frame.width() {
1041            let center = frame.get(x, y).luma();
1042            let left = frame.sample_clamped(x as i32 - 1, y as i32).luma();
1043            let right = frame.sample_clamped(x as i32 + 1, y as i32).luma();
1044            let up = frame.sample_clamped(x as i32, y as i32 - 1).luma();
1045            let down = frame.sample_clamped(x as i32, y as i32 + 1).luma();
1046            let grad = (right - left)
1047                .abs()
1048                .max((down - up).abs())
1049                .max((center - left).abs());
1050            field.set(x, y, (grad / 0.25).clamp(0.0, 1.0));
1051        }
1052    }
1053    field
1054}
1055
1056fn residual_proxy_field(frame: &ImageFrame) -> ScalarField {
1057    let mut field = ScalarField::new(frame.width(), frame.height());
1058    for y in 0..frame.height() {
1059        for x in 0..frame.width() {
1060            let blurred = box_blur_luma(frame, x, y);
1061            let residual = (frame.get(x, y).luma() - blurred).abs();
1062            field.set(x, y, (residual / 0.22).clamp(0.0, 1.0));
1063        }
1064    }
1065    field
1066}
1067
1068pub(crate) fn local_contrast_field(frame: &ImageFrame) -> ScalarField {
1069    let mut field = ScalarField::new(frame.width(), frame.height());
1070    for y in 0..frame.height() {
1071        for x in 0..frame.width() {
1072            let center = frame.get(x, y).luma();
1073            let mut strongest = 0.0f32;
1074            for dy in -1i32..=1 {
1075                for dx in -1i32..=1 {
1076                    if dx == 0 && dy == 0 {
1077                        continue;
1078                    }
1079                    let neighbor = frame.sample_clamped(x as i32 + dx, y as i32 + dy).luma();
1080                    strongest = strongest.max((center - neighbor).abs());
1081                }
1082            }
1083            field.set(x, y, (strongest / 0.18).clamp(0.0, 1.0));
1084        }
1085    }
1086    field
1087}
1088
1089fn pilot_variance_field(pilot_a: &ImageFrame, pilot_b: &ImageFrame) -> ScalarField {
1090    let mut field = ScalarField::new(pilot_a.width(), pilot_a.height());
1091    for y in 0..pilot_a.height() {
1092        for x in 0..pilot_a.width() {
1093            let diff = pilot_a.get(x, y).abs_diff(pilot_b.get(x, y));
1094            field.set(x, y, (diff / 0.20).clamp(0.0, 1.0));
1095        }
1096    }
1097    field
1098}
1099
1100pub(crate) fn combine_fields(
1101    fields: &[(&ScalarField, f32)],
1102    width: usize,
1103    height: usize,
1104) -> ScalarField {
1105    let mut field = ScalarField::new(width, height);
1106    for y in 0..height {
1107        for x in 0..width {
1108            let value = fields
1109                .iter()
1110                .map(|(current, weight)| current.get(x, y) * *weight)
1111                .sum::<f32>()
1112                .clamp(0.0, 1.0);
1113            field.set(x, y, value);
1114        }
1115    }
1116    field
1117}
1118
1119fn box_blur_luma(frame: &ImageFrame, x: usize, y: usize) -> f32 {
1120    let mut sum = 0.0;
1121    let mut count = 0usize;
1122    for dy in -1i32..=1 {
1123        for dx in -1i32..=1 {
1124            let color = frame.sample_clamped(x as i32 + dx, y as i32 + dy);
1125            sum += color.luma();
1126            count += 1;
1127        }
1128    }
1129    sum / count as f32
1130}
1131
1132fn sample_offset(pixel_seed: u32, sample_index: u32) -> (f32, f32) {
1133    let shift_x = unit_hash(pixel_seed ^ 0x9e37_79b9);
1134    let shift_y = unit_hash(pixel_seed ^ 0x85eb_ca6b);
1135    let u = (radical_inverse(sample_index + 1, 2) + shift_x).fract();
1136    let v = (radical_inverse(sample_index + 1, 3) + shift_y).fract();
1137    (u, v)
1138}
1139
1140fn unit_hash(value: u32) -> f32 {
1141    let mixed = value.wrapping_mul(0x045d_9f3b).rotate_left(7) ^ 0xa511_e9b3;
1142    (mixed as f32 / u32::MAX as f32).fract()
1143}
1144
1145fn radical_inverse(mut index: u32, base: u32) -> f32 {
1146    let mut reversed = 0.0;
1147    let mut inv_base = 1.0 / base as f32;
1148    while index > 0 {
1149        let digit = index % base;
1150        reversed += digit as f32 * inv_base;
1151        index /= base;
1152        inv_base /= base as f32;
1153    }
1154    reversed
1155}