1use serde::{Deserialize, Serialize};
2
3use crate::config::DemoConfig;
4use crate::dsfb::DsfbRun;
5use crate::error::{Error, Result};
6use crate::frame::{mean_abs_error, Color, ImageFrame, ScalarField};
7use crate::scene::{
8 Rect, ScenarioExpectation, ScenarioId, ScenarioSupportCategory, SceneFrame, SceneSequence,
9};
10
11#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize)]
12pub enum AllocationPolicyId {
13 Uniform,
14 EdgeGuided,
15 ResidualGuided,
16 ContrastGuided,
17 VarianceGuided,
18 CombinedHeuristic,
19 NativeTrust,
20 ImportedTrust,
21 HybridTrustVariance,
22}
23
24impl AllocationPolicyId {
25 pub fn as_str(self) -> &'static str {
26 match self {
27 Self::Uniform => "uniform",
28 Self::EdgeGuided => "edge_guided",
29 Self::ResidualGuided => "residual_guided",
30 Self::ContrastGuided => "contrast_guided",
31 Self::VarianceGuided => "variance_guided",
32 Self::CombinedHeuristic => "combined_heuristic",
33 Self::NativeTrust => "native_trust",
34 Self::ImportedTrust => "imported_trust",
35 Self::HybridTrustVariance => "hybrid_trust_variance",
36 }
37 }
38
39 pub fn label(self) -> &'static str {
40 match self {
41 Self::Uniform => "Uniform",
42 Self::EdgeGuided => "Edge-guided",
43 Self::ResidualGuided => "Residual-guided",
44 Self::ContrastGuided => "Contrast-guided",
45 Self::VarianceGuided => "Variance-guided",
46 Self::CombinedHeuristic => "Combined heuristic",
47 Self::NativeTrust => "Native trust",
48 Self::ImportedTrust => "Imported trust",
49 Self::HybridTrustVariance => "Hybrid trust + variance",
50 }
51 }
52}
53
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct BudgetCurvePoint {
56 pub average_spp: f32,
57 pub roi_mae: f32,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct BudgetCurve {
62 pub scenario_id: String,
63 pub policy_id: String,
64 pub points: Vec<BudgetCurvePoint>,
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct DemoBPolicyMetrics {
69 pub policy_id: String,
70 pub label: String,
71 pub total_samples: usize,
72 pub overall_mae: f32,
73 pub overall_rmse: f32,
74 pub roi_mae: f32,
75 pub roi_rmse: f32,
76 pub non_roi_mae: f32,
77 pub non_roi_rmse: f32,
78 pub roi_mean_spp: f32,
79 pub non_roi_mean_spp: f32,
80 pub max_spp: usize,
81 pub allocation_concentration: f32,
82 pub extra_roi_samples_vs_uniform: f32,
83 pub roi_error_reduction_per_extra_roi_sample: f32,
84}
85
86#[derive(Clone, Debug, Serialize)]
87pub struct DemoBScenarioReport {
88 pub scenario_id: String,
89 pub scenario_title: String,
90 pub expectation: ScenarioExpectation,
91 pub support_category: ScenarioSupportCategory,
92 pub sampling_taxonomy: String,
93 pub demo_b_taxonomy: String,
94 pub onset_frame: usize,
95 pub target_label: String,
96 pub target_pixels: usize,
97 pub policies: Vec<DemoBPolicyMetrics>,
98 pub headline: String,
99 pub bounded_note: String,
100}
101
102#[derive(Clone, Debug, Serialize)]
103pub struct DemoBSummary {
104 pub scenario_ids: Vec<String>,
105 pub policy_ids: Vec<String>,
106 pub primary_behavioral_result: String,
107 pub imported_trust_beats_uniform_scenarios: usize,
108 pub imported_trust_beats_combined_heuristic_scenarios: usize,
109 pub neutral_or_mixed_scenarios: Vec<String>,
110}
111
112#[derive(Clone, Debug, Serialize)]
113pub struct DemoBSuiteMetrics {
114 pub summary: DemoBSummary,
115 pub scenarios: Vec<DemoBScenarioReport>,
116 pub budget_efficiency_curves: Vec<BudgetCurve>,
117}
118
119#[derive(Clone, Debug)]
120pub struct DemoBPolicyRun {
121 pub policy_id: AllocationPolicyId,
122 pub frame: ImageFrame,
123 pub error: ScalarField,
124 pub spp: ScalarField,
125 pub metrics: DemoBPolicyMetrics,
126}
127
128#[derive(Clone, Debug)]
129pub struct DemoBScenarioRun {
130 pub reference_frame: ImageFrame,
131 pub policy_runs: Vec<DemoBPolicyRun>,
132 pub target_bbox: crate::frame::BoundingBox,
133}
134
135struct PolicyMetricsContext<'a> {
136 counts: &'a [usize],
137 frame: &'a ImageFrame,
138 reference_frame: &'a ImageFrame,
139 _error: &'a ScalarField,
140 target_mask: &'a [bool],
141 extra_roi_samples_vs_uniform: f32,
142 uniform_roi_mae: f32,
143}
144
145pub fn run_demo_b_suite(
146 config: &DemoConfig,
147 scenarios: &[(SceneSequence, DsfbRun)],
148) -> Result<(DemoBSuiteMetrics, Vec<(String, DemoBScenarioRun)>)> {
149 if scenarios.is_empty() {
150 return Err(Error::Message(
151 "Demo B suite requires at least one scenario".to_string(),
152 ));
153 }
154
155 let mut reports = Vec::with_capacity(scenarios.len());
156 let mut runs = Vec::with_capacity(scenarios.len());
157 let mut curves = Vec::new();
158
159 for (sequence, dsfb_host_realistic) in scenarios {
160 let (report, scenario_run, scenario_curves) =
161 run_demo_b_scenario(config, sequence, dsfb_host_realistic)?;
162 reports.push(report);
163 runs.push((sequence.scenario_id.as_str().to_string(), scenario_run));
164 curves.extend(scenario_curves);
165 }
166
167 let canonical = &reports[0];
168 let canonical_uniform = find_policy(canonical, AllocationPolicyId::Uniform)?;
169 let canonical_imported = find_policy(canonical, AllocationPolicyId::ImportedTrust)?;
170
171 let imported_trust_beats_uniform_scenarios = reports
172 .iter()
173 .filter(|report| {
174 let uniform = report
175 .policies
176 .iter()
177 .find(|policy| policy.policy_id == AllocationPolicyId::Uniform.as_str());
178 let trust = report
179 .policies
180 .iter()
181 .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str());
182 match (uniform, trust) {
183 (Some(uniform), Some(trust)) => trust.roi_mae + 1e-6 < uniform.roi_mae,
184 _ => false,
185 }
186 })
187 .count();
188 let imported_trust_beats_combined_heuristic_scenarios = reports
189 .iter()
190 .filter(|report| {
191 let combined = report
192 .policies
193 .iter()
194 .find(|policy| policy.policy_id == AllocationPolicyId::CombinedHeuristic.as_str());
195 let trust = report
196 .policies
197 .iter()
198 .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str());
199 match (combined, trust) {
200 (Some(combined), Some(trust)) => trust.roi_mae + 1e-6 < combined.roi_mae,
201 _ => false,
202 }
203 })
204 .count();
205
206 let neutral_or_mixed_scenarios = reports
207 .iter()
208 .filter(|report| {
209 matches!(report.expectation, ScenarioExpectation::NeutralExpected)
210 || report
211 .policies
212 .iter()
213 .find(|policy| policy.policy_id == AllocationPolicyId::ImportedTrust.as_str())
214 .zip(report.policies.iter().find(|policy| {
215 policy.policy_id == AllocationPolicyId::CombinedHeuristic.as_str()
216 }))
217 .map(|(trust, combined)| trust.roi_mae > combined.roi_mae)
218 .unwrap_or(false)
219 })
220 .map(|report| report.scenario_id.clone())
221 .collect::<Vec<_>>();
222
223 Ok((
224 DemoBSuiteMetrics {
225 summary: DemoBSummary {
226 scenario_ids: reports
227 .iter()
228 .map(|report| report.scenario_id.clone())
229 .collect(),
230 policy_ids: [
231 AllocationPolicyId::Uniform,
232 AllocationPolicyId::EdgeGuided,
233 AllocationPolicyId::ResidualGuided,
234 AllocationPolicyId::ContrastGuided,
235 AllocationPolicyId::VarianceGuided,
236 AllocationPolicyId::CombinedHeuristic,
237 AllocationPolicyId::NativeTrust,
238 AllocationPolicyId::ImportedTrust,
239 AllocationPolicyId::HybridTrustVariance,
240 ]
241 .iter()
242 .map(|policy| policy.as_str().to_string())
243 .collect(),
244 primary_behavioral_result: format!(
245 "On the canonical sampling scenario, imported trust reduced ROI MAE from {:.5} for uniform allocation to {:.5} under the same total budget.",
246 canonical_uniform.roi_mae, canonical_imported.roi_mae
247 ),
248 imported_trust_beats_uniform_scenarios,
249 imported_trust_beats_combined_heuristic_scenarios,
250 neutral_or_mixed_scenarios,
251 },
252 scenarios: reports,
253 budget_efficiency_curves: curves,
254 },
255 runs,
256 ))
257}
258
259fn run_demo_b_scenario(
260 config: &DemoConfig,
261 sequence: &SceneSequence,
262 dsfb_host_realistic: &DsfbRun,
263) -> Result<(DemoBScenarioReport, DemoBScenarioRun, Vec<BudgetCurve>)> {
264 let onset = sequence
265 .onset_frame
266 .min(sequence.frames.len().saturating_sub(1));
267 let scene_frame = &sequence.frames[onset];
268 let width = sequence.config.width;
269 let height = sequence.config.height;
270 let total_pixels = width * height;
271 let uniform_total_samples = config.demo_b_uniform_spp * total_pixels;
272 let min_total = config.demo_b_min_spp * total_pixels;
273 let max_total = config.demo_b_max_spp * total_pixels;
274 if uniform_total_samples < min_total || uniform_total_samples > max_total {
275 return Err(Error::Message(
276 "Demo B uniform total sample budget is incompatible with the min/max spp bounds"
277 .to_string(),
278 ));
279 }
280
281 let reference_counts = vec![config.demo_b_reference_spp; total_pixels];
282 let uniform_counts = vec![config.demo_b_uniform_spp; total_pixels];
283 let reference_frame = render_with_counts(sequence, scene_frame, &reference_counts);
284 let pilot_a = render_with_counts(sequence, scene_frame, &vec![1usize; total_pixels]);
285 let pilot_b = render_with_offset_counts(sequence, scene_frame, &vec![1usize; total_pixels], 17);
286 let target_bbox = crate::frame::bounding_box_from_mask(&sequence.target_mask, width, height)
287 .ok_or_else(|| Error::Message("Demo B target mask was empty".to_string()))?;
288
289 let imported_trust = invert_trust(&dsfb_host_realistic.supervision_frames[onset].trust);
290 let edge_difficulty = gradient_field(&pilot_a);
291 let residual_difficulty = residual_proxy_field(&pilot_a);
292 let contrast_difficulty = local_contrast_field(&pilot_a);
293 let variance_difficulty = pilot_variance_field(&pilot_a, &pilot_b);
294 let combined_difficulty = combine_fields(
295 &[
296 (&edge_difficulty, 0.35),
297 (&residual_difficulty, 0.25),
298 (&contrast_difficulty, 0.25),
299 (&variance_difficulty, 0.15),
300 ],
301 width,
302 height,
303 );
304 let native_trust_difficulty = combine_fields(
305 &[
306 (&edge_difficulty, 0.18),
307 (&residual_difficulty, 0.28),
308 (&contrast_difficulty, 0.24),
309 (&variance_difficulty, 0.30),
310 ],
311 width,
312 height,
313 );
314 let hybrid_difficulty = combine_fields(
315 &[(&imported_trust, 0.55), (&variance_difficulty, 0.45)],
316 width,
317 height,
318 );
319
320 let policies = [
321 (AllocationPolicyId::Uniform, None),
322 (AllocationPolicyId::EdgeGuided, Some(&edge_difficulty)),
323 (
324 AllocationPolicyId::ResidualGuided,
325 Some(&residual_difficulty),
326 ),
327 (
328 AllocationPolicyId::ContrastGuided,
329 Some(&contrast_difficulty),
330 ),
331 (
332 AllocationPolicyId::VarianceGuided,
333 Some(&variance_difficulty),
334 ),
335 (
336 AllocationPolicyId::CombinedHeuristic,
337 Some(&combined_difficulty),
338 ),
339 (
340 AllocationPolicyId::NativeTrust,
341 Some(&native_trust_difficulty),
342 ),
343 (AllocationPolicyId::ImportedTrust, Some(&imported_trust)),
344 (
345 AllocationPolicyId::HybridTrustVariance,
346 Some(&hybrid_difficulty),
347 ),
348 ];
349
350 let uniform_frame = render_with_counts(sequence, scene_frame, &uniform_counts);
351 let uniform_error = build_error_field(&uniform_frame, &reference_frame);
352 let uniform_metrics = policy_metrics(
353 AllocationPolicyId::Uniform,
354 PolicyMetricsContext {
355 counts: &uniform_counts,
356 frame: &uniform_frame,
357 reference_frame: &reference_frame,
358 _error: &uniform_error,
359 target_mask: &sequence.target_mask,
360 extra_roi_samples_vs_uniform: 0.0,
361 uniform_roi_mae: 0.0,
362 },
363 );
364 let mut policy_runs = vec![DemoBPolicyRun {
365 policy_id: AllocationPolicyId::Uniform,
366 frame: uniform_frame,
367 error: uniform_error,
368 spp: build_count_field(&uniform_counts, width, height),
369 metrics: uniform_metrics.clone(),
370 }];
371
372 for (policy_id, field) in policies.iter().skip(1) {
373 let difficulty = field.expect("guided policies require a difficulty field");
374 let counts = guided_allocation(
375 difficulty,
376 uniform_total_samples,
377 config.demo_b_min_spp,
378 config.demo_b_max_spp,
379 )?;
380 let frame = render_with_counts(sequence, scene_frame, &counts);
381 let error = build_error_field(&frame, &reference_frame);
382 let roi_mean_spp = mean_count_over_mask(&counts, &sequence.target_mask);
383 let metrics = policy_metrics(
384 *policy_id,
385 PolicyMetricsContext {
386 counts: &counts,
387 frame: &frame,
388 reference_frame: &reference_frame,
389 _error: &error,
390 target_mask: &sequence.target_mask,
391 extra_roi_samples_vs_uniform: roi_mean_spp - uniform_metrics.roi_mean_spp,
392 uniform_roi_mae: uniform_metrics.roi_mae,
393 },
394 );
395 policy_runs.push(DemoBPolicyRun {
396 policy_id: *policy_id,
397 frame,
398 error,
399 spp: build_count_field(&counts, width, height),
400 metrics,
401 });
402 }
403
404 for run in &policy_runs {
405 let total = run.metrics.total_samples;
406 if total != uniform_total_samples {
407 return Err(Error::Message(format!(
408 "policy {} violated fixed-budget fairness: expected {}, got {}",
409 run.policy_id.as_str(),
410 uniform_total_samples,
411 total
412 )));
413 }
414 }
415
416 let budget_levels = [1.0f32, config.demo_b_uniform_spp as f32, 4.0f32, 8.0f32];
417 let mut curves = Vec::new();
418 for policy_id in [
419 AllocationPolicyId::Uniform,
420 AllocationPolicyId::CombinedHeuristic,
421 AllocationPolicyId::NativeTrust,
422 AllocationPolicyId::ImportedTrust,
423 AllocationPolicyId::HybridTrustVariance,
424 ] {
425 let mut points = Vec::new();
426 for average_spp in budget_levels {
427 let total_samples = (average_spp * total_pixels as f32).round() as usize;
428 let counts = match policy_id {
429 AllocationPolicyId::Uniform => {
430 vec![average_spp.round().max(1.0) as usize; total_pixels]
431 }
432 AllocationPolicyId::CombinedHeuristic => guided_allocation(
433 &combined_difficulty,
434 total_samples,
435 1,
436 config.demo_b_max_spp,
437 )?,
438 AllocationPolicyId::NativeTrust => guided_allocation(
439 &native_trust_difficulty,
440 total_samples,
441 1,
442 config.demo_b_max_spp,
443 )?,
444 AllocationPolicyId::ImportedTrust => {
445 guided_allocation(&imported_trust, total_samples, 1, config.demo_b_max_spp)?
446 }
447 AllocationPolicyId::HybridTrustVariance => {
448 guided_allocation(&hybrid_difficulty, total_samples, 1, config.demo_b_max_spp)?
449 }
450 _ => unreachable!(),
451 };
452 let frame = render_with_counts(sequence, scene_frame, &counts);
453 points.push(BudgetCurvePoint {
454 average_spp,
455 roi_mae: mean_abs_error_over_mask(&frame, &reference_frame, &sequence.target_mask),
456 });
457 }
458 curves.push(BudgetCurve {
459 scenario_id: sequence.scenario_id.as_str().to_string(),
460 policy_id: policy_id.as_str().to_string(),
461 points,
462 });
463 }
464
465 let imported_trust_metrics = policy_runs
466 .iter()
467 .find(|run| run.policy_id == AllocationPolicyId::ImportedTrust)
468 .map(|run| &run.metrics)
469 .ok_or_else(|| Error::Message("imported trust policy missing".to_string()))?;
470 let combined_metrics = policy_runs
471 .iter()
472 .find(|run| run.policy_id == AllocationPolicyId::CombinedHeuristic)
473 .map(|run| &run.metrics)
474 .ok_or_else(|| Error::Message("combined heuristic policy missing".to_string()))?;
475
476 let headline = format!(
477 "{}: imported-trust ROI MAE {:.5}, combined-heuristic ROI MAE {:.5}, uniform ROI MAE {:.5}.",
478 sequence.scenario_title,
479 imported_trust_metrics.roi_mae,
480 combined_metrics.roi_mae,
481 uniform_metrics.roi_mae
482 );
483 let bounded_note = match sequence.expectation {
484 ScenarioExpectation::BenefitExpected => {
485 if imported_trust_metrics.roi_mae > combined_metrics.roi_mae {
486 "Combined heuristic remains stronger on this scenario, which is surfaced explicitly in the decision report.".to_string()
487 } else {
488 "Imported trust remains competitive under equal budget on this scenario.".to_string()
489 }
490 }
491 ScenarioExpectation::NeutralExpected => {
492 "Neutral case: guidance is not expected to produce a large win, so non-ROI penalties and concentration behavior matter more than raw ROI gain.".to_string()
493 }
494 };
495
496 Ok((
497 DemoBScenarioReport {
498 scenario_id: sequence.scenario_id.as_str().to_string(),
499 scenario_title: sequence.scenario_title.clone(),
500 expectation: sequence.expectation,
501 support_category: sequence.support_category,
502 sampling_taxonomy: sequence.sampling_taxonomy.clone(),
503 demo_b_taxonomy: sequence.demo_b_taxonomy.clone(),
504 onset_frame: onset,
505 target_label: sequence.target_label.clone(),
506 target_pixels: sequence.target_mask.iter().filter(|value| **value).count(),
507 policies: policy_runs.iter().map(|run| run.metrics.clone()).collect(),
508 headline,
509 bounded_note,
510 },
511 DemoBScenarioRun {
512 reference_frame,
513 policy_runs,
514 target_bbox,
515 },
516 curves,
517 ))
518}
519
520fn policy_metrics(
521 policy_id: AllocationPolicyId,
522 context: PolicyMetricsContext<'_>,
523) -> DemoBPolicyMetrics {
524 let non_roi_mask = context
525 .target_mask
526 .iter()
527 .map(|value| !value)
528 .collect::<Vec<_>>();
529 let total_samples = context.counts.iter().sum::<usize>();
530 let roi_mean_spp = mean_count_over_mask(context.counts, context.target_mask);
531 let non_roi_mean_spp = mean_count_over_mask(context.counts, &non_roi_mask);
532 let roi_mae =
533 mean_abs_error_over_mask(context.frame, context.reference_frame, context.target_mask);
534 let policy_label = policy_id.label().to_string();
535 let allocation_concentration = if non_roi_mean_spp <= f32::EPSILON {
536 0.0
537 } else {
538 roi_mean_spp / non_roi_mean_spp
539 };
540
541 DemoBPolicyMetrics {
542 policy_id: policy_id.as_str().to_string(),
543 label: policy_label,
544 total_samples,
545 overall_mae: mean_abs_error(context.frame, context.reference_frame),
546 overall_rmse: rmse(context.frame, context.reference_frame, None),
547 roi_mae,
548 roi_rmse: rmse(
549 context.frame,
550 context.reference_frame,
551 Some(context.target_mask),
552 ),
553 non_roi_mae: mean_abs_error_over_mask(
554 context.frame,
555 context.reference_frame,
556 &non_roi_mask,
557 ),
558 non_roi_rmse: rmse(context.frame, context.reference_frame, Some(&non_roi_mask)),
559 roi_mean_spp,
560 non_roi_mean_spp,
561 max_spp: context.counts.iter().copied().max().unwrap_or(0),
562 allocation_concentration,
563 extra_roi_samples_vs_uniform: context.extra_roi_samples_vs_uniform,
564 roi_error_reduction_per_extra_roi_sample: if context.extra_roi_samples_vs_uniform
565 <= f32::EPSILON
566 {
567 0.0
568 } else {
569 (context.uniform_roi_mae - roi_mae) / context.extra_roi_samples_vs_uniform
570 },
571 }
572}
573
574fn find_policy(
575 report: &DemoBScenarioReport,
576 policy_id: AllocationPolicyId,
577) -> Result<&DemoBPolicyMetrics> {
578 report
579 .policies
580 .iter()
581 .find(|policy| policy.policy_id == policy_id.as_str())
582 .ok_or_else(|| Error::Message(format!("Demo B policy {} missing", policy_id.as_str())))
583}
584
585fn render_with_counts(
586 sequence: &SceneSequence,
587 scene_frame: &SceneFrame,
588 counts: &[usize],
589) -> ImageFrame {
590 render_with_offset_counts(sequence, scene_frame, counts, 0)
591}
592
593fn render_with_offset_counts(
594 sequence: &SceneSequence,
595 scene_frame: &SceneFrame,
596 counts: &[usize],
597 seed_offset: u32,
598) -> ImageFrame {
599 let mut frame = ImageFrame::new(sequence.config.width, sequence.config.height);
600 for y in 0..sequence.config.height {
601 for x in 0..sequence.config.width {
602 let pixel_index = y * sequence.config.width + x;
603 let sample_count = counts[pixel_index].max(1);
604 let mut accum = Color::rgb(0.0, 0.0, 0.0);
605
606 for sample_index in 0..sample_count {
607 let (offset_x, offset_y) =
608 sample_offset(pixel_index as u32 ^ seed_offset, sample_index as u32);
609 let sample = sample_scene_continuous(
610 sequence,
611 scene_frame,
612 x as f32 + offset_x,
613 y as f32 + offset_y,
614 );
615 accum = Color::rgb(accum.r + sample.r, accum.g + sample.g, accum.b + sample.b);
616 }
617
618 let inv = 1.0 / sample_count as f32;
619 frame.set(
620 x,
621 y,
622 Color::rgb(accum.r * inv, accum.g * inv, accum.b * inv).clamp01(),
623 );
624 }
625 }
626 frame
627}
628
629fn sample_scene_continuous(
630 sequence: &SceneSequence,
631 scene_frame: &SceneFrame,
632 sample_x: f32,
633 sample_y: f32,
634) -> Color {
635 let mut color =
636 background_color_continuous(sample_x, sample_y, &sequence.config, sequence.scenario_id);
637 if is_thin_structure_continuous(sample_x, sample_y, &sequence.config, sequence.scenario_id) {
638 color = thin_structure_color_continuous(
639 sample_x,
640 sample_y,
641 &sequence.config,
642 sequence.scenario_id,
643 );
644 }
645 if rect_contains_continuous(scene_frame.object_rect, sample_x, sample_y)
646 && !matches!(
647 sequence.scenario_id,
648 ScenarioId::ContrastPulse | ScenarioId::StabilityHoldout
649 )
650 {
651 color = object_color_continuous(sample_x, sample_y, scene_frame.object_rect);
652 }
653 if matches!(sequence.scenario_id, ScenarioId::ContrastPulse)
654 && scene_frame.index >= sequence.onset_frame
655 && contrast_pulse_rect(&sequence.config).contains(sample_x as i32, sample_y as i32)
656 {
657 color = Color::rgb(color.r * 1.22, color.g * 1.22, color.b * 1.22).clamp01();
658 }
659 color
660}
661
662fn background_color_continuous(
663 sample_x: f32,
664 sample_y: f32,
665 config: &crate::config::SceneConfig,
666 scenario_id: ScenarioId,
667) -> Color {
668 let xf = sample_x / config.width.max(1) as f32;
669 let yf = sample_y / config.height.max(1) as f32;
670 let checker = if ((sample_x / 12.0).floor() + (sample_y / 12.0).floor()) as i32 % 2 == 0 {
671 1.0
672 } else {
673 0.0
674 };
675 let diagonal = if (sample_x + 2.0 * sample_y).rem_euclid(22.0) < 6.0 {
676 1.0
677 } else {
678 0.0
679 };
680 let stripes = if (3.0 * sample_x + sample_y).rem_euclid(17.0) < 5.0 {
681 1.0
682 } else {
683 0.0
684 };
685 let vignette_x = (xf - 0.5).abs();
686 let vignette_y = (yf - 0.5).abs();
687 let vignette = 1.0 - (vignette_x * 0.35 + vignette_y * 0.4);
688
689 match scenario_id {
690 ScenarioId::ThinReveal | ScenarioId::StabilityHoldout => Color::rgb(
691 (0.12 + 0.16 * xf + 0.05 * checker + 0.03 * diagonal) * vignette,
692 (0.15 + 0.11 * yf + 0.04 * diagonal) * vignette,
693 (0.22 + 0.18 * (1.0 - xf) + 0.03 * checker) * vignette,
694 ),
695 ScenarioId::FastPan => Color::rgb(
696 (0.10 + 0.18 * xf + 0.08 * checker + 0.05 * stripes) * vignette,
697 (0.11 + 0.15 * yf + 0.10 * diagonal + 0.04 * stripes) * vignette,
698 (0.18 + 0.20 * (1.0 - xf) + 0.06 * checker) * vignette,
699 ),
700 ScenarioId::DiagonalReveal => Color::rgb(
701 (0.08 + 0.24 * checker + 0.20 * diagonal + 0.05 * xf) * vignette,
702 (0.08 + 0.18 * stripes + 0.07 * yf) * vignette,
703 (0.12 + 0.25 * (1.0 - checker) + 0.04 * xf) * vignette,
704 ),
705 ScenarioId::RevealBand
706 | ScenarioId::MotionBiasBand
707 | ScenarioId::LayeredSlats
708 | ScenarioId::NoisyReprojection
709 | ScenarioId::HeuristicFriendlyPan => {
710 let micro = ((sample_x * 0.83 + sample_y * 1.91).sin() * 0.5 + 0.5)
711 * ((sample_x * 1.37 - sample_y * 0.71).cos() * 0.5 + 0.5);
712 let band = if (18.0..=(config.height as f32 - 18.0)).contains(&sample_y)
713 && (26.0..=(config.width as f32 - 24.0)).contains(&sample_x)
714 {
715 1.0
716 } else {
717 0.0
718 };
719 let micro_gain = match scenario_id {
720 ScenarioId::LayeredSlats => 0.13,
721 ScenarioId::NoisyReprojection => 0.16,
722 ScenarioId::HeuristicFriendlyPan => 0.06,
723 _ => 0.10,
724 };
725 Color::rgb(
726 (0.10 + 0.14 * xf + 0.05 * checker + micro_gain * micro * band) * vignette,
727 (0.12 + 0.13 * yf + 0.06 * diagonal + 0.08 * micro * band) * vignette,
728 (0.16 + 0.18 * (1.0 - xf) + 0.07 * stripes + 0.12 * micro * band) * vignette,
729 )
730 }
731 ScenarioId::ContrastPulse => {
732 Color::rgb(0.18 + 0.06 * xf, 0.18 + 0.05 * yf, 0.24 + 0.06 * (1.0 - xf))
733 }
734 }
735}
736
737fn is_thin_structure_continuous(
738 sample_x: f32,
739 sample_y: f32,
740 config: &crate::config::SceneConfig,
741 scenario_id: ScenarioId,
742) -> bool {
743 if matches!(scenario_id, ScenarioId::ContrastPulse) {
744 return false;
745 }
746 let vertical_center = config.thin_vertical_x as f32 + 0.5;
747 let vertical = (sample_x - vertical_center).abs() <= 0.18
748 && sample_y >= 14.0
749 && sample_y <= config.height as f32 - 14.0;
750 let diagonal_line =
751 (sample_y - (0.58 * sample_x + 10.5)).abs() <= 0.20 && (28.0..=118.0).contains(&sample_x);
752 let mixed_width_band = {
753 let in_band = (18.0..=(config.height as f32 - 18.0)).contains(&sample_y)
754 && (26.0..=(config.width as f32 - 24.0)).contains(&sample_x);
755 let thin_slats = (sample_x - 28.0).rem_euclid(11.0) < 0.18;
756 let medium_slats = (sample_x - 34.0).rem_euclid(19.0) < 1.10;
757 let wide_slats = (sample_x - 48.0).rem_euclid(29.0) < 2.15;
758 let diagonal = (sample_y - (0.44 * sample_x + 12.0)).abs() <= 1.15
759 && (38.0..=(config.width as f32 - 32.0)).contains(&sample_x);
760 in_band && (thin_slats || medium_slats || wide_slats || diagonal)
761 };
762 match scenario_id {
763 ScenarioId::DiagonalReveal => diagonal_line,
764 ScenarioId::RevealBand
765 | ScenarioId::MotionBiasBand
766 | ScenarioId::LayeredSlats
767 | ScenarioId::NoisyReprojection
768 | ScenarioId::HeuristicFriendlyPan => mixed_width_band,
769 _ => vertical || diagonal_line,
770 }
771}
772
773fn thin_structure_color_continuous(
774 sample_x: f32,
775 sample_y: f32,
776 config: &crate::config::SceneConfig,
777 scenario_id: ScenarioId,
778) -> Color {
779 let vertical_center = config.thin_vertical_x as f32 + 0.5;
780 if !matches!(scenario_id, ScenarioId::DiagonalReveal)
781 && (sample_x - vertical_center).abs() <= 0.18
782 {
783 let pulse = if (sample_y / 3.0).floor() as i32 % 2 == 0 {
784 1.0
785 } else {
786 0.84
787 };
788 return Color::rgb(0.95 * pulse, 0.96 * pulse, 0.98);
789 }
790 if matches!(scenario_id, ScenarioId::DiagonalReveal) {
791 Color::rgb(0.24, 0.29, 0.35)
792 } else if matches!(
793 scenario_id,
794 ScenarioId::RevealBand | ScenarioId::LayeredSlats
795 ) {
796 let phase = ((sample_x + 2.0 * sample_y).rem_euclid(9.0)) / 8.0;
797 Color::rgb(
798 0.22 + 0.48 * phase,
799 0.58 + 0.26 * phase,
800 0.84 + 0.12 * (1.0 - phase),
801 )
802 } else if matches!(
803 scenario_id,
804 ScenarioId::MotionBiasBand | ScenarioId::NoisyReprojection
805 ) {
806 let phase = ((2.0 * sample_x + sample_y).rem_euclid(13.0)) / 12.0;
807 Color::rgb(
808 0.78 + 0.16 * phase,
809 0.74 + 0.10 * (1.0 - phase),
810 0.26 + 0.18 * phase,
811 )
812 } else if matches!(scenario_id, ScenarioId::HeuristicFriendlyPan) {
813 let phase = ((sample_x - 1.5 * sample_y).rem_euclid(15.0)) / 14.0;
814 Color::rgb(
815 0.18 + 0.14 * phase,
816 0.86 - 0.18 * phase,
817 0.28 + 0.10 * (1.0 - phase),
818 )
819 } else {
820 Color::rgb(0.64, 0.90, 0.96)
821 }
822}
823
824fn contrast_pulse_rect(config: &crate::config::SceneConfig) -> Rect {
825 Rect {
826 x: (config.width as i32 / 2) - 18,
827 y: (config.height as i32 / 2) - 18,
828 width: 52,
829 height: 36,
830 }
831}
832
833fn rect_contains_continuous(rect: Rect, sample_x: f32, sample_y: f32) -> bool {
834 sample_x >= rect.x as f32
835 && sample_x < (rect.x + rect.width) as f32
836 && sample_y >= rect.y as f32
837 && sample_y < (rect.y + rect.height) as f32
838}
839
840fn object_color_continuous(sample_x: f32, sample_y: f32, rect: Rect) -> Color {
841 let local_x = (sample_x - rect.x as f32) / rect.width.max(1) as f32;
842 let local_y = (sample_y - rect.y as f32) / rect.height.max(1) as f32;
843 let stripe = if (0.36..0.46).contains(&local_x) {
844 0.55
845 } else {
846 1.0
847 };
848 let rim = if !(0.05..=0.95).contains(&local_x) || !(0.05..=0.95).contains(&local_y) {
849 1.12
850 } else {
851 1.0
852 };
853 Color::rgb(
854 (0.82 + 0.10 * local_y) * stripe * rim,
855 (0.35 + 0.12 * (1.0 - local_y)) * stripe * rim,
856 (0.20 + 0.08 * local_x) * stripe * rim,
857 )
858 .clamp01()
859}
860
861pub(crate) fn guided_allocation(
862 difficulty: &ScalarField,
863 total_samples: usize,
864 min_spp: usize,
865 max_spp: usize,
866) -> Result<Vec<usize>> {
867 let total_pixels = difficulty.width() * difficulty.height();
868 let min_total = min_spp * total_pixels;
869 let max_total = max_spp * total_pixels;
870 if total_samples < min_total || total_samples > max_total {
871 return Err(Error::Message(
872 "guided allocation budget is incompatible with the min/max spp bounds".to_string(),
873 ));
874 }
875
876 let mut counts = vec![min_spp; total_pixels];
877 let mut remaining = total_samples - min_total;
878 if remaining == 0 {
879 return Ok(counts);
880 }
881
882 let weights = difficulty
883 .values()
884 .iter()
885 .map(|value| 0.05 + 0.95 * value.clamp(0.0, 1.0).powf(2.4))
886 .collect::<Vec<_>>();
887
888 while remaining > 0 {
889 let available_weight: f32 = counts
890 .iter()
891 .zip(weights.iter())
892 .filter_map(|(count, weight)| (*count < max_spp).then_some(*weight))
893 .sum();
894 if available_weight <= f32::EPSILON {
895 break;
896 }
897
898 let round_budget = remaining;
899 let mut floor_assignments = vec![0usize; total_pixels];
900 let mut fractional_parts = Vec::new();
901 for (index, (count, weight)) in counts
902 .iter()
903 .copied()
904 .zip(weights.iter().copied())
905 .enumerate()
906 {
907 if count >= max_spp {
908 continue;
909 }
910 let capacity = max_spp - count;
911 let target = round_budget as f32 * weight / available_weight;
912 let whole = target.floor() as usize;
913 let assigned = whole.min(capacity);
914 floor_assignments[index] = assigned;
915 if assigned < capacity {
916 fractional_parts.push((target - assigned as f32, index));
917 }
918 }
919
920 let mut assigned_this_round = 0usize;
921 for (count, extra) in counts.iter_mut().zip(floor_assignments.iter().copied()) {
922 *count += extra;
923 assigned_this_round += extra;
924 }
925 remaining -= assigned_this_round.min(remaining);
926 if remaining == 0 {
927 break;
928 }
929
930 fractional_parts.sort_by(|left, right| right.0.total_cmp(&left.0));
931 let mut assigned_fractional = 0usize;
932 for (_, index) in fractional_parts {
933 if remaining == 0 {
934 break;
935 }
936 if counts[index] < max_spp {
937 counts[index] += 1;
938 remaining -= 1;
939 assigned_fractional += 1;
940 }
941 }
942
943 if assigned_this_round == 0 && assigned_fractional == 0 {
944 let mut fallback = weights
945 .iter()
946 .copied()
947 .enumerate()
948 .filter(|(index, _)| counts[*index] < max_spp)
949 .map(|(index, weight)| (weight, index))
950 .collect::<Vec<_>>();
951 fallback.sort_by(|left, right| right.0.total_cmp(&left.0));
952 for (_, index) in fallback {
953 if remaining == 0 {
954 break;
955 }
956 counts[index] += 1;
957 remaining -= 1;
958 }
959 }
960 }
961
962 Ok(counts)
963}
964
965fn build_error_field(frame: &ImageFrame, reference: &ImageFrame) -> ScalarField {
966 let mut field = ScalarField::new(frame.width(), frame.height());
967 for y in 0..frame.height() {
968 for x in 0..frame.width() {
969 field.set(x, y, frame.get(x, y).abs_diff(reference.get(x, y)));
970 }
971 }
972 field
973}
974
975pub(crate) fn build_count_field(counts: &[usize], width: usize, height: usize) -> ScalarField {
976 let mut field = ScalarField::new(width, height);
977 for y in 0..height {
978 for x in 0..width {
979 field.set(x, y, counts[y * width + x] as f32);
980 }
981 }
982 field
983}
984
985pub(crate) fn mean_count_over_mask(counts: &[usize], mask: &[bool]) -> f32 {
986 let mut sum = 0usize;
987 let mut count = 0usize;
988 for (spp, include) in counts.iter().copied().zip(mask.iter().copied()) {
989 if include {
990 sum += spp;
991 count += 1;
992 }
993 }
994 if count == 0 {
995 0.0
996 } else {
997 sum as f32 / count as f32
998 }
999}
1000
1001fn mean_abs_error_over_mask(frame_a: &ImageFrame, frame_b: &ImageFrame, mask: &[bool]) -> f32 {
1002 crate::frame::mean_abs_error_over_mask(frame_a, frame_b, mask)
1003}
1004
1005fn rmse(frame: &ImageFrame, reference: &ImageFrame, mask: Option<&[bool]>) -> f32 {
1006 let mut sum = 0.0;
1007 let mut count = 0usize;
1008
1009 for y in 0..frame.height() {
1010 for x in 0..frame.width() {
1011 let index = y * frame.width() + x;
1012 if mask.map(|values| values[index]).unwrap_or(true) {
1013 let diff = frame.get(x, y).abs_diff(reference.get(x, y));
1014 sum += diff * diff;
1015 count += 1;
1016 }
1017 }
1018 }
1019
1020 if count == 0 {
1021 0.0
1022 } else {
1023 (sum / count as f32).sqrt()
1024 }
1025}
1026
1027pub(crate) fn invert_trust(trust: &ScalarField) -> ScalarField {
1028 let mut field = ScalarField::new(trust.width(), trust.height());
1029 for y in 0..trust.height() {
1030 for x in 0..trust.width() {
1031 field.set(x, y, (1.0 - trust.get(x, y)).clamp(0.0, 1.0));
1032 }
1033 }
1034 field
1035}
1036
1037pub(crate) fn gradient_field(frame: &ImageFrame) -> ScalarField {
1038 let mut field = ScalarField::new(frame.width(), frame.height());
1039 for y in 0..frame.height() {
1040 for x in 0..frame.width() {
1041 let center = frame.get(x, y).luma();
1042 let left = frame.sample_clamped(x as i32 - 1, y as i32).luma();
1043 let right = frame.sample_clamped(x as i32 + 1, y as i32).luma();
1044 let up = frame.sample_clamped(x as i32, y as i32 - 1).luma();
1045 let down = frame.sample_clamped(x as i32, y as i32 + 1).luma();
1046 let grad = (right - left)
1047 .abs()
1048 .max((down - up).abs())
1049 .max((center - left).abs());
1050 field.set(x, y, (grad / 0.25).clamp(0.0, 1.0));
1051 }
1052 }
1053 field
1054}
1055
1056fn residual_proxy_field(frame: &ImageFrame) -> ScalarField {
1057 let mut field = ScalarField::new(frame.width(), frame.height());
1058 for y in 0..frame.height() {
1059 for x in 0..frame.width() {
1060 let blurred = box_blur_luma(frame, x, y);
1061 let residual = (frame.get(x, y).luma() - blurred).abs();
1062 field.set(x, y, (residual / 0.22).clamp(0.0, 1.0));
1063 }
1064 }
1065 field
1066}
1067
1068pub(crate) fn local_contrast_field(frame: &ImageFrame) -> ScalarField {
1069 let mut field = ScalarField::new(frame.width(), frame.height());
1070 for y in 0..frame.height() {
1071 for x in 0..frame.width() {
1072 let center = frame.get(x, y).luma();
1073 let mut strongest = 0.0f32;
1074 for dy in -1i32..=1 {
1075 for dx in -1i32..=1 {
1076 if dx == 0 && dy == 0 {
1077 continue;
1078 }
1079 let neighbor = frame.sample_clamped(x as i32 + dx, y as i32 + dy).luma();
1080 strongest = strongest.max((center - neighbor).abs());
1081 }
1082 }
1083 field.set(x, y, (strongest / 0.18).clamp(0.0, 1.0));
1084 }
1085 }
1086 field
1087}
1088
1089fn pilot_variance_field(pilot_a: &ImageFrame, pilot_b: &ImageFrame) -> ScalarField {
1090 let mut field = ScalarField::new(pilot_a.width(), pilot_a.height());
1091 for y in 0..pilot_a.height() {
1092 for x in 0..pilot_a.width() {
1093 let diff = pilot_a.get(x, y).abs_diff(pilot_b.get(x, y));
1094 field.set(x, y, (diff / 0.20).clamp(0.0, 1.0));
1095 }
1096 }
1097 field
1098}
1099
1100pub(crate) fn combine_fields(
1101 fields: &[(&ScalarField, f32)],
1102 width: usize,
1103 height: usize,
1104) -> ScalarField {
1105 let mut field = ScalarField::new(width, height);
1106 for y in 0..height {
1107 for x in 0..width {
1108 let value = fields
1109 .iter()
1110 .map(|(current, weight)| current.get(x, y) * *weight)
1111 .sum::<f32>()
1112 .clamp(0.0, 1.0);
1113 field.set(x, y, value);
1114 }
1115 }
1116 field
1117}
1118
1119fn box_blur_luma(frame: &ImageFrame, x: usize, y: usize) -> f32 {
1120 let mut sum = 0.0;
1121 let mut count = 0usize;
1122 for dy in -1i32..=1 {
1123 for dx in -1i32..=1 {
1124 let color = frame.sample_clamped(x as i32 + dx, y as i32 + dy);
1125 sum += color.luma();
1126 count += 1;
1127 }
1128 }
1129 sum / count as f32
1130}
1131
1132fn sample_offset(pixel_seed: u32, sample_index: u32) -> (f32, f32) {
1133 let shift_x = unit_hash(pixel_seed ^ 0x9e37_79b9);
1134 let shift_y = unit_hash(pixel_seed ^ 0x85eb_ca6b);
1135 let u = (radical_inverse(sample_index + 1, 2) + shift_x).fract();
1136 let v = (radical_inverse(sample_index + 1, 3) + shift_y).fract();
1137 (u, v)
1138}
1139
1140fn unit_hash(value: u32) -> f32 {
1141 let mixed = value.wrapping_mul(0x045d_9f3b).rotate_left(7) ^ 0xa511_e9b3;
1142 (mixed as f32 / u32::MAX as f32).fract()
1143}
1144
1145fn radical_inverse(mut index: u32, base: u32) -> f32 {
1146 let mut reversed = 0.0;
1147 let mut inv_base = 1.0 / base as f32;
1148 while index > 0 {
1149 let digit = index % base;
1150 reversed += digit as f32 * inv_base;
1151 index /= base;
1152 inv_base /= base as f32;
1153 }
1154 reversed
1155}