oxiblas_core/
scalar.rs

1//! Scalar traits for numeric types used in OxiBLAS.
2//!
3//! This module defines the trait hierarchy for numeric types:
4//! - `Scalar`: Base trait for all scalar types
5//! - `Real`: Real number types (f32, f64, f16 with feature, QuadFloat with f128 feature)
6//! - `ComplexScalar`: Complex number types
7//! - `Field`: Field operations (complete algebraic structure)
8
9use core::fmt::{Debug, Display};
10use core::iter::Sum;
11use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
12use num_complex::{Complex32, Complex64};
13use num_traits::{Float, FromPrimitive, NumAssign, One, Zero};
14
15// =============================================================================
16// Type aliases for convenience
17// =============================================================================
18
19/// 32-bit complex type alias (same as `Complex32`).
20pub type C32 = Complex32;
21
22/// 64-bit complex type alias (same as `Complex64`).
23pub type C64 = Complex64;
24
25// =============================================================================
26// Complex number constructors and utilities
27// =============================================================================
28
29/// Creates a complex number from real and imaginary parts.
30///
31/// # Examples
32///
33/// ```
34/// use oxiblas_core::scalar::c64;
35/// let z = c64(3.0, 4.0);
36/// assert_eq!(z.re, 3.0);
37/// assert_eq!(z.im, 4.0);
38/// ```
39#[inline]
40pub const fn c64(re: f64, im: f64) -> C64 {
41    Complex64::new(re, im)
42}
43
44/// Creates a complex number from real and imaginary parts (f32).
45///
46/// # Examples
47///
48/// ```
49/// use oxiblas_core::scalar::c32;
50/// let z = c32(3.0, 4.0);
51/// assert_eq!(z.re, 3.0);
52/// assert_eq!(z.im, 4.0);
53/// ```
54#[inline]
55pub const fn c32(re: f32, im: f32) -> C32 {
56    Complex32::new(re, im)
57}
58
59/// The imaginary unit i (64-bit).
60pub const I64: C64 = Complex64::new(0.0, 1.0);
61
62/// The imaginary unit i (32-bit).
63pub const I32: C32 = Complex32::new(0.0, 1.0);
64
65/// Returns the imaginary unit as a 64-bit complex number.
66#[inline]
67pub const fn imag_unit() -> C64 {
68    I64
69}
70
71/// Returns the imaginary unit as a 32-bit complex number.
72#[inline]
73pub const fn imag_unit32() -> C32 {
74    I32
75}
76
77/// Creates a purely imaginary number (64-bit).
78///
79/// # Examples
80///
81/// ```
82/// use oxiblas_core::scalar::imag;
83/// let z = imag(2.0);
84/// assert_eq!(z.re, 0.0);
85/// assert_eq!(z.im, 2.0);
86/// ```
87#[inline]
88pub const fn imag(im: f64) -> C64 {
89    Complex64::new(0.0, im)
90}
91
92/// Creates a purely imaginary number (32-bit).
93#[inline]
94pub const fn imag32(im: f32) -> C32 {
95    Complex32::new(0.0, im)
96}
97
98/// Creates a real number as complex (64-bit).
99///
100/// # Examples
101///
102/// ```
103/// use oxiblas_core::scalar::real;
104/// let z = real(3.0);
105/// assert_eq!(z.re, 3.0);
106/// assert_eq!(z.im, 0.0);
107/// ```
108#[inline]
109pub const fn real(re: f64) -> C64 {
110    Complex64::new(re, 0.0)
111}
112
113/// Creates a real number as complex (32-bit).
114#[inline]
115pub const fn real32(re: f32) -> C32 {
116    Complex32::new(re, 0.0)
117}
118
119/// Creates a complex number from polar coordinates (64-bit).
120///
121/// # Arguments
122/// * `r` - The magnitude (radius)
123/// * `theta` - The angle in radians
124///
125/// # Examples
126///
127/// ```
128/// use oxiblas_core::scalar::from_polar;
129/// use std::f64::consts::PI;
130/// let z = from_polar(1.0, PI / 2.0);
131/// assert!((z.re - 0.0).abs() < 1e-10);
132/// assert!((z.im - 1.0).abs() < 1e-10);
133/// ```
134#[inline]
135pub fn from_polar(r: f64, theta: f64) -> C64 {
136    Complex64::from_polar(r, theta)
137}
138
139/// Creates a complex number from polar coordinates (32-bit).
140#[inline]
141pub fn from_polar32(r: f32, theta: f32) -> C32 {
142    Complex32::from_polar(r, theta)
143}
144
145/// Extension trait for more ergonomic complex number operations.
146pub trait ComplexExt: Sized {
147    /// The real component type.
148    type Real;
149
150    /// Returns true if this complex number is purely real (imaginary part ≈ 0).
151    #[allow(clippy::wrong_self_convention)] // Complex numbers are Copy, self by value is efficient
152    fn is_purely_real(self, tolerance: Self::Real) -> bool;
153
154    /// Returns true if this complex number is purely imaginary (real part ≈ 0).
155    #[allow(clippy::wrong_self_convention)] // Complex numbers are Copy, self by value is efficient
156    fn is_purely_imaginary(self, tolerance: Self::Real) -> bool;
157
158    /// Rotates the complex number by the given angle (in radians).
159    fn rotate(self, angle: Self::Real) -> Self;
160
161    /// Scales the magnitude while keeping the phase.
162    fn scale_magnitude(self, factor: Self::Real) -> Self;
163
164    /// Returns the complex number normalized to unit magnitude.
165    fn normalize(self) -> Self;
166
167    /// Reflects across the real axis (same as conjugate).
168    fn reflect_real(self) -> Self;
169
170    /// Reflects across the imaginary axis.
171    fn reflect_imag(self) -> Self;
172
173    /// Returns the distance to another complex number.
174    fn distance(self, other: Self) -> Self::Real;
175}
176
177impl ComplexExt for C64 {
178    type Real = f64;
179
180    #[inline]
181    fn is_purely_real(self, tolerance: f64) -> bool {
182        self.im.abs() <= tolerance
183    }
184
185    #[inline]
186    fn is_purely_imaginary(self, tolerance: f64) -> bool {
187        self.re.abs() <= tolerance
188    }
189
190    #[inline]
191    fn rotate(self, angle: f64) -> Self {
192        self * Complex64::from_polar(1.0, angle)
193    }
194
195    #[inline]
196    fn scale_magnitude(self, factor: f64) -> Self {
197        let (r, theta) = self.to_polar();
198        Complex64::from_polar(r * factor, theta)
199    }
200
201    #[inline]
202    fn normalize(self) -> Self {
203        let norm = self.norm();
204        if norm == 0.0 {
205            Complex64::new(0.0, 0.0)
206        } else {
207            self / norm
208        }
209    }
210
211    #[inline]
212    fn reflect_real(self) -> Self {
213        self.conj()
214    }
215
216    #[inline]
217    fn reflect_imag(self) -> Self {
218        Complex64::new(-self.re, self.im)
219    }
220
221    #[inline]
222    fn distance(self, other: Self) -> f64 {
223        (self - other).norm()
224    }
225}
226
227impl ComplexExt for C32 {
228    type Real = f32;
229
230    #[inline]
231    fn is_purely_real(self, tolerance: f32) -> bool {
232        self.im.abs() <= tolerance
233    }
234
235    #[inline]
236    fn is_purely_imaginary(self, tolerance: f32) -> bool {
237        self.re.abs() <= tolerance
238    }
239
240    #[inline]
241    fn rotate(self, angle: f32) -> Self {
242        self * Complex32::from_polar(1.0, angle)
243    }
244
245    #[inline]
246    fn scale_magnitude(self, factor: f32) -> Self {
247        let (r, theta) = self.to_polar();
248        Complex32::from_polar(r * factor, theta)
249    }
250
251    #[inline]
252    fn normalize(self) -> Self {
253        let norm = self.norm();
254        if norm == 0.0 {
255            Complex32::new(0.0, 0.0)
256        } else {
257            self / norm
258        }
259    }
260
261    #[inline]
262    fn reflect_real(self) -> Self {
263        self.conj()
264    }
265
266    #[inline]
267    fn reflect_imag(self) -> Self {
268        Complex32::new(-self.re, self.im)
269    }
270
271    #[inline]
272    fn distance(self, other: Self) -> f32 {
273        (self - other).norm()
274    }
275}
276
277/// Trait for converting real numbers to complex.
278pub trait ToComplex<C> {
279    /// Converts to complex with zero imaginary part.
280    fn to_complex(self) -> C;
281
282    /// Converts to complex with given imaginary part.
283    fn with_imag(self, im: Self) -> C;
284}
285
286impl ToComplex<C64> for f64 {
287    #[inline]
288    fn to_complex(self) -> C64 {
289        Complex64::new(self, 0.0)
290    }
291
292    #[inline]
293    fn with_imag(self, im: f64) -> C64 {
294        Complex64::new(self, im)
295    }
296}
297
298impl ToComplex<C32> for f32 {
299    #[inline]
300    fn to_complex(self) -> C32 {
301        Complex32::new(self, 0.0)
302    }
303
304    #[inline]
305    fn with_imag(self, im: f32) -> C32 {
306        Complex32::new(self, im)
307    }
308}
309
310#[cfg(feature = "f128")]
311use core::ops::{Rem, RemAssign};
312
313#[cfg(feature = "f16")]
314use half::f16;
315
316#[cfg(feature = "f128")]
317use twofloat::TwoFloat;
318
319/// Quad-precision floating-point type using double-double arithmetic.
320///
321/// This newtype wraps `TwoFloat` from the `twofloat` crate, which provides
322/// approximately 106 bits of mantissa precision (31 decimal digits) using
323/// double-double arithmetic. This gives quadruple precision (similar to IEEE 754
324/// binary128) without requiring platform-specific quadmath libraries.
325///
326/// # Features
327///
328/// - Cross-platform pure Rust implementation
329/// - ~31 decimal digits of precision
330/// - All standard mathematical operations (sin, cos, exp, ln, etc.)
331/// - Compatible with OxiBLAS scalar traits
332///
333/// # Example
334///
335/// ```ignore
336/// use oxiblas_core::scalar::QuadFloat;
337///
338/// let x = QuadFloat::from(2.0);
339/// let y = x.sqrt();
340/// assert!((y * y - x).abs() < QuadFloat::from(1e-30));
341/// ```
342#[cfg(feature = "f128")]
343#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Default)]
344#[repr(transparent)]
345pub struct QuadFloat(TwoFloat);
346
347#[cfg(feature = "f128")]
348impl QuadFloat {
349    /// Create a new QuadFloat from a f64
350    #[inline]
351    pub const fn new(value: f64) -> Self {
352        Self(TwoFloat::from_f64(value))
353    }
354
355    /// Get the underlying TwoFloat
356    #[inline]
357    pub const fn inner(self) -> TwoFloat {
358        self.0
359    }
360}
361
362#[cfg(feature = "f128")]
363impl From<f64> for QuadFloat {
364    #[inline]
365    fn from(value: f64) -> Self {
366        Self(TwoFloat::from_f64(value))
367    }
368}
369
370#[cfg(feature = "f128")]
371impl From<TwoFloat> for QuadFloat {
372    #[inline]
373    fn from(value: TwoFloat) -> Self {
374        Self(value)
375    }
376}
377
378// Implement arithmetic operations by delegating to TwoFloat
379#[cfg(feature = "f128")]
380impl Add for QuadFloat {
381    type Output = Self;
382    #[inline]
383    fn add(self, rhs: Self) -> Self::Output {
384        Self(self.0 + rhs.0)
385    }
386}
387
388#[cfg(feature = "f128")]
389impl Sub for QuadFloat {
390    type Output = Self;
391    #[inline]
392    fn sub(self, rhs: Self) -> Self::Output {
393        Self(self.0 - rhs.0)
394    }
395}
396
397#[cfg(feature = "f128")]
398impl Mul for QuadFloat {
399    type Output = Self;
400    #[inline]
401    fn mul(self, rhs: Self) -> Self::Output {
402        Self(self.0 * rhs.0)
403    }
404}
405
406#[cfg(feature = "f128")]
407impl Div for QuadFloat {
408    type Output = Self;
409    #[inline]
410    fn div(self, rhs: Self) -> Self::Output {
411        Self(self.0 / rhs.0)
412    }
413}
414
415#[cfg(feature = "f128")]
416impl Neg for QuadFloat {
417    type Output = Self;
418    #[inline]
419    fn neg(self) -> Self::Output {
420        Self(-self.0)
421    }
422}
423
424#[cfg(feature = "f128")]
425impl AddAssign for QuadFloat {
426    #[inline]
427    fn add_assign(&mut self, rhs: Self) {
428        self.0 = self.0 + rhs.0;
429    }
430}
431
432#[cfg(feature = "f128")]
433impl SubAssign for QuadFloat {
434    #[inline]
435    fn sub_assign(&mut self, rhs: Self) {
436        self.0 = self.0 - rhs.0;
437    }
438}
439
440#[cfg(feature = "f128")]
441impl MulAssign for QuadFloat {
442    #[inline]
443    fn mul_assign(&mut self, rhs: Self) {
444        self.0 = self.0 * rhs.0;
445    }
446}
447
448#[cfg(feature = "f128")]
449impl DivAssign for QuadFloat {
450    #[inline]
451    fn div_assign(&mut self, rhs: Self) {
452        self.0 = self.0 / rhs.0;
453    }
454}
455
456#[cfg(feature = "f128")]
457impl Rem for QuadFloat {
458    type Output = Self;
459    #[inline]
460    fn rem(self, rhs: Self) -> Self::Output {
461        // Implement remainder using floor division
462        let quotient = QuadFloat::from((self.0 / rhs.0).hi().floor());
463        self - quotient * rhs
464    }
465}
466
467#[cfg(feature = "f128")]
468impl RemAssign for QuadFloat {
469    #[inline]
470    fn rem_assign(&mut self, rhs: Self) {
471        *self = *self % rhs;
472    }
473}
474
475#[cfg(feature = "f128")]
476impl Sum for QuadFloat {
477    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
478        iter.fold(QuadFloat::from(0.0), |acc, x| acc + x)
479    }
480}
481
482#[cfg(feature = "f128")]
483impl<'a> Sum<&'a QuadFloat> for QuadFloat {
484    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
485        iter.copied().fold(QuadFloat::from(0.0), |acc, x| acc + x)
486    }
487}
488
489#[cfg(feature = "f128")]
490impl Zero for QuadFloat {
491    #[inline]
492    fn zero() -> Self {
493        QuadFloat::from(0.0)
494    }
495
496    #[inline]
497    fn is_zero(&self) -> bool {
498        self.0 == TwoFloat::from_f64(0.0)
499    }
500}
501
502#[cfg(feature = "f128")]
503impl One for QuadFloat {
504    #[inline]
505    fn one() -> Self {
506        QuadFloat::from(1.0)
507    }
508}
509
510// NumAssign is automatically derived from Num + NumAssignOps
511
512#[cfg(feature = "f128")]
513impl Display for QuadFloat {
514    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
515        write!(f, "{}", self.0)
516    }
517}
518
519// Implement Float trait for QuadFloat by delegating to TwoFloat
520#[cfg(feature = "f128")]
521impl Float for QuadFloat {
522    fn nan() -> Self {
523        QuadFloat(TwoFloat::NAN)
524    }
525
526    fn infinity() -> Self {
527        QuadFloat(TwoFloat::INFINITY)
528    }
529
530    fn neg_infinity() -> Self {
531        QuadFloat(TwoFloat::NEG_INFINITY)
532    }
533
534    fn neg_zero() -> Self {
535        QuadFloat(-TwoFloat::from_f64(0.0))
536    }
537
538    fn min_value() -> Self {
539        QuadFloat(TwoFloat::MIN)
540    }
541
542    fn min_positive_value() -> Self {
543        QuadFloat(TwoFloat::MIN_POSITIVE)
544    }
545
546    fn max_value() -> Self {
547        QuadFloat(TwoFloat::MAX)
548    }
549
550    fn is_nan(self) -> bool {
551        self.0.is_nan()
552    }
553
554    fn is_infinite(self) -> bool {
555        self.0.is_infinite()
556    }
557
558    fn is_finite(self) -> bool {
559        self.0.is_finite()
560    }
561
562    fn is_normal(self) -> bool {
563        self.0.is_normal()
564    }
565
566    fn classify(self) -> core::num::FpCategory {
567        self.0.classify()
568    }
569
570    fn floor(self) -> Self {
571        QuadFloat::from(self.0.hi().floor())
572    }
573
574    fn ceil(self) -> Self {
575        QuadFloat::from(self.0.hi().ceil())
576    }
577
578    fn round(self) -> Self {
579        QuadFloat::from(self.0.hi().round())
580    }
581
582    fn trunc(self) -> Self {
583        QuadFloat::from(self.0.hi().trunc())
584    }
585
586    fn fract(self) -> Self {
587        QuadFloat::from(self.0.hi().fract())
588    }
589
590    fn abs(self) -> Self {
591        QuadFloat(self.0.abs())
592    }
593
594    fn signum(self) -> Self {
595        let zero = QuadFloat::from(0.0);
596        let one = QuadFloat::from(1.0);
597        if self > zero {
598            one
599        } else if self < zero {
600            -one
601        } else {
602            zero
603        }
604    }
605
606    fn is_sign_positive(self) -> bool {
607        self.0.is_sign_positive()
608    }
609
610    fn is_sign_negative(self) -> bool {
611        self.0.is_sign_negative()
612    }
613
614    fn mul_add(self, a: Self, b: Self) -> Self {
615        self * a + b
616    }
617
618    fn recip(self) -> Self {
619        QuadFloat(self.0.recip())
620    }
621
622    fn powi(self, n: i32) -> Self {
623        QuadFloat(self.0.powi(n))
624    }
625
626    fn powf(self, n: Self) -> Self {
627        QuadFloat(self.0.powf(n.0))
628    }
629
630    fn sqrt(self) -> Self {
631        QuadFloat(self.0.sqrt())
632    }
633
634    fn exp(self) -> Self {
635        QuadFloat(self.0.exp())
636    }
637
638    fn exp2(self) -> Self {
639        QuadFloat(TwoFloat::from_f64(2.0).powf(self.0))
640    }
641
642    fn ln(self) -> Self {
643        QuadFloat(self.0.ln())
644    }
645
646    fn log(self, base: Self) -> Self {
647        QuadFloat(self.0.ln() / base.0.ln())
648    }
649
650    fn log2(self) -> Self {
651        QuadFloat(self.0.ln() / TwoFloat::from_f64(2.0).ln())
652    }
653
654    fn log10(self) -> Self {
655        QuadFloat(self.0.log10())
656    }
657
658    fn max(self, other: Self) -> Self {
659        if self > other { self } else { other }
660    }
661
662    fn min(self, other: Self) -> Self {
663        if self < other { self } else { other }
664    }
665
666    fn abs_sub(self, other: Self) -> Self {
667        if self > other {
668            self - other
669        } else {
670            QuadFloat::from(0.0)
671        }
672    }
673
674    fn cbrt(self) -> Self {
675        QuadFloat(self.0.powf(TwoFloat::from_f64(1.0 / 3.0)))
676    }
677
678    fn hypot(self, other: Self) -> Self {
679        Float::sqrt(self * self + other * other)
680    }
681
682    fn sin(self) -> Self {
683        QuadFloat(self.0.sin())
684    }
685
686    fn cos(self) -> Self {
687        QuadFloat(self.0.cos())
688    }
689
690    fn tan(self) -> Self {
691        QuadFloat(self.0.tan())
692    }
693
694    fn asin(self) -> Self {
695        QuadFloat(self.0.asin())
696    }
697
698    fn acos(self) -> Self {
699        QuadFloat(self.0.acos())
700    }
701
702    fn atan(self) -> Self {
703        QuadFloat(self.0.atan())
704    }
705
706    fn atan2(self, other: Self) -> Self {
707        QuadFloat(self.0.atan2(other.0))
708    }
709
710    fn sin_cos(self) -> (Self, Self) {
711        let (sin, cos) = self.0.sin_cos();
712        (QuadFloat(sin), QuadFloat(cos))
713    }
714
715    fn exp_m1(self) -> Self {
716        QuadFloat(self.0.exp() - TwoFloat::from_f64(1.0))
717    }
718
719    fn ln_1p(self) -> Self {
720        QuadFloat((self.0 + TwoFloat::from_f64(1.0)).ln())
721    }
722
723    fn sinh(self) -> Self {
724        QuadFloat(self.0.sinh())
725    }
726
727    fn cosh(self) -> Self {
728        QuadFloat(self.0.cosh())
729    }
730
731    fn tanh(self) -> Self {
732        QuadFloat(self.0.tanh())
733    }
734
735    fn asinh(self) -> Self {
736        QuadFloat(self.0.asinh())
737    }
738
739    fn acosh(self) -> Self {
740        QuadFloat(self.0.acosh())
741    }
742
743    fn atanh(self) -> Self {
744        QuadFloat(self.0.atanh())
745    }
746
747    fn integer_decode(self) -> (u64, i16, i8) {
748        // For double-double, we decode the high part
749        self.0.hi().integer_decode()
750    }
751
752    fn epsilon() -> Self {
753        QuadFloat::from(f64::EPSILON) * QuadFloat::from(f64::EPSILON)
754    }
755
756    fn to_degrees(self) -> Self {
757        const FACTOR: f64 = 180.0 / core::f64::consts::PI;
758        self * QuadFloat::from(FACTOR)
759    }
760
761    fn to_radians(self) -> Self {
762        const FACTOR: f64 = core::f64::consts::PI / 180.0;
763        self * QuadFloat::from(FACTOR)
764    }
765}
766
767#[cfg(feature = "f128")]
768impl FromPrimitive for QuadFloat {
769    fn from_i64(n: i64) -> Option<Self> {
770        Some(QuadFloat::from(n as f64))
771    }
772
773    fn from_u64(n: u64) -> Option<Self> {
774        Some(QuadFloat::from(n as f64))
775    }
776
777    fn from_f64(n: f64) -> Option<Self> {
778        Some(QuadFloat::from(n))
779    }
780}
781
782#[cfg(feature = "f128")]
783impl num_traits::Num for QuadFloat {
784    type FromStrRadixErr = num_traits::ParseFloatError;
785
786    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
787        f64::from_str_radix(str, radix)
788            .map(QuadFloat::from)
789            .map_err(|_| num_traits::ParseFloatError {
790                kind: num_traits::FloatErrorKind::Invalid,
791            })
792    }
793}
794
795#[cfg(feature = "f128")]
796impl num_traits::NumCast for QuadFloat {
797    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
798        n.to_f64().map(<QuadFloat as From<f64>>::from)
799    }
800}
801
802#[cfg(feature = "f128")]
803impl num_traits::ToPrimitive for QuadFloat {
804    fn to_i64(&self) -> Option<i64> {
805        self.0.hi().to_i64()
806    }
807
808    fn to_u64(&self) -> Option<u64> {
809        self.0.hi().to_u64()
810    }
811
812    fn to_f64(&self) -> Option<f64> {
813        Some(self.0.hi())
814    }
815}
816
817/// Base trait for all scalar types used in OxiBLAS.
818///
819/// This trait provides the fundamental requirements for any numeric type
820/// that can be used in matrix operations.
821pub trait Scalar:
822    Copy
823    + Clone
824    + Debug
825    + Display
826    + Default
827    + Send
828    + Sync
829    + PartialEq
830    + Zero
831    + One
832    + Add<Output = Self>
833    + Sub<Output = Self>
834    + Mul<Output = Self>
835    + Div<Output = Self>
836    + AddAssign
837    + SubAssign
838    + MulAssign
839    + DivAssign
840    + Neg<Output = Self>
841    + Sum
842    + NumAssign
843    + FromPrimitive
844    + 'static
845{
846    /// The real component type (for complex numbers, this is the component type).
847    type Real: Real;
848
849    /// Returns the absolute value (modulus for complex numbers).
850    fn abs(self) -> Self::Real;
851
852    /// Returns the complex conjugate. For real numbers, returns self.
853    fn conj(self) -> Self;
854
855    /// Returns true if this is a real type (not complex).
856    fn is_real() -> bool;
857
858    /// Returns the real part.
859    fn real(self) -> Self::Real;
860
861    /// Returns the imaginary part (zero for real types).
862    fn imag(self) -> Self::Real;
863
864    /// Creates a scalar from real and imaginary parts.
865    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self;
866
867    /// Creates a scalar from just the real part (imaginary = 0).
868    fn from_real(re: Self::Real) -> Self {
869        Self::from_real_imag(re, Self::Real::zero())
870    }
871
872    /// Square of the absolute value (more efficient than abs().powi(2)).
873    fn abs_sq(self) -> Self::Real {
874        let re = self.real();
875        let im = self.imag();
876        re * re + im * im
877    }
878
879    /// Machine epsilon for this type.
880    fn epsilon() -> Self::Real;
881
882    /// Smallest positive normal value.
883    fn min_positive() -> Self::Real;
884
885    /// Largest finite value.
886    fn max_value() -> Self::Real;
887
888    /// Size of the type in bytes.
889    const SIZE: usize = core::mem::size_of::<Self>();
890
891    /// Alignment requirement.
892    const ALIGN: usize = core::mem::align_of::<Self>();
893}
894
895/// Trait for real number types (f32, f64).
896pub trait Real: Scalar<Real = Self> + Float + PartialOrd {
897    /// Square root.
898    fn sqrt(self) -> Self;
899
900    /// Natural logarithm.
901    fn ln(self) -> Self;
902
903    /// Exponential function.
904    fn exp(self) -> Self;
905
906    /// Sine.
907    fn sin(self) -> Self;
908
909    /// Cosine.
910    fn cos(self) -> Self;
911
912    /// Arctangent of y/x with correct quadrant.
913    fn atan2(self, other: Self) -> Self;
914
915    /// Power function.
916    fn powf(self, n: Self) -> Self;
917
918    /// Sign function: 1.0 if positive, -1.0 if negative, 0.0 if zero.
919    fn signum(self) -> Self;
920
921    /// Fused multiply-add: self * a + b
922    fn mul_add(self, a: Self, b: Self) -> Self;
923
924    /// Floor function.
925    fn floor(self) -> Self;
926
927    /// Ceiling function.
928    fn ceil(self) -> Self;
929
930    /// Round to nearest integer.
931    fn round(self) -> Self;
932
933    /// Truncate toward zero.
934    fn trunc(self) -> Self;
935
936    /// Safe reciprocal (returns None if self is zero or would overflow).
937    fn safe_recip(self) -> Option<Self> {
938        if Scalar::abs(self) < Self::min_positive() {
939            None
940        } else {
941            Some(Self::one() / self)
942        }
943    }
944
945    /// Hypot: sqrt(self^2 + other^2) computed without overflow.
946    fn hypot(self, other: Self) -> Self;
947}
948
949/// Trait for complex scalar types.
950pub trait ComplexScalar: Scalar {
951    /// Creates a complex number from real and imaginary parts.
952    fn new(re: Self::Real, im: Self::Real) -> Self;
953
954    /// Returns the argument (phase angle) of the complex number.
955    fn arg(self) -> Self::Real;
956
957    /// Returns the polar form (r, theta) where self = r * e^(i*theta).
958    fn to_polar(self) -> (Self::Real, Self::Real) {
959        (self.abs(), self.arg())
960    }
961
962    /// Creates a complex number from polar form.
963    fn from_polar(r: Self::Real, theta: Self::Real) -> Self;
964
965    /// Complex exponential.
966    fn cexp(self) -> Self;
967
968    /// Complex logarithm (principal branch).
969    fn cln(self) -> Self;
970
971    /// Complex square root (principal branch).
972    fn csqrt(self) -> Self;
973}
974
975/// Field trait - complete algebraic structure with all operations.
976///
977/// This is the main trait used throughout OxiBLAS for generic programming
978/// over numeric types.
979pub trait Field: Scalar {
980    /// Computes self * alpha + other * beta
981    #[inline]
982    fn scale_add(self, alpha: Self, other: Self, beta: Self) -> Self {
983        self * alpha + other * beta
984    }
985
986    /// Computes self * conj(other) for complex, self * other for real.
987    fn mul_conj(self, other: Self) -> Self;
988
989    /// Computes conj(self) * other for complex, self * other for real.
990    fn conj_mul(self, other: Self) -> Self;
991
992    /// Reciprocal (1/self).
993    fn recip(self) -> Self;
994
995    /// Integer power.
996    fn powi(self, n: i32) -> Self;
997}
998
999// =============================================================================
1000// Implementations for f32
1001// =============================================================================
1002
1003impl Scalar for f32 {
1004    type Real = f32;
1005
1006    #[inline]
1007    fn abs(self) -> Self::Real {
1008        <f32 as Float>::abs(self)
1009    }
1010
1011    #[inline]
1012    fn conj(self) -> Self {
1013        self
1014    }
1015
1016    #[inline]
1017    fn is_real() -> bool {
1018        true
1019    }
1020
1021    #[inline]
1022    fn real(self) -> Self::Real {
1023        self
1024    }
1025
1026    #[inline]
1027    fn imag(self) -> Self::Real {
1028        0.0
1029    }
1030
1031    #[inline]
1032    fn from_real_imag(re: Self::Real, _im: Self::Real) -> Self {
1033        re
1034    }
1035
1036    #[inline]
1037    fn abs_sq(self) -> Self::Real {
1038        self * self
1039    }
1040
1041    #[inline]
1042    fn epsilon() -> Self::Real {
1043        f32::EPSILON
1044    }
1045
1046    #[inline]
1047    fn min_positive() -> Self::Real {
1048        f32::MIN_POSITIVE
1049    }
1050
1051    #[inline]
1052    fn max_value() -> Self::Real {
1053        f32::MAX
1054    }
1055}
1056
1057impl Real for f32 {
1058    #[inline]
1059    fn sqrt(self) -> Self {
1060        <f32 as Float>::sqrt(self)
1061    }
1062
1063    #[inline]
1064    fn ln(self) -> Self {
1065        <f32 as Float>::ln(self)
1066    }
1067
1068    #[inline]
1069    fn exp(self) -> Self {
1070        <f32 as Float>::exp(self)
1071    }
1072
1073    #[inline]
1074    fn sin(self) -> Self {
1075        <f32 as Float>::sin(self)
1076    }
1077
1078    #[inline]
1079    fn cos(self) -> Self {
1080        <f32 as Float>::cos(self)
1081    }
1082
1083    #[inline]
1084    fn atan2(self, other: Self) -> Self {
1085        <f32 as Float>::atan2(self, other)
1086    }
1087
1088    #[inline]
1089    fn powf(self, n: Self) -> Self {
1090        <f32 as Float>::powf(self, n)
1091    }
1092
1093    #[inline]
1094    fn signum(self) -> Self {
1095        <f32 as Float>::signum(self)
1096    }
1097
1098    #[inline]
1099    fn mul_add(self, a: Self, b: Self) -> Self {
1100        <f32 as Float>::mul_add(self, a, b)
1101    }
1102
1103    #[inline]
1104    fn floor(self) -> Self {
1105        <f32 as Float>::floor(self)
1106    }
1107
1108    #[inline]
1109    fn ceil(self) -> Self {
1110        <f32 as Float>::ceil(self)
1111    }
1112
1113    #[inline]
1114    fn round(self) -> Self {
1115        <f32 as Float>::round(self)
1116    }
1117
1118    #[inline]
1119    fn trunc(self) -> Self {
1120        <f32 as Float>::trunc(self)
1121    }
1122
1123    #[inline]
1124    fn hypot(self, other: Self) -> Self {
1125        <f32 as Float>::hypot(self, other)
1126    }
1127}
1128
1129impl Field for f32 {
1130    #[inline]
1131    fn mul_conj(self, other: Self) -> Self {
1132        self * other
1133    }
1134
1135    #[inline]
1136    fn conj_mul(self, other: Self) -> Self {
1137        self * other
1138    }
1139
1140    #[inline]
1141    fn recip(self) -> Self {
1142        1.0 / self
1143    }
1144
1145    #[inline]
1146    fn powi(self, n: i32) -> Self {
1147        <f32 as Float>::powi(self, n)
1148    }
1149}
1150
1151// =============================================================================
1152// Implementations for f64
1153// =============================================================================
1154
1155impl Scalar for f64 {
1156    type Real = f64;
1157
1158    #[inline]
1159    fn abs(self) -> Self::Real {
1160        <f64 as Float>::abs(self)
1161    }
1162
1163    #[inline]
1164    fn conj(self) -> Self {
1165        self
1166    }
1167
1168    #[inline]
1169    fn is_real() -> bool {
1170        true
1171    }
1172
1173    #[inline]
1174    fn real(self) -> Self::Real {
1175        self
1176    }
1177
1178    #[inline]
1179    fn imag(self) -> Self::Real {
1180        0.0
1181    }
1182
1183    #[inline]
1184    fn from_real_imag(re: Self::Real, _im: Self::Real) -> Self {
1185        re
1186    }
1187
1188    #[inline]
1189    fn abs_sq(self) -> Self::Real {
1190        self * self
1191    }
1192
1193    #[inline]
1194    fn epsilon() -> Self::Real {
1195        f64::EPSILON
1196    }
1197
1198    #[inline]
1199    fn min_positive() -> Self::Real {
1200        f64::MIN_POSITIVE
1201    }
1202
1203    #[inline]
1204    fn max_value() -> Self::Real {
1205        f64::MAX
1206    }
1207}
1208
1209impl Real for f64 {
1210    #[inline]
1211    fn sqrt(self) -> Self {
1212        <f64 as Float>::sqrt(self)
1213    }
1214
1215    #[inline]
1216    fn ln(self) -> Self {
1217        <f64 as Float>::ln(self)
1218    }
1219
1220    #[inline]
1221    fn exp(self) -> Self {
1222        <f64 as Float>::exp(self)
1223    }
1224
1225    #[inline]
1226    fn sin(self) -> Self {
1227        <f64 as Float>::sin(self)
1228    }
1229
1230    #[inline]
1231    fn cos(self) -> Self {
1232        <f64 as Float>::cos(self)
1233    }
1234
1235    #[inline]
1236    fn atan2(self, other: Self) -> Self {
1237        <f64 as Float>::atan2(self, other)
1238    }
1239
1240    #[inline]
1241    fn powf(self, n: Self) -> Self {
1242        <f64 as Float>::powf(self, n)
1243    }
1244
1245    #[inline]
1246    fn signum(self) -> Self {
1247        <f64 as Float>::signum(self)
1248    }
1249
1250    #[inline]
1251    fn mul_add(self, a: Self, b: Self) -> Self {
1252        <f64 as Float>::mul_add(self, a, b)
1253    }
1254
1255    #[inline]
1256    fn floor(self) -> Self {
1257        <f64 as Float>::floor(self)
1258    }
1259
1260    #[inline]
1261    fn ceil(self) -> Self {
1262        <f64 as Float>::ceil(self)
1263    }
1264
1265    #[inline]
1266    fn round(self) -> Self {
1267        <f64 as Float>::round(self)
1268    }
1269
1270    #[inline]
1271    fn trunc(self) -> Self {
1272        <f64 as Float>::trunc(self)
1273    }
1274
1275    #[inline]
1276    fn hypot(self, other: Self) -> Self {
1277        <f64 as Float>::hypot(self, other)
1278    }
1279}
1280
1281impl Field for f64 {
1282    #[inline]
1283    fn mul_conj(self, other: Self) -> Self {
1284        self * other
1285    }
1286
1287    #[inline]
1288    fn conj_mul(self, other: Self) -> Self {
1289        self * other
1290    }
1291
1292    #[inline]
1293    fn recip(self) -> Self {
1294        1.0 / self
1295    }
1296
1297    #[inline]
1298    fn powi(self, n: i32) -> Self {
1299        <f64 as Float>::powi(self, n)
1300    }
1301}
1302
1303// =============================================================================
1304// Implementations for Complex32
1305// =============================================================================
1306
1307impl Scalar for Complex32 {
1308    type Real = f32;
1309
1310    #[inline]
1311    fn abs(self) -> Self::Real {
1312        self.norm()
1313    }
1314
1315    #[inline]
1316    fn conj(self) -> Self {
1317        Complex32::conj(&self)
1318    }
1319
1320    #[inline]
1321    fn is_real() -> bool {
1322        false
1323    }
1324
1325    #[inline]
1326    fn real(self) -> Self::Real {
1327        self.re
1328    }
1329
1330    #[inline]
1331    fn imag(self) -> Self::Real {
1332        self.im
1333    }
1334
1335    #[inline]
1336    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
1337        Complex32::new(re, im)
1338    }
1339
1340    #[inline]
1341    fn abs_sq(self) -> Self::Real {
1342        self.norm_sqr()
1343    }
1344
1345    #[inline]
1346    fn epsilon() -> Self::Real {
1347        f32::EPSILON
1348    }
1349
1350    #[inline]
1351    fn min_positive() -> Self::Real {
1352        f32::MIN_POSITIVE
1353    }
1354
1355    #[inline]
1356    fn max_value() -> Self::Real {
1357        f32::MAX
1358    }
1359}
1360
1361impl ComplexScalar for Complex32 {
1362    #[inline]
1363    fn new(re: Self::Real, im: Self::Real) -> Self {
1364        Complex32::new(re, im)
1365    }
1366
1367    #[inline]
1368    fn arg(self) -> Self::Real {
1369        self.arg()
1370    }
1371
1372    #[inline]
1373    fn from_polar(r: Self::Real, theta: Self::Real) -> Self {
1374        Complex32::from_polar(r, theta)
1375    }
1376
1377    #[inline]
1378    fn cexp(self) -> Self {
1379        self.exp()
1380    }
1381
1382    #[inline]
1383    fn cln(self) -> Self {
1384        self.ln()
1385    }
1386
1387    #[inline]
1388    fn csqrt(self) -> Self {
1389        self.sqrt()
1390    }
1391}
1392
1393impl Field for Complex32 {
1394    #[inline]
1395    fn mul_conj(self, other: Self) -> Self {
1396        self * other.conj()
1397    }
1398
1399    #[inline]
1400    fn conj_mul(self, other: Self) -> Self {
1401        self.conj() * other
1402    }
1403
1404    #[inline]
1405    fn recip(self) -> Self {
1406        Complex32::new(1.0, 0.0) / self
1407    }
1408
1409    #[inline]
1410    fn powi(self, n: i32) -> Self {
1411        self.powu(n.unsigned_abs())
1412            * if n < 0 {
1413                self.recip().powu(n.unsigned_abs())
1414            } else {
1415                Complex32::new(1.0, 0.0)
1416            }
1417    }
1418}
1419
1420// =============================================================================
1421// Implementations for Complex64
1422// =============================================================================
1423
1424impl Scalar for Complex64 {
1425    type Real = f64;
1426
1427    #[inline]
1428    fn abs(self) -> Self::Real {
1429        self.norm()
1430    }
1431
1432    #[inline]
1433    fn conj(self) -> Self {
1434        Complex64::conj(&self)
1435    }
1436
1437    #[inline]
1438    fn is_real() -> bool {
1439        false
1440    }
1441
1442    #[inline]
1443    fn real(self) -> Self::Real {
1444        self.re
1445    }
1446
1447    #[inline]
1448    fn imag(self) -> Self::Real {
1449        self.im
1450    }
1451
1452    #[inline]
1453    fn from_real_imag(re: Self::Real, im: Self::Real) -> Self {
1454        Complex64::new(re, im)
1455    }
1456
1457    #[inline]
1458    fn abs_sq(self) -> Self::Real {
1459        self.norm_sqr()
1460    }
1461
1462    #[inline]
1463    fn epsilon() -> Self::Real {
1464        f64::EPSILON
1465    }
1466
1467    #[inline]
1468    fn min_positive() -> Self::Real {
1469        f64::MIN_POSITIVE
1470    }
1471
1472    #[inline]
1473    fn max_value() -> Self::Real {
1474        f64::MAX
1475    }
1476}
1477
1478impl ComplexScalar for Complex64 {
1479    #[inline]
1480    fn new(re: Self::Real, im: Self::Real) -> Self {
1481        Complex64::new(re, im)
1482    }
1483
1484    #[inline]
1485    fn arg(self) -> Self::Real {
1486        self.arg()
1487    }
1488
1489    #[inline]
1490    fn from_polar(r: Self::Real, theta: Self::Real) -> Self {
1491        Complex64::from_polar(r, theta)
1492    }
1493
1494    #[inline]
1495    fn cexp(self) -> Self {
1496        self.exp()
1497    }
1498
1499    #[inline]
1500    fn cln(self) -> Self {
1501        self.ln()
1502    }
1503
1504    #[inline]
1505    fn csqrt(self) -> Self {
1506        self.sqrt()
1507    }
1508}
1509
1510impl Field for Complex64 {
1511    #[inline]
1512    fn mul_conj(self, other: Self) -> Self {
1513        self * other.conj()
1514    }
1515
1516    #[inline]
1517    fn conj_mul(self, other: Self) -> Self {
1518        self.conj() * other
1519    }
1520
1521    #[inline]
1522    fn recip(self) -> Self {
1523        Complex64::new(1.0, 0.0) / self
1524    }
1525
1526    #[inline]
1527    fn powi(self, n: i32) -> Self {
1528        if n >= 0 {
1529            self.powu(n as u32)
1530        } else {
1531            self.recip().powu((-n) as u32)
1532        }
1533    }
1534}
1535
1536#[cfg(test)]
1537mod tests {
1538    use super::*;
1539
1540    #[test]
1541    fn test_f32_scalar() {
1542        let x: f32 = 3.0;
1543        assert_eq!(x.abs(), 3.0);
1544        assert_eq!(x.conj(), 3.0);
1545        assert!(f32::is_real());
1546        assert_eq!(x.real(), 3.0);
1547        assert_eq!(x.imag(), 0.0);
1548        assert_eq!(x.abs_sq(), 9.0);
1549    }
1550
1551    #[test]
1552    fn test_f64_scalar() {
1553        let x: f64 = -4.0;
1554        assert_eq!(x.abs(), 4.0);
1555        assert_eq!(x.conj(), -4.0);
1556        assert!(f64::is_real());
1557        assert_eq!(x.real(), -4.0);
1558        assert_eq!(x.imag(), 0.0);
1559        assert_eq!(x.abs_sq(), 16.0);
1560    }
1561
1562    #[test]
1563    fn test_complex32_scalar() {
1564        let z = Complex32::new(3.0, 4.0);
1565        assert!((z.abs() - 5.0).abs() < 1e-6);
1566        assert_eq!(z.conj(), Complex32::new(3.0, -4.0));
1567        assert!(!Complex32::is_real());
1568        assert_eq!(z.real(), 3.0);
1569        assert_eq!(z.imag(), 4.0);
1570        assert!((z.abs_sq() - 25.0).abs() < 1e-6);
1571    }
1572
1573    #[test]
1574    fn test_complex64_scalar() {
1575        let z = Complex64::new(3.0, 4.0);
1576        assert!((z.abs() - 5.0).abs() < 1e-12);
1577        assert_eq!(z.conj(), Complex64::new(3.0, -4.0));
1578        assert!(!Complex64::is_real());
1579        assert_eq!(z.real(), 3.0);
1580        assert_eq!(z.imag(), 4.0);
1581        assert!((z.abs_sq() - 25.0).abs() < 1e-12);
1582    }
1583
1584    #[test]
1585    fn test_field_operations() {
1586        let a: f64 = 2.0;
1587        let b: f64 = 3.0;
1588        assert_eq!(a.mul_conj(b), 6.0);
1589        assert_eq!(a.conj_mul(b), 6.0);
1590        assert!((a.recip() - 0.5).abs() < 1e-12);
1591        assert!((a.powi(3) - 8.0).abs() < 1e-12);
1592    }
1593
1594    #[test]
1595    fn test_complex_field_operations() {
1596        let a = Complex64::new(1.0, 2.0);
1597        let b = Complex64::new(3.0, 4.0);
1598
1599        // mul_conj: a * conj(b) = (1+2i) * (3-4i) = 3 - 4i + 6i - 8i^2 = 3 + 2i + 8 = 11 + 2i
1600        let mc = a.mul_conj(b);
1601        assert!((mc.re - 11.0).abs() < 1e-12);
1602        assert!((mc.im - 2.0).abs() < 1e-12);
1603
1604        // conj_mul: conj(a) * b = (1-2i) * (3+4i) = 3 + 4i - 6i - 8i^2 = 3 - 2i + 8 = 11 - 2i
1605        let cm = a.conj_mul(b);
1606        assert!((cm.re - 11.0).abs() < 1e-12);
1607        assert!((cm.im - (-2.0)).abs() < 1e-12);
1608    }
1609
1610    #[test]
1611    #[cfg(feature = "f16")]
1612    fn test_f16_scalar() {
1613        use half::f16;
1614        let x = f16::from_f32(3.0);
1615        let y = f16::from_f32(4.0);
1616        let result = x + y;
1617        assert!((result.to_f32() - 7.0).abs() < 0.01);
1618    }
1619
1620    #[test]
1621    #[cfg(feature = "f128")]
1622    fn test_f128_scalar() {
1623        use crate::scalar::Real as ScalarReal;
1624
1625        // Test basic arithmetic
1626        let x = QuadFloat::from(3.0);
1627        let y = QuadFloat::from(4.0);
1628        let result = x + y;
1629        assert!(Scalar::abs(result - QuadFloat::from(7.0)) < QuadFloat::from(1e-28));
1630
1631        // Test sqrt operation
1632        let a = QuadFloat::from(2.0);
1633        let epsilon = QuadFloat::from(1e-28);
1634        let sqrt_a = ScalarReal::sqrt(a);
1635        assert!(Scalar::abs(sqrt_a * sqrt_a - a) < epsilon);
1636
1637        // Test other operations
1638        let b = QuadFloat::from(5.0);
1639        let c = QuadFloat::from(3.0);
1640        assert!(Scalar::abs(b / c - QuadFloat::from(5.0 / 3.0)) < QuadFloat::from(1e-15));
1641    }
1642
1643    // Complex ergonomics tests
1644    #[test]
1645    fn test_complex_constructors() {
1646        // c64 and c32 constructors
1647        let z64 = c64(3.0, 4.0);
1648        assert_eq!(z64.re, 3.0);
1649        assert_eq!(z64.im, 4.0);
1650
1651        let z32 = c32(1.0, 2.0);
1652        assert_eq!(z32.re, 1.0);
1653        assert_eq!(z32.im, 2.0);
1654
1655        // imag() and real() constructors
1656        let i = imag(5.0);
1657        assert_eq!(i.re, 0.0);
1658        assert_eq!(i.im, 5.0);
1659
1660        let r = real(3.0);
1661        assert_eq!(r.re, 3.0);
1662        assert_eq!(r.im, 0.0);
1663
1664        // I64 constant
1665        assert_eq!(I64.re, 0.0);
1666        assert_eq!(I64.im, 1.0);
1667    }
1668
1669    #[test]
1670    fn test_complex_polar() {
1671        use std::f64::consts::PI;
1672
1673        let z = from_polar(1.0, PI / 2.0);
1674        assert!((z.re - 0.0).abs() < 1e-10);
1675        assert!((z.im - 1.0).abs() < 1e-10);
1676
1677        let z2 = from_polar(2.0, 0.0);
1678        assert!((z2.re - 2.0).abs() < 1e-10);
1679        assert!((z2.im - 0.0).abs() < 1e-10);
1680    }
1681
1682    #[test]
1683    fn test_complex_ext() {
1684        use std::f64::consts::PI;
1685
1686        let z = c64(3.0, 4.0);
1687
1688        // normalize
1689        let n = z.normalize();
1690        assert!((n.norm() - 1.0).abs() < 1e-10);
1691
1692        // is_purely_real / is_purely_imaginary
1693        assert!(c64(3.0, 0.0).is_purely_real(1e-10));
1694        assert!(!c64(3.0, 1.0).is_purely_real(1e-10));
1695        assert!(c64(0.0, 4.0).is_purely_imaginary(1e-10));
1696        assert!(!c64(1.0, 4.0).is_purely_imaginary(1e-10));
1697
1698        // rotate
1699        let r = c64(1.0, 0.0).rotate(PI / 2.0);
1700        assert!((r.re - 0.0).abs() < 1e-10);
1701        assert!((r.im - 1.0).abs() < 1e-10);
1702
1703        // distance
1704        let a = c64(0.0, 0.0);
1705        let b = c64(3.0, 4.0);
1706        assert!((a.distance(b) - 5.0).abs() < 1e-10);
1707
1708        // reflect_imag
1709        let reflected = c64(1.0, 2.0).reflect_imag();
1710        assert_eq!(reflected.re, -1.0);
1711        assert_eq!(reflected.im, 2.0);
1712    }
1713
1714    #[test]
1715    fn test_to_complex() {
1716        let x: f64 = 3.0;
1717        let z = x.to_complex();
1718        assert_eq!(z.re, 3.0);
1719        assert_eq!(z.im, 0.0);
1720
1721        let z2 = (2.0f64).with_imag(5.0);
1722        assert_eq!(z2.re, 2.0);
1723        assert_eq!(z2.im, 5.0);
1724
1725        // f32 version
1726        let x32: f32 = 4.0;
1727        let z32 = x32.to_complex();
1728        assert_eq!(z32.re, 4.0);
1729        assert_eq!(z32.im, 0.0);
1730    }
1731
1732    // Scalar specialization tests
1733    #[test]
1734    fn test_simd_compatible() {
1735        // Test SIMD width constants
1736
1737        // Test use_simd_for
1738        assert!(!f32::use_simd_for(4));
1739        assert!(f32::use_simd_for(32));
1740    }
1741
1742    #[test]
1743    fn test_scalar_batch_f64() {
1744        let x = [1.0f64, 2.0, 3.0, 4.0];
1745        let y = [5.0f64, 6.0, 7.0, 8.0];
1746
1747        // dot_batch
1748        let dot = f64::dot_batch(&x, &y);
1749        assert!((dot - 70.0).abs() < 1e-10); // 1*5 + 2*6 + 3*7 + 4*8 = 70
1750
1751        // sum_batch
1752        let sum = f64::sum_batch(&x);
1753        assert!((sum - 10.0).abs() < 1e-10);
1754
1755        // asum_batch
1756        let x_neg = [-1.0f64, 2.0, -3.0, 4.0];
1757        let asum = f64::asum_batch(&x_neg);
1758        assert!((asum - 10.0).abs() < 1e-10);
1759
1760        // iamax_batch
1761        let x_mixed = [1.0f64, -5.0, 3.0, 2.0];
1762        let iamax = f64::iamax_batch(&x_mixed);
1763        assert_eq!(iamax, 1); // index of -5.0
1764
1765        // scale_batch
1766        let mut x_scale = [1.0f64, 2.0, 3.0];
1767        f64::scale_batch(2.0, &mut x_scale);
1768        assert!((x_scale[0] - 2.0).abs() < 1e-10);
1769        assert!((x_scale[1] - 4.0).abs() < 1e-10);
1770        assert!((x_scale[2] - 6.0).abs() < 1e-10);
1771
1772        // axpy_batch
1773        let x_axpy = [1.0f64, 2.0, 3.0];
1774        let mut y_axpy = [1.0f64, 1.0, 1.0];
1775        f64::axpy_batch(2.0, &x_axpy, &mut y_axpy);
1776        assert!((y_axpy[0] - 3.0).abs() < 1e-10); // 2*1 + 1
1777        assert!((y_axpy[1] - 5.0).abs() < 1e-10); // 2*2 + 1
1778        assert!((y_axpy[2] - 7.0).abs() < 1e-10); // 2*3 + 1
1779
1780        // fma_batch
1781        let a = [1.0f64, 2.0, 3.0];
1782        let b = [2.0f64, 3.0, 4.0];
1783        let c = [1.0f64, 1.0, 1.0];
1784        let mut out = [0.0f64; 3];
1785        f64::fma_batch(&a, &b, &c, &mut out);
1786        assert!((out[0] - 3.0).abs() < 1e-10); // 1*2 + 1
1787        assert!((out[1] - 7.0).abs() < 1e-10); // 2*3 + 1
1788        assert!((out[2] - 13.0).abs() < 1e-10); // 3*4 + 1
1789    }
1790
1791    #[test]
1792    fn test_scalar_batch_complex64() {
1793        let x = [c64(1.0, 1.0), c64(2.0, 2.0)];
1794        let y = [c64(1.0, -1.0), c64(2.0, -2.0)];
1795
1796        // dot_batch: (1+i)*(1-i) + (2+2i)*(2-2i) = 2 + 8 = 10
1797        let dot = Complex64::dot_batch(&x, &y);
1798        assert!((dot.re - 10.0).abs() < 1e-10);
1799        assert!(dot.im.abs() < 1e-10);
1800
1801        // sum_batch
1802        let sum = Complex64::sum_batch(&x);
1803        assert!((sum.re - 3.0).abs() < 1e-10);
1804        assert!((sum.im - 3.0).abs() < 1e-10);
1805
1806        // asum_batch (BLAS-style: sum of |re| + |im|)
1807        let asum = Complex64::asum_batch(&x);
1808        assert!((asum - 6.0).abs() < 1e-10); // (1+1) + (2+2)
1809
1810        // iamax_batch
1811        let iamax = Complex64::iamax_batch(&x);
1812        assert_eq!(iamax, 1); // index of (2+2i) has larger |re|+|im|
1813    }
1814
1815    #[test]
1816    fn test_scalar_classify() {
1817        assert_eq!(f32::CLASS, ScalarClass::RealF32);
1818        assert_eq!(f64::CLASS, ScalarClass::RealF64);
1819        assert_eq!(Complex32::CLASS, ScalarClass::ComplexF32);
1820        assert_eq!(Complex64::CLASS, ScalarClass::ComplexF64);
1821
1822        assert_eq!(f32::PRECISION_LEVEL, 2);
1823        assert_eq!(f64::PRECISION_LEVEL, 3);
1824
1825        assert_eq!(f32::STORAGE_BYTES, 4);
1826        assert_eq!(f64::STORAGE_BYTES, 8);
1827        assert_eq!(Complex64::STORAGE_BYTES, 16);
1828    }
1829
1830    #[test]
1831    fn test_unroll_hints() {
1832        assert_eq!(f32::UNROLL_FACTOR, 8);
1833        assert_eq!(f64::UNROLL_FACTOR, 4);
1834        assert_eq!(Complex64::UNROLL_FACTOR, 2);
1835    }
1836
1837    #[test]
1838    fn test_extended_precision() {
1839        // f32 -> f64 accumulation
1840        let x: f32 = 1.5;
1841        let acc: f64 = x.to_accumulator();
1842        assert!((acc - 1.5).abs() < 1e-10);
1843
1844        let back: f32 = f32::from_accumulator(acc);
1845        assert!((back - 1.5).abs() < 1e-6);
1846
1847        // Complex32 -> Complex64
1848        let z = c32(1.0, 2.0);
1849        let z_acc: Complex64 = z.to_accumulator();
1850        assert!((z_acc.re - 1.0).abs() < 1e-10);
1851        assert!((z_acc.im - 2.0).abs() < 1e-10);
1852    }
1853
1854    #[test]
1855    fn test_kahan_sum() {
1856        let mut kahan = KahanSum::<f64>::new();
1857        for i in 0..1000 {
1858            kahan.add(0.1);
1859            let _ = i; // suppress warning
1860        }
1861        // Should be close to 100.0
1862        let result = kahan.sum();
1863        assert!((result - 100.0).abs() < 1e-10);
1864    }
1865
1866    #[test]
1867    fn test_pairwise_sum() {
1868        let values: Vec<f64> = (0..1000).map(|_| 0.1).collect();
1869        let result = pairwise_sum(&values);
1870        assert!((result - 100.0).abs() < 1e-10);
1871
1872        // Empty case
1873        let empty: Vec<f64> = vec![];
1874        assert_eq!(pairwise_sum(&empty), 0.0);
1875
1876        // Small case
1877        let small = [1.0, 2.0, 3.0];
1878        assert!(Scalar::abs(pairwise_sum(&small) - 6.0) < 1e-10);
1879    }
1880
1881    #[test]
1882    fn test_kbk_sum() {
1883        let mut kbk = KBKSum::<f64>::new();
1884        for _ in 0..10000 {
1885            kbk.add(0.1);
1886        }
1887        let result = kbk.sum();
1888        // KBK should give very accurate results
1889        assert!((result - 1000.0).abs() < 1e-8);
1890    }
1891}
1892
1893// =============================================================================
1894// Half-precision (f16) support
1895// =============================================================================
1896
1897#[cfg(feature = "f16")]
1898impl Scalar for f16 {
1899    type Real = f16;
1900
1901    #[inline]
1902    fn abs(self) -> Self::Real {
1903        if self < f16::ZERO { -self } else { self }
1904    }
1905
1906    #[inline]
1907    fn conj(self) -> Self {
1908        self
1909    }
1910
1911    #[inline]
1912    fn is_real() -> bool {
1913        true
1914    }
1915
1916    #[inline]
1917    fn real(self) -> Self::Real {
1918        self
1919    }
1920
1921    #[inline]
1922    fn imag(self) -> Self::Real {
1923        f16::ZERO
1924    }
1925
1926    #[inline]
1927    fn from_real_imag(re: Self::Real, _im: Self::Real) -> Self {
1928        re
1929    }
1930
1931    #[inline]
1932    fn abs_sq(self) -> Self::Real {
1933        self * self
1934    }
1935
1936    #[inline]
1937    fn epsilon() -> Self::Real {
1938        f16::EPSILON
1939    }
1940
1941    #[inline]
1942    fn min_positive() -> Self::Real {
1943        f16::MIN_POSITIVE
1944    }
1945
1946    #[inline]
1947    fn max_value() -> Self::Real {
1948        f16::MAX
1949    }
1950}
1951
1952#[cfg(feature = "f16")]
1953impl Real for f16 {
1954    #[inline]
1955    fn sqrt(self) -> Self {
1956        f16::from_f32(self.to_f32().sqrt())
1957    }
1958
1959    #[inline]
1960    fn ln(self) -> Self {
1961        f16::from_f32(self.to_f32().ln())
1962    }
1963
1964    #[inline]
1965    fn exp(self) -> Self {
1966        f16::from_f32(self.to_f32().exp())
1967    }
1968
1969    #[inline]
1970    fn sin(self) -> Self {
1971        f16::from_f32(self.to_f32().sin())
1972    }
1973
1974    #[inline]
1975    fn cos(self) -> Self {
1976        f16::from_f32(self.to_f32().cos())
1977    }
1978
1979    #[inline]
1980    fn atan2(self, other: Self) -> Self {
1981        f16::from_f32(self.to_f32().atan2(other.to_f32()))
1982    }
1983
1984    #[inline]
1985    fn powf(self, n: Self) -> Self {
1986        f16::from_f32(self.to_f32().powf(n.to_f32()))
1987    }
1988
1989    #[inline]
1990    fn signum(self) -> Self {
1991        if self > f16::ZERO {
1992            f16::ONE
1993        } else if self < f16::ZERO {
1994            -f16::ONE
1995        } else {
1996            f16::ZERO
1997        }
1998    }
1999
2000    #[inline]
2001    fn mul_add(self, a: Self, b: Self) -> Self {
2002        f16::from_f32(self.to_f32().mul_add(a.to_f32(), b.to_f32()))
2003    }
2004
2005    #[inline]
2006    fn floor(self) -> Self {
2007        f16::from_f32(self.to_f32().floor())
2008    }
2009
2010    #[inline]
2011    fn ceil(self) -> Self {
2012        f16::from_f32(self.to_f32().ceil())
2013    }
2014
2015    #[inline]
2016    fn round(self) -> Self {
2017        f16::from_f32(self.to_f32().round())
2018    }
2019
2020    #[inline]
2021    fn trunc(self) -> Self {
2022        f16::from_f32(self.to_f32().trunc())
2023    }
2024
2025    #[inline]
2026    fn hypot(self, other: Self) -> Self {
2027        f16::from_f32(self.to_f32().hypot(other.to_f32()))
2028    }
2029}
2030
2031#[cfg(feature = "f16")]
2032impl Field for f16 {
2033    #[inline]
2034    fn mul_conj(self, other: Self) -> Self {
2035        self * other
2036    }
2037
2038    #[inline]
2039    fn conj_mul(self, other: Self) -> Self {
2040        self * other
2041    }
2042
2043    #[inline]
2044    fn recip(self) -> Self {
2045        f16::ONE / self
2046    }
2047
2048    #[inline]
2049    fn powi(self, n: i32) -> Self {
2050        f16::from_f32(self.to_f32().powi(n))
2051    }
2052}
2053
2054// =============================================================================
2055// Quad-precision (double-double) support using QuadFloat wrapper
2056// =============================================================================
2057
2058#[cfg(feature = "f128")]
2059impl Scalar for QuadFloat {
2060    type Real = QuadFloat;
2061
2062    #[inline]
2063    fn abs(self) -> Self::Real {
2064        QuadFloat(self.0.abs())
2065    }
2066
2067    #[inline]
2068    fn conj(self) -> Self {
2069        self
2070    }
2071
2072    #[inline]
2073    fn is_real() -> bool {
2074        true
2075    }
2076
2077    #[inline]
2078    fn real(self) -> Self::Real {
2079        self
2080    }
2081
2082    #[inline]
2083    fn imag(self) -> Self::Real {
2084        QuadFloat::from(0.0)
2085    }
2086
2087    #[inline]
2088    fn from_real_imag(re: Self::Real, _im: Self::Real) -> Self {
2089        re
2090    }
2091
2092    #[inline]
2093    fn abs_sq(self) -> Self::Real {
2094        self * self
2095    }
2096
2097    #[inline]
2098    fn epsilon() -> Self::Real {
2099        // Double-double epsilon is approximately 2^-106
2100        QuadFloat::from(f64::EPSILON) * QuadFloat::from(f64::EPSILON)
2101    }
2102
2103    #[inline]
2104    fn min_positive() -> Self::Real {
2105        QuadFloat::from(f64::MIN_POSITIVE)
2106    }
2107
2108    #[inline]
2109    fn max_value() -> Self::Real {
2110        QuadFloat::from(f64::MAX)
2111    }
2112}
2113
2114#[cfg(feature = "f128")]
2115impl Real for QuadFloat {
2116    #[inline]
2117    fn sqrt(self) -> Self {
2118        QuadFloat(self.0.sqrt())
2119    }
2120
2121    #[inline]
2122    fn ln(self) -> Self {
2123        QuadFloat(self.0.ln())
2124    }
2125
2126    #[inline]
2127    fn exp(self) -> Self {
2128        QuadFloat(self.0.exp())
2129    }
2130
2131    #[inline]
2132    fn sin(self) -> Self {
2133        QuadFloat(self.0.sin())
2134    }
2135
2136    #[inline]
2137    fn cos(self) -> Self {
2138        QuadFloat(self.0.cos())
2139    }
2140
2141    #[inline]
2142    fn atan2(self, other: Self) -> Self {
2143        QuadFloat(self.0.atan2(other.0))
2144    }
2145
2146    #[inline]
2147    fn powf(self, n: Self) -> Self {
2148        QuadFloat(self.0.powf(n.0))
2149    }
2150
2151    #[inline]
2152    fn signum(self) -> Self {
2153        let zero = QuadFloat::from(0.0);
2154        let one = QuadFloat::from(1.0);
2155        if self > zero {
2156            one
2157        } else if self < zero {
2158            -one
2159        } else {
2160            zero
2161        }
2162    }
2163
2164    #[inline]
2165    fn mul_add(self, a: Self, b: Self) -> Self {
2166        // TwoFloat doesn't have mul_add, so implement manually
2167        self * a + b
2168    }
2169
2170    #[inline]
2171    fn floor(self) -> Self {
2172        QuadFloat::from(self.0.hi().floor())
2173    }
2174
2175    #[inline]
2176    fn ceil(self) -> Self {
2177        QuadFloat::from(self.0.hi().ceil())
2178    }
2179
2180    #[inline]
2181    fn round(self) -> Self {
2182        QuadFloat::from(self.0.hi().round())
2183    }
2184
2185    #[inline]
2186    fn trunc(self) -> Self {
2187        QuadFloat::from(self.0.hi().trunc())
2188    }
2189
2190    #[inline]
2191    fn hypot(self, other: Self) -> Self {
2192        Float::sqrt(self * self + other * other)
2193    }
2194}
2195
2196#[cfg(feature = "f128")]
2197impl Field for QuadFloat {
2198    #[inline]
2199    fn mul_conj(self, other: Self) -> Self {
2200        self * other
2201    }
2202
2203    #[inline]
2204    fn conj_mul(self, other: Self) -> Self {
2205        self * other
2206    }
2207
2208    #[inline]
2209    fn recip(self) -> Self {
2210        QuadFloat(self.0.recip())
2211    }
2212
2213    #[inline]
2214    fn powi(self, n: i32) -> Self {
2215        QuadFloat(self.0.powi(n))
2216    }
2217}
2218
2219// =============================================================================
2220// Scalar trait specialization for performance
2221// =============================================================================
2222
2223/// Marker trait for types with hardware FMA (fused multiply-add) support.
2224///
2225/// Types implementing this trait have efficient hardware FMA instructions,
2226/// enabling optimized implementations of algorithms like dot products and
2227/// matrix multiplications.
2228pub trait HasFastFma: Scalar {}
2229
2230impl HasFastFma for f32 {}
2231impl HasFastFma for f64 {}
2232impl HasFastFma for Complex32 {}
2233impl HasFastFma for Complex64 {}
2234
2235/// Marker trait for types that can be efficiently vectorized with SIMD.
2236///
2237/// This trait indicates that the type has a natural mapping to SIMD registers
2238/// and operations.
2239pub trait SimdCompatible: Scalar {
2240    /// The preferred SIMD width (number of elements) for this type.
2241    const SIMD_WIDTH: usize;
2242
2243    /// Returns true if SIMD operations are beneficial for the given length.
2244    #[inline]
2245    fn use_simd_for(len: usize) -> bool {
2246        len >= Self::SIMD_WIDTH * 2
2247    }
2248}
2249
2250impl SimdCompatible for f32 {
2251    #[cfg(target_arch = "x86_64")]
2252    const SIMD_WIDTH: usize = 8; // AVX2: 256-bit / 32-bit = 8
2253
2254    #[cfg(target_arch = "aarch64")]
2255    const SIMD_WIDTH: usize = 4; // NEON: 128-bit / 32-bit = 4
2256
2257    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
2258    const SIMD_WIDTH: usize = 4;
2259}
2260
2261impl SimdCompatible for f64 {
2262    #[cfg(target_arch = "x86_64")]
2263    const SIMD_WIDTH: usize = 4; // AVX2: 256-bit / 64-bit = 4
2264
2265    #[cfg(target_arch = "aarch64")]
2266    const SIMD_WIDTH: usize = 2; // NEON: 128-bit / 64-bit = 2
2267
2268    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
2269    const SIMD_WIDTH: usize = 2;
2270}
2271
2272impl SimdCompatible for Complex32 {
2273    // Complex types have half the SIMD width due to doubled storage
2274    #[cfg(target_arch = "x86_64")]
2275    const SIMD_WIDTH: usize = 4;
2276
2277    #[cfg(target_arch = "aarch64")]
2278    const SIMD_WIDTH: usize = 2;
2279
2280    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
2281    const SIMD_WIDTH: usize = 2;
2282}
2283
2284impl SimdCompatible for Complex64 {
2285    #[cfg(target_arch = "x86_64")]
2286    const SIMD_WIDTH: usize = 2;
2287
2288    #[cfg(target_arch = "aarch64")]
2289    const SIMD_WIDTH: usize = 1;
2290
2291    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
2292    const SIMD_WIDTH: usize = 1;
2293}
2294
2295/// Batch operations on scalar arrays for performance-critical code.
2296///
2297/// This trait provides optimized implementations of common operations on
2298/// contiguous arrays of scalars, leveraging SIMD where available.
2299pub trait ScalarBatch: Scalar + SimdCompatible {
2300    /// Computes the dot product of two slices.
2301    ///
2302    /// # Safety
2303    /// Both slices must have the same length.
2304    fn dot_batch(x: &[Self], y: &[Self]) -> Self;
2305
2306    /// Computes the sum of all elements.
2307    fn sum_batch(x: &[Self]) -> Self;
2308
2309    /// Computes the sum of absolute values (L1 norm).
2310    fn asum_batch(x: &[Self]) -> Self::Real;
2311
2312    /// Finds the index of the element with maximum absolute value.
2313    fn iamax_batch(x: &[Self]) -> usize;
2314
2315    /// Scales a vector: x = alpha * x
2316    fn scale_batch(alpha: Self, x: &mut [Self]);
2317
2318    /// AXPY operation: y = alpha * x + y
2319    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]);
2320
2321    /// Fused multiply-add on arrays: `z[i] = a[i] * b[i] + c[i]`
2322    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]);
2323}
2324
2325impl ScalarBatch for f32 {
2326    #[inline]
2327    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
2328        debug_assert_eq!(x.len(), y.len());
2329        let mut sum = 0.0f32;
2330        for i in 0..x.len() {
2331            sum = x[i].mul_add(y[i], sum);
2332        }
2333        sum
2334    }
2335
2336    #[inline]
2337    fn sum_batch(x: &[Self]) -> Self {
2338        x.iter().copied().sum()
2339    }
2340
2341    #[inline]
2342    fn asum_batch(x: &[Self]) -> Self::Real {
2343        x.iter().map(|&v| v.abs()).sum()
2344    }
2345
2346    #[inline]
2347    fn iamax_batch(x: &[Self]) -> usize {
2348        x.iter()
2349            .enumerate()
2350            .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
2351            .map(|(i, _)| i)
2352            .unwrap_or(0)
2353    }
2354
2355    #[inline]
2356    fn scale_batch(alpha: Self, x: &mut [Self]) {
2357        for xi in x.iter_mut() {
2358            *xi *= alpha;
2359        }
2360    }
2361
2362    #[inline]
2363    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
2364        debug_assert_eq!(x.len(), y.len());
2365        for i in 0..x.len() {
2366            y[i] = alpha.mul_add(x[i], y[i]);
2367        }
2368    }
2369
2370    #[inline]
2371    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
2372        debug_assert_eq!(a.len(), b.len());
2373        debug_assert_eq!(a.len(), c.len());
2374        debug_assert_eq!(a.len(), out.len());
2375        for i in 0..a.len() {
2376            out[i] = a[i].mul_add(b[i], c[i]);
2377        }
2378    }
2379}
2380
2381impl ScalarBatch for f64 {
2382    #[inline]
2383    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
2384        debug_assert_eq!(x.len(), y.len());
2385        let mut sum = 0.0f64;
2386        for i in 0..x.len() {
2387            sum = x[i].mul_add(y[i], sum);
2388        }
2389        sum
2390    }
2391
2392    #[inline]
2393    fn sum_batch(x: &[Self]) -> Self {
2394        x.iter().copied().sum()
2395    }
2396
2397    #[inline]
2398    fn asum_batch(x: &[Self]) -> Self::Real {
2399        x.iter().map(|&v| v.abs()).sum()
2400    }
2401
2402    #[inline]
2403    fn iamax_batch(x: &[Self]) -> usize {
2404        x.iter()
2405            .enumerate()
2406            .max_by(|(_, a), (_, b)| a.abs().partial_cmp(&b.abs()).unwrap())
2407            .map(|(i, _)| i)
2408            .unwrap_or(0)
2409    }
2410
2411    #[inline]
2412    fn scale_batch(alpha: Self, x: &mut [Self]) {
2413        for xi in x.iter_mut() {
2414            *xi *= alpha;
2415        }
2416    }
2417
2418    #[inline]
2419    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
2420        debug_assert_eq!(x.len(), y.len());
2421        for i in 0..x.len() {
2422            y[i] = alpha.mul_add(x[i], y[i]);
2423        }
2424    }
2425
2426    #[inline]
2427    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
2428        debug_assert_eq!(a.len(), b.len());
2429        debug_assert_eq!(a.len(), c.len());
2430        debug_assert_eq!(a.len(), out.len());
2431        for i in 0..a.len() {
2432            out[i] = a[i].mul_add(b[i], c[i]);
2433        }
2434    }
2435}
2436
2437impl ScalarBatch for Complex32 {
2438    #[inline]
2439    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
2440        debug_assert_eq!(x.len(), y.len());
2441        let mut sum = Complex32::new(0.0, 0.0);
2442        for i in 0..x.len() {
2443            sum += x[i] * y[i];
2444        }
2445        sum
2446    }
2447
2448    #[inline]
2449    fn sum_batch(x: &[Self]) -> Self {
2450        x.iter().copied().sum()
2451    }
2452
2453    #[inline]
2454    fn asum_batch(x: &[Self]) -> Self::Real {
2455        x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
2456    }
2457
2458    #[inline]
2459    fn iamax_batch(x: &[Self]) -> usize {
2460        x.iter()
2461            .enumerate()
2462            .max_by(|(_, a), (_, b)| {
2463                (a.re.abs() + a.im.abs())
2464                    .partial_cmp(&(b.re.abs() + b.im.abs()))
2465                    .unwrap()
2466            })
2467            .map(|(i, _)| i)
2468            .unwrap_or(0)
2469    }
2470
2471    #[inline]
2472    fn scale_batch(alpha: Self, x: &mut [Self]) {
2473        for xi in x.iter_mut() {
2474            *xi *= alpha;
2475        }
2476    }
2477
2478    #[inline]
2479    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
2480        debug_assert_eq!(x.len(), y.len());
2481        for i in 0..x.len() {
2482            y[i] += alpha * x[i];
2483        }
2484    }
2485
2486    #[inline]
2487    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
2488        debug_assert_eq!(a.len(), b.len());
2489        debug_assert_eq!(a.len(), c.len());
2490        debug_assert_eq!(a.len(), out.len());
2491        for i in 0..a.len() {
2492            out[i] = a[i] * b[i] + c[i];
2493        }
2494    }
2495}
2496
2497impl ScalarBatch for Complex64 {
2498    #[inline]
2499    fn dot_batch(x: &[Self], y: &[Self]) -> Self {
2500        debug_assert_eq!(x.len(), y.len());
2501        let mut sum = Complex64::new(0.0, 0.0);
2502        for i in 0..x.len() {
2503            sum += x[i] * y[i];
2504        }
2505        sum
2506    }
2507
2508    #[inline]
2509    fn sum_batch(x: &[Self]) -> Self {
2510        x.iter().copied().sum()
2511    }
2512
2513    #[inline]
2514    fn asum_batch(x: &[Self]) -> Self::Real {
2515        x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
2516    }
2517
2518    #[inline]
2519    fn iamax_batch(x: &[Self]) -> usize {
2520        x.iter()
2521            .enumerate()
2522            .max_by(|(_, a), (_, b)| {
2523                (a.re.abs() + a.im.abs())
2524                    .partial_cmp(&(b.re.abs() + b.im.abs()))
2525                    .unwrap()
2526            })
2527            .map(|(i, _)| i)
2528            .unwrap_or(0)
2529    }
2530
2531    #[inline]
2532    fn scale_batch(alpha: Self, x: &mut [Self]) {
2533        for xi in x.iter_mut() {
2534            *xi *= alpha;
2535        }
2536    }
2537
2538    #[inline]
2539    fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
2540        debug_assert_eq!(x.len(), y.len());
2541        for i in 0..x.len() {
2542            y[i] += alpha * x[i];
2543        }
2544    }
2545
2546    #[inline]
2547    fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
2548        debug_assert_eq!(a.len(), b.len());
2549        debug_assert_eq!(a.len(), c.len());
2550        debug_assert_eq!(a.len(), out.len());
2551        for i in 0..a.len() {
2552            out[i] = a[i] * b[i] + c[i];
2553        }
2554    }
2555}
2556
2557/// Type-level scalar classification for compile-time dispatch.
2558///
2559/// This enum enables algorithms to specialize at compile time based on
2560/// the scalar type's properties.
2561#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2562pub enum ScalarClass {
2563    /// Single-precision real (f32)
2564    RealF32,
2565    /// Double-precision real (f64)
2566    RealF64,
2567    /// Single-precision complex
2568    ComplexF32,
2569    /// Double-precision complex
2570    ComplexF64,
2571    /// Half-precision real (f16)
2572    RealF16,
2573    /// Quad-precision real (f128)
2574    RealF128,
2575    /// Unknown/other type
2576    Other,
2577}
2578
2579/// Trait for compile-time scalar classification.
2580pub trait ScalarClassify: Scalar {
2581    /// The compile-time class of this scalar type.
2582    const CLASS: ScalarClass;
2583
2584    /// Returns the precision level (1 = lowest, 4 = highest).
2585    const PRECISION_LEVEL: u8;
2586
2587    /// Returns the storage size in bytes.
2588    const STORAGE_BYTES: usize = core::mem::size_of::<Self>();
2589}
2590
2591impl ScalarClassify for f32 {
2592    const CLASS: ScalarClass = ScalarClass::RealF32;
2593    const PRECISION_LEVEL: u8 = 2;
2594}
2595
2596impl ScalarClassify for f64 {
2597    const CLASS: ScalarClass = ScalarClass::RealF64;
2598    const PRECISION_LEVEL: u8 = 3;
2599}
2600
2601impl ScalarClassify for Complex32 {
2602    const CLASS: ScalarClass = ScalarClass::ComplexF32;
2603    const PRECISION_LEVEL: u8 = 2;
2604}
2605
2606impl ScalarClassify for Complex64 {
2607    const CLASS: ScalarClass = ScalarClass::ComplexF64;
2608    const PRECISION_LEVEL: u8 = 3;
2609}
2610
2611#[cfg(feature = "f16")]
2612impl ScalarClassify for f16 {
2613    const CLASS: ScalarClass = ScalarClass::RealF16;
2614    const PRECISION_LEVEL: u8 = 1;
2615}
2616
2617#[cfg(feature = "f128")]
2618impl ScalarClassify for QuadFloat {
2619    const CLASS: ScalarClass = ScalarClass::RealF128;
2620    const PRECISION_LEVEL: u8 = 4;
2621}
2622
2623/// Unrolling hints for vectorized loops.
2624///
2625/// These constants help the compiler make better unrolling decisions
2626/// for different scalar types.
2627pub trait UnrollHints: Scalar {
2628    /// Recommended unroll factor for tight loops.
2629    const UNROLL_FACTOR: usize;
2630
2631    /// Recommended chunk size for blocked algorithms.
2632    const BLOCK_SIZE: usize;
2633
2634    /// Whether to prefer streaming stores (for large writes).
2635    const PREFER_STREAMING: bool;
2636}
2637
2638impl UnrollHints for f32 {
2639    const UNROLL_FACTOR: usize = 8;
2640    const BLOCK_SIZE: usize = 64;
2641    const PREFER_STREAMING: bool = true;
2642}
2643
2644impl UnrollHints for f64 {
2645    const UNROLL_FACTOR: usize = 4;
2646    const BLOCK_SIZE: usize = 32;
2647    const PREFER_STREAMING: bool = true;
2648}
2649
2650impl UnrollHints for Complex32 {
2651    const UNROLL_FACTOR: usize = 4;
2652    const BLOCK_SIZE: usize = 32;
2653    const PREFER_STREAMING: bool = true;
2654}
2655
2656impl UnrollHints for Complex64 {
2657    const UNROLL_FACTOR: usize = 2;
2658    const BLOCK_SIZE: usize = 16;
2659    const PREFER_STREAMING: bool = true;
2660}
2661
2662/// Extended precision accumulation support.
2663///
2664/// For algorithms requiring higher precision during intermediate calculations,
2665/// this trait provides access to an extended precision accumulator type.
2666pub trait ExtendedPrecision: Scalar {
2667    /// The type used for extended precision accumulation.
2668    type Accumulator: Scalar;
2669
2670    /// Converts a value to the accumulator type.
2671    fn to_accumulator(self) -> Self::Accumulator;
2672
2673    /// Converts from the accumulator type back to this type.
2674    fn from_accumulator(acc: Self::Accumulator) -> Self;
2675}
2676
2677impl ExtendedPrecision for f32 {
2678    type Accumulator = f64;
2679
2680    #[inline]
2681    fn to_accumulator(self) -> f64 {
2682        self as f64
2683    }
2684
2685    #[inline]
2686    fn from_accumulator(acc: f64) -> f32 {
2687        acc as f32
2688    }
2689}
2690
2691impl ExtendedPrecision for f64 {
2692    // For f64, we use the same type (or could use f128 if available)
2693    type Accumulator = f64;
2694
2695    #[inline]
2696    fn to_accumulator(self) -> f64 {
2697        self
2698    }
2699
2700    #[inline]
2701    fn from_accumulator(acc: f64) -> f64 {
2702        acc
2703    }
2704}
2705
2706impl ExtendedPrecision for Complex32 {
2707    type Accumulator = Complex64;
2708
2709    #[inline]
2710    fn to_accumulator(self) -> Complex64 {
2711        Complex64::new(self.re as f64, self.im as f64)
2712    }
2713
2714    #[inline]
2715    fn from_accumulator(acc: Complex64) -> Complex32 {
2716        Complex32::new(acc.re as f32, acc.im as f32)
2717    }
2718}
2719
2720impl ExtendedPrecision for Complex64 {
2721    type Accumulator = Complex64;
2722
2723    #[inline]
2724    fn to_accumulator(self) -> Complex64 {
2725        self
2726    }
2727
2728    #[inline]
2729    fn from_accumulator(acc: Complex64) -> Complex64 {
2730        acc
2731    }
2732}
2733
2734/// Kahan summation for improved accuracy.
2735///
2736/// Uses compensated summation to reduce floating-point errors.
2737#[derive(Debug, Clone, Copy)]
2738pub struct KahanSum<T: Scalar> {
2739    sum: T,
2740    compensation: T,
2741}
2742
2743impl<T: Scalar> Default for KahanSum<T> {
2744    fn default() -> Self {
2745        Self::new()
2746    }
2747}
2748
2749impl<T: Scalar> KahanSum<T> {
2750    /// Creates a new Kahan sum accumulator initialized to zero.
2751    #[inline]
2752    pub fn new() -> Self {
2753        Self {
2754            sum: T::zero(),
2755            compensation: T::zero(),
2756        }
2757    }
2758
2759    /// Adds a value to the sum with compensation.
2760    #[inline]
2761    pub fn add(&mut self, value: T) {
2762        let y = value - self.compensation;
2763        let t = self.sum + y;
2764        self.compensation = (t - self.sum) - y;
2765        self.sum = t;
2766    }
2767
2768    /// Returns the current sum.
2769    #[inline]
2770    pub fn sum(self) -> T {
2771        self.sum
2772    }
2773}
2774
2775/// Pairwise summation for reduced error accumulation.
2776///
2777/// Recursively splits the array and sums pairs, reducing error from O(n) to O(log n).
2778#[inline]
2779pub fn pairwise_sum<T: Scalar>(values: &[T]) -> T {
2780    const THRESHOLD: usize = 32;
2781
2782    if values.is_empty() {
2783        return T::zero();
2784    }
2785    if values.len() <= THRESHOLD {
2786        return values.iter().copied().fold(T::zero(), |acc, x| acc + x);
2787    }
2788
2789    let mid = values.len() / 2;
2790    pairwise_sum(&values[..mid]) + pairwise_sum(&values[mid..])
2791}
2792
2793/// Kahan-Babuska-Klein summation (improved compensated summation).
2794///
2795/// Provides even better error bounds than standard Kahan summation.
2796#[derive(Debug, Clone, Copy)]
2797pub struct KBKSum<T: Scalar> {
2798    sum: T,
2799    cs: T,
2800    ccs: T,
2801}
2802
2803impl<T: Scalar> Default for KBKSum<T> {
2804    fn default() -> Self {
2805        Self::new()
2806    }
2807}
2808
2809impl<T: Scalar> KBKSum<T> {
2810    /// Creates a new KBK sum accumulator.
2811    #[inline]
2812    pub fn new() -> Self {
2813        Self {
2814            sum: T::zero(),
2815            cs: T::zero(),
2816            ccs: T::zero(),
2817        }
2818    }
2819
2820    /// Adds a value with double compensation.
2821    #[inline]
2822    pub fn add(&mut self, value: T) {
2823        let t = self.sum + value;
2824        let c = if Scalar::abs(self.sum) >= Scalar::abs(value) {
2825            (self.sum - t) + value
2826        } else {
2827            (value - t) + self.sum
2828        };
2829        self.sum = t;
2830
2831        let t2 = self.cs + c;
2832        let cc = if Scalar::abs(self.cs) >= Scalar::abs(c) {
2833            (self.cs - t2) + c
2834        } else {
2835            (c - t2) + self.cs
2836        };
2837        self.cs = t2;
2838        self.ccs += cc;
2839    }
2840
2841    /// Returns the compensated sum.
2842    #[inline]
2843    pub fn sum(self) -> T {
2844        self.sum + self.cs + self.ccs
2845    }
2846}