Skip to main content

dsfb_computer_graphics/taa/
mod.rs

1use crate::frame::{Color, ImageFrame, ScalarField};
2use crate::parameters::{BaselineParameters, SmoothstepThreshold};
3use crate::scene::{Normal3, SceneSequence};
4
5#[derive(Clone, Debug)]
6pub struct TaaRun {
7    pub resolved_frames: Vec<ImageFrame>,
8    pub reprojected_history_frames: Vec<ImageFrame>,
9}
10
11#[derive(Clone, Debug)]
12pub struct HeuristicRun {
13    pub id: String,
14    pub label: String,
15    pub description: String,
16    pub taa: TaaRun,
17    pub alpha_frames: Vec<ScalarField>,
18    pub response_frames: Vec<ScalarField>,
19}
20
21#[derive(Clone, Debug)]
22pub struct ResidualThresholdRun {
23    pub taa: TaaRun,
24    pub alpha_frames: Vec<ScalarField>,
25    pub trigger_frames: Vec<ScalarField>,
26}
27
28pub fn run_fixed_alpha(sequence: &SceneSequence, alpha: f32) -> TaaRun {
29    run_fixed_alpha_baseline(sequence, alpha).taa
30}
31
32pub fn run_fixed_alpha_baseline(sequence: &SceneSequence, alpha: f32) -> HeuristicRun {
33    run_heuristic_baseline(
34        sequence,
35        "fixed_alpha",
36        "Fixed-alpha baseline",
37        "Uniform temporal blend weight with no rejection or clamp logic.",
38        move |context| (context.history, alpha, 0.0),
39    )
40}
41
42pub fn run_residual_threshold(
43    sequence: &SceneSequence,
44    alpha_low: f32,
45    alpha_high: f32,
46    threshold_low: f32,
47    threshold_high: f32,
48) -> ResidualThresholdRun {
49    let heuristic = run_heuristic_baseline(
50        sequence,
51        "residual_threshold",
52        "Residual-threshold baseline",
53        "Per-pixel alpha increases when current vs history residual exceeds a threshold.",
54        move |context| {
55            let residual = context.current.abs_diff(context.history);
56            let trigger = smoothstep_threshold(
57                SmoothstepThreshold::new(threshold_low, threshold_high),
58                residual,
59            );
60            let alpha = alpha_low + (alpha_high - alpha_low) * trigger;
61            (context.history, alpha, trigger)
62        },
63    );
64
65    ResidualThresholdRun {
66        taa: heuristic.taa,
67        alpha_frames: heuristic.alpha_frames,
68        trigger_frames: heuristic.response_frames,
69    }
70}
71
72pub fn run_residual_threshold_baseline(
73    sequence: &SceneSequence,
74    alpha_low: f32,
75    alpha_high: f32,
76    threshold_low: f32,
77    threshold_high: f32,
78) -> HeuristicRun {
79    run_heuristic_baseline(
80        sequence,
81        "residual_threshold",
82        "Residual-threshold baseline",
83        "Per-pixel alpha increases when current vs history residual exceeds a threshold.",
84        move |context| {
85            let residual = context.current.abs_diff(context.history);
86            let trigger = smoothstep_threshold(
87                SmoothstepThreshold::new(threshold_low, threshold_high),
88                residual,
89            );
90            let alpha = alpha_low + (alpha_high - alpha_low) * trigger;
91            (context.history, alpha, trigger)
92        },
93    )
94}
95
96pub fn run_neighborhood_clamp_baseline(
97    sequence: &SceneSequence,
98    parameters: &BaselineParameters,
99) -> HeuristicRun {
100    run_heuristic_baseline(
101        sequence,
102        "neighborhood_clamp",
103        "Neighborhood-clamped baseline",
104        "History is clamped to the current 3x3 neighborhood before blending. Alpha rises with clamp distance.",
105        move |context| {
106            let clamped = clamp_to_current_neighborhood(context.scene_frame, context.history, context.x, context.y);
107            let clamp_distance = clamped.abs_diff(context.history);
108            let trigger = smoothstep_threshold(parameters.clamp_distance, clamp_distance);
109            let alpha = parameters.residual_alpha_range.min
110                + (parameters.residual_alpha_range.max - parameters.residual_alpha_range.min)
111                    * trigger;
112            (clamped, alpha, trigger)
113        },
114    )
115}
116
117pub fn run_depth_normal_rejection_baseline(
118    sequence: &SceneSequence,
119    parameters: &BaselineParameters,
120) -> HeuristicRun {
121    run_heuristic_baseline(
122        sequence,
123        "depth_normal_reject",
124        "Depth/normal rejection baseline",
125        "Alpha rises with reprojected depth or normal disagreement.",
126        move |context| {
127            let depth_gate = smoothstep_threshold(
128                parameters.depth_disagreement,
129                (context.current_depth - context.reprojected_depth).abs(),
130            );
131            let normal_gate = smoothstep_threshold(
132                parameters.normal_disagreement,
133                1.0 - context
134                    .current_normal
135                    .dot(context.reprojected_normal)
136                    .clamp(-1.0, 1.0),
137            );
138            let trigger = depth_gate.max(normal_gate);
139            let alpha = parameters.residual_alpha_range.min
140                + (parameters.residual_alpha_range.max - parameters.residual_alpha_range.min)
141                    * trigger;
142            (context.history, alpha, trigger)
143        },
144    )
145}
146
147pub fn run_reactive_mask_baseline(
148    sequence: &SceneSequence,
149    parameters: &BaselineParameters,
150) -> HeuristicRun {
151    run_heuristic_baseline(
152        sequence,
153        "reactive_mask",
154        "Reactive-mask-style baseline",
155        "Residual, depth, and neighborhood disagreement combine into a reactive alpha increase.",
156        move |context| {
157            let residual_gate = smoothstep_threshold(
158                parameters.residual_threshold,
159                context.current.abs_diff(context.history),
160            );
161            let depth_gate = smoothstep_threshold(
162                parameters.depth_disagreement,
163                (context.current_depth - context.reprojected_depth).abs(),
164            );
165            let neighborhood_gate = smoothstep_threshold(
166                parameters.neighborhood_distance,
167                neighborhood_distance(context.scene_frame, context.history, context.x, context.y),
168            );
169            let trigger = residual_gate.max(depth_gate).max(neighborhood_gate);
170            let alpha = parameters.residual_alpha_range.min
171                + (parameters.residual_alpha_range.max - parameters.residual_alpha_range.min)
172                    * trigger;
173            (context.history, alpha, trigger)
174        },
175    )
176}
177
178pub fn run_strong_heuristic_baseline(
179    sequence: &SceneSequence,
180    parameters: &BaselineParameters,
181) -> HeuristicRun {
182    run_heuristic_baseline(
183        sequence,
184        "strong_heuristic",
185        "Strong heuristic baseline",
186        "Neighborhood clamp plus combined residual/depth/normal/neighborhood trigger.",
187        move |context| {
188            let clamped = clamp_to_current_neighborhood(
189                context.scene_frame,
190                context.history,
191                context.x,
192                context.y,
193            );
194            let clamp_distance = clamped.abs_diff(context.history);
195            let residual_gate = smoothstep_threshold(
196                parameters.residual_threshold,
197                context.current.abs_diff(clamped),
198            );
199            let depth_gate = smoothstep_threshold(
200                parameters.depth_disagreement,
201                (context.current_depth - context.reprojected_depth).abs(),
202            );
203            let normal_gate = smoothstep_threshold(
204                parameters.normal_disagreement,
205                1.0 - context
206                    .current_normal
207                    .dot(context.reprojected_normal)
208                    .clamp(-1.0, 1.0),
209            );
210            let neighborhood_gate =
211                smoothstep_threshold(parameters.neighborhood_distance, clamp_distance);
212            let trigger = residual_gate
213                .max(depth_gate)
214                .max(normal_gate)
215                .max(neighborhood_gate);
216            let alpha = parameters.residual_alpha_range.min
217                + (parameters.residual_alpha_range.max - parameters.residual_alpha_range.min)
218                    * trigger;
219            (clamped, alpha, trigger)
220        },
221    )
222}
223
224#[derive(Clone, Copy)]
225struct PixelContext<'a> {
226    scene_frame: &'a crate::scene::SceneFrame,
227    current: Color,
228    history: Color,
229    current_depth: f32,
230    reprojected_depth: f32,
231    current_normal: Normal3,
232    reprojected_normal: Normal3,
233    x: usize,
234    y: usize,
235}
236
237fn run_heuristic_baseline(
238    sequence: &SceneSequence,
239    id: &str,
240    label: &str,
241    description: &str,
242    mut policy: impl FnMut(PixelContext<'_>) -> (Color, f32, f32),
243) -> HeuristicRun {
244    let mut resolved_frames = Vec::with_capacity(sequence.frames.len());
245    let mut reprojected_history_frames = Vec::with_capacity(sequence.frames.len());
246    let mut alpha_frames = Vec::with_capacity(sequence.frames.len());
247    let mut response_frames = Vec::with_capacity(sequence.frames.len());
248
249    for (frame_index, scene_frame) in sequence.frames.iter().enumerate() {
250        let width = scene_frame.ground_truth.width();
251        let height = scene_frame.ground_truth.height();
252        if frame_index == 0 {
253            resolved_frames.push(scene_frame.ground_truth.clone());
254            reprojected_history_frames.push(scene_frame.ground_truth.clone());
255            alpha_frames.push(fill_scalar(width, height, 0.0));
256            response_frames.push(ScalarField::new(width, height));
257            continue;
258        }
259
260        let previous_resolved = &resolved_frames[frame_index - 1];
261        let previous_scene = &sequence.frames[frame_index - 1];
262        let mut reprojected = ImageFrame::new(width, height);
263        let mut resolved = ImageFrame::new(width, height);
264        let mut alpha_frame = ScalarField::new(width, height);
265        let mut response_frame = ScalarField::new(width, height);
266
267        for y in 0..height {
268            for x in 0..width {
269                let motion = scene_frame.motion[y * width + x];
270                let prev_x = x as f32 + motion.to_prev_x;
271                let prev_y = y as f32 + motion.to_prev_y;
272                let history = previous_resolved.sample_bilinear_clamped(prev_x, prev_y);
273                let current = scene_frame.ground_truth.get(x, y);
274                let context = PixelContext {
275                    scene_frame,
276                    current,
277                    history,
278                    current_depth: scene_frame.depth[y * width + x],
279                    reprojected_depth: sample_scalar_bilinear_clamped(
280                        &previous_scene.depth,
281                        width,
282                        height,
283                        prev_x,
284                        prev_y,
285                    ),
286                    current_normal: scene_frame.normals[y * width + x],
287                    reprojected_normal: sample_normal_bilinear_clamped(
288                        &previous_scene.normals,
289                        width,
290                        height,
291                        prev_x,
292                        prev_y,
293                    ),
294                    x,
295                    y,
296                };
297                let (history_used, alpha, response) = policy(context);
298
299                reprojected.set(x, y, history_used);
300                resolved.set(x, y, history_used.lerp(current, alpha));
301                alpha_frame.set(x, y, alpha);
302                response_frame.set(x, y, response);
303            }
304        }
305
306        reprojected_history_frames.push(reprojected);
307        resolved_frames.push(resolved);
308        alpha_frames.push(alpha_frame);
309        response_frames.push(response_frame);
310    }
311
312    HeuristicRun {
313        id: id.to_string(),
314        label: label.to_string(),
315        description: description.to_string(),
316        taa: TaaRun {
317            resolved_frames,
318            reprojected_history_frames,
319        },
320        alpha_frames,
321        response_frames,
322    }
323}
324
325fn clamp_to_current_neighborhood(
326    scene_frame: &crate::scene::SceneFrame,
327    history: Color,
328    x: usize,
329    y: usize,
330) -> Color {
331    let mut min_r = f32::INFINITY;
332    let mut min_g = f32::INFINITY;
333    let mut min_b = f32::INFINITY;
334    let mut max_r = f32::NEG_INFINITY;
335    let mut max_g = f32::NEG_INFINITY;
336    let mut max_b = f32::NEG_INFINITY;
337
338    for (nx, ny) in neighbors(
339        x,
340        y,
341        scene_frame.ground_truth.width(),
342        scene_frame.ground_truth.height(),
343    ) {
344        let color = scene_frame.ground_truth.get(nx, ny);
345        min_r = min_r.min(color.r);
346        min_g = min_g.min(color.g);
347        min_b = min_b.min(color.b);
348        max_r = max_r.max(color.r);
349        max_g = max_g.max(color.g);
350        max_b = max_b.max(color.b);
351    }
352    let current = scene_frame.ground_truth.get(x, y);
353    min_r = min_r.min(current.r);
354    min_g = min_g.min(current.g);
355    min_b = min_b.min(current.b);
356    max_r = max_r.max(current.r);
357    max_g = max_g.max(current.g);
358    max_b = max_b.max(current.b);
359
360    Color::rgb(
361        history.r.clamp(min_r, max_r),
362        history.g.clamp(min_g, max_g),
363        history.b.clamp(min_b, max_b),
364    )
365}
366
367fn neighborhood_distance(
368    scene_frame: &crate::scene::SceneFrame,
369    history: Color,
370    x: usize,
371    y: usize,
372) -> f32 {
373    clamp_to_current_neighborhood(scene_frame, history, x, y).abs_diff(history)
374}
375
376fn fill_scalar(width: usize, height: usize, value: f32) -> ScalarField {
377    let mut field = ScalarField::new(width, height);
378    for y in 0..height {
379        for x in 0..width {
380            field.set(x, y, value);
381        }
382    }
383    field
384}
385
386fn smoothstep_threshold(threshold: SmoothstepThreshold, value: f32) -> f32 {
387    let span = (threshold.high - threshold.low).max(f32::EPSILON);
388    let t = ((value - threshold.low) / span).clamp(0.0, 1.0);
389    t * t * (3.0 - 2.0 * t)
390}
391
392fn sample_scalar_bilinear_clamped(
393    values: &[f32],
394    width: usize,
395    height: usize,
396    x: f32,
397    y: f32,
398) -> f32 {
399    let x0 = x.floor();
400    let y0 = y.floor();
401    let x1 = x0 + 1.0;
402    let y1 = y0 + 1.0;
403    let tx = (x - x0).clamp(0.0, 1.0);
404    let ty = (y - y0).clamp(0.0, 1.0);
405
406    let sample = |sample_x: f32, sample_y: f32| {
407        let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
408        let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
409        values[sy * width + sx]
410    };
411
412    let top = sample(x0, y0) * (1.0 - tx) + sample(x1, y0) * tx;
413    let bottom = sample(x0, y1) * (1.0 - tx) + sample(x1, y1) * tx;
414    top * (1.0 - ty) + bottom * ty
415}
416
417fn sample_normal_bilinear_clamped(
418    values: &[Normal3],
419    width: usize,
420    height: usize,
421    x: f32,
422    y: f32,
423) -> Normal3 {
424    let x0 = x.floor();
425    let y0 = y.floor();
426    let x1 = x0 + 1.0;
427    let y1 = y0 + 1.0;
428    let tx = (x - x0).clamp(0.0, 1.0);
429    let ty = (y - y0).clamp(0.0, 1.0);
430
431    let sample = |sample_x: f32, sample_y: f32| {
432        let sx = sample_x.clamp(0.0, width.saturating_sub(1) as f32) as usize;
433        let sy = sample_y.clamp(0.0, height.saturating_sub(1) as f32) as usize;
434        values[sy * width + sx]
435    };
436
437    let c00 = sample(x0, y0);
438    let c10 = sample(x1, y0);
439    let c01 = sample(x0, y1);
440    let c11 = sample(x1, y1);
441    Normal3::new(
442        (c00.x * (1.0 - tx) + c10.x * tx) * (1.0 - ty) + (c01.x * (1.0 - tx) + c11.x * tx) * ty,
443        (c00.y * (1.0 - tx) + c10.y * tx) * (1.0 - ty) + (c01.y * (1.0 - tx) + c11.y * tx) * ty,
444        (c00.z * (1.0 - tx) + c10.z * tx) * (1.0 - ty) + (c01.z * (1.0 - tx) + c11.z * tx) * ty,
445    )
446    .normalized()
447}
448
449fn neighbors(x: usize, y: usize, width: usize, height: usize) -> Vec<(usize, usize)> {
450    let mut values = Vec::with_capacity(8);
451    for dy in -1i32..=1 {
452        for dx in -1i32..=1 {
453            if dx == 0 && dy == 0 {
454                continue;
455            }
456            let nx = x as i32 + dx;
457            let ny = y as i32 + dy;
458            if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 {
459                values.push((nx as usize, ny as usize));
460            }
461        }
462    }
463    values
464}