1use serde::Serialize;
2
3use crate::frame::{ImageFrame, ScalarField};
4use crate::host::{
5 default_host_realistic_profile, gated_reference_profile, motion_augmented_profile,
6 profile_residual_only, profile_without_alpha_modulation, profile_without_grammar,
7 profile_without_motion, profile_without_thin, profile_without_visibility,
8 supervise_temporal_reuse, synthetic_visibility_profile, HostSupervisionProfile,
9 HostTemporalInputs,
10};
11use crate::scene::{MotionVector, Normal3, SceneFrame, SceneSequence};
12
13#[derive(Clone, Debug)]
14pub struct ProxyFields {
15 pub residual_proxy: ScalarField,
16 pub visibility_proxy: ScalarField,
17 pub depth_proxy: ScalarField,
18 pub normal_proxy: ScalarField,
19 pub motion_proxy: ScalarField,
20 pub neighborhood_proxy: ScalarField,
21 pub thin_proxy: ScalarField,
22 pub history_instability_proxy: ScalarField,
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
26pub enum StructuralState {
27 Nominal,
28 DisocclusionLike,
29 UnstableHistory,
30 MotionEdge,
31}
32
33#[derive(Clone, Debug, Default, Serialize)]
34pub struct StateCounts {
35 pub nominal: usize,
36 pub disocclusion_like: usize,
37 pub unstable_history: usize,
38 pub motion_edge: usize,
39}
40
41#[derive(Clone, Debug)]
42pub struct StateField {
43 width: usize,
44 values: Vec<StructuralState>,
45}
46
47impl StateField {
48 pub fn new(width: usize, height: usize) -> Self {
49 Self {
50 width,
51 values: vec![StructuralState::Nominal; width * height],
52 }
53 }
54
55 pub fn width(&self) -> usize {
56 self.width
57 }
58
59 pub fn height(&self) -> usize {
60 self.values.len() / self.width.max(1)
61 }
62
63 pub fn set(&mut self, x: usize, y: usize, value: StructuralState) {
64 self.values[y * self.width + x] = value;
65 }
66
67 pub fn values(&self) -> &[StructuralState] {
68 &self.values
69 }
70
71 pub fn counts(&self) -> StateCounts {
72 let mut counts = StateCounts::default();
73 for state in &self.values {
74 match state {
75 StructuralState::Nominal => counts.nominal += 1,
76 StructuralState::DisocclusionLike => counts.disocclusion_like += 1,
77 StructuralState::UnstableHistory => counts.unstable_history += 1,
78 StructuralState::MotionEdge => counts.motion_edge += 1,
79 }
80 }
81 counts
82 }
83
84 pub fn counts_over_mask(&self, mask: &[bool]) -> StateCounts {
85 let mut counts = StateCounts::default();
86 for (state, include) in self.values.iter().zip(mask.iter().copied()) {
87 if !include {
88 continue;
89 }
90 match state {
91 StructuralState::Nominal => counts.nominal += 1,
92 StructuralState::DisocclusionLike => counts.disocclusion_like += 1,
93 StructuralState::UnstableHistory => counts.unstable_history += 1,
94 StructuralState::MotionEdge => counts.motion_edge += 1,
95 }
96 }
97 counts
98 }
99}
100
101#[derive(Clone, Debug)]
102pub struct SupervisionFrame {
103 pub residual: ScalarField,
104 pub trust: ScalarField,
105 pub alpha: ScalarField,
106 pub intervention: ScalarField,
107 pub proxies: ProxyFields,
108 pub state: StateField,
109}
110
111#[derive(Clone, Debug)]
112pub struct DsfbRun {
113 pub profile: HostSupervisionProfile,
114 pub resolved_frames: Vec<ImageFrame>,
115 pub reprojected_history_frames: Vec<ImageFrame>,
116 pub supervision_frames: Vec<SupervisionFrame>,
117}
118
119pub fn run_gated_taa(sequence: &SceneSequence, alpha_min: f32, alpha_max: f32) -> DsfbRun {
120 run_profiled_taa(sequence, &gated_reference_profile(alpha_min, alpha_max))
121}
122
123pub fn run_visibility_assisted_taa(
124 sequence: &SceneSequence,
125 alpha_min: f32,
126 alpha_max: f32,
127) -> DsfbRun {
128 run_profiled_taa(
129 sequence,
130 &synthetic_visibility_profile(alpha_min, alpha_max),
131 )
132}
133
134pub fn ablation_profiles(alpha_min: f32, alpha_max: f32) -> Vec<HostSupervisionProfile> {
135 vec![
136 synthetic_visibility_profile(alpha_min, alpha_max),
137 default_host_realistic_profile(alpha_min, alpha_max),
138 gated_reference_profile(alpha_min, alpha_max),
139 motion_augmented_profile(alpha_min, alpha_max),
140 profile_without_visibility(alpha_min, alpha_max),
141 profile_without_thin(alpha_min, alpha_max),
142 profile_without_motion(alpha_min, alpha_max),
143 profile_without_grammar(alpha_min, alpha_max),
144 profile_residual_only(alpha_min, alpha_max),
145 profile_without_alpha_modulation(alpha_min, alpha_max),
146 ]
147}
148
149pub fn run_profiled_taa(sequence: &SceneSequence, profile: &HostSupervisionProfile) -> DsfbRun {
150 let mut resolved_frames = Vec::with_capacity(sequence.frames.len());
151 let mut reprojected_history_frames = Vec::with_capacity(sequence.frames.len());
152 let mut supervision_frames = Vec::with_capacity(sequence.frames.len());
153
154 for (frame_index, scene_frame) in sequence.frames.iter().enumerate() {
155 let width = scene_frame.ground_truth.width();
156 let height = scene_frame.ground_truth.height();
157 if frame_index == 0 {
158 resolved_frames.push(scene_frame.ground_truth.clone());
159 reprojected_history_frames.push(scene_frame.ground_truth.clone());
160 supervision_frames.push(empty_supervision(
161 width,
162 height,
163 1.0,
164 profile.parameters.alpha_range.min,
165 ));
166 continue;
167 }
168
169 let previous_resolved = &resolved_frames[frame_index - 1];
170 let previous_scene_frame = &sequence.frames[frame_index - 1];
171 let reprojected = reproject_frame(previous_resolved, scene_frame);
172 let reprojected_depth = reproject_depth(previous_scene_frame, scene_frame);
173 let reprojected_normals = reproject_normals(previous_scene_frame, scene_frame);
174 let visibility_hint = profile
175 .use_visibility_hint
176 .then_some(scene_frame.disocclusion_mask.as_slice());
177 let thin_hint_field = profile
178 .use_visibility_hint
179 .then(|| compute_thin_hint(scene_frame));
180 let thin_hint = thin_hint_field.as_ref();
181
182 let host_inputs = HostTemporalInputs {
183 current_color: &scene_frame.ground_truth,
184 reprojected_history: &reprojected,
185 motion_vectors: &scene_frame.motion,
186 current_depth: &scene_frame.depth,
187 reprojected_depth: &reprojected_depth,
188 current_normals: &scene_frame.normals,
189 reprojected_normals: &reprojected_normals,
190 visibility_hint,
191 thin_hint,
192 };
193 let outputs = supervise_temporal_reuse(&host_inputs, profile);
194 let resolved = resolve_with_alpha(&reprojected, &scene_frame.ground_truth, &outputs.alpha);
195
196 reprojected_history_frames.push(reprojected);
197 resolved_frames.push(resolved);
198 supervision_frames.push(SupervisionFrame {
199 residual: outputs.residual,
200 trust: outputs.trust,
201 alpha: outputs.alpha,
202 intervention: outputs.intervention,
203 proxies: ProxyFields {
204 residual_proxy: outputs.proxies.residual_proxy,
205 visibility_proxy: outputs.proxies.visibility_proxy,
206 depth_proxy: outputs.proxies.depth_proxy,
207 normal_proxy: outputs.proxies.normal_proxy,
208 motion_proxy: outputs.proxies.motion_proxy,
209 neighborhood_proxy: outputs.proxies.neighborhood_proxy,
210 thin_proxy: outputs.proxies.thin_proxy,
211 history_instability_proxy: outputs.proxies.history_instability_proxy,
212 },
213 state: outputs.state,
214 });
215 }
216
217 DsfbRun {
218 profile: profile.clone(),
219 resolved_frames,
220 reprojected_history_frames,
221 supervision_frames,
222 }
223}
224
225fn resolve_with_alpha(
226 history: &ImageFrame,
227 current: &ImageFrame,
228 alpha: &ScalarField,
229) -> ImageFrame {
230 let mut resolved = ImageFrame::new(history.width(), history.height());
231 for y in 0..history.height() {
232 for x in 0..history.width() {
233 resolved.set(
234 x,
235 y,
236 history.get(x, y).lerp(current.get(x, y), alpha.get(x, y)),
237 );
238 }
239 }
240 resolved
241}
242
243fn reproject_frame(previous_resolved: &ImageFrame, scene_frame: &SceneFrame) -> ImageFrame {
244 let mut reprojected = ImageFrame::new(
245 scene_frame.ground_truth.width(),
246 scene_frame.ground_truth.height(),
247 );
248 for y in 0..scene_frame.ground_truth.height() {
249 for x in 0..scene_frame.ground_truth.width() {
250 let motion = scene_frame.motion[y * scene_frame.ground_truth.width() + x];
251 reprojected.set(
252 x,
253 y,
254 previous_resolved.sample_bilinear_clamped(
255 x as f32 + motion.to_prev_x,
256 y as f32 + motion.to_prev_y,
257 ),
258 );
259 }
260 }
261 reprojected
262}
263
264fn reproject_depth(previous_scene_frame: &SceneFrame, scene_frame: &SceneFrame) -> Vec<f32> {
265 reproject_scalar_buffer(
266 &previous_scene_frame.depth,
267 scene_frame.ground_truth.width(),
268 scene_frame.ground_truth.height(),
269 &scene_frame.motion,
270 )
271}
272
273fn reproject_normals(previous_scene_frame: &SceneFrame, scene_frame: &SceneFrame) -> Vec<Normal3> {
274 let width = scene_frame.ground_truth.width();
275 let height = scene_frame.ground_truth.height();
276 let mut reprojected = vec![Normal3::new(0.0, 0.0, 1.0); width * height];
277 for y in 0..height {
278 for x in 0..width {
279 let index = y * width + x;
280 let motion = scene_frame.motion[index];
281 reprojected[index] = sample_normal_bilinear_clamped(
282 &previous_scene_frame.normals,
283 width,
284 height,
285 x as f32 + motion.to_prev_x,
286 y as f32 + motion.to_prev_y,
287 );
288 }
289 }
290 reprojected
291}
292
293fn reproject_scalar_buffer(
294 previous_values: &[f32],
295 width: usize,
296 height: usize,
297 motion: &[MotionVector],
298) -> Vec<f32> {
299 let mut reprojected = vec![0.0; width * height];
300 for y in 0..height {
301 for x in 0..width {
302 let index = y * width + x;
303 let vector = motion[index];
304 reprojected[index] = sample_scalar_bilinear_clamped(
305 previous_values,
306 width,
307 height,
308 x as f32 + vector.to_prev_x,
309 y as f32 + vector.to_prev_y,
310 );
311 }
312 }
313 reprojected
314}
315
316fn compute_thin_hint(scene_frame: &SceneFrame) -> ScalarField {
317 let width = scene_frame.ground_truth.width();
318 let height = scene_frame.ground_truth.height();
319 let mut field = ScalarField::new(width, height);
320 for y in 0..height {
321 for x in 0..width {
322 let index = y * width + x;
323 let hint = matches!(
324 scene_frame.layers[index],
325 crate::scene::SurfaceTag::ThinStructure
326 ) || neighbors(x, y, width, height).into_iter().any(|(nx, ny)| {
327 matches!(
328 scene_frame.layers[ny * width + nx],
329 crate::scene::SurfaceTag::ThinStructure
330 )
331 });
332 field.set(x, y, if hint { 1.0 } else { 0.0 });
333 }
334 }
335 field
336}
337
338fn empty_supervision(
339 width: usize,
340 height: usize,
341 trust_value: f32,
342 alpha_value: f32,
343) -> SupervisionFrame {
344 let mut trust = ScalarField::new(width, height);
345 let mut alpha = ScalarField::new(width, height);
346 let mut intervention = ScalarField::new(width, height);
347 let mut state = StateField::new(width, height);
348 for y in 0..height {
349 for x in 0..width {
350 trust.set(x, y, trust_value);
351 alpha.set(x, y, alpha_value);
352 intervention.set(x, y, 1.0 - trust_value);
353 state.set(x, y, StructuralState::Nominal);
354 }
355 }
356 SupervisionFrame {
357 residual: ScalarField::new(width, height),
358 trust,
359 alpha,
360 intervention,
361 proxies: ProxyFields {
362 residual_proxy: ScalarField::new(width, height),
363 visibility_proxy: ScalarField::new(width, height),
364 depth_proxy: ScalarField::new(width, height),
365 normal_proxy: ScalarField::new(width, height),
366 motion_proxy: ScalarField::new(width, height),
367 neighborhood_proxy: ScalarField::new(width, height),
368 thin_proxy: ScalarField::new(width, height),
369 history_instability_proxy: ScalarField::new(width, height),
370 },
371 state,
372 }
373}
374
375fn neighbors(x: usize, y: usize, width: usize, height: usize) -> Vec<(usize, usize)> {
376 let mut values = Vec::with_capacity(8);
377 for dy in -1i32..=1 {
378 for dx in -1i32..=1 {
379 if dx == 0 && dy == 0 {
380 continue;
381 }
382 let nx = x as i32 + dx;
383 let ny = y as i32 + dy;
384 if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 {
385 values.push((nx as usize, ny as usize));
386 }
387 }
388 }
389 values
390}
391
392fn sample_scalar_bilinear_clamped(
393 values: &[f32],
394 width: usize,
395 height: usize,
396 x: f32,
397 y: f32,
398) -> f32 {
399 let x0 = x.floor();
400 let y0 = y.floor();
401 let x1 = x0 + 1.0;
402 let y1 = y0 + 1.0;
403 let tx = (x - x0).clamp(0.0, 1.0);
404 let ty = (y - y0).clamp(0.0, 1.0);
405
406 let sample = |sample_x: f32, sample_y: f32| {
407 let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
408 let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
409 values[sy * width + sx]
410 };
411
412 let top = sample(x0, y0) * (1.0 - tx) + sample(x1, y0) * tx;
413 let bottom = sample(x0, y1) * (1.0 - tx) + sample(x1, y1) * tx;
414 top * (1.0 - ty) + bottom * ty
415}
416
417fn sample_normal_bilinear_clamped(
418 values: &[Normal3],
419 width: usize,
420 height: usize,
421 x: f32,
422 y: f32,
423) -> Normal3 {
424 let x0 = x.floor();
425 let y0 = y.floor();
426 let x1 = x0 + 1.0;
427 let y1 = y0 + 1.0;
428 let tx = (x - x0).clamp(0.0, 1.0);
429 let ty = (y - y0).clamp(0.0, 1.0);
430
431 let sample = |sample_x: f32, sample_y: f32| {
432 let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
433 let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
434 values[sy * width + sx]
435 };
436
437 let c00 = sample(x0, y0);
438 let c10 = sample(x1, y0);
439 let c01 = sample(x0, y1);
440 let c11 = sample(x1, y1);
441 Normal3::new(
442 (c00.x * (1.0 - tx) + c10.x * tx) * (1.0 - ty) + (c01.x * (1.0 - tx) + c11.x * tx) * ty,
443 (c00.y * (1.0 - tx) + c10.y * tx) * (1.0 - ty) + (c01.y * (1.0 - tx) + c11.y * tx) * ty,
444 (c00.z * (1.0 - tx) + c10.z * tx) * (1.0 - ty) + (c01.z * (1.0 - tx) + c11.z * tx) * ty,
445 )
446 .normalized()
447}