Skip to main content

oximedia_codec/simd/
blend.rs

1//! Blending operations for video codec implementations.
2//!
3//! This module provides blending primitives used in:
4//! - Motion compensation (bilinear interpolation)
5//! - Frame blending
6//! - Alpha compositing
7//!
8//! All operations are designed to map efficiently to SIMD instructions.
9
10#![forbid(unsafe_code)]
11
12use super::scalar::ScalarFallback;
13use super::traits::{SimdOps, SimdOpsExt};
14use super::types::{I16x8, U8x16};
15
16/// Blending operations using SIMD.
17pub struct BlendOps<S: SimdOps> {
18    simd: S,
19}
20
21impl<S: SimdOps + Default> Default for BlendOps<S> {
22    fn default() -> Self {
23        Self::new(S::default())
24    }
25}
26
27impl<S: SimdOps> BlendOps<S> {
28    /// Create a new blending operations instance.
29    #[inline]
30    #[must_use]
31    pub const fn new(simd: S) -> Self {
32        Self { simd }
33    }
34
35    /// Get the underlying SIMD implementation.
36    #[inline]
37    #[must_use]
38    pub const fn simd(&self) -> &S {
39        &self.simd
40    }
41
42    /// Linear interpolation between two values.
43    ///
44    /// Returns: a + (b - a) * weight / 256
45    ///
46    /// Weight is in range [0, 256] where:
47    /// - 0 = 100% a
48    /// - 256 = 100% b
49    /// - 128 = 50% each
50    #[inline]
51    #[allow(clippy::cast_sign_loss)]
52    pub fn lerp_u8(&self, a: u8, b: u8, weight: u8) -> u8 {
53        let a32 = i32::from(a);
54        let b32 = i32::from(b);
55        let w32 = i32::from(weight);
56        let result = a32 + ((b32 - a32) * w32 + 128) / 256;
57        // Safe: clamping to [0, 255] ensures the value fits in u8
58        result.clamp(0, 255) as u8
59    }
60
61    /// Linear interpolation for i16x8 vectors.
62    ///
63    /// Returns: a + (b - a) * weight / 256
64    #[inline]
65    pub fn lerp_i16x8(&self, a: I16x8, b: I16x8, weight: i16) -> I16x8 {
66        let diff = self.simd.sub_i16x8(b, a);
67        let weight_vec = I16x8::splat(weight);
68        let scaled = self.simd.mul_i16x8(diff, weight_vec);
69        let shifted = self.simd.shr_i16x8(scaled, 8);
70        self.simd.add_i16x8(a, shifted)
71    }
72
73    /// Weighted average of two u8x16 vectors.
74    ///
75    /// Returns: (a * (256 - weight) + b * weight + 128) / 256
76    #[inline]
77    #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
78    pub fn weighted_avg_u8x16(&self, a: U8x16, b: U8x16, weight: u8) -> U8x16 {
79        let mut result = [0u8; 16];
80        let w = u16::from(weight);
81        let inv_w = 256 - w;
82
83        for i in 0..16 {
84            // Result is always in [0, 255] due to the weighted average
85            let val = (u16::from(a.0[i]) * inv_w + u16::from(b.0[i]) * w + 128) / 256;
86            result[i] = val as u8;
87        }
88
89        U8x16(result)
90    }
91
92    /// Bilinear blend for motion compensation.
93    ///
94    /// Blends 4 samples using horizontal and vertical weights.
95    /// Used for sub-pixel motion estimation.
96    ///
97    /// Layout:
98    /// ```text
99    /// tl --- tr
100    /// |      |
101    /// bl --- br
102    /// ```
103    ///
104    /// Returns: blend of all 4 based on (hweight, vweight)
105    #[inline]
106    #[allow(dead_code)]
107    pub fn bilinear_blend_u8(
108        &self,
109        tl: u8,
110        tr: u8,
111        bl: u8,
112        br: u8,
113        hweight: u8,
114        vweight: u8,
115    ) -> u8 {
116        // Horizontal interpolation for top and bottom
117        let top = self.lerp_u8(tl, tr, hweight);
118        let bottom = self.lerp_u8(bl, br, hweight);
119
120        // Vertical interpolation
121        self.lerp_u8(top, bottom, vweight)
122    }
123
124    /// Bilinear blend for a row of 8 pixels.
125    ///
126    /// Takes 4 input rows and blends them bilinearly.
127    #[inline]
128    #[allow(dead_code, clippy::too_many_arguments)]
129    pub fn bilinear_blend_row_8(
130        &self,
131        tl: &[u8],
132        tr: &[u8],
133        bl: &[u8],
134        br: &[u8],
135        hweight: u8,
136        vweight: u8,
137        dst: &mut [u8],
138    ) {
139        let len = 8
140            .min(tl.len())
141            .min(tr.len())
142            .min(bl.len())
143            .min(br.len())
144            .min(dst.len());
145        for i in 0..len {
146            dst[i] = self.bilinear_blend_u8(tl[i], tr[i], bl[i], br[i], hweight, vweight);
147        }
148    }
149}
150
151impl<S: SimdOps + SimdOpsExt> BlendOps<S> {
152    /// Bilinear blend using SIMD for a row of 8 pixels.
153    #[allow(dead_code, clippy::similar_names, clippy::too_many_arguments)]
154    pub fn bilinear_blend_row_8_simd(
155        &self,
156        tl: &[u8],
157        tr: &[u8],
158        bl: &[u8],
159        br: &[u8],
160        hweight: u8,
161        vweight: u8,
162        dst: &mut [u8],
163    ) {
164        // Load as i16 for computation
165        let tl_v = self.simd.load8_u8_to_i16x8(tl);
166        let tr_v = self.simd.load8_u8_to_i16x8(tr);
167        let bl_v = self.simd.load8_u8_to_i16x8(bl);
168        let br_v = self.simd.load8_u8_to_i16x8(br);
169
170        // Horizontal blend
171        let top = self.lerp_i16x8(tl_v, tr_v, i16::from(hweight));
172        let bottom = self.lerp_i16x8(bl_v, br_v, i16::from(hweight));
173
174        // Vertical blend
175        let result = self.lerp_i16x8(top, bottom, i16::from(vweight));
176
177        // Store result
178        self.simd.store8_i16x8_as_u8(result, dst);
179    }
180}
181
182/// Create a blending operations instance with scalar fallback.
183#[inline]
184#[must_use]
185pub fn blend_ops() -> BlendOps<ScalarFallback> {
186    BlendOps::new(ScalarFallback::new())
187}
188
189/// Half-pixel interpolation filter taps (6-tap filter).
190///
191/// Used for sub-pixel motion compensation in H.264/AV1.
192#[allow(dead_code)]
193pub const HALF_PEL_FILTER: [i16; 6] = [1, -5, 20, 20, -5, 1];
194
195/// Quarter-pixel interpolation filter taps.
196#[allow(dead_code)]
197pub const QUARTER_PEL_FILTER: [i16; 6] = [1, -5, 52, 20, -5, 1];
198
199/// Apply 6-tap horizontal filter for half-pixel interpolation.
200#[allow(dead_code, clippy::cast_sign_loss)]
201pub fn apply_half_pel_h(src: &[u8], dst: &mut [u8], width: usize) {
202    if width < 6 || src.len() < width + 5 {
203        return;
204    }
205
206    for x in 0..width {
207        let mut sum: i32 = 0;
208        for (k, &tap) in HALF_PEL_FILTER.iter().enumerate() {
209            sum += i32::from(src[x + k]) * i32::from(tap);
210        }
211        // Round and clip - safe because we clamp to [0, 255]
212        let result = (sum + 16) >> 5;
213        dst[x] = result.clamp(0, 255) as u8;
214    }
215}
216
217/// Apply 6-tap vertical filter for half-pixel interpolation.
218#[allow(dead_code, clippy::cast_sign_loss)]
219pub fn apply_half_pel_v(src: &[&[u8]], dst: &mut [u8], width: usize) {
220    if src.len() < 6 {
221        return;
222    }
223
224    for x in 0..width.min(dst.len()) {
225        let mut sum: i32 = 0;
226        for (k, &tap) in HALF_PEL_FILTER.iter().enumerate() {
227            if x < src[k].len() {
228                sum += i32::from(src[k][x]) * i32::from(tap);
229            }
230        }
231        // Round and clip - safe because we clamp to [0, 255]
232        let result = (sum + 16) >> 5;
233        dst[x] = result.clamp(0, 255) as u8;
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_lerp_u8() {
243        let blend = blend_ops();
244
245        // 0 weight = 100% a
246        assert_eq!(blend.lerp_u8(100, 200, 0), 100);
247
248        // 128 weight = 50% each (approximately)
249        let mid = blend.lerp_u8(0, 200, 128);
250        assert!(mid >= 99 && mid <= 101); // Allow rounding
251
252        // Near full weight
253        let high = blend.lerp_u8(0, 200, 255);
254        assert!(high >= 198 && high <= 200);
255    }
256
257    #[test]
258    fn test_weighted_avg_u8x16() {
259        let blend = blend_ops();
260
261        let a = U8x16::splat(100);
262        let b = U8x16::splat(200);
263
264        // 50% blend
265        let result = blend.weighted_avg_u8x16(a, b, 128);
266        for &v in &result.0 {
267            assert!(v >= 149 && v <= 151);
268        }
269
270        // 0% = all a
271        let result_a = blend.weighted_avg_u8x16(a, b, 0);
272        assert_eq!(result_a.0, [100; 16]);
273
274        // 100% = all b (weight = 256, but we use 255 max)
275        let result_b = blend.weighted_avg_u8x16(a, b, 255);
276        for &v in &result_b.0 {
277            assert!(v >= 199 && v <= 200);
278        }
279    }
280
281    #[test]
282    fn test_bilinear_blend() {
283        let blend = blend_ops();
284
285        // All same values should return same value
286        let result = blend.bilinear_blend_u8(100, 100, 100, 100, 128, 128);
287        assert_eq!(result, 100);
288
289        // Corner cases
290        let tl_only = blend.bilinear_blend_u8(100, 0, 0, 0, 0, 0);
291        assert_eq!(tl_only, 100);
292
293        let tr_only = blend.bilinear_blend_u8(0, 100, 0, 0, 255, 0);
294        assert!(tr_only >= 99);
295
296        let bl_only = blend.bilinear_blend_u8(0, 0, 100, 0, 0, 255);
297        assert!(bl_only >= 99);
298    }
299
300    #[test]
301    fn test_lerp_i16x8() {
302        let blend = blend_ops();
303
304        let a = I16x8::from_array([0, 10, 20, 30, 40, 50, 60, 70]);
305        let b = I16x8::from_array([100, 110, 120, 130, 140, 150, 160, 170]);
306
307        // 50% blend (weight = 128)
308        let result = blend.lerp_i16x8(a, b, 128);
309        // Each should be approximately (a + b) / 2
310        assert!(result.0[0] >= 49 && result.0[0] <= 51);
311    }
312
313    #[test]
314    fn test_bilinear_row() {
315        let blend = blend_ops();
316
317        let tl = [100u8; 8];
318        let tr = [100u8; 8];
319        let bl = [100u8; 8];
320        let br = [100u8; 8];
321        let mut dst = [0u8; 8];
322
323        blend.bilinear_blend_row_8(&tl, &tr, &bl, &br, 128, 128, &mut dst);
324
325        for &v in &dst {
326            assert_eq!(v, 100);
327        }
328    }
329
330    #[test]
331    fn test_half_pel_filter() {
332        // Sum of filter taps should be 32 (for normalization by >>5)
333        let sum: i16 = HALF_PEL_FILTER.iter().sum();
334        assert_eq!(sum, 32);
335    }
336
337    #[test]
338    fn test_apply_half_pel_h() {
339        // Create a simple test pattern
340        let src = [128u8; 16];
341        let mut dst = [0u8; 10];
342
343        apply_half_pel_h(&src, &mut dst, 10);
344
345        // Constant input should produce constant output
346        for &v in &dst {
347            assert!(v >= 127 && v <= 129);
348        }
349    }
350}