Skip to main content

oximedia_simd/
lib.rs

1//! Hand-written assembly SIMD kernels for `OxiMedia`
2//!
3//! This crate provides highly optimized assembly implementations of critical
4//! performance paths in the `OxiMedia` video codec, including:
5//! - DCT (Discrete Cosine Transform) in various sizes
6//! - Interpolation kernels (bilinear, bicubic, 8-tap)
7//! - SAD (Sum of Absolute Differences) for motion estimation
8//!
9//! All assembly is wrapped in safe Rust APIs with proper alignment checks,
10//! buffer validation, and runtime CPU feature detection.
11
12#![deny(unsafe_op_in_unsafe_fn)]
13#![allow(dead_code)]
14
15use std::sync::OnceLock;
16
17#[cfg(all(feature = "native-asm", target_arch = "x86_64"))]
18mod x86;
19
20#[cfg(all(feature = "native-asm", target_arch = "aarch64"))]
21mod arm;
22
23mod scalar;
24
25pub mod accumulator;
26pub mod alpha_premul;
27pub mod audio_ops;
28pub mod avx512;
29pub mod bitwise_ops;
30pub mod blend;
31pub mod blend_simd;
32pub mod color_convert_simd;
33pub mod color_space;
34pub mod convolution;
35pub mod deblock_filter;
36pub mod dispatch;
37pub mod dot_product;
38pub mod entropy_coding;
39pub mod filter;
40pub mod fixed_point;
41pub mod gather_scatter;
42pub mod histogram;
43pub mod interleave;
44pub mod lookup_table;
45pub mod math_ops;
46pub mod matrix;
47pub mod min_max;
48pub mod motion_search;
49pub mod neon;
50pub mod pack_unpack;
51pub mod pixel_ops;
52pub mod portable;
53pub mod prefix_sum;
54pub mod psnr;
55pub mod reduce;
56pub mod resize;
57pub mod satd;
58pub mod saturate;
59pub mod simd_bench;
60pub mod ssim;
61pub mod swizzle;
62pub mod threshold;
63pub mod transpose;
64pub mod vector_math;
65pub mod yuv_ops;
66
67/// CPU features detected at runtime.
68///
69/// Use [`CpuFeatures::detect`] to query the current CPU capabilities, or
70/// [`detect_cpu_features`] for a cached version backed by a [`OnceLock`].
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72#[allow(clippy::struct_excessive_bools)]
73pub struct CpuFeatures {
74    /// AVX2 (256-bit integer / float SIMD).
75    pub avx2: bool,
76    /// AVX-512 Foundation (512-bit float SIMD).
77    pub avx512f: bool,
78    /// AVX-512 Byte-and-Word (512-bit byte/word integer SIMD).
79    pub avx512bw: bool,
80    /// AVX-512 Vector Length extensions (128/256-bit masked operations).
81    pub avx512vl: bool,
82    /// SSE 4.2 (128-bit SIMD with string/text processing extras).
83    pub sse4_2: bool,
84    /// ARM NEON (128-bit SIMD on aarch64).
85    pub neon: bool,
86}
87
88impl CpuFeatures {
89    /// Detect CPU features on the current machine.
90    ///
91    /// This is a thin wrapper around [`detect_cpu_features`] which caches the
92    /// result in a `OnceLock` so subsequent calls are free.
93    #[must_use]
94    pub fn detect() -> Self {
95        detect_cpu_features()
96    }
97
98    /// Return the widest available SIMD register width in bits.
99    ///
100    /// | CPU capability | Width |
101    /// |---------------|-------|
102    /// | AVX-512F      | 512   |
103    /// | AVX2          | 256   |
104    /// | SSE 4.2       | 128   |
105    /// | scalar        |  64   |
106    #[must_use]
107    pub fn best_simd_width(&self) -> usize {
108        if self.avx512f {
109            512
110        } else if self.avx2 {
111            256
112        } else if self.sse4_2 {
113            128
114        } else {
115            64 // scalar word width
116        }
117    }
118}
119
120/// Returns `true` when the executing CPU provides NEON (always `true` on `aarch64`).
121///
122/// This is a convenience re-export of [`neon::has_neon`].
123#[must_use]
124pub fn has_neon() -> bool {
125    neon::has_neon()
126}
127
128static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
129
130/// Detect CPU features at runtime
131pub fn detect_cpu_features() -> CpuFeatures {
132    *CPU_FEATURES.get_or_init(|| {
133        #[cfg(target_arch = "x86_64")]
134        {
135            detect_x86_features()
136        }
137        #[cfg(target_arch = "aarch64")]
138        {
139            detect_arm_features()
140        }
141        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
142        {
143            CpuFeatures {
144                avx2: false,
145                avx512f: false,
146                avx512bw: false,
147                avx512vl: false,
148                sse4_2: false,
149                neon: false,
150            }
151        }
152    })
153}
154
155#[cfg(target_arch = "x86_64")]
156fn detect_x86_features() -> CpuFeatures {
157    CpuFeatures {
158        avx2: is_x86_feature_detected!("avx2"),
159        avx512f: is_x86_feature_detected!("avx512f"),
160        avx512bw: is_x86_feature_detected!("avx512bw"),
161        avx512vl: is_x86_feature_detected!("avx512vl"),
162        sse4_2: is_x86_feature_detected!("sse4.2"),
163        neon: false,
164    }
165}
166
167#[cfg(target_arch = "aarch64")]
168fn detect_arm_features() -> CpuFeatures {
169    CpuFeatures {
170        avx2: false,
171        avx512f: false,
172        avx512bw: false,
173        avx512vl: false,
174        sse4_2: false,
175        neon: cfg!(target_feature = "neon") || std::arch::is_aarch64_feature_detected!("neon"),
176    }
177}
178
179/// DCT transform sizes
180#[derive(Debug, Clone, Copy, PartialEq, Eq)]
181pub enum DctSize {
182    Dct4x4,
183    Dct8x8,
184    Dct16x16,
185    Dct32x32,
186    /// AV1 large-block 64×64 transform (4096 coefficients).
187    Dct64x64,
188}
189
190/// Interpolation filter types
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum InterpolationFilter {
193    Bilinear,
194    Bicubic,
195    EightTap,
196}
197
198/// Block sizes for SAD operations
199#[derive(Debug, Clone, Copy, PartialEq, Eq)]
200pub enum BlockSize {
201    Block16x16,
202    Block32x32,
203    Block64x64,
204}
205
206/// Error types for SIMD operations
207#[derive(Debug, Clone, Copy, PartialEq, Eq)]
208pub enum SimdError {
209    InvalidAlignment,
210    InvalidBufferSize,
211    UnsupportedOperation,
212    CpuFeatureNotAvailable,
213}
214
215impl std::fmt::Display for SimdError {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            SimdError::InvalidAlignment => write!(f, "Invalid buffer alignment"),
219            SimdError::InvalidBufferSize => write!(f, "Invalid buffer size"),
220            SimdError::UnsupportedOperation => write!(f, "Unsupported operation"),
221            SimdError::CpuFeatureNotAvailable => write!(f, "Required CPU feature not available"),
222        }
223    }
224}
225
226impl std::error::Error for SimdError {}
227
228pub type Result<T> = std::result::Result<T, SimdError>;
229
230/// Perform forward DCT transform
231///
232/// # Safety
233/// - `input` must be properly aligned (32-byte for AVX2)
234/// - Buffer sizes must match the transform size
235/// - No overlapping buffers
236///
237/// # Arguments
238/// - `input`: Input pixel data
239/// - `output`: Output DCT coefficients
240/// - `size`: Transform size
241///
242/// # Returns
243/// - `Ok(())` on success
244/// - `Err(SimdError)` if validation fails
245///
246/// # Errors
247///
248/// Returns an error if buffer sizes don't match the transform size.
249pub fn forward_dct(input: &[i16], output: &mut [i16], size: DctSize) -> Result<()> {
250    let required_size = match size {
251        DctSize::Dct4x4 => 16,
252        DctSize::Dct8x8 => 64,
253        DctSize::Dct16x16 => 256,
254        DctSize::Dct32x32 => 1024,
255        DctSize::Dct64x64 => 4096,
256    };
257
258    if input.len() < required_size || output.len() < required_size {
259        return Err(SimdError::InvalidBufferSize);
260    }
261
262    let _features = detect_cpu_features();
263
264    #[cfg(all(feature = "native-asm", target_arch = "x86_64"))]
265    {
266        if _features.avx2 {
267            return x86::forward_dct_avx2(input, output, size);
268        }
269    }
270
271    #[cfg(all(feature = "native-asm", target_arch = "aarch64"))]
272    {
273        if _features.neon {
274            return arm::forward_dct_neon(input, output, size);
275        }
276    }
277
278    // Fallback to scalar implementation
279    scalar::forward_dct_scalar(input, output, size)
280}
281
282/// Perform inverse DCT transform
283///
284/// # Errors
285///
286/// Returns an error if:
287/// - Buffer alignment is insufficient
288/// - Buffer sizes don't match the transform size
289/// - CPU features validation fails
290pub fn inverse_dct(input: &[i16], output: &mut [i16], size: DctSize) -> Result<()> {
291    let required_size = match size {
292        DctSize::Dct4x4 => 16,
293        DctSize::Dct8x8 => 64,
294        DctSize::Dct16x16 => 256,
295        DctSize::Dct32x32 => 1024,
296        DctSize::Dct64x64 => 4096,
297    };
298
299    if input.len() < required_size || output.len() < required_size {
300        return Err(SimdError::InvalidBufferSize);
301    }
302
303    let _features = detect_cpu_features();
304
305    #[cfg(all(feature = "native-asm", target_arch = "x86_64"))]
306    {
307        if _features.avx2 {
308            return x86::inverse_dct_avx2(input, output, size);
309        }
310    }
311
312    #[cfg(all(feature = "native-asm", target_arch = "aarch64"))]
313    {
314        if _features.neon {
315            return arm::inverse_dct_neon(input, output, size);
316        }
317    }
318
319    // Fallback to scalar implementation
320    scalar::inverse_dct_scalar(input, output, size)
321}
322
323/// Perform interpolation for motion compensation
324///
325/// # Arguments
326/// - `src`: Source image data
327/// - `dst`: Destination buffer
328/// - `src_stride`: Source stride in pixels
329/// - `dst_stride`: Destination stride in pixels
330/// - `width`: Block width
331/// - `height`: Block height
332/// - `dx`: Horizontal fractional position (0-15)
333/// - `dy`: Vertical fractional position (0-15)
334/// - `filter`: Interpolation filter type
335///
336/// # Errors
337///
338/// Returns an error if buffer sizes are invalid
339#[allow(clippy::too_many_arguments)]
340pub fn interpolate(
341    src: &[u8],
342    dst: &mut [u8],
343    src_stride: usize,
344    dst_stride: usize,
345    width: usize,
346    height: usize,
347    dx: i32,
348    dy: i32,
349    filter: InterpolationFilter,
350) -> Result<()> {
351    // Validate buffer sizes
352    if src.len() < (height + 8) * src_stride {
353        return Err(SimdError::InvalidBufferSize);
354    }
355    if dst.len() < height * dst_stride {
356        return Err(SimdError::InvalidBufferSize);
357    }
358
359    let _features = detect_cpu_features();
360
361    #[cfg(all(feature = "native-asm", target_arch = "x86_64"))]
362    {
363        if _features.avx2 {
364            return x86::interpolate_avx2(
365                src, dst, src_stride, dst_stride, width, height, dx, dy, filter,
366            );
367        }
368    }
369
370    #[cfg(all(feature = "native-asm", target_arch = "aarch64"))]
371    {
372        if _features.neon {
373            return arm::interpolate_neon(
374                src, dst, src_stride, dst_stride, width, height, dx, dy, filter,
375            );
376        }
377    }
378
379    // Fallback to scalar implementation
380    scalar::interpolate_scalar(
381        src, dst, src_stride, dst_stride, width, height, dx, dy, filter,
382    )
383}
384
385/// Calculate Sum of Absolute Differences (SAD)
386///
387/// # Arguments
388/// - `src1`: First source block
389/// - `src2`: Second source block
390/// - `stride1`: Stride for src1
391/// - `stride2`: Stride for src2
392/// - `size`: Block size
393///
394/// # Returns
395/// - `Ok(sad_value)` on success
396/// - `Err(SimdError)` if validation fails
397///
398/// # Errors
399///
400/// Returns an error if buffer sizes are invalid
401pub fn sad(
402    src1: &[u8],
403    src2: &[u8],
404    stride1: usize,
405    stride2: usize,
406    size: BlockSize,
407) -> Result<u32> {
408    let (width, height) = match size {
409        BlockSize::Block16x16 => (16, 16),
410        BlockSize::Block32x32 => (32, 32),
411        BlockSize::Block64x64 => (64, 64),
412    };
413
414    if src1.len() < height * stride1 || src2.len() < height * stride2 {
415        return Err(SimdError::InvalidBufferSize);
416    }
417
418    let _features = detect_cpu_features();
419
420    #[cfg(all(feature = "native-asm", target_arch = "x86_64"))]
421    {
422        if _features.avx512bw {
423            return x86::sad_avx512(src1, src2, stride1, stride2, size);
424        }
425        if _features.avx2 {
426            return x86::sad_avx2(src1, src2, stride1, stride2, size);
427        }
428    }
429
430    #[cfg(all(feature = "native-asm", target_arch = "aarch64"))]
431    {
432        if _features.neon {
433            return arm::sad_neon(src1, src2, stride1, stride2, size);
434        }
435    }
436
437    // Fallback to scalar implementation
438    scalar::sad_scalar(src1, src2, stride1, stride2, width, height)
439}
440
441/// Check if a pointer is properly aligned for SIMD operations
442#[inline]
443#[must_use]
444pub fn is_aligned(ptr: *const u8, alignment: usize) -> bool {
445    (ptr as usize).is_multiple_of(alignment)
446}
447
448/// Validate buffer alignment for AVX2 (32-byte alignment)
449///
450/// # Errors
451///
452/// Returns an error if buffer is not 32-byte aligned
453pub fn validate_avx2_alignment(buffer: &[u8]) -> Result<()> {
454    if !is_aligned(buffer.as_ptr(), 32) {
455        return Err(SimdError::InvalidAlignment);
456    }
457    Ok(())
458}
459
460/// Validate buffer alignment for AVX-512 (64-byte alignment)
461///
462/// # Errors
463///
464/// Returns an error if buffer is not 64-byte aligned
465pub fn validate_avx512_alignment(buffer: &[u8]) -> Result<()> {
466    if !is_aligned(buffer.as_ptr(), 64) {
467        return Err(SimdError::InvalidAlignment);
468    }
469    Ok(())
470}
471
472/// Validate buffer alignment for NEON (16-byte alignment)
473///
474/// # Errors
475///
476/// Returns an error if buffer is not 16-byte aligned
477pub fn validate_neon_alignment(buffer: &[u8]) -> Result<()> {
478    if !is_aligned(buffer.as_ptr(), 16) {
479        return Err(SimdError::InvalidAlignment);
480    }
481    Ok(())
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_cpu_feature_detection() {
490        let features = detect_cpu_features();
491        // Just ensure it doesn't crash
492        println!("Detected CPU features: {features:?}");
493    }
494
495    #[test]
496    fn test_alignment_check() {
497        let aligned = [0u8; 64];
498        assert!(is_aligned(aligned.as_ptr(), 8));
499    }
500
501    #[test]
502    fn test_dct_sizes() {
503        assert_eq!(
504            match DctSize::Dct4x4 {
505                DctSize::Dct4x4 => 16,
506                _ => 0,
507            },
508            16
509        );
510    }
511}