Skip to main content

oxiblas_core/scalar/
batch.rs

1//! Batch operations, SIMD compatibility, classification, and summation algorithms.
2
3use num_complex::{Complex32, Complex64};
4
5use super::traits::Scalar;
6
7#[cfg(feature = "f16")]
8use half::f16;
9
10#[cfg(feature = "f128")]
11use super::extended::QuadFloat;
12
13// =============================================================================
14// Scalar trait specialization for performance
15// =============================================================================
16
17/// Marker trait for types with hardware FMA (fused multiply-add) support.
18///
19/// Types implementing this trait have efficient hardware FMA instructions,
20/// enabling optimized implementations of algorithms like dot products and
21/// matrix multiplications.
22pub trait HasFastFma: Scalar {}
23
24impl HasFastFma for f32 {}
25impl HasFastFma for f64 {}
26impl HasFastFma for Complex32 {}
27impl HasFastFma for Complex64 {}
28
29/// Marker trait for types that can be efficiently vectorized with SIMD.
30///
31/// This trait indicates that the type has a natural mapping to SIMD registers
32/// and operations.
33pub trait SimdCompatible: Scalar {
34    /// The preferred SIMD width (number of elements) for this type.
35    const SIMD_WIDTH: usize;
36
37    /// Returns true if SIMD operations are beneficial for the given length.
38    #[inline]
39    fn use_simd_for(len: usize) -> bool {
40        len >= Self::SIMD_WIDTH * 2
41    }
42}
43
44impl SimdCompatible for f32 {
45    #[cfg(target_arch = "x86_64")]
46    const SIMD_WIDTH: usize = 8; // AVX2: 256-bit / 32-bit = 8
47
48    #[cfg(target_arch = "aarch64")]
49    const SIMD_WIDTH: usize = 4; // NEON: 128-bit / 32-bit = 4
50
51    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
52    const SIMD_WIDTH: usize = 4;
53}
54
55impl SimdCompatible for f64 {
56    #[cfg(target_arch = "x86_64")]
57    const SIMD_WIDTH: usize = 4; // AVX2: 256-bit / 64-bit = 4
58
59    #[cfg(target_arch = "aarch64")]
60    const SIMD_WIDTH: usize = 2; // NEON: 128-bit / 64-bit = 2
61
62    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
63    const SIMD_WIDTH: usize = 2;
64}
65
66impl SimdCompatible for Complex32 {
67    // Complex types have half the SIMD width due to doubled storage
68    #[cfg(target_arch = "x86_64")]
69    const SIMD_WIDTH: usize = 4;
70
71    #[cfg(target_arch = "aarch64")]
72    const SIMD_WIDTH: usize = 2;
73
74    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
75    const SIMD_WIDTH: usize = 2;
76}
77
78impl SimdCompatible for Complex64 {
79    #[cfg(target_arch = "x86_64")]
80    const SIMD_WIDTH: usize = 2;
81
82    #[cfg(target_arch = "aarch64")]
83    const SIMD_WIDTH: usize = 1;
84
85    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
86    const SIMD_WIDTH: usize = 1;
87}
88
89/// Batch operations on scalar arrays for performance-critical code.
90///
91/// This trait provides optimized implementations of common operations on
92/// contiguous arrays of scalars, leveraging SIMD where available.
93pub trait ScalarBatch: Scalar + SimdCompatible {
94    /// Computes the dot product of two slices.
95    ///
96    /// # Safety
97    /// Both slices must have the same length.
98    fn dot_batch(x: &[Self], y: &[Self]) -> Self;
99
100    /// Computes the sum of all elements.
101    fn sum_batch(x: &[Self]) -> Self;
102
103    /// Computes the sum of absolute values (L1 norm).
104    fn asum_batch(x: &[Self]) -> Self::Real;
105
106    /// Finds the index of the element with maximum absolute value.
107    fn iamax_batch(x: &[Self]) -> usize;
108
109    /// Scales a vector: x = alpha * x
110    fn scale_batch(alpha: Self, x: &mut [Self]);
111
112    /// AXPY operation: y = alpha * x + y
113    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]);
114
115    /// Fused multiply-add on arrays: `z[i] = a[i] * b[i] + c[i]`
116    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]);
117}
118
119impl ScalarBatch for f32 {
120    #[inline]
121    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
122        debug_assert_eq!(x.len(), y.len());
123        let mut sum = 0.0f32;
124        for i in 0..x.len() {
125            sum = x[i].mul_add(y[i], sum);
126        }
127        sum
128    }
129
130    #[inline]
131    fn sum_batch(x: &[Self]) -> Self {
132        x.iter().copied().sum()
133    }
134
135    #[inline]
136    fn asum_batch(x: &[Self]) -> Self::Real {
137        x.iter().map(|&v| v.abs()).sum()
138    }
139
140    #[inline]
141    fn iamax_batch(x: &[Self]) -> usize {
142        x.iter()
143            .enumerate()
144            .max_by(|(_, a), (_, b)| {
145                a.abs()
146                    .partial_cmp(&b.abs())
147                    .unwrap_or(core::cmp::Ordering::Equal)
148            })
149            .map(|(i, _)| i)
150            .unwrap_or(0)
151    }
152
153    #[inline]
154    fn scale_batch(alpha: Self, x: &mut [Self]) {
155        for xi in x.iter_mut() {
156            *xi *= alpha;
157        }
158    }
159
160    #[inline]
161    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
162        debug_assert_eq!(x.len(), y.len());
163        for i in 0..x.len() {
164            y[i] = alpha.mul_add(x[i], y[i]);
165        }
166    }
167
168    #[inline]
169    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
170        debug_assert_eq!(a.len(), b.len());
171        debug_assert_eq!(a.len(), c.len());
172        debug_assert_eq!(a.len(), out.len());
173        for i in 0..a.len() {
174            out[i] = a[i].mul_add(b[i], c[i]);
175        }
176    }
177}
178
179impl ScalarBatch for f64 {
180    #[inline]
181    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
182        debug_assert_eq!(x.len(), y.len());
183        let mut sum = 0.0f64;
184        for i in 0..x.len() {
185            sum = x[i].mul_add(y[i], sum);
186        }
187        sum
188    }
189
190    #[inline]
191    fn sum_batch(x: &[Self]) -> Self {
192        x.iter().copied().sum()
193    }
194
195    #[inline]
196    fn asum_batch(x: &[Self]) -> Self::Real {
197        x.iter().map(|&v| v.abs()).sum()
198    }
199
200    #[inline]
201    fn iamax_batch(x: &[Self]) -> usize {
202        x.iter()
203            .enumerate()
204            .max_by(|(_, a), (_, b)| {
205                a.abs()
206                    .partial_cmp(&b.abs())
207                    .unwrap_or(core::cmp::Ordering::Equal)
208            })
209            .map(|(i, _)| i)
210            .unwrap_or(0)
211    }
212
213    #[inline]
214    fn scale_batch(alpha: Self, x: &mut [Self]) {
215        for xi in x.iter_mut() {
216            *xi *= alpha;
217        }
218    }
219
220    #[inline]
221    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
222        debug_assert_eq!(x.len(), y.len());
223        for i in 0..x.len() {
224            y[i] = alpha.mul_add(x[i], y[i]);
225        }
226    }
227
228    #[inline]
229    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
230        debug_assert_eq!(a.len(), b.len());
231        debug_assert_eq!(a.len(), c.len());
232        debug_assert_eq!(a.len(), out.len());
233        for i in 0..a.len() {
234            out[i] = a[i].mul_add(b[i], c[i]);
235        }
236    }
237}
238
239impl ScalarBatch for Complex32 {
240    #[inline]
241    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
242        debug_assert_eq!(x.len(), y.len());
243        let mut sum = Complex32::new(0.0, 0.0);
244        for i in 0..x.len() {
245            sum += x[i] * y[i];
246        }
247        sum
248    }
249
250    #[inline]
251    fn sum_batch(x: &[Self]) -> Self {
252        x.iter().copied().sum()
253    }
254
255    #[inline]
256    fn asum_batch(x: &[Self]) -> Self::Real {
257        x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
258    }
259
260    #[inline]
261    fn iamax_batch(x: &[Self]) -> usize {
262        x.iter()
263            .enumerate()
264            .max_by(|(_, a), (_, b)| {
265                (a.re.abs() + a.im.abs())
266                    .partial_cmp(&(b.re.abs() + b.im.abs()))
267                    .unwrap_or(core::cmp::Ordering::Equal)
268            })
269            .map(|(i, _)| i)
270            .unwrap_or(0)
271    }
272
273    #[inline]
274    fn scale_batch(alpha: Self, x: &mut [Self]) {
275        for xi in x.iter_mut() {
276            *xi *= alpha;
277        }
278    }
279
280    #[inline]
281    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
282        debug_assert_eq!(x.len(), y.len());
283        for i in 0..x.len() {
284            y[i] += alpha * x[i];
285        }
286    }
287
288    #[inline]
289    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
290        debug_assert_eq!(a.len(), b.len());
291        debug_assert_eq!(a.len(), c.len());
292        debug_assert_eq!(a.len(), out.len());
293        for i in 0..a.len() {
294            out[i] = a[i] * b[i] + c[i];
295        }
296    }
297}
298
299impl ScalarBatch for Complex64 {
300    #[inline]
301    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
302        debug_assert_eq!(x.len(), y.len());
303        let mut sum = Complex64::new(0.0, 0.0);
304        for i in 0..x.len() {
305            sum += x[i] * y[i];
306        }
307        sum
308    }
309
310    #[inline]
311    fn sum_batch(x: &[Self]) -> Self {
312        x.iter().copied().sum()
313    }
314
315    #[inline]
316    fn asum_batch(x: &[Self]) -> Self::Real {
317        x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
318    }
319
320    #[inline]
321    fn iamax_batch(x: &[Self]) -> usize {
322        x.iter()
323            .enumerate()
324            .max_by(|(_, a), (_, b)| {
325                (a.re.abs() + a.im.abs())
326                    .partial_cmp(&(b.re.abs() + b.im.abs()))
327                    .unwrap_or(core::cmp::Ordering::Equal)
328            })
329            .map(|(i, _)| i)
330            .unwrap_or(0)
331    }
332
333    #[inline]
334    fn scale_batch(alpha: Self, x: &mut [Self]) {
335        for xi in x.iter_mut() {
336            *xi *= alpha;
337        }
338    }
339
340    #[inline]
341    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
342        debug_assert_eq!(x.len(), y.len());
343        for i in 0..x.len() {
344            y[i] += alpha * x[i];
345        }
346    }
347
348    #[inline]
349    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
350        debug_assert_eq!(a.len(), b.len());
351        debug_assert_eq!(a.len(), c.len());
352        debug_assert_eq!(a.len(), out.len());
353        for i in 0..a.len() {
354            out[i] = a[i] * b[i] + c[i];
355        }
356    }
357}
358
359/// Type-level scalar classification for compile-time dispatch.
360///
361/// This enum enables algorithms to specialize at compile time based on
362/// the scalar type's properties.
363#[derive(Debug, Clone, Copy, PartialEq, Eq)]
364pub enum ScalarClass {
365    /// Single-precision real (f32)
366    RealF32,
367    /// Double-precision real (f64)
368    RealF64,
369    /// Single-precision complex
370    ComplexF32,
371    /// Double-precision complex
372    ComplexF64,
373    /// Half-precision real (f16)
374    RealF16,
375    /// Quad-precision real (f128)
376    RealF128,
377    /// Unknown/other type
378    Other,
379}
380
381/// Trait for compile-time scalar classification.
382pub trait ScalarClassify: Scalar {
383    /// The compile-time class of this scalar type.
384    const CLASS: ScalarClass;
385
386    /// Returns the precision level (1 = lowest, 4 = highest).
387    const PRECISION_LEVEL: u8;
388
389    /// Returns the storage size in bytes.
390    const STORAGE_BYTES: usize = core::mem::size_of::<Self>();
391}
392
393impl ScalarClassify for f32 {
394    const CLASS: ScalarClass = ScalarClass::RealF32;
395    const PRECISION_LEVEL: u8 = 2;
396}
397
398impl ScalarClassify for f64 {
399    const CLASS: ScalarClass = ScalarClass::RealF64;
400    const PRECISION_LEVEL: u8 = 3;
401}
402
403impl ScalarClassify for Complex32 {
404    const CLASS: ScalarClass = ScalarClass::ComplexF32;
405    const PRECISION_LEVEL: u8 = 2;
406}
407
408impl ScalarClassify for Complex64 {
409    const CLASS: ScalarClass = ScalarClass::ComplexF64;
410    const PRECISION_LEVEL: u8 = 3;
411}
412
413#[cfg(feature = "f16")]
414impl ScalarClassify for f16 {
415    const CLASS: ScalarClass = ScalarClass::RealF16;
416    const PRECISION_LEVEL: u8 = 1;
417}
418
419#[cfg(feature = "f128")]
420impl ScalarClassify for QuadFloat {
421    const CLASS: ScalarClass = ScalarClass::RealF128;
422    const PRECISION_LEVEL: u8 = 4;
423}
424
425/// Unrolling hints for vectorized loops.
426///
427/// These constants help the compiler make better unrolling decisions
428/// for different scalar types.
429pub trait UnrollHints: Scalar {
430    /// Recommended unroll factor for tight loops.
431    const UNROLL_FACTOR: usize;
432
433    /// Recommended chunk size for blocked algorithms.
434    const BLOCK_SIZE: usize;
435
436    /// Whether to prefer streaming stores (for large writes).
437    const PREFER_STREAMING: bool;
438}
439
440impl UnrollHints for f32 {
441    const UNROLL_FACTOR: usize = 8;
442    const BLOCK_SIZE: usize = 64;
443    const PREFER_STREAMING: bool = true;
444}
445
446impl UnrollHints for f64 {
447    const UNROLL_FACTOR: usize = 4;
448    const BLOCK_SIZE: usize = 32;
449    const PREFER_STREAMING: bool = true;
450}
451
452impl UnrollHints for Complex32 {
453    const UNROLL_FACTOR: usize = 4;
454    const BLOCK_SIZE: usize = 32;
455    const PREFER_STREAMING: bool = true;
456}
457
458impl UnrollHints for Complex64 {
459    const UNROLL_FACTOR: usize = 2;
460    const BLOCK_SIZE: usize = 16;
461    const PREFER_STREAMING: bool = true;
462}
463
464/// Extended precision accumulation support.
465///
466/// For algorithms requiring higher precision during intermediate calculations,
467/// this trait provides access to an extended precision accumulator type.
468pub trait ExtendedPrecision: Scalar {
469    /// The type used for extended precision accumulation.
470    type Accumulator: Scalar;
471
472    /// Converts a value to the accumulator type.
473    fn to_accumulator(self) -> Self::Accumulator;
474
475    /// Converts from the accumulator type back to this type.
476    fn from_accumulator(acc: Self::Accumulator) -> Self;
477}
478
479impl ExtendedPrecision for f32 {
480    type Accumulator = f64;
481
482    #[inline]
483    fn to_accumulator(self) -> f64 {
484        self as f64
485    }
486
487    #[inline]
488    fn from_accumulator(acc: f64) -> f32 {
489        acc as f32
490    }
491}
492
493impl ExtendedPrecision for f64 {
494    // For f64, we use the same type (or could use f128 if available)
495    type Accumulator = f64;
496
497    #[inline]
498    fn to_accumulator(self) -> f64 {
499        self
500    }
501
502    #[inline]
503    fn from_accumulator(acc: f64) -> f64 {
504        acc
505    }
506}
507
508impl ExtendedPrecision for Complex32 {
509    type Accumulator = Complex64;
510
511    #[inline]
512    fn to_accumulator(self) -> Complex64 {
513        Complex64::new(self.re as f64, self.im as f64)
514    }
515
516    #[inline]
517    fn from_accumulator(acc: Complex64) -> Complex32 {
518        Complex32::new(acc.re as f32, acc.im as f32)
519    }
520}
521
522impl ExtendedPrecision for Complex64 {
523    type Accumulator = Complex64;
524
525    #[inline]
526    fn to_accumulator(self) -> Complex64 {
527        self
528    }
529
530    #[inline]
531    fn from_accumulator(acc: Complex64) -> Complex64 {
532        acc
533    }
534}
535
536// =============================================================================
537// Summation algorithms
538// =============================================================================
539
540/// Kahan summation for improved accuracy.
541///
542/// Uses compensated summation to reduce floating-point errors.
543#[derive(Debug, Clone, Copy)]
544pub struct KahanSum<T: Scalar> {
545    sum: T,
546    compensation: T,
547}
548
549impl<T: Scalar> Default for KahanSum<T> {
550    fn default() -> Self {
551        Self::new()
552    }
553}
554
555impl<T: Scalar> KahanSum<T> {
556    /// Creates a new Kahan sum accumulator initialized to zero.
557    #[inline]
558    pub fn new() -> Self {
559        Self {
560            sum: T::zero(),
561            compensation: T::zero(),
562        }
563    }
564
565    /// Adds a value to the sum with compensation.
566    #[inline]
567    pub fn add(&mut self, value: T) {
568        let y = value - self.compensation;
569        let t = self.sum + y;
570        self.compensation = (t - self.sum) - y;
571        self.sum = t;
572    }
573
574    /// Returns the current sum.
575    #[inline]
576    pub fn sum(self) -> T {
577        self.sum
578    }
579}
580
581/// Pairwise summation for reduced error accumulation.
582///
583/// Recursively splits the array and sums pairs, reducing error from O(n) to O(log n).
584#[inline]
585pub fn pairwise_sum<T: Scalar>(values: &[T]) -> T {
586    const THRESHOLD: usize = 32;
587
588    if values.is_empty() {
589        return T::zero();
590    }
591    if values.len() <= THRESHOLD {
592        return values.iter().copied().fold(T::zero(), |acc, x| acc + x);
593    }
594
595    let mid = values.len() / 2;
596    pairwise_sum(&values[..mid]) + pairwise_sum(&values[mid..])
597}
598
599/// Kahan-Babuska-Klein summation (improved compensated summation).
600///
601/// Provides even better error bounds than standard Kahan summation.
602#[derive(Debug, Clone, Copy)]
603pub struct KBKSum<T: Scalar> {
604    sum: T,
605    cs: T,
606    ccs: T,
607}
608
609impl<T: Scalar> Default for KBKSum<T> {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615impl<T: Scalar> KBKSum<T> {
616    /// Creates a new KBK sum accumulator.
617    #[inline]
618    pub fn new() -> Self {
619        Self {
620            sum: T::zero(),
621            cs: T::zero(),
622            ccs: T::zero(),
623        }
624    }
625
626    /// Adds a value with double compensation.
627    #[inline]
628    pub fn add(&mut self, value: T) {
629        let t = self.sum + value;
630        let c = if Scalar::abs(self.sum) >= Scalar::abs(value) {
631            (self.sum - t) + value
632        } else {
633            (value - t) + self.sum
634        };
635        self.sum = t;
636
637        let t2 = self.cs + c;
638        let cc = if Scalar::abs(self.cs) >= Scalar::abs(c) {
639            (self.cs - t2) + c
640        } else {
641            (c - t2) + self.cs
642        };
643        self.cs = t2;
644        self.ccs += cc;
645    }
646
647    /// Returns the compensated sum.
648    #[inline]
649    pub fn sum(self) -> T {
650        self.sum + self.cs + self.ccs
651    }
652}