Skip to main content

oximedia_codec/reconstruct/
super_res.rs

1//! Super-resolution upscaling for AV1.
2//!
3//! AV1 supports encoding at a reduced horizontal resolution and upscaling
4//! during decode. This provides coding efficiency gains at minimal quality
5//! loss using Lanczos-like filtering.
6
7#![forbid(unsafe_code)]
8#![allow(clippy::unreadable_literal)]
9#![allow(clippy::items_after_statements)]
10#![allow(clippy::unnecessary_wraps)]
11#![allow(clippy::struct_excessive_bools)]
12#![allow(clippy::identity_op)]
13#![allow(clippy::range_plus_one)]
14#![allow(clippy::needless_range_loop)]
15#![allow(clippy::useless_conversion)]
16#![allow(clippy::redundant_closure_for_method_calls)]
17#![allow(clippy::single_match_else)]
18#![allow(dead_code)]
19#![allow(clippy::doc_markdown)]
20#![allow(clippy::unused_self)]
21#![allow(clippy::trivially_copy_pass_by_ref)]
22#![allow(clippy::cast_possible_truncation)]
23#![allow(clippy::cast_sign_loss)]
24#![allow(clippy::cast_possible_wrap)]
25#![allow(clippy::cast_precision_loss)]
26#![allow(clippy::missing_errors_doc)]
27#![allow(clippy::too_many_arguments)]
28#![allow(clippy::similar_names)]
29#![allow(clippy::cast_lossless)]
30#![allow(clippy::many_single_char_names)]
31
32use super::pipeline::FrameContext;
33use super::{FrameBuffer, PlaneBuffer, ReconstructResult, ReconstructionError};
34
35// =============================================================================
36// Constants
37// =============================================================================
38
39/// Minimum super-resolution scale denominator.
40pub const SUPERRES_DENOM_MIN: u8 = 9;
41
42/// Maximum super-resolution scale denominator (no scaling).
43pub const SUPERRES_DENOM_MAX: u8 = 16;
44
45/// Number of super-resolution filter taps.
46pub const SUPERRES_FILTER_TAPS: usize = 8;
47
48/// Super-resolution filter bits.
49pub const SUPERRES_FILTER_BITS: u8 = 6;
50
51/// Super-resolution filter offset.
52pub const SUPERRES_FILTER_OFFSET: i32 = 1 << (SUPERRES_FILTER_BITS - 1);
53
54// =============================================================================
55// Upscale Method
56// =============================================================================
57
58/// Upscaling method.
59#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
60pub enum UpscaleMethod {
61    /// Lanczos-like filter (default for AV1).
62    #[default]
63    Lanczos,
64    /// Bilinear interpolation (simpler, lower quality).
65    Bilinear,
66    /// Bicubic interpolation.
67    Bicubic,
68    /// Nearest neighbor (fastest, lowest quality).
69    Nearest,
70}
71
72impl UpscaleMethod {
73    /// Get filter kernel size.
74    #[must_use]
75    pub const fn kernel_size(self) -> usize {
76        match self {
77            Self::Lanczos => 8,
78            Self::Bicubic => 4,
79            Self::Bilinear => 2,
80            Self::Nearest => 1,
81        }
82    }
83}
84
85// =============================================================================
86// Super-Res Configuration
87// =============================================================================
88
89/// Configuration for super-resolution.
90#[derive(Clone, Debug)]
91pub struct SuperResConfig {
92    /// Scale denominator (9-16, 16 = no scaling).
93    pub denominator: u8,
94    /// Original (encoded) width.
95    pub encoded_width: u32,
96    /// Target (output) width.
97    pub upscaled_width: u32,
98    /// Frame height (unchanged).
99    pub height: u32,
100    /// Upscale method.
101    pub method: UpscaleMethod,
102}
103
104impl Default for SuperResConfig {
105    fn default() -> Self {
106        Self {
107            denominator: SUPERRES_DENOM_MAX,
108            encoded_width: 0,
109            upscaled_width: 0,
110            height: 0,
111            method: UpscaleMethod::Lanczos,
112        }
113    }
114}
115
116impl SuperResConfig {
117    /// Create a new super-res configuration.
118    #[must_use]
119    pub fn new(encoded_width: u32, upscaled_width: u32, height: u32) -> Self {
120        let denominator = if upscaled_width > 0 && encoded_width > 0 {
121            let ratio = (encoded_width as f32 * 16.0 / upscaled_width as f32).round() as u8;
122            ratio.clamp(SUPERRES_DENOM_MIN, SUPERRES_DENOM_MAX)
123        } else {
124            SUPERRES_DENOM_MAX
125        };
126
127        Self {
128            denominator,
129            encoded_width,
130            upscaled_width,
131            height,
132            method: UpscaleMethod::Lanczos,
133        }
134    }
135
136    /// Create from scale denominator.
137    #[must_use]
138    pub fn from_denominator(denominator: u8, encoded_width: u32, height: u32) -> Self {
139        let denom = denominator.clamp(SUPERRES_DENOM_MIN, SUPERRES_DENOM_MAX);
140        let upscaled_width = (encoded_width as u64 * 16 / u64::from(denom)) as u32;
141
142        Self {
143            denominator: denom,
144            encoded_width,
145            upscaled_width,
146            height,
147            method: UpscaleMethod::Lanczos,
148        }
149    }
150
151    /// Get the scale factor.
152    #[must_use]
153    pub fn scale_factor(&self) -> f32 {
154        16.0 / f32::from(self.denominator)
155    }
156
157    /// Check if super-res is enabled.
158    #[must_use]
159    pub fn is_enabled(&self) -> bool {
160        self.denominator < SUPERRES_DENOM_MAX
161    }
162
163    /// Set upscale method.
164    #[must_use]
165    pub const fn with_method(mut self, method: UpscaleMethod) -> Self {
166        self.method = method;
167        self
168    }
169
170    /// Calculate the source x position for a target x position.
171    #[must_use]
172    pub fn source_x(&self, target_x: u32) -> (u32, u32) {
173        // Calculate with higher precision
174        let step = (u64::from(self.denominator) << 14) / 16;
175        let offset = step / 2;
176        let src_pos = u64::from(target_x) * step + offset;
177
178        let integer = (src_pos >> 14) as u32;
179        let fraction = ((src_pos & 0x3FFF) >> 8) as u32; // 6 bits of fraction
180
181        (integer, fraction)
182    }
183}
184
185// =============================================================================
186// Lanczos Filter Kernels
187// =============================================================================
188
189/// Pre-computed Lanczos filter kernels for different phases.
190const SUPERRES_FILTER_KERNELS: [[i16; SUPERRES_FILTER_TAPS]; 64] = [
191    // Phase 0
192    [0, 0, 0, 64, 0, 0, 0, 0],
193    [0, 1, -3, 63, 4, -1, 0, 0],
194    [0, 1, -5, 62, 8, -2, 0, 0],
195    [-1, 2, -7, 61, 12, -3, 0, 0],
196    [-1, 2, -9, 59, 17, -4, 0, 0],
197    [-1, 3, -10, 57, 21, -5, -1, 0],
198    [-1, 3, -11, 55, 26, -6, -1, -1],
199    [-1, 3, -12, 52, 30, -6, -1, -1],
200    // Phase 8
201    [-1, 4, -13, 50, 34, -7, -2, -1],
202    [-1, 4, -13, 47, 38, -7, -2, -2],
203    [-1, 4, -13, 44, 42, -8, -2, -2],
204    [-1, 4, -13, 41, 45, -8, -2, -2],
205    [-1, 4, -13, 38, 48, -8, -2, -2],
206    [-1, 4, -12, 35, 51, -8, -2, -3],
207    [-1, 4, -12, 32, 53, -8, -2, -2],
208    [-1, 4, -11, 29, 55, -7, -2, -3],
209    // Phase 16
210    [-1, 3, -10, 26, 57, -7, -2, -2],
211    [-1, 3, -9, 23, 58, -6, -2, -2],
212    [0, 3, -8, 21, 59, -5, -2, -4],
213    [0, 3, -7, 18, 60, -4, -2, -4],
214    [0, 2, -6, 15, 61, -3, -1, -4],
215    [0, 2, -5, 13, 61, -2, -1, -4],
216    [0, 2, -4, 10, 62, -1, -1, -4],
217    [0, 1, -3, 8, 62, 0, 0, -4],
218    // Phase 24
219    [0, 1, -2, 5, 63, 1, 0, -4],
220    [0, 1, -1, 3, 63, 2, 0, -4],
221    [0, 0, 0, 1, 63, 3, 0, -3],
222    [0, 0, 0, 0, 64, 4, 0, -4],
223    [0, 0, 1, -2, 63, 5, 0, -3],
224    [0, 0, 1, -3, 63, 7, 0, -4],
225    [0, 0, 2, -4, 62, 9, -1, -4],
226    [0, 0, 2, -6, 61, 11, -1, -3],
227    // Phase 32 (center)
228    [0, 0, 2, -7, 60, 13, -2, -2],
229    [0, 0, 3, -8, 59, 15, -2, -3],
230    [0, 0, 3, -9, 57, 18, -3, -2],
231    [0, -1, 3, -10, 55, 21, -3, -1],
232    [0, -1, 4, -11, 53, 24, -4, -1],
233    [0, -1, 4, -11, 51, 27, -4, -2],
234    [-1, -1, 4, -12, 49, 30, -4, -1],
235    [-1, -1, 4, -12, 46, 33, -4, -1],
236    // Phase 40
237    [-1, -1, 5, -12, 44, 36, -5, -2],
238    [-1, -2, 5, -12, 41, 39, -5, -1],
239    [-1, -2, 5, -12, 38, 42, -5, -1],
240    [-1, -2, 5, -11, 35, 44, -5, -1],
241    [-1, -2, 5, -11, 33, 46, -5, -1],
242    [-1, -2, 5, -10, 30, 48, -4, -2],
243    [-1, -2, 5, -10, 27, 50, -4, -1],
244    [-1, -2, 5, -9, 25, 51, -4, -1],
245    // Phase 48
246    [-1, -2, 4, -9, 22, 53, -3, 0],
247    [-1, -2, 4, -8, 20, 54, -3, 0],
248    [-1, -2, 4, -7, 17, 55, -2, 0],
249    [-1, -2, 4, -6, 15, 56, -1, -1],
250    [0, -2, 4, -5, 12, 57, 0, -2],
251    [0, -2, 3, -4, 10, 58, 1, -2],
252    [0, -2, 3, -3, 8, 59, 2, -3],
253    [0, -1, 3, -2, 6, 59, 3, -4],
254    // Phase 56
255    [0, -1, 2, -1, 4, 60, 4, -4],
256    [0, -1, 2, 0, 2, 60, 5, -4],
257    [0, -1, 1, 0, 0, 61, 6, -3],
258    [0, 0, 1, 1, -1, 60, 7, -4],
259    [0, 0, 1, 1, -2, 60, 8, -4],
260    [0, 0, 0, 2, -3, 59, 10, -4],
261    [0, 0, 0, 2, -4, 58, 12, -4],
262    [0, 0, 0, 3, -5, 57, 14, -5],
263];
264
265/// Get filter kernel for a given phase.
266fn get_filter_kernel(phase: usize) -> &'static [i16; SUPERRES_FILTER_TAPS] {
267    &SUPERRES_FILTER_KERNELS[phase.min(63)]
268}
269
270// =============================================================================
271// Super-Res Upscaler
272// =============================================================================
273
274/// Super-resolution upscaler.
275#[derive(Debug)]
276pub struct SuperResUpscaler {
277    /// Current configuration.
278    config: Option<SuperResConfig>,
279    /// Temporary row buffer.
280    row_buffer: Vec<i32>,
281}
282
283impl Default for SuperResUpscaler {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289impl SuperResUpscaler {
290    /// Create a new super-res upscaler.
291    #[must_use]
292    pub fn new() -> Self {
293        Self {
294            config: None,
295            row_buffer: Vec::new(),
296        }
297    }
298
299    /// Create with configuration.
300    #[must_use]
301    pub fn with_config(config: SuperResConfig) -> Self {
302        let buffer_size = config.upscaled_width as usize + SUPERRES_FILTER_TAPS;
303        Self {
304            config: Some(config),
305            row_buffer: vec![0; buffer_size],
306        }
307    }
308
309    /// Set configuration.
310    pub fn set_config(&mut self, config: SuperResConfig) {
311        let buffer_size = config.upscaled_width as usize + SUPERRES_FILTER_TAPS;
312        self.row_buffer.resize(buffer_size, 0);
313        self.config = Some(config);
314    }
315
316    /// Get current configuration.
317    #[must_use]
318    pub fn config(&self) -> Option<&SuperResConfig> {
319        self.config.as_ref()
320    }
321
322    /// Apply super-resolution to a frame.
323    ///
324    /// # Errors
325    ///
326    /// Returns error if super-res application fails.
327    pub fn apply(
328        &mut self,
329        frame: &mut FrameBuffer,
330        context: &FrameContext,
331    ) -> ReconstructResult<()> {
332        // Check if super-res is needed
333        if !context.needs_super_res() {
334            return Ok(());
335        }
336
337        // Clone config to avoid borrow issues
338        let config = self.config.clone().ok_or_else(|| {
339            ReconstructionError::InvalidInput("Super-res config not set".to_string())
340        })?;
341
342        if !config.is_enabled() {
343            return Ok(());
344        }
345
346        let bd = frame.bit_depth();
347
348        // Upscale Y plane
349        self.upscale_plane(frame.y_plane_mut(), &config, bd)?;
350
351        // Upscale chroma planes (with adjusted dimensions)
352        if let Some(u) = frame.u_plane_mut() {
353            self.upscale_plane(u, &config, bd)?;
354        }
355        if let Some(v) = frame.v_plane_mut() {
356            self.upscale_plane(v, &config, bd)?;
357        }
358
359        Ok(())
360    }
361
362    /// Upscale a single plane.
363    fn upscale_plane(
364        &mut self,
365        plane: &mut PlaneBuffer,
366        config: &SuperResConfig,
367        bd: u8,
368    ) -> ReconstructResult<()> {
369        let src_width = plane.width() as usize;
370        let height = plane.height() as usize;
371        let target_width = config.upscaled_width as usize;
372
373        if target_width <= src_width {
374            return Ok(()); // No upscaling needed
375        }
376
377        match config.method {
378            UpscaleMethod::Lanczos => {
379                self.upscale_lanczos(plane, src_width, target_width, height, bd)
380            }
381            UpscaleMethod::Bilinear => {
382                self.upscale_bilinear(plane, src_width, target_width, height, bd)
383            }
384            UpscaleMethod::Bicubic => {
385                self.upscale_bicubic(plane, src_width, target_width, height, bd)
386            }
387            UpscaleMethod::Nearest => self.upscale_nearest(plane, src_width, target_width, height),
388        }
389    }
390
391    /// Upscale using Lanczos-like filter.
392    fn upscale_lanczos(
393        &mut self,
394        plane: &mut PlaneBuffer,
395        src_width: usize,
396        target_width: usize,
397        height: usize,
398        bd: u8,
399    ) -> ReconstructResult<()> {
400        let max_val = (1i32 << bd) - 1;
401        let config = self.config.as_ref().ok_or_else(|| {
402            ReconstructionError::Internal(
403                "config not initialized before upscale_lanczos".to_string(),
404            )
405        })?;
406
407        // Process each row
408        for y in 0..height {
409            let row = plane.row(y as u32);
410
411            // Upscale row
412            for x in 0..target_width {
413                let (src_x, phase) = config.source_x(x as u32);
414                let kernel = get_filter_kernel(phase as usize);
415
416                let mut sum: i32 = 0;
417                for (i, &k) in kernel.iter().enumerate() {
418                    let sx = (src_x as i32 + i as i32 - 3).clamp(0, src_width as i32 - 1) as usize;
419                    sum += i32::from(row[sx]) * i32::from(k);
420                }
421
422                // Round and clamp
423                let result =
424                    ((sum + SUPERRES_FILTER_OFFSET) >> SUPERRES_FILTER_BITS).clamp(0, max_val);
425                self.row_buffer[x] = result;
426            }
427
428            // Write back (note: this would require resizing the plane)
429            // For now, just write what fits
430            let dst_row = plane.row_mut(y as u32);
431            let write_width = target_width.min(dst_row.len());
432            for x in 0..write_width {
433                dst_row[x] = self.row_buffer[x] as i16;
434            }
435        }
436
437        Ok(())
438    }
439
440    /// Upscale using bilinear interpolation.
441    fn upscale_bilinear(
442        &mut self,
443        plane: &mut PlaneBuffer,
444        src_width: usize,
445        target_width: usize,
446        height: usize,
447        bd: u8,
448    ) -> ReconstructResult<()> {
449        let max_val = (1i32 << bd) - 1;
450        let scale = src_width as f32 / target_width as f32;
451
452        for y in 0..height {
453            let row = plane.row(y as u32);
454
455            for x in 0..target_width {
456                let src_x = x as f32 * scale;
457                let x0 = src_x.floor() as usize;
458                let x1 = (x0 + 1).min(src_width - 1);
459                let frac = src_x.fract();
460
461                let v0 = i32::from(row[x0]);
462                let v1 = i32::from(row[x1]);
463                let result = ((v0 as f32 * (1.0 - frac) + v1 as f32 * frac).round() as i32)
464                    .clamp(0, max_val);
465                self.row_buffer[x] = result;
466            }
467
468            let dst_row = plane.row_mut(y as u32);
469            let write_width = target_width.min(dst_row.len());
470            for x in 0..write_width {
471                dst_row[x] = self.row_buffer[x] as i16;
472            }
473        }
474
475        Ok(())
476    }
477
478    /// Upscale using bicubic interpolation.
479    fn upscale_bicubic(
480        &mut self,
481        plane: &mut PlaneBuffer,
482        src_width: usize,
483        target_width: usize,
484        height: usize,
485        bd: u8,
486    ) -> ReconstructResult<()> {
487        let max_val = (1i32 << bd) - 1;
488        let scale = src_width as f32 / target_width as f32;
489
490        for y in 0..height {
491            let row = plane.row(y as u32);
492
493            for x in 0..target_width {
494                let src_x = x as f32 * scale;
495                let x1 = src_x.floor() as i32;
496                let frac = src_x.fract();
497
498                // Bicubic weights
499                let w = bicubic_weights(frac);
500
501                let mut sum = 0.0f32;
502                for (i, &weight) in w.iter().enumerate() {
503                    let sx = (x1 + i as i32 - 1).clamp(0, src_width as i32 - 1) as usize;
504                    sum += f32::from(row[sx]) * weight;
505                }
506
507                let result = (sum.round() as i32).clamp(0, max_val);
508                self.row_buffer[x] = result;
509            }
510
511            let dst_row = plane.row_mut(y as u32);
512            let write_width = target_width.min(dst_row.len());
513            for x in 0..write_width {
514                dst_row[x] = self.row_buffer[x] as i16;
515            }
516        }
517
518        Ok(())
519    }
520
521    /// Upscale using nearest neighbor.
522    fn upscale_nearest(
523        &mut self,
524        plane: &mut PlaneBuffer,
525        src_width: usize,
526        target_width: usize,
527        height: usize,
528    ) -> ReconstructResult<()> {
529        let scale = src_width as f32 / target_width as f32;
530
531        for y in 0..height {
532            let row = plane.row(y as u32);
533
534            for x in 0..target_width {
535                let src_x = ((x as f32 * scale).round() as usize).min(src_width - 1);
536                self.row_buffer[x] = i32::from(row[src_x]);
537            }
538
539            let dst_row = plane.row_mut(y as u32);
540            let write_width = target_width.min(dst_row.len());
541            for x in 0..write_width {
542                dst_row[x] = self.row_buffer[x] as i16;
543            }
544        }
545
546        Ok(())
547    }
548}
549
550/// Calculate bicubic interpolation weights.
551fn bicubic_weights(t: f32) -> [f32; 4] {
552    let a = -0.5f32; // Bicubic parameter
553
554    [
555        a * t * t * t - 2.0 * a * t * t + a * t,
556        (a + 2.0) * t * t * t - (a + 3.0) * t * t + 1.0,
557        -(a + 2.0) * t * t * t + (2.0 * a + 3.0) * t * t - a * t,
558        -a * t * t * t + a * t * t,
559    ]
560}
561
562// =============================================================================
563// Tests
564// =============================================================================
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569    use crate::reconstruct::ChromaSubsampling;
570
571    #[test]
572    fn test_upscale_method() {
573        assert_eq!(UpscaleMethod::Lanczos.kernel_size(), 8);
574        assert_eq!(UpscaleMethod::Bicubic.kernel_size(), 4);
575        assert_eq!(UpscaleMethod::Bilinear.kernel_size(), 2);
576        assert_eq!(UpscaleMethod::Nearest.kernel_size(), 1);
577    }
578
579    #[test]
580    fn test_super_res_config_default() {
581        let config = SuperResConfig::default();
582        assert_eq!(config.denominator, SUPERRES_DENOM_MAX);
583        assert!(!config.is_enabled());
584    }
585
586    #[test]
587    fn test_super_res_config_new() {
588        let config = SuperResConfig::new(1600, 1920, 1080);
589        assert!(config.is_enabled());
590        assert!(config.scale_factor() > 1.0);
591    }
592
593    #[test]
594    fn test_super_res_config_from_denominator() {
595        let config = SuperResConfig::from_denominator(12, 1440, 1080);
596        assert_eq!(config.denominator, 12);
597        assert!(config.is_enabled());
598        // 1440 * 16 / 12 = 1920
599        assert_eq!(config.upscaled_width, 1920);
600    }
601
602    #[test]
603    fn test_super_res_config_no_scaling() {
604        let config = SuperResConfig::from_denominator(16, 1920, 1080);
605        assert!(!config.is_enabled());
606        assert!((config.scale_factor() - 1.0).abs() < f32::EPSILON);
607    }
608
609    #[test]
610    fn test_super_res_config_source_x() {
611        let config = SuperResConfig::from_denominator(12, 1440, 1080);
612        let (src_x, phase) = config.source_x(0);
613        assert_eq!(src_x, 0);
614        let _ = phase; // Phase depends on implementation details
615    }
616
617    #[test]
618    fn test_super_res_upscaler_creation() {
619        let upscaler = SuperResUpscaler::new();
620        assert!(upscaler.config().is_none());
621    }
622
623    #[test]
624    fn test_super_res_upscaler_with_config() {
625        let config = SuperResConfig::from_denominator(12, 1440, 1080);
626        let upscaler = SuperResUpscaler::with_config(config);
627        assert!(upscaler.config().is_some());
628    }
629
630    #[test]
631    fn test_super_res_upscaler_set_config() {
632        let mut upscaler = SuperResUpscaler::new();
633        let config = SuperResConfig::from_denominator(12, 1440, 1080);
634        upscaler.set_config(config);
635        assert!(upscaler.config().is_some());
636    }
637
638    #[test]
639    fn test_super_res_apply_disabled() {
640        let mut frame = FrameBuffer::new(64, 64, 8, ChromaSubsampling::Cs420);
641        let context = FrameContext::new(64, 64); // No super-res needed
642
643        let mut upscaler = SuperResUpscaler::new();
644        let result = upscaler.apply(&mut frame, &context);
645        assert!(result.is_ok());
646    }
647
648    #[test]
649    fn test_bicubic_weights() {
650        let w = bicubic_weights(0.0);
651        // At t=0, weights should sum to 1 with w[1] being dominant
652        let sum: f32 = w.iter().sum();
653        assert!((sum - 1.0).abs() < 0.01);
654
655        let w_half = bicubic_weights(0.5);
656        let sum_half: f32 = w_half.iter().sum();
657        assert!((sum_half - 1.0).abs() < 0.01);
658    }
659
660    #[test]
661    fn test_filter_kernel() {
662        let kernel = get_filter_kernel(0);
663        // Phase 0 should be identity-like (center tap dominant)
664        let sum: i16 = kernel.iter().sum();
665        assert_eq!(sum, 64); // Should sum to 64 (1.0 in fixed point)
666    }
667
668    #[test]
669    fn test_constants() {
670        assert_eq!(SUPERRES_DENOM_MIN, 9);
671        assert_eq!(SUPERRES_DENOM_MAX, 16);
672        assert_eq!(SUPERRES_FILTER_TAPS, 8);
673        assert_eq!(SUPERRES_FILTER_BITS, 6);
674    }
675}