Skip to main content

oximedia_codec/reconstruct/
super_res.rs

1//! Super-resolution upscaling for AV1.
2//!
3//! AV1 supports encoding at a reduced horizontal resolution and upscaling
4//! during decode. This provides coding efficiency gains at minimal quality
5//! loss using Lanczos-like filtering.
6
7#![forbid(unsafe_code)]
8#![allow(clippy::unreadable_literal)]
9#![allow(clippy::items_after_statements)]
10#![allow(clippy::unnecessary_wraps)]
11#![allow(clippy::struct_excessive_bools)]
12#![allow(clippy::identity_op)]
13#![allow(clippy::range_plus_one)]
14#![allow(clippy::needless_range_loop)]
15#![allow(clippy::useless_conversion)]
16#![allow(clippy::redundant_closure_for_method_calls)]
17#![allow(clippy::single_match_else)]
18#![allow(dead_code)]
19#![allow(clippy::doc_markdown)]
20#![allow(clippy::unused_self)]
21#![allow(clippy::trivially_copy_pass_by_ref)]
22#![allow(clippy::cast_possible_truncation)]
23#![allow(clippy::cast_sign_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::cast_precision_loss)]
26#![allow(clippy::missing_errors_doc)]
27#![allow(clippy::too_many_arguments)]
28#![allow(clippy::similar_names)]
29#![allow(clippy::cast_lossless)]
30#![allow(clippy::many_single_char_names)]
31
32use super::pipeline::FrameContext;
33use super::{FrameBuffer, PlaneBuffer, ReconstructResult, ReconstructionError};
34
35// =============================================================================
36// Constants
37// =============================================================================
38
39/// Minimum super-resolution scale denominator.
40pub const SUPERRES_DENOM_MIN: u8 = 9;
41
42/// Maximum super-resolution scale denominator (no scaling).
43pub const SUPERRES_DENOM_MAX: u8 = 16;
44
45/// Number of super-resolution filter taps.
46pub const SUPERRES_FILTER_TAPS: usize = 8;
47
48/// Super-resolution filter bits.
49pub const SUPERRES_FILTER_BITS: u8 = 6;
50
51/// Super-resolution filter offset.
52pub const SUPERRES_FILTER_OFFSET: i32 = 1 << (SUPERRES_FILTER_BITS - 1);
53
54// =============================================================================
55// Upscale Method
56// =============================================================================
57
58/// Upscaling method.
59#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
60pub enum UpscaleMethod {
61    /// Lanczos-like filter (default for AV1).
62    #[default]
63    Lanczos,
64    /// Bilinear interpolation (simpler, lower quality).
65    Bilinear,
66    /// Bicubic interpolation.
67    Bicubic,
68    /// Nearest neighbor (fastest, lowest quality).
69    Nearest,
70}
71
72impl UpscaleMethod {
73    /// Get filter kernel size.
74    #[must_use]
75    pub const fn kernel_size(self) -> usize {
76        match self {
77            Self::Lanczos => 8,
78            Self::Bicubic => 4,
79            Self::Bilinear => 2,
80            Self::Nearest => 1,
81        }
82    }
83}
84
85// =============================================================================
86// Super-Res Configuration
87// =============================================================================
88
89/// Configuration for super-resolution.
90#[derive(Clone, Debug)]
91pub struct SuperResConfig {
92    /// Scale denominator (9-16, 16 = no scaling).
93    pub denominator: u8,
94    /// Original (encoded) width.
95    pub encoded_width: u32,
96    /// Target (output) width.
97    pub upscaled_width: u32,
98    /// Frame height (unchanged).
99    pub height: u32,
100    /// Upscale method.
101    pub method: UpscaleMethod,
102}
103
104impl Default for SuperResConfig {
105    fn default() -> Self {
106        Self {
107            denominator: SUPERRES_DENOM_MAX,
108            encoded_width: 0,
109            upscaled_width: 0,
110            height: 0,
111            method: UpscaleMethod::Lanczos,
112        }
113    }
114}
115
116impl SuperResConfig {
117    /// Create a new super-res configuration.
118    #[must_use]
119    pub fn new(encoded_width: u32, upscaled_width: u32, height: u32) -> Self {
120        let denominator = if upscaled_width > 0 && encoded_width > 0 {
121            let ratio = (encoded_width as f32 * 16.0 / upscaled_width as f32).round() as u8;
122            ratio.clamp(SUPERRES_DENOM_MIN, SUPERRES_DENOM_MAX)
123        } else {
124            SUPERRES_DENOM_MAX
125        };
126
127        Self {
128            denominator,
129            encoded_width,
130            upscaled_width,
131            height,
132            method: UpscaleMethod::Lanczos,
133        }
134    }
135
136    /// Create from scale denominator.
137    #[must_use]
138    pub fn from_denominator(denominator: u8, encoded_width: u32, height: u32) -> Self {
139        let denom = denominator.clamp(SUPERRES_DENOM_MIN, SUPERRES_DENOM_MAX);
140        let upscaled_width = (encoded_width as u64 * 16 / u64::from(denom)) as u32;
141
142        Self {
143            denominator: denom,
144            encoded_width,
145            upscaled_width,
146            height,
147            method: UpscaleMethod::Lanczos,
148        }
149    }
150
151    /// Get the scale factor.
152    #[must_use]
153    pub fn scale_factor(&self) -> f32 {
154        16.0 / f32::from(self.denominator)
155    }
156
157    /// Check if super-res is enabled.
158    #[must_use]
159    pub fn is_enabled(&self) -> bool {
160        self.denominator < SUPERRES_DENOM_MAX
161    }
162
163    /// Set upscale method.
164    #[must_use]
165    pub const fn with_method(mut self, method: UpscaleMethod) -> Self {
166        self.method = method;
167        self
168    }
169
170    /// Calculate the source x position for a target x position.
171    #[must_use]
172    pub fn source_x(&self, target_x: u32) -> (u32, u32) {
173        // Calculate with higher precision
174        let step = (u64::from(self.denominator) << 14) / 16;
175        let offset = step / 2;
176        let src_pos = u64::from(target_x) * step + offset;
177
178        let integer = (src_pos >> 14) as u32;
179        let fraction = ((src_pos & 0x3FFF) >> 8) as u32; // 6 bits of fraction
180
181        (integer, fraction)
182    }
183}
184
185// =============================================================================
186// Lanczos Filter Kernels
187// =============================================================================
188
189/// Pre-computed Lanczos filter kernels for different phases.
190const SUPERRES_FILTER_KERNELS: [[i16; SUPERRES_FILTER_TAPS]; 64] = [
191    // Phase 0
192    [0, 0, 0, 64, 0, 0, 0, 0],
193    [0, 1, -3, 63, 4, -1, 0, 0],
194    [0, 1, -5, 62, 8, -2, 0, 0],
195    [-1, 2, -7, 61, 12, -3, 0, 0],
196    [-1, 2, -9, 59, 17, -4, 0, 0],
197    [-1, 3, -10, 57, 21, -5, -1, 0],
198    [-1, 3, -11, 55, 26, -6, -1, -1],
199    [-1, 3, -12, 52, 30, -6, -1, -1],
200    // Phase 8
201    [-1, 4, -13, 50, 34, -7, -2, -1],
202    [-1, 4, -13, 47, 38, -7, -2, -2],
203    [-1, 4, -13, 44, 42, -8, -2, -2],
204    [-1, 4, -13, 41, 45, -8, -2, -2],
205    [-1, 4, -13, 38, 48, -8, -2, -2],
206    [-1, 4, -12, 35, 51, -8, -2, -3],
207    [-1, 4, -12, 32, 53, -8, -2, -2],
208    [-1, 4, -11, 29, 55, -7, -2, -3],
209    // Phase 16
210    [-1, 3, -10, 26, 57, -7, -2, -2],
211    [-1, 3, -9, 23, 58, -6, -2, -2],
212    [0, 3, -8, 21, 59, -5, -2, -4],
213    [0, 3, -7, 18, 60, -4, -2, -4],
214    [0, 2, -6, 15, 61, -3, -1, -4],
215    [0, 2, -5, 13, 61, -2, -1, -4],
216    [0, 2, -4, 10, 62, -1, -1, -4],
217    [0, 1, -3, 8, 62, 0, 0, -4],
218    // Phase 24
219    [0, 1, -2, 5, 63, 1, 0, -4],
220    [0, 1, -1, 3, 63, 2, 0, -4],
221    [0, 0, 0, 1, 63, 3, 0, -3],
222    [0, 0, 0, 0, 64, 4, 0, -4],
223    [0, 0, 1, -2, 63, 5, 0, -3],
224    [0, 0, 1, -3, 63, 7, 0, -4],
225    [0, 0, 2, -4, 62, 9, -1, -4],
226    [0, 0, 2, -6, 61, 11, -1, -3],
227    // Phase 32 (center)
228    [0, 0, 2, -7, 60, 13, -2, -2],
229    [0, 0, 3, -8, 59, 15, -2, -3],
230    [0, 0, 3, -9, 57, 18, -3, -2],
231    [0, -1, 3, -10, 55, 21, -3, -1],
232    [0, -1, 4, -11, 53, 24, -4, -1],
233    [0, -1, 4, -11, 51, 27, -4, -2],
234    [-1, -1, 4, -12, 49, 30, -4, -1],
235    [-1, -1, 4, -12, 46, 33, -4, -1],
236    // Phase 40
237    [-1, -1, 5, -12, 44, 36, -5, -2],
238    [-1, -2, 5, -12, 41, 39, -5, -1],
239    [-1, -2, 5, -12, 38, 42, -5, -1],
240    [-1, -2, 5, -11, 35, 44, -5, -1],
241    [-1, -2, 5, -11, 33, 46, -5, -1],
242    [-1, -2, 5, -10, 30, 48, -4, -2],
243    [-1, -2, 5, -10, 27, 50, -4, -1],
244    [-1, -2, 5, -9, 25, 51, -4, -1],
245    // Phase 48
246    [-1, -2, 4, -9, 22, 53, -3, 0],
247    [-1, -2, 4, -8, 20, 54, -3, 0],
248    [-1, -2, 4, -7, 17, 55, -2, 0],
249    [-1, -2, 4, -6, 15, 56, -1, -1],
250    [0, -2, 4, -5, 12, 57, 0, -2],
251    [0, -2, 3, -4, 10, 58, 1, -2],
252    [0, -2, 3, -3, 8, 59, 2, -3],
253    [0, -1, 3, -2, 6, 59, 3, -4],
254    // Phase 56
255    [0, -1, 2, -1, 4, 60, 4, -4],
256    [0, -1, 2, 0, 2, 60, 5, -4],
257    [0, -1, 1, 0, 0, 61, 6, -3],
258    [0, 0, 1, 1, -1, 60, 7, -4],
259    [0, 0, 1, 1, -2, 60, 8, -4],
260    [0, 0, 0, 2, -3, 59, 10, -4],
261    [0, 0, 0, 2, -4, 58, 12, -4],
262    [0, 0, 0, 3, -5, 57, 14, -5],
263];
264
265/// Get filter kernel for a given phase.
266fn get_filter_kernel(phase: usize) -> &'static [i16; SUPERRES_FILTER_TAPS] {
267    &SUPERRES_FILTER_KERNELS[phase.min(63)]
268}
269
270// =============================================================================
271// Super-Res Upscaler
272// =============================================================================
273
274/// Super-resolution upscaler.
275#[derive(Debug)]
276pub struct SuperResUpscaler {
277    /// Current configuration.
278    config: Option<SuperResConfig>,
279    /// Temporary row buffer.
280    row_buffer: Vec<i32>,
281}
282
283impl Default for SuperResUpscaler {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl SuperResUpscaler {
290    /// Create a new super-res upscaler.
291    #[must_use]
292    pub fn new() -> Self {
293        Self {
294            config: None,
295            row_buffer: Vec::new(),
296        }
297    }
298
299    /// Create with configuration.
300    #[must_use]
301    pub fn with_config(config: SuperResConfig) -> Self {
302        let buffer_size = config.upscaled_width as usize + SUPERRES_FILTER_TAPS;
303        Self {
304            config: Some(config),
305            row_buffer: vec![0; buffer_size],
306        }
307    }
308
309    /// Set configuration.
310    pub fn set_config(&mut self, config: SuperResConfig) {
311        let buffer_size = config.upscaled_width as usize + SUPERRES_FILTER_TAPS;
312        self.row_buffer.resize(buffer_size, 0);
313        self.config = Some(config);
314    }
315
316    /// Get current configuration.
317    #[must_use]
318    pub fn config(&self) -> Option<&SuperResConfig> {
319        self.config.as_ref()
320    }
321
322    /// Apply super-resolution to a frame.
323    ///
324    /// # Errors
325    ///
326    /// Returns error if super-res application fails.
327    pub fn apply(
328        &mut self,
329        frame: &mut FrameBuffer,
330        context: &FrameContext,
331    ) -> ReconstructResult<()> {
332        // Check if super-res is needed
333        if !context.needs_super_res() {
334            return Ok(());
335        }
336
337        // Clone config to avoid borrow issues
338        let config = self.config.clone().ok_or_else(|| {
339            ReconstructionError::InvalidInput("Super-res config not set".to_string())
340        })?;
341
342        if !config.is_enabled() {
343            return Ok(());
344        }
345
346        let bd = frame.bit_depth();
347
348        // Upscale Y plane
349        self.upscale_plane(frame.y_plane_mut(), &config, bd)?;
350
351        // Upscale chroma planes (with adjusted dimensions)
352        if let Some(u) = frame.u_plane_mut() {
353            self.upscale_plane(u, &config, bd)?;
354        }
355        if let Some(v) = frame.v_plane_mut() {
356            self.upscale_plane(v, &config, bd)?;
357        }
358
359        Ok(())
360    }
361
362    /// Upscale a single plane.
363    fn upscale_plane(
364        &mut self,
365        plane: &mut PlaneBuffer,
366        config: &SuperResConfig,
367        bd: u8,
368    ) -> ReconstructResult<()> {
369        let src_width = plane.width() as usize;
370        let height = plane.height() as usize;
371        let target_width = config.upscaled_width as usize;
372
373        if target_width <= src_width {
374            return Ok(()); // No upscaling needed
375        }
376
377        match config.method {
378            UpscaleMethod::Lanczos => {
379                self.upscale_lanczos(plane, src_width, target_width, height, bd)
380            }
381            UpscaleMethod::Bilinear => {
382                self.upscale_bilinear(plane, src_width, target_width, height, bd)
383            }
384            UpscaleMethod::Bicubic => {
385                self.upscale_bicubic(plane, src_width, target_width, height, bd)
386            }
387            UpscaleMethod::Nearest => self.upscale_nearest(plane, src_width, target_width, height),
388        }
389    }
390
391    /// Upscale using Lanczos-like filter.
392    fn upscale_lanczos(
393        &mut self,
394        plane: &mut PlaneBuffer,
395        src_width: usize,
396        target_width: usize,
397        height: usize,
398        bd: u8,
399    ) -> ReconstructResult<()> {
400        let max_val = (1i32 << bd) - 1;
401        let config = self
402            .config
403            .as_ref()
404            .expect("config initialized before upscale_lanczos is called");
405
406        // Process each row
407        for y in 0..height {
408            let row = plane.row(y as u32);
409
410            // Upscale row
411            for x in 0..target_width {
412                let (src_x, phase) = config.source_x(x as u32);
413                let kernel = get_filter_kernel(phase as usize);
414
415                let mut sum: i32 = 0;
416                for (i, &k) in kernel.iter().enumerate() {
417                    let sx = (src_x as i32 + i as i32 - 3).clamp(0, src_width as i32 - 1) as usize;
418                    sum += i32::from(row[sx]) * i32::from(k);
419                }
420
421                // Round and clamp
422                let result =
423                    ((sum + SUPERRES_FILTER_OFFSET) >> SUPERRES_FILTER_BITS).clamp(0, max_val);
424                self.row_buffer[x] = result;
425            }
426
427            // Write back (note: this would require resizing the plane)
428            // For now, just write what fits
429            let dst_row = plane.row_mut(y as u32);
430            let write_width = target_width.min(dst_row.len());
431            for x in 0..write_width {
432                dst_row[x] = self.row_buffer[x] as i16;
433            }
434        }
435
436        Ok(())
437    }
438
439    /// Upscale using bilinear interpolation.
440    fn upscale_bilinear(
441        &mut self,
442        plane: &mut PlaneBuffer,
443        src_width: usize,
444        target_width: usize,
445        height: usize,
446        bd: u8,
447    ) -> ReconstructResult<()> {
448        let max_val = (1i32 << bd) - 1;
449        let scale = src_width as f32 / target_width as f32;
450
451        for y in 0..height {
452            let row = plane.row(y as u32);
453
454            for x in 0..target_width {
455                let src_x = x as f32 * scale;
456                let x0 = src_x.floor() as usize;
457                let x1 = (x0 + 1).min(src_width - 1);
458                let frac = src_x.fract();
459
460                let v0 = i32::from(row[x0]);
461                let v1 = i32::from(row[x1]);
462                let result = ((v0 as f32 * (1.0 - frac) + v1 as f32 * frac).round() as i32)
463                    .clamp(0, max_val);
464                self.row_buffer[x] = result;
465            }
466
467            let dst_row = plane.row_mut(y as u32);
468            let write_width = target_width.min(dst_row.len());
469            for x in 0..write_width {
470                dst_row[x] = self.row_buffer[x] as i16;
471            }
472        }
473
474        Ok(())
475    }
476
477    /// Upscale using bicubic interpolation.
478    fn upscale_bicubic(
479        &mut self,
480        plane: &mut PlaneBuffer,
481        src_width: usize,
482        target_width: usize,
483        height: usize,
484        bd: u8,
485    ) -> ReconstructResult<()> {
486        let max_val = (1i32 << bd) - 1;
487        let scale = src_width as f32 / target_width as f32;
488
489        for y in 0..height {
490            let row = plane.row(y as u32);
491
492            for x in 0..target_width {
493                let src_x = x as f32 * scale;
494                let x1 = src_x.floor() as i32;
495                let frac = src_x.fract();
496
497                // Bicubic weights
498                let w = bicubic_weights(frac);
499
500                let mut sum = 0.0f32;
501                for (i, &weight) in w.iter().enumerate() {
502                    let sx = (x1 + i as i32 - 1).clamp(0, src_width as i32 - 1) as usize;
503                    sum += f32::from(row[sx]) * weight;
504                }
505
506                let result = (sum.round() as i32).clamp(0, max_val);
507                self.row_buffer[x] = result;
508            }
509
510            let dst_row = plane.row_mut(y as u32);
511            let write_width = target_width.min(dst_row.len());
512            for x in 0..write_width {
513                dst_row[x] = self.row_buffer[x] as i16;
514            }
515        }
516
517        Ok(())
518    }
519
520    /// Upscale using nearest neighbor.
521    fn upscale_nearest(
522        &mut self,
523        plane: &mut PlaneBuffer,
524        src_width: usize,
525        target_width: usize,
526        height: usize,
527    ) -> ReconstructResult<()> {
528        let scale = src_width as f32 / target_width as f32;
529
530        for y in 0..height {
531            let row = plane.row(y as u32);
532
533            for x in 0..target_width {
534                let src_x = ((x as f32 * scale).round() as usize).min(src_width - 1);
535                self.row_buffer[x] = i32::from(row[src_x]);
536            }
537
538            let dst_row = plane.row_mut(y as u32);
539            let write_width = target_width.min(dst_row.len());
540            for x in 0..write_width {
541                dst_row[x] = self.row_buffer[x] as i16;
542            }
543        }
544
545        Ok(())
546    }
547}
548
549/// Calculate bicubic interpolation weights.
550fn bicubic_weights(t: f32) -> [f32; 4] {
551    let a = -0.5f32; // Bicubic parameter
552
553    [
554        a * t * t * t - 2.0 * a * t * t + a * t,
555        (a + 2.0) * t * t * t - (a + 3.0) * t * t + 1.0,
556        -(a + 2.0) * t * t * t + (2.0 * a + 3.0) * t * t - a * t,
557        -a * t * t * t + a * t * t,
558    ]
559}
560
561// =============================================================================
562// Tests
563// =============================================================================
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use crate::reconstruct::ChromaSubsampling;
569
570    #[test]
571    fn test_upscale_method() {
572        assert_eq!(UpscaleMethod::Lanczos.kernel_size(), 8);
573        assert_eq!(UpscaleMethod::Bicubic.kernel_size(), 4);
574        assert_eq!(UpscaleMethod::Bilinear.kernel_size(), 2);
575        assert_eq!(UpscaleMethod::Nearest.kernel_size(), 1);
576    }
577
578    #[test]
579    fn test_super_res_config_default() {
580        let config = SuperResConfig::default();
581        assert_eq!(config.denominator, SUPERRES_DENOM_MAX);
582        assert!(!config.is_enabled());
583    }
584
585    #[test]
586    fn test_super_res_config_new() {
587        let config = SuperResConfig::new(1600, 1920, 1080);
588        assert!(config.is_enabled());
589        assert!(config.scale_factor() > 1.0);
590    }
591
592    #[test]
593    fn test_super_res_config_from_denominator() {
594        let config = SuperResConfig::from_denominator(12, 1440, 1080);
595        assert_eq!(config.denominator, 12);
596        assert!(config.is_enabled());
597        // 1440 * 16 / 12 = 1920
598        assert_eq!(config.upscaled_width, 1920);
599    }
600
601    #[test]
602    fn test_super_res_config_no_scaling() {
603        let config = SuperResConfig::from_denominator(16, 1920, 1080);
604        assert!(!config.is_enabled());
605        assert!((config.scale_factor() - 1.0).abs() < f32::EPSILON);
606    }
607
608    #[test]
609    fn test_super_res_config_source_x() {
610        let config = SuperResConfig::from_denominator(12, 1440, 1080);
611        let (src_x, phase) = config.source_x(0);
612        assert_eq!(src_x, 0);
613        let _ = phase; // Phase depends on implementation details
614    }
615
616    #[test]
617    fn test_super_res_upscaler_creation() {
618        let upscaler = SuperResUpscaler::new();
619        assert!(upscaler.config().is_none());
620    }
621
622    #[test]
623    fn test_super_res_upscaler_with_config() {
624        let config = SuperResConfig::from_denominator(12, 1440, 1080);
625        let upscaler = SuperResUpscaler::with_config(config);
626        assert!(upscaler.config().is_some());
627    }
628
629    #[test]
630    fn test_super_res_upscaler_set_config() {
631        let mut upscaler = SuperResUpscaler::new();
632        let config = SuperResConfig::from_denominator(12, 1440, 1080);
633        upscaler.set_config(config);
634        assert!(upscaler.config().is_some());
635    }
636
637    #[test]
638    fn test_super_res_apply_disabled() {
639        let mut frame = FrameBuffer::new(64, 64, 8, ChromaSubsampling::Cs420);
640        let context = FrameContext::new(64, 64); // No super-res needed
641
642        let mut upscaler = SuperResUpscaler::new();
643        let result = upscaler.apply(&mut frame, &context);
644        assert!(result.is_ok());
645    }
646
647    #[test]
648    fn test_bicubic_weights() {
649        let w = bicubic_weights(0.0);
650        // At t=0, weights should sum to 1 with w[1] being dominant
651        let sum: f32 = w.iter().sum();
652        assert!((sum - 1.0).abs() < 0.01);
653
654        let w_half = bicubic_weights(0.5);
655        let sum_half: f32 = w_half.iter().sum();
656        assert!((sum_half - 1.0).abs() < 0.01);
657    }
658
659    #[test]
660    fn test_filter_kernel() {
661        let kernel = get_filter_kernel(0);
662        // Phase 0 should be identity-like (center tap dominant)
663        let sum: i16 = kernel.iter().sum();
664        assert_eq!(sum, 64); // Should sum to 64 (1.0 in fixed point)
665    }
666
667    #[test]
668    fn test_constants() {
669        assert_eq!(SUPERRES_DENOM_MIN, 9);
670        assert_eq!(SUPERRES_DENOM_MAX, 16);
671        assert_eq!(SUPERRES_FILTER_TAPS, 8);
672        assert_eq!(SUPERRES_FILTER_BITS, 6);
673    }
674}