Skip to main content

dsfb_computer_graphics/dsfb/
mod.rs

1use serde::Serialize;
2
3use crate::frame::{ImageFrame, ScalarField};
4use crate::host::{
5    default_host_realistic_profile, gated_reference_profile, motion_augmented_profile,
6    profile_residual_only, profile_without_alpha_modulation, profile_without_grammar,
7    profile_without_motion, profile_without_thin, profile_without_visibility,
8    supervise_temporal_reuse, synthetic_visibility_profile, HostSupervisionProfile,
9    HostTemporalInputs,
10};
11use crate::scene::{MotionVector, Normal3, SceneFrame, SceneSequence};
12
13#[derive(Clone, Debug)]
14pub struct ProxyFields {
15    pub residual_proxy: ScalarField,
16    pub visibility_proxy: ScalarField,
17    pub depth_proxy: ScalarField,
18    pub normal_proxy: ScalarField,
19    pub motion_proxy: ScalarField,
20    pub neighborhood_proxy: ScalarField,
21    pub thin_proxy: ScalarField,
22    pub history_instability_proxy: ScalarField,
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize)]
26pub enum StructuralState {
27    Nominal,
28    DisocclusionLike,
29    UnstableHistory,
30    MotionEdge,
31}
32
33#[derive(Clone, Debug, Default, Serialize)]
34pub struct StateCounts {
35    pub nominal: usize,
36    pub disocclusion_like: usize,
37    pub unstable_history: usize,
38    pub motion_edge: usize,
39}
40
41#[derive(Clone, Debug)]
42pub struct StateField {
43    width: usize,
44    values: Vec<StructuralState>,
45}
46
47impl StateField {
48    pub fn new(width: usize, height: usize) -> Self {
49        Self {
50            width,
51            values: vec![StructuralState::Nominal; width * height],
52        }
53    }
54
55    pub fn width(&self) -> usize {
56        self.width
57    }
58
59    pub fn height(&self) -> usize {
60        self.values.len() / self.width.max(1)
61    }
62
63    pub fn set(&mut self, x: usize, y: usize, value: StructuralState) {
64        self.values[y * self.width + x] = value;
65    }
66
67    pub fn values(&self) -> &[StructuralState] {
68        &self.values
69    }
70
71    pub fn counts(&self) -> StateCounts {
72        let mut counts = StateCounts::default();
73        for state in &self.values {
74            match state {
75                StructuralState::Nominal => counts.nominal += 1,
76                StructuralState::DisocclusionLike => counts.disocclusion_like += 1,
77                StructuralState::UnstableHistory => counts.unstable_history += 1,
78                StructuralState::MotionEdge => counts.motion_edge += 1,
79            }
80        }
81        counts
82    }
83
84    pub fn counts_over_mask(&self, mask: &[bool]) -> StateCounts {
85        let mut counts = StateCounts::default();
86        for (state, include) in self.values.iter().zip(mask.iter().copied()) {
87            if !include {
88                continue;
89            }
90            match state {
91                StructuralState::Nominal => counts.nominal += 1,
92                StructuralState::DisocclusionLike => counts.disocclusion_like += 1,
93                StructuralState::UnstableHistory => counts.unstable_history += 1,
94                StructuralState::MotionEdge => counts.motion_edge += 1,
95            }
96        }
97        counts
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct SupervisionFrame {
103    pub residual: ScalarField,
104    pub trust: ScalarField,
105    pub alpha: ScalarField,
106    pub intervention: ScalarField,
107    pub proxies: ProxyFields,
108    pub state: StateField,
109}
110
111#[derive(Clone, Debug)]
112pub struct DsfbRun {
113    pub profile: HostSupervisionProfile,
114    pub resolved_frames: Vec<ImageFrame>,
115    pub reprojected_history_frames: Vec<ImageFrame>,
116    pub supervision_frames: Vec<SupervisionFrame>,
117}
118
119pub fn run_gated_taa(sequence: &SceneSequence, alpha_min: f32, alpha_max: f32) -> DsfbRun {
120    run_profiled_taa(sequence, &gated_reference_profile(alpha_min, alpha_max))
121}
122
123pub fn run_visibility_assisted_taa(
124    sequence: &SceneSequence,
125    alpha_min: f32,
126    alpha_max: f32,
127) -> DsfbRun {
128    run_profiled_taa(
129        sequence,
130        &synthetic_visibility_profile(alpha_min, alpha_max),
131    )
132}
133
134pub fn ablation_profiles(alpha_min: f32, alpha_max: f32) -> Vec<HostSupervisionProfile> {
135    vec![
136        synthetic_visibility_profile(alpha_min, alpha_max),
137        default_host_realistic_profile(alpha_min, alpha_max),
138        gated_reference_profile(alpha_min, alpha_max),
139        motion_augmented_profile(alpha_min, alpha_max),
140        profile_without_visibility(alpha_min, alpha_max),
141        profile_without_thin(alpha_min, alpha_max),
142        profile_without_motion(alpha_min, alpha_max),
143        profile_without_grammar(alpha_min, alpha_max),
144        profile_residual_only(alpha_min, alpha_max),
145        profile_without_alpha_modulation(alpha_min, alpha_max),
146    ]
147}
148
149pub fn run_profiled_taa(sequence: &SceneSequence, profile: &HostSupervisionProfile) -> DsfbRun {
150    let mut resolved_frames = Vec::with_capacity(sequence.frames.len());
151    let mut reprojected_history_frames = Vec::with_capacity(sequence.frames.len());
152    let mut supervision_frames = Vec::with_capacity(sequence.frames.len());
153
154    for (frame_index, scene_frame) in sequence.frames.iter().enumerate() {
155        let width = scene_frame.ground_truth.width();
156        let height = scene_frame.ground_truth.height();
157        if frame_index == 0 {
158            resolved_frames.push(scene_frame.ground_truth.clone());
159            reprojected_history_frames.push(scene_frame.ground_truth.clone());
160            supervision_frames.push(empty_supervision(
161                width,
162                height,
163                1.0,
164                profile.parameters.alpha_range.min,
165            ));
166            continue;
167        }
168
169        let previous_resolved = &resolved_frames[frame_index - 1];
170        let previous_scene_frame = &sequence.frames[frame_index - 1];
171        let reprojected = reproject_frame(previous_resolved, scene_frame);
172        let reprojected_depth = reproject_depth(previous_scene_frame, scene_frame);
173        let reprojected_normals = reproject_normals(previous_scene_frame, scene_frame);
174        let visibility_hint = profile
175            .use_visibility_hint
176            .then_some(scene_frame.disocclusion_mask.as_slice());
177        let thin_hint_field = profile
178            .use_visibility_hint
179            .then(|| compute_thin_hint(scene_frame));
180        let thin_hint = thin_hint_field.as_ref();
181
182        let host_inputs = HostTemporalInputs {
183            current_color: &scene_frame.ground_truth,
184            reprojected_history: &reprojected,
185            motion_vectors: &scene_frame.motion,
186            current_depth: &scene_frame.depth,
187            reprojected_depth: &reprojected_depth,
188            current_normals: &scene_frame.normals,
189            reprojected_normals: &reprojected_normals,
190            visibility_hint,
191            thin_hint,
192        };
193        let outputs = supervise_temporal_reuse(&host_inputs, profile);
194        let resolved = resolve_with_alpha(&reprojected, &scene_frame.ground_truth, &outputs.alpha);
195
196        reprojected_history_frames.push(reprojected);
197        resolved_frames.push(resolved);
198        supervision_frames.push(SupervisionFrame {
199            residual: outputs.residual,
200            trust: outputs.trust,
201            alpha: outputs.alpha,
202            intervention: outputs.intervention,
203            proxies: ProxyFields {
204                residual_proxy: outputs.proxies.residual_proxy,
205                visibility_proxy: outputs.proxies.visibility_proxy,
206                depth_proxy: outputs.proxies.depth_proxy,
207                normal_proxy: outputs.proxies.normal_proxy,
208                motion_proxy: outputs.proxies.motion_proxy,
209                neighborhood_proxy: outputs.proxies.neighborhood_proxy,
210                thin_proxy: outputs.proxies.thin_proxy,
211                history_instability_proxy: outputs.proxies.history_instability_proxy,
212            },
213            state: outputs.state,
214        });
215    }
216
217    DsfbRun {
218        profile: profile.clone(),
219        resolved_frames,
220        reprojected_history_frames,
221        supervision_frames,
222    }
223}
224
225fn resolve_with_alpha(
226    history: &ImageFrame,
227    current: &ImageFrame,
228    alpha: &ScalarField,
229) -> ImageFrame {
230    let mut resolved = ImageFrame::new(history.width(), history.height());
231    for y in 0..history.height() {
232        for x in 0..history.width() {
233            resolved.set(
234                x,
235                y,
236                history.get(x, y).lerp(current.get(x, y), alpha.get(x, y)),
237            );
238        }
239    }
240    resolved
241}
242
243fn reproject_frame(previous_resolved: &ImageFrame, scene_frame: &SceneFrame) -> ImageFrame {
244    let mut reprojected = ImageFrame::new(
245        scene_frame.ground_truth.width(),
246        scene_frame.ground_truth.height(),
247    );
248    for y in 0..scene_frame.ground_truth.height() {
249        for x in 0..scene_frame.ground_truth.width() {
250            let motion = scene_frame.motion[y * scene_frame.ground_truth.width() + x];
251            reprojected.set(
252                x,
253                y,
254                previous_resolved.sample_bilinear_clamped(
255                    x as f32 + motion.to_prev_x,
256                    y as f32 + motion.to_prev_y,
257                ),
258            );
259        }
260    }
261    reprojected
262}
263
264fn reproject_depth(previous_scene_frame: &SceneFrame, scene_frame: &SceneFrame) -> Vec<f32> {
265    reproject_scalar_buffer(
266        &previous_scene_frame.depth,
267        scene_frame.ground_truth.width(),
268        scene_frame.ground_truth.height(),
269        &scene_frame.motion,
270    )
271}
272
273fn reproject_normals(previous_scene_frame: &SceneFrame, scene_frame: &SceneFrame) -> Vec<Normal3> {
274    let width = scene_frame.ground_truth.width();
275    let height = scene_frame.ground_truth.height();
276    let mut reprojected = vec![Normal3::new(0.0, 0.0, 1.0); width * height];
277    for y in 0..height {
278        for x in 0..width {
279            let index = y * width + x;
280            let motion = scene_frame.motion[index];
281            reprojected[index] = sample_normal_bilinear_clamped(
282                &previous_scene_frame.normals,
283                width,
284                height,
285                x as f32 + motion.to_prev_x,
286                y as f32 + motion.to_prev_y,
287            );
288        }
289    }
290    reprojected
291}
292
293fn reproject_scalar_buffer(
294    previous_values: &[f32],
295    width: usize,
296    height: usize,
297    motion: &[MotionVector],
298) -> Vec<f32> {
299    let mut reprojected = vec![0.0; width * height];
300    for y in 0..height {
301        for x in 0..width {
302            let index = y * width + x;
303            let vector = motion[index];
304            reprojected[index] = sample_scalar_bilinear_clamped(
305                previous_values,
306                width,
307                height,
308                x as f32 + vector.to_prev_x,
309                y as f32 + vector.to_prev_y,
310            );
311        }
312    }
313    reprojected
314}
315
316fn compute_thin_hint(scene_frame: &SceneFrame) -> ScalarField {
317    let width = scene_frame.ground_truth.width();
318    let height = scene_frame.ground_truth.height();
319    let mut field = ScalarField::new(width, height);
320    for y in 0..height {
321        for x in 0..width {
322            let index = y * width + x;
323            let hint = matches!(
324                scene_frame.layers[index],
325                crate::scene::SurfaceTag::ThinStructure
326            ) || neighbors(x, y, width, height).into_iter().any(|(nx, ny)| {
327                matches!(
328                    scene_frame.layers[ny * width + nx],
329                    crate::scene::SurfaceTag::ThinStructure
330                )
331            });
332            field.set(x, y, if hint { 1.0 } else { 0.0 });
333        }
334    }
335    field
336}
337
338fn empty_supervision(
339    width: usize,
340    height: usize,
341    trust_value: f32,
342    alpha_value: f32,
343) -> SupervisionFrame {
344    let mut trust = ScalarField::new(width, height);
345    let mut alpha = ScalarField::new(width, height);
346    let mut intervention = ScalarField::new(width, height);
347    let mut state = StateField::new(width, height);
348    for y in 0..height {
349        for x in 0..width {
350            trust.set(x, y, trust_value);
351            alpha.set(x, y, alpha_value);
352            intervention.set(x, y, 1.0 - trust_value);
353            state.set(x, y, StructuralState::Nominal);
354        }
355    }
356    SupervisionFrame {
357        residual: ScalarField::new(width, height),
358        trust,
359        alpha,
360        intervention,
361        proxies: ProxyFields {
362            residual_proxy: ScalarField::new(width, height),
363            visibility_proxy: ScalarField::new(width, height),
364            depth_proxy: ScalarField::new(width, height),
365            normal_proxy: ScalarField::new(width, height),
366            motion_proxy: ScalarField::new(width, height),
367            neighborhood_proxy: ScalarField::new(width, height),
368            thin_proxy: ScalarField::new(width, height),
369            history_instability_proxy: ScalarField::new(width, height),
370        },
371        state,
372    }
373}
374
375fn neighbors(x: usize, y: usize, width: usize, height: usize) -> Vec<(usize, usize)> {
376    let mut values = Vec::with_capacity(8);
377    for dy in -1i32..=1 {
378        for dx in -1i32..=1 {
379            if dx == 0 && dy == 0 {
380                continue;
381            }
382            let nx = x as i32 + dx;
383            let ny = y as i32 + dy;
384            if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 {
385                values.push((nx as usize, ny as usize));
386            }
387        }
388    }
389    values
390}
391
392fn sample_scalar_bilinear_clamped(
393    values: &[f32],
394    width: usize,
395    height: usize,
396    x: f32,
397    y: f32,
398) -> f32 {
399    let x0 = x.floor();
400    let y0 = y.floor();
401    let x1 = x0 + 1.0;
402    let y1 = y0 + 1.0;
403    let tx = (x - x0).clamp(0.0, 1.0);
404    let ty = (y - y0).clamp(0.0, 1.0);
405
406    let sample = |sample_x: f32, sample_y: f32| {
407        let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
408        let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
409        values[sy * width + sx]
410    };
411
412    let top = sample(x0, y0) * (1.0 - tx) + sample(x1, y0) * tx;
413    let bottom = sample(x0, y1) * (1.0 - tx) + sample(x1, y1) * tx;
414    top * (1.0 - ty) + bottom * ty
415}
416
417fn sample_normal_bilinear_clamped(
418    values: &[Normal3],
419    width: usize,
420    height: usize,
421    x: f32,
422    y: f32,
423) -> Normal3 {
424    let x0 = x.floor();
425    let y0 = y.floor();
426    let x1 = x0 + 1.0;
427    let y1 = y0 + 1.0;
428    let tx = (x - x0).clamp(0.0, 1.0);
429    let ty = (y - y0).clamp(0.0, 1.0);
430
431    let sample = |sample_x: f32, sample_y: f32| {
432        let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
433        let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
434        values[sy * width + sx]
435    };
436
437    let c00 = sample(x0, y0);
438    let c10 = sample(x1, y0);
439    let c01 = sample(x0, y1);
440    let c11 = sample(x1, y1);
441    Normal3::new(
442        (c00.x * (1.0 - tx) + c10.x * tx) * (1.0 - ty) + (c01.x * (1.0 - tx) + c11.x * tx) * ty,
443        (c00.y * (1.0 - tx) + c10.y * tx) * (1.0 - ty) + (c01.y * (1.0 - tx) + c11.y * tx) * ty,
444        (c00.z * (1.0 - tx) + c10.z * tx) * (1.0 - ty) + (c01.z * (1.0 - tx) + c11.z * tx) * ty,
445    )
446    .normalized()
447}