Skip to main content

oximedia_codec/simd/
types.rs

1//! Common SIMD types and type aliases.
2//!
3//! This module defines type aliases for common SIMD vector types used in
4//! video codec implementations. These types abstract over the underlying
5//! SIMD implementation (scalar fallback, SSE, AVX, NEON, etc.).
6//!
7//! # Naming Convention
8//!
9//! Types follow the pattern `{element_type}x{lane_count}`:
10//! - `i16x8` - 8 lanes of `i16` (128-bit)
11//! - `i32x4` - 4 lanes of `i32` (128-bit)
12//! - `u8x16` - 16 lanes of `u8` (128-bit)
13
14use std::ops::{Add, Index, IndexMut, Mul, Sub};
15
16/// 8-lane vector of 16-bit signed integers (128-bit).
17///
18/// Common uses:
19/// - DCT coefficients
20/// - Pixel differences
21/// - Filter taps
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
23pub struct I16x8(pub [i16; 8]);
24
25/// 16-lane vector of 16-bit signed integers (256-bit).
26///
27/// Common uses:
28/// - Wide DCT operations
29/// - Parallel coefficient processing
30#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
31pub struct I16x16(pub [i16; 16]);
32
33/// 4-lane vector of 32-bit signed integers (128-bit).
34///
35/// Common uses:
36/// - Accumulated DCT results
37/// - Intermediate filter calculations
38/// - SAD accumulation
39#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
40pub struct I32x4(pub [i32; 4]);
41
42/// 8-lane vector of 32-bit signed integers (256-bit).
43///
44/// Common uses:
45/// - Wide accumulation
46/// - 8-point parallel operations
47#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
48pub struct I32x8(pub [i32; 8]);
49
50/// 16-lane vector of 8-bit unsigned integers (128-bit).
51///
52/// Common uses:
53/// - Raw pixel data
54/// - SAD calculations
55/// - Luma/chroma samples
56#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
57pub struct U8x16(pub [u8; 16]);
58
59/// 32-lane vector of 8-bit unsigned integers (256-bit).
60///
61/// Common uses:
62/// - Wide pixel operations
63/// - AVX-width processing
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
65pub struct U8x32(pub [u8; 32]);
66
67// ============================================================================
68// I16x8 Implementation
69// ============================================================================
70
71impl I16x8 {
72    /// Create a new vector with all lanes set to zero.
73    #[inline]
74    #[must_use]
75    pub const fn zero() -> Self {
76        Self([0; 8])
77    }
78
79    /// Create a new vector with all lanes set to the same value.
80    #[inline]
81    #[must_use]
82    pub const fn splat(value: i16) -> Self {
83        Self([value; 8])
84    }
85
86    /// Create a vector from an array.
87    #[inline]
88    #[must_use]
89    pub const fn from_array(arr: [i16; 8]) -> Self {
90        Self(arr)
91    }
92
93    /// Convert to an array.
94    #[inline]
95    #[must_use]
96    pub const fn to_array(self) -> [i16; 8] {
97        self.0
98    }
99
100    /// Get element at index.
101    #[inline]
102    #[must_use]
103    pub fn get(&self, index: usize) -> i16 {
104        self.0[index]
105    }
106
107    /// Set element at index.
108    #[inline]
109    pub fn set(&mut self, index: usize, value: i16) {
110        self.0[index] = value;
111    }
112
113    /// Widen to I32x4 (low half).
114    #[inline]
115    #[must_use]
116    pub fn widen_low(self) -> I32x4 {
117        I32x4([
118            i32::from(self.0[0]),
119            i32::from(self.0[1]),
120            i32::from(self.0[2]),
121            i32::from(self.0[3]),
122        ])
123    }
124
125    /// Widen to I32x4 (high half).
126    #[inline]
127    #[must_use]
128    pub fn widen_high(self) -> I32x4 {
129        I32x4([
130            i32::from(self.0[4]),
131            i32::from(self.0[5]),
132            i32::from(self.0[6]),
133            i32::from(self.0[7]),
134        ])
135    }
136
137    /// Get a pointer to the underlying array.
138    #[inline]
139    #[must_use]
140    pub const fn as_ptr(&self) -> *const i16 {
141        self.0.as_ptr()
142    }
143
144    /// Get a mutable pointer to the underlying array.
145    #[inline]
146    #[must_use]
147    pub fn as_mut_ptr(&mut self) -> *mut i16 {
148        self.0.as_mut_ptr()
149    }
150
151    /// Get an iterator over the elements.
152    #[inline]
153    pub fn iter(&self) -> std::slice::Iter<'_, i16> {
154        self.0.iter()
155    }
156
157    /// Copy elements from a slice.
158    #[inline]
159    pub fn copy_from_slice(&mut self, src: &[i16]) {
160        self.0.copy_from_slice(src);
161    }
162}
163
164impl Add for I16x8 {
165    type Output = Self;
166
167    #[inline]
168    fn add(self, rhs: Self) -> Self::Output {
169        Self([
170            self.0[0].wrapping_add(rhs.0[0]),
171            self.0[1].wrapping_add(rhs.0[1]),
172            self.0[2].wrapping_add(rhs.0[2]),
173            self.0[3].wrapping_add(rhs.0[3]),
174            self.0[4].wrapping_add(rhs.0[4]),
175            self.0[5].wrapping_add(rhs.0[5]),
176            self.0[6].wrapping_add(rhs.0[6]),
177            self.0[7].wrapping_add(rhs.0[7]),
178        ])
179    }
180}
181
182impl Sub for I16x8 {
183    type Output = Self;
184
185    #[inline]
186    fn sub(self, rhs: Self) -> Self::Output {
187        Self([
188            self.0[0].wrapping_sub(rhs.0[0]),
189            self.0[1].wrapping_sub(rhs.0[1]),
190            self.0[2].wrapping_sub(rhs.0[2]),
191            self.0[3].wrapping_sub(rhs.0[3]),
192            self.0[4].wrapping_sub(rhs.0[4]),
193            self.0[5].wrapping_sub(rhs.0[5]),
194            self.0[6].wrapping_sub(rhs.0[6]),
195            self.0[7].wrapping_sub(rhs.0[7]),
196        ])
197    }
198}
199
200impl Mul for I16x8 {
201    type Output = Self;
202
203    #[inline]
204    fn mul(self, rhs: Self) -> Self::Output {
205        Self([
206            self.0[0].wrapping_mul(rhs.0[0]),
207            self.0[1].wrapping_mul(rhs.0[1]),
208            self.0[2].wrapping_mul(rhs.0[2]),
209            self.0[3].wrapping_mul(rhs.0[3]),
210            self.0[4].wrapping_mul(rhs.0[4]),
211            self.0[5].wrapping_mul(rhs.0[5]),
212            self.0[6].wrapping_mul(rhs.0[6]),
213            self.0[7].wrapping_mul(rhs.0[7]),
214        ])
215    }
216}
217
218impl Index<usize> for I16x8 {
219    type Output = i16;
220
221    #[inline]
222    fn index(&self, index: usize) -> &Self::Output {
223        &self.0[index]
224    }
225}
226
227impl IndexMut<usize> for I16x8 {
228    #[inline]
229    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
230        &mut self.0[index]
231    }
232}
233
234// ============================================================================
235// I16x16 Implementation
236// ============================================================================
237
238impl I16x16 {
239    /// Create a new vector with all lanes set to zero.
240    #[inline]
241    #[must_use]
242    pub const fn zero() -> Self {
243        Self([0; 16])
244    }
245
246    /// Create a new vector with all lanes set to the same value.
247    #[inline]
248    #[must_use]
249    pub const fn splat(value: i16) -> Self {
250        Self([value; 16])
251    }
252
253    /// Create a vector from an array.
254    #[inline]
255    #[must_use]
256    pub const fn from_array(arr: [i16; 16]) -> Self {
257        Self(arr)
258    }
259
260    /// Convert to an array.
261    #[inline]
262    #[must_use]
263    pub const fn to_array(self) -> [i16; 16] {
264        self.0
265    }
266}
267
268// ============================================================================
269// I32x4 Implementation
270// ============================================================================
271
272impl I32x4 {
273    /// Create a new vector with all lanes set to zero.
274    #[inline]
275    #[must_use]
276    pub const fn zero() -> Self {
277        Self([0; 4])
278    }
279
280    /// Create a new vector with all lanes set to the same value.
281    #[inline]
282    #[must_use]
283    pub const fn splat(value: i32) -> Self {
284        Self([value; 4])
285    }
286
287    /// Create a vector from an array.
288    #[inline]
289    #[must_use]
290    pub const fn from_array(arr: [i32; 4]) -> Self {
291        Self(arr)
292    }
293
294    /// Convert to an array.
295    #[inline]
296    #[must_use]
297    pub const fn to_array(self) -> [i32; 4] {
298        self.0
299    }
300
301    /// Horizontal sum of all elements.
302    #[inline]
303    #[must_use]
304    pub fn horizontal_sum(self) -> i32 {
305        self.0[0]
306            .wrapping_add(self.0[1])
307            .wrapping_add(self.0[2])
308            .wrapping_add(self.0[3])
309    }
310
311    /// Narrow to I16x8 with another I32x4 (saturating).
312    #[inline]
313    #[must_use]
314    #[allow(clippy::cast_possible_truncation)]
315    pub fn narrow_sat(self, high: Self) -> I16x8 {
316        let saturate = |v: i32| -> i16 { v.clamp(i32::from(i16::MIN), i32::from(i16::MAX)) as i16 };
317        I16x8([
318            saturate(self.0[0]),
319            saturate(self.0[1]),
320            saturate(self.0[2]),
321            saturate(self.0[3]),
322            saturate(high.0[0]),
323            saturate(high.0[1]),
324            saturate(high.0[2]),
325            saturate(high.0[3]),
326        ])
327    }
328
329    /// Get a pointer to the underlying array.
330    #[inline]
331    #[must_use]
332    pub const fn as_ptr(&self) -> *const i32 {
333        self.0.as_ptr()
334    }
335
336    /// Get a mutable pointer to the underlying array.
337    #[inline]
338    #[must_use]
339    pub fn as_mut_ptr(&mut self) -> *mut i32 {
340        self.0.as_mut_ptr()
341    }
342
343    /// Get an iterator over the elements.
344    #[inline]
345    pub fn iter(&self) -> std::slice::Iter<'_, i32> {
346        self.0.iter()
347    }
348}
349
350impl Add for I32x4 {
351    type Output = Self;
352
353    #[inline]
354    fn add(self, rhs: Self) -> Self::Output {
355        Self([
356            self.0[0].wrapping_add(rhs.0[0]),
357            self.0[1].wrapping_add(rhs.0[1]),
358            self.0[2].wrapping_add(rhs.0[2]),
359            self.0[3].wrapping_add(rhs.0[3]),
360        ])
361    }
362}
363
364impl Sub for I32x4 {
365    type Output = Self;
366
367    #[inline]
368    fn sub(self, rhs: Self) -> Self::Output {
369        Self([
370            self.0[0].wrapping_sub(rhs.0[0]),
371            self.0[1].wrapping_sub(rhs.0[1]),
372            self.0[2].wrapping_sub(rhs.0[2]),
373            self.0[3].wrapping_sub(rhs.0[3]),
374        ])
375    }
376}
377
378impl Index<usize> for I32x4 {
379    type Output = i32;
380
381    #[inline]
382    fn index(&self, index: usize) -> &Self::Output {
383        &self.0[index]
384    }
385}
386
387impl IndexMut<usize> for I32x4 {
388    #[inline]
389    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
390        &mut self.0[index]
391    }
392}
393
394// ============================================================================
395// I32x8 Implementation
396// ============================================================================
397
398impl I32x8 {
399    /// Create a new vector with all lanes set to zero.
400    #[inline]
401    #[must_use]
402    pub const fn zero() -> Self {
403        Self([0; 8])
404    }
405
406    /// Create a new vector with all lanes set to the same value.
407    #[inline]
408    #[must_use]
409    pub const fn splat(value: i32) -> Self {
410        Self([value; 8])
411    }
412
413    /// Create a vector from an array.
414    #[inline]
415    #[must_use]
416    pub const fn from_array(arr: [i32; 8]) -> Self {
417        Self(arr)
418    }
419
420    /// Convert to an array.
421    #[inline]
422    #[must_use]
423    pub const fn to_array(self) -> [i32; 8] {
424        self.0
425    }
426
427    /// Horizontal sum of all elements.
428    #[inline]
429    #[must_use]
430    pub fn horizontal_sum(self) -> i32 {
431        self.0.iter().fold(0i32, |acc, &x| acc.wrapping_add(x))
432    }
433}
434
435// ============================================================================
436// U8x16 Implementation
437// ============================================================================
438
439impl U8x16 {
440    /// Create a new vector with all lanes set to zero.
441    #[inline]
442    #[must_use]
443    pub const fn zero() -> Self {
444        Self([0; 16])
445    }
446
447    /// Create a new vector with all lanes set to the same value.
448    #[inline]
449    #[must_use]
450    pub const fn splat(value: u8) -> Self {
451        Self([value; 16])
452    }
453
454    /// Create a vector from an array.
455    #[inline]
456    #[must_use]
457    pub const fn from_array(arr: [u8; 16]) -> Self {
458        Self(arr)
459    }
460
461    /// Convert to an array.
462    #[inline]
463    #[must_use]
464    pub const fn to_array(self) -> [u8; 16] {
465        self.0
466    }
467
468    /// Get element at index.
469    #[inline]
470    #[must_use]
471    pub fn get(&self, index: usize) -> u8 {
472        self.0[index]
473    }
474
475    /// Set element at index.
476    #[inline]
477    pub fn set(&mut self, index: usize, value: u8) {
478        self.0[index] = value;
479    }
480
481    /// Widen low 8 bytes to I16x8.
482    #[inline]
483    #[must_use]
484    pub fn widen_low_i16(self) -> I16x8 {
485        I16x8([
486            i16::from(self.0[0]),
487            i16::from(self.0[1]),
488            i16::from(self.0[2]),
489            i16::from(self.0[3]),
490            i16::from(self.0[4]),
491            i16::from(self.0[5]),
492            i16::from(self.0[6]),
493            i16::from(self.0[7]),
494        ])
495    }
496
497    /// Widen high 8 bytes to I16x8.
498    #[inline]
499    #[must_use]
500    pub fn widen_high_i16(self) -> I16x8 {
501        I16x8([
502            i16::from(self.0[8]),
503            i16::from(self.0[9]),
504            i16::from(self.0[10]),
505            i16::from(self.0[11]),
506            i16::from(self.0[12]),
507            i16::from(self.0[13]),
508            i16::from(self.0[14]),
509            i16::from(self.0[15]),
510        ])
511    }
512
513    /// Get a pointer to the underlying array.
514    #[inline]
515    #[must_use]
516    pub const fn as_ptr(&self) -> *const u8 {
517        self.0.as_ptr()
518    }
519
520    /// Get a mutable pointer to the underlying array.
521    #[inline]
522    #[must_use]
523    pub fn as_mut_ptr(&mut self) -> *mut u8 {
524        self.0.as_mut_ptr()
525    }
526
527    /// Get an iterator over the elements.
528    #[inline]
529    pub fn iter(&self) -> std::slice::Iter<'_, u8> {
530        self.0.iter()
531    }
532
533    /// Copy elements from a slice.
534    #[inline]
535    pub fn copy_from_slice(&mut self, src: &[u8]) {
536        self.0.copy_from_slice(src);
537    }
538}
539
540impl Index<usize> for U8x16 {
541    type Output = u8;
542
543    #[inline]
544    fn index(&self, index: usize) -> &Self::Output {
545        &self.0[index]
546    }
547}
548
549impl IndexMut<usize> for U8x16 {
550    #[inline]
551    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
552        &mut self.0[index]
553    }
554}
555
556// ============================================================================
557// U8x32 Implementation
558// ============================================================================
559
560impl U8x32 {
561    /// Create a new vector with all lanes set to zero.
562    #[inline]
563    #[must_use]
564    pub const fn zero() -> Self {
565        Self([0; 32])
566    }
567
568    /// Create a new vector with all lanes set to the same value.
569    #[inline]
570    #[must_use]
571    pub const fn splat(value: u8) -> Self {
572        Self([value; 32])
573    }
574
575    /// Create a vector from an array.
576    #[inline]
577    #[must_use]
578    pub const fn from_array(arr: [u8; 32]) -> Self {
579        Self(arr)
580    }
581
582    /// Convert to an array.
583    #[inline]
584    #[must_use]
585    pub const fn to_array(self) -> [u8; 32] {
586        self.0
587    }
588
589    /// Split into two U8x16 vectors.
590    #[inline]
591    #[must_use]
592    pub fn split(self) -> (U8x16, U8x16) {
593        let mut low = [0u8; 16];
594        let mut high = [0u8; 16];
595        low.copy_from_slice(&self.0[0..16]);
596        high.copy_from_slice(&self.0[16..32]);
597        (U8x16(low), U8x16(high))
598    }
599
600    /// Get a pointer to the underlying array.
601    #[inline]
602    #[must_use]
603    pub const fn as_ptr(&self) -> *const u8 {
604        self.0.as_ptr()
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[test]
613    fn test_i16x8_basic() {
614        let a = I16x8::splat(10);
615        let b = I16x8::splat(5);
616        let sum = a + b;
617        assert_eq!(sum.0, [15; 8]);
618
619        let diff = a - b;
620        assert_eq!(diff.0, [5; 8]);
621    }
622
623    #[test]
624    fn test_i16x8_widen() {
625        let v = I16x8::from_array([1, 2, 3, 4, 5, 6, 7, 8]);
626        let low = v.widen_low();
627        let high = v.widen_high();
628        assert_eq!(low.0, [1, 2, 3, 4]);
629        assert_eq!(high.0, [5, 6, 7, 8]);
630    }
631
632    #[test]
633    fn test_i32x4_horizontal_sum() {
634        let v = I32x4::from_array([1, 2, 3, 4]);
635        assert_eq!(v.horizontal_sum(), 10);
636    }
637
638    #[test]
639    fn test_i32x4_narrow_sat() {
640        let low = I32x4::from_array([100, -100, 32767, -32768]);
641        let high = I32x4::from_array([40000, -40000, 0, 1]);
642        let result = low.narrow_sat(high);
643        assert_eq!(result.0, [100, -100, 32767, -32768, 32767, -32768, 0, 1]);
644    }
645
646    #[test]
647    fn test_u8x16_widen() {
648        let v = U8x16::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
649        let low = v.widen_low_i16();
650        let high = v.widen_high_i16();
651        assert_eq!(low.0, [0, 1, 2, 3, 4, 5, 6, 7]);
652        assert_eq!(high.0, [8, 9, 10, 11, 12, 13, 14, 15]);
653    }
654
655    #[test]
656    fn test_u8x32_split() {
657        let mut arr = [0u8; 32];
658        for (i, elem) in arr.iter_mut().enumerate() {
659            *elem = i as u8;
660        }
661        let v = U8x32::from_array(arr);
662        let (low, high) = v.split();
663        assert_eq!(low.0[0], 0);
664        assert_eq!(low.0[15], 15);
665        assert_eq!(high.0[0], 16);
666        assert_eq!(high.0[15], 31);
667    }
668}