Skip to main content

numra_core/
scalar.rs

1//! Scalar type trait for numerical computation.
2//!
3//! This module defines the [`Scalar`] trait which captures all operations needed
4//! for numerical algorithms without depending on external crates.
5//!
6//! # Supported Types
7//!
8//! - `f64` - 64-bit IEEE 754 floating point (recommended for most use)
9//! - `f32` - 32-bit IEEE 754 floating point (for memory-constrained applications)
10//!
11//! # Design Notes
12//!
13//! The trait is designed to be:
14//! - Self-contained (no external trait dependencies like num-traits or nalgebra)
15//! - Comprehensive (all operations needed by numerical algorithms)
16//! - Zero-cost (all methods inline to optimal assembly)
17
18use core::fmt::{Debug, Display};
19use core::iter::Sum;
20use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
21
22/// A real scalar type suitable for numerical computation.
23///
24/// This trait provides all mathematical operations needed by numerical methods,
25/// including basic arithmetic, trigonometric functions, and special functions.
26///
27/// # Example
28///
29/// ```rust
30/// use numra_core::Scalar;
31///
32/// fn quadratic_formula<S: Scalar>(a: S, b: S, c: S) -> Option<(S, S)> {
33///     let discriminant = b * b - S::from_f64(4.0) * a * c;
34///     if discriminant < S::ZERO {
35///         return None;
36///     }
37///     let sqrt_d = discriminant.sqrt();
38///     let two_a = S::TWO * a;
39///     Some(((-b + sqrt_d) / two_a, (-b - sqrt_d) / two_a))
40/// }
41/// ```
42pub trait Scalar:
43    Copy
44    + Clone
45    + Debug
46    + Display
47    + PartialOrd
48    + Add<Output = Self>
49    + Sub<Output = Self>
50    + Mul<Output = Self>
51    + Div<Output = Self>
52    + Neg<Output = Self>
53    + AddAssign
54    + SubAssign
55    + MulAssign
56    + DivAssign
57    + Sum
58    + Send
59    + Sync
60    + 'static
61{
62    // ===== Constants =====
63
64    /// Additive identity: 0
65    const ZERO: Self;
66    /// Multiplicative identity: 1
67    const ONE: Self;
68    /// Two (commonly needed constant)
69    const TWO: Self;
70    /// One half
71    const HALF: Self;
72    /// Machine epsilon (smallest x such that 1 + x ≠ 1)
73    const EPSILON: Self;
74    /// Positive infinity
75    const INFINITY: Self;
76    /// Negative infinity
77    const NEG_INFINITY: Self;
78    /// Not a number
79    const NAN: Self;
80    /// Pi (π)
81    const PI: Self;
82    /// Euler's number (e)
83    const E: Self;
84    /// Square root of 2
85    const SQRT_2: Self;
86    /// Natural log of 2
87    const LN_2: Self;
88
89    // ===== Conversion =====
90
91    /// Create from f64
92    fn from_f64(x: f64) -> Self;
93
94    /// Create from f32
95    fn from_f32(x: f32) -> Self;
96
97    /// Create from i32
98    fn from_i32(x: i32) -> Self;
99
100    /// Create from usize
101    fn from_usize(n: usize) -> Self;
102
103    /// Convert to f64
104    fn to_f64(self) -> f64;
105
106    /// Convert to f32
107    fn to_f32(self) -> f32;
108
109    // ===== Basic Operations =====
110
111    /// Absolute value
112    fn abs(self) -> Self;
113
114    /// Square root
115    fn sqrt(self) -> Self;
116
117    /// Cube root
118    fn cbrt(self) -> Self;
119
120    /// Square (x²) - more efficient than x * x for some types
121    #[inline]
122    fn sq(self) -> Self {
123        self * self
124    }
125
126    /// Integer power
127    fn powi(self, n: i32) -> Self;
128
129    /// Floating point power
130    fn powf(self, n: Self) -> Self;
131
132    /// Reciprocal (1/x)
133    #[inline]
134    fn recip(self) -> Self {
135        Self::ONE / self
136    }
137
138    /// Hypotenuse: sqrt(x² + y²) computed without overflow
139    fn hypot(self, other: Self) -> Self;
140
141    // ===== Trigonometric Functions =====
142
143    /// Sine
144    fn sin(self) -> Self;
145
146    /// Cosine
147    fn cos(self) -> Self;
148
149    /// Tangent
150    fn tan(self) -> Self;
151
152    /// Arcsine
153    fn asin(self) -> Self;
154
155    /// Arccosine
156    fn acos(self) -> Self;
157
158    /// Arctangent
159    fn atan(self) -> Self;
160
161    /// Two-argument arctangent (atan2)
162    fn atan2(self, other: Self) -> Self;
163
164    /// Simultaneous sine and cosine (more efficient than separate calls)
165    #[inline]
166    fn sincos(self) -> (Self, Self) {
167        (self.sin(), self.cos())
168    }
169
170    // ===== Exponential and Logarithmic =====
171
172    /// Natural exponential (e^x)
173    fn exp(self) -> Self;
174
175    /// Base-2 exponential (2^x)
176    fn exp2(self) -> Self;
177
178    /// exp(x) - 1, accurate for small x
179    fn exp_m1(self) -> Self;
180
181    /// Natural logarithm
182    fn ln(self) -> Self;
183
184    /// Base-2 logarithm
185    fn log2(self) -> Self;
186
187    /// Base-10 logarithm
188    fn log10(self) -> Self;
189
190    /// ln(1 + x), accurate for small x
191    fn ln_1p(self) -> Self;
192
193    // ===== Hyperbolic Functions =====
194
195    /// Hyperbolic sine
196    fn sinh(self) -> Self;
197
198    /// Hyperbolic cosine
199    fn cosh(self) -> Self;
200
201    /// Hyperbolic tangent
202    fn tanh(self) -> Self;
203
204    /// Inverse hyperbolic sine
205    fn asinh(self) -> Self;
206
207    /// Inverse hyperbolic cosine
208    fn acosh(self) -> Self;
209
210    /// Inverse hyperbolic tangent
211    fn atanh(self) -> Self;
212
213    // ===== Comparison and Ordering =====
214
215    /// Maximum of two values
216    fn max(self, other: Self) -> Self;
217
218    /// Minimum of two values
219    fn min(self, other: Self) -> Self;
220
221    /// Clamp value to range [min, max]
222    fn clamp(self, min: Self, max: Self) -> Self;
223
224    /// Copy sign from another value
225    fn copysign(self, sign: Self) -> Self;
226
227    // ===== Predicates =====
228
229    /// Is the value finite (not NaN or infinity)?
230    fn is_finite(self) -> bool;
231
232    /// Is the value NaN?
233    fn is_nan(self) -> bool;
234
235    /// Is the value infinite?
236    fn is_infinite(self) -> bool;
237
238    /// Is the sign positive (including +0)?
239    fn is_sign_positive(self) -> bool;
240
241    /// Is the sign negative (including -0)?
242    fn is_sign_negative(self) -> bool;
243
244    // ===== Rounding =====
245
246    /// Round toward negative infinity
247    fn floor(self) -> Self;
248
249    /// Round toward positive infinity
250    fn ceil(self) -> Self;
251
252    /// Round to nearest integer
253    fn round(self) -> Self;
254
255    /// Round toward zero
256    fn trunc(self) -> Self;
257
258    /// Fractional part
259    fn fract(self) -> Self;
260
261    // ===== Special Functions =====
262
263    /// Gamma function Γ(x)
264    fn gamma_fn(self) -> Self;
265
266    /// Natural log of gamma function ln(Γ(x))
267    fn ln_gamma(self) -> Self;
268
269    /// Error function erf(x)
270    fn erf_fn(self) -> Self;
271
272    /// Complementary error function erfc(x) = 1 - erf(x)
273    fn erfc_fn(self) -> Self;
274
275    // ===== Utility Methods =====
276
277    /// Fused multiply-add: (self * a) + b with single rounding
278    fn mul_add(self, a: Self, b: Self) -> Self;
279
280    /// Sign function: -1, 0, or 1
281    #[inline]
282    fn signum(self) -> Self {
283        if self > Self::ZERO {
284            Self::ONE
285        } else if self < Self::ZERO {
286            -Self::ONE
287        } else {
288            Self::ZERO
289        }
290    }
291}
292
293// ============================================================================
294// Implementation for f64
295// ============================================================================
296
297impl Scalar for f64 {
298    const ZERO: Self = 0.0;
299    const ONE: Self = 1.0;
300    const TWO: Self = 2.0;
301    const HALF: Self = 0.5;
302    const EPSILON: Self = f64::EPSILON;
303    const INFINITY: Self = f64::INFINITY;
304    const NEG_INFINITY: Self = f64::NEG_INFINITY;
305    const NAN: Self = f64::NAN;
306    const PI: Self = core::f64::consts::PI;
307    const E: Self = core::f64::consts::E;
308    const SQRT_2: Self = core::f64::consts::SQRT_2;
309    const LN_2: Self = core::f64::consts::LN_2;
310
311    #[inline]
312    fn from_f64(x: f64) -> Self {
313        x
314    }
315    #[inline]
316    fn from_f32(x: f32) -> Self {
317        x as f64
318    }
319    #[inline]
320    fn from_i32(x: i32) -> Self {
321        x as f64
322    }
323    #[inline]
324    fn from_usize(n: usize) -> Self {
325        n as f64
326    }
327    #[inline]
328    fn to_f64(self) -> f64 {
329        self
330    }
331    #[inline]
332    fn to_f32(self) -> f32 {
333        self as f32
334    }
335
336    #[inline]
337    fn abs(self) -> Self {
338        libm::fabs(self)
339    }
340    #[inline]
341    fn sqrt(self) -> Self {
342        libm::sqrt(self)
343    }
344    #[inline]
345    fn cbrt(self) -> Self {
346        libm::cbrt(self)
347    }
348    #[inline]
349    fn powi(self, n: i32) -> Self {
350        libm::pow(self, n as f64)
351    }
352    #[inline]
353    fn powf(self, n: Self) -> Self {
354        libm::pow(self, n)
355    }
356    #[inline]
357    fn hypot(self, other: Self) -> Self {
358        libm::hypot(self, other)
359    }
360
361    #[inline]
362    fn sin(self) -> Self {
363        libm::sin(self)
364    }
365    #[inline]
366    fn cos(self) -> Self {
367        libm::cos(self)
368    }
369    #[inline]
370    fn tan(self) -> Self {
371        libm::tan(self)
372    }
373    #[inline]
374    fn asin(self) -> Self {
375        libm::asin(self)
376    }
377    #[inline]
378    fn acos(self) -> Self {
379        libm::acos(self)
380    }
381    #[inline]
382    fn atan(self) -> Self {
383        libm::atan(self)
384    }
385    #[inline]
386    fn atan2(self, other: Self) -> Self {
387        libm::atan2(self, other)
388    }
389    #[inline]
390    fn sincos(self) -> (Self, Self) {
391        libm::sincos(self)
392    }
393
394    #[inline]
395    fn exp(self) -> Self {
396        libm::exp(self)
397    }
398    #[inline]
399    fn exp2(self) -> Self {
400        libm::exp2(self)
401    }
402    #[inline]
403    fn exp_m1(self) -> Self {
404        libm::expm1(self)
405    }
406    #[inline]
407    fn ln(self) -> Self {
408        libm::log(self)
409    }
410    #[inline]
411    fn log2(self) -> Self {
412        libm::log2(self)
413    }
414    #[inline]
415    fn log10(self) -> Self {
416        libm::log10(self)
417    }
418    #[inline]
419    fn ln_1p(self) -> Self {
420        libm::log1p(self)
421    }
422
423    #[inline]
424    fn sinh(self) -> Self {
425        libm::sinh(self)
426    }
427    #[inline]
428    fn cosh(self) -> Self {
429        libm::cosh(self)
430    }
431    #[inline]
432    fn tanh(self) -> Self {
433        libm::tanh(self)
434    }
435    #[inline]
436    fn asinh(self) -> Self {
437        libm::asinh(self)
438    }
439    #[inline]
440    fn acosh(self) -> Self {
441        libm::acosh(self)
442    }
443    #[inline]
444    fn atanh(self) -> Self {
445        libm::atanh(self)
446    }
447
448    #[inline]
449    fn max(self, other: Self) -> Self {
450        libm::fmax(self, other)
451    }
452    #[inline]
453    fn min(self, other: Self) -> Self {
454        libm::fmin(self, other)
455    }
456    #[inline]
457    fn clamp(self, min: Self, max: Self) -> Self {
458        libm::fmax(min, libm::fmin(self, max))
459    }
460    #[inline]
461    fn copysign(self, sign: Self) -> Self {
462        libm::copysign(self, sign)
463    }
464
465    #[inline]
466    fn is_finite(self) -> bool {
467        self.is_finite()
468    }
469    #[inline]
470    fn is_nan(self) -> bool {
471        self.is_nan()
472    }
473    #[inline]
474    fn is_infinite(self) -> bool {
475        self.is_infinite()
476    }
477    #[inline]
478    fn is_sign_positive(self) -> bool {
479        self.is_sign_positive()
480    }
481    #[inline]
482    fn is_sign_negative(self) -> bool {
483        self.is_sign_negative()
484    }
485
486    #[inline]
487    fn floor(self) -> Self {
488        libm::floor(self)
489    }
490    #[inline]
491    fn ceil(self) -> Self {
492        libm::ceil(self)
493    }
494    #[inline]
495    fn round(self) -> Self {
496        libm::round(self)
497    }
498    #[inline]
499    fn trunc(self) -> Self {
500        libm::trunc(self)
501    }
502    #[inline]
503    fn fract(self) -> Self {
504        self - libm::trunc(self)
505    }
506
507    #[inline]
508    fn gamma_fn(self) -> Self {
509        libm::tgamma(self)
510    }
511    #[inline]
512    fn ln_gamma(self) -> Self {
513        libm::lgamma(self)
514    }
515    #[inline]
516    fn erf_fn(self) -> Self {
517        libm::erf(self)
518    }
519    #[inline]
520    fn erfc_fn(self) -> Self {
521        libm::erfc(self)
522    }
523
524    #[inline]
525    fn mul_add(self, a: Self, b: Self) -> Self {
526        libm::fma(self, a, b)
527    }
528}
529
530// ============================================================================
531// Implementation for f32
532// ============================================================================
533
534impl Scalar for f32 {
535    const ZERO: Self = 0.0;
536    const ONE: Self = 1.0;
537    const TWO: Self = 2.0;
538    const HALF: Self = 0.5;
539    const EPSILON: Self = f32::EPSILON;
540    const INFINITY: Self = f32::INFINITY;
541    const NEG_INFINITY: Self = f32::NEG_INFINITY;
542    const NAN: Self = f32::NAN;
543    const PI: Self = core::f32::consts::PI;
544    const E: Self = core::f32::consts::E;
545    const SQRT_2: Self = core::f32::consts::SQRT_2;
546    const LN_2: Self = core::f32::consts::LN_2;
547
548    #[inline]
549    fn from_f64(x: f64) -> Self {
550        x as f32
551    }
552    #[inline]
553    fn from_f32(x: f32) -> Self {
554        x
555    }
556    #[inline]
557    fn from_i32(x: i32) -> Self {
558        x as f32
559    }
560    #[inline]
561    fn from_usize(n: usize) -> Self {
562        n as f32
563    }
564    #[inline]
565    fn to_f64(self) -> f64 {
566        self as f64
567    }
568    #[inline]
569    fn to_f32(self) -> f32 {
570        self
571    }
572
573    #[inline]
574    fn abs(self) -> Self {
575        libm::fabsf(self)
576    }
577    #[inline]
578    fn sqrt(self) -> Self {
579        libm::sqrtf(self)
580    }
581    #[inline]
582    fn cbrt(self) -> Self {
583        libm::cbrtf(self)
584    }
585    #[inline]
586    fn powi(self, n: i32) -> Self {
587        libm::powf(self, n as f32)
588    }
589    #[inline]
590    fn powf(self, n: Self) -> Self {
591        libm::powf(self, n)
592    }
593    #[inline]
594    fn hypot(self, other: Self) -> Self {
595        libm::hypotf(self, other)
596    }
597
598    #[inline]
599    fn sin(self) -> Self {
600        libm::sinf(self)
601    }
602    #[inline]
603    fn cos(self) -> Self {
604        libm::cosf(self)
605    }
606    #[inline]
607    fn tan(self) -> Self {
608        libm::tanf(self)
609    }
610    #[inline]
611    fn asin(self) -> Self {
612        libm::asinf(self)
613    }
614    #[inline]
615    fn acos(self) -> Self {
616        libm::acosf(self)
617    }
618    #[inline]
619    fn atan(self) -> Self {
620        libm::atanf(self)
621    }
622    #[inline]
623    fn atan2(self, other: Self) -> Self {
624        libm::atan2f(self, other)
625    }
626    #[inline]
627    fn sincos(self) -> (Self, Self) {
628        libm::sincosf(self)
629    }
630
631    #[inline]
632    fn exp(self) -> Self {
633        libm::expf(self)
634    }
635    #[inline]
636    fn exp2(self) -> Self {
637        libm::exp2f(self)
638    }
639    #[inline]
640    fn exp_m1(self) -> Self {
641        libm::expm1f(self)
642    }
643    #[inline]
644    fn ln(self) -> Self {
645        libm::logf(self)
646    }
647    #[inline]
648    fn log2(self) -> Self {
649        libm::log2f(self)
650    }
651    #[inline]
652    fn log10(self) -> Self {
653        libm::log10f(self)
654    }
655    #[inline]
656    fn ln_1p(self) -> Self {
657        libm::log1pf(self)
658    }
659
660    #[inline]
661    fn sinh(self) -> Self {
662        libm::sinhf(self)
663    }
664    #[inline]
665    fn cosh(self) -> Self {
666        libm::coshf(self)
667    }
668    #[inline]
669    fn tanh(self) -> Self {
670        libm::tanhf(self)
671    }
672    #[inline]
673    fn asinh(self) -> Self {
674        libm::asinhf(self)
675    }
676    #[inline]
677    fn acosh(self) -> Self {
678        libm::acoshf(self)
679    }
680    #[inline]
681    fn atanh(self) -> Self {
682        libm::atanhf(self)
683    }
684
685    #[inline]
686    fn max(self, other: Self) -> Self {
687        libm::fmaxf(self, other)
688    }
689    #[inline]
690    fn min(self, other: Self) -> Self {
691        libm::fminf(self, other)
692    }
693    #[inline]
694    fn clamp(self, min: Self, max: Self) -> Self {
695        libm::fmaxf(min, libm::fminf(self, max))
696    }
697    #[inline]
698    fn copysign(self, sign: Self) -> Self {
699        libm::copysignf(self, sign)
700    }
701
702    #[inline]
703    fn is_finite(self) -> bool {
704        self.is_finite()
705    }
706    #[inline]
707    fn is_nan(self) -> bool {
708        self.is_nan()
709    }
710    #[inline]
711    fn is_infinite(self) -> bool {
712        self.is_infinite()
713    }
714    #[inline]
715    fn is_sign_positive(self) -> bool {
716        self.is_sign_positive()
717    }
718    #[inline]
719    fn is_sign_negative(self) -> bool {
720        self.is_sign_negative()
721    }
722
723    #[inline]
724    fn floor(self) -> Self {
725        libm::floorf(self)
726    }
727    #[inline]
728    fn ceil(self) -> Self {
729        libm::ceilf(self)
730    }
731    #[inline]
732    fn round(self) -> Self {
733        libm::roundf(self)
734    }
735    #[inline]
736    fn trunc(self) -> Self {
737        libm::truncf(self)
738    }
739    #[inline]
740    fn fract(self) -> Self {
741        self - libm::truncf(self)
742    }
743
744    #[inline]
745    fn gamma_fn(self) -> Self {
746        libm::tgammaf(self)
747    }
748    #[inline]
749    fn ln_gamma(self) -> Self {
750        libm::lgammaf(self)
751    }
752    #[inline]
753    fn erf_fn(self) -> Self {
754        libm::erff(self)
755    }
756    #[inline]
757    fn erfc_fn(self) -> Self {
758        libm::erfcf(self)
759    }
760
761    #[inline]
762    fn mul_add(self, a: Self, b: Self) -> Self {
763        libm::fmaf(self, a, b)
764    }
765}
766
767// ============================================================================
768// Batch conversion utilities
769// ============================================================================
770
771/// Convert a slice of any Scalar type to a `Vec<f64>`.
772pub fn to_f64_vec<S: Scalar>(v: &[S]) -> Vec<f64> {
773    v.iter().map(|x| x.to_f64()).collect()
774}
775
776/// Convert a slice of f64 to a Vec of any Scalar type.
777pub fn from_f64_vec<S: Scalar>(v: &[f64]) -> Vec<S> {
778    v.iter().map(|&x| S::from_f64(x)).collect()
779}
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    #[test]
786    #[allow(clippy::assertions_on_constants)]
787    fn test_constants_f64() {
788        assert_eq!(f64::ZERO, 0.0);
789        assert_eq!(f64::ONE, 1.0);
790        assert_eq!(f64::TWO, 2.0);
791        assert!(f64::EPSILON > 0.0);
792        assert!(f64::EPSILON < 1e-10);
793    }
794
795    #[test]
796    fn test_basic_ops_f64() {
797        let x: f64 = 4.0;
798        assert!((x.sqrt() - 2.0).abs() < 1e-10);
799        assert!((x.sq() - 16.0).abs() < 1e-10);
800    }
801
802    #[test]
803    fn test_trig_f64() {
804        let x: f64 = f64::PI / 4.0;
805        let (s, c) = x.sincos();
806        assert!((s - c).abs() < 1e-10); // sin(π/4) = cos(π/4)
807    }
808
809    #[test]
810    fn test_special_functions_f64() {
811        // Gamma(5) = 4! = 24
812        let g: f64 = 5.0_f64.gamma_fn();
813        assert!((g - 24.0).abs() < 1e-10);
814
815        // erf(0) = 0
816        let e: f64 = 0.0_f64.erf_fn();
817        assert!(e.abs() < 1e-10);
818    }
819
820    #[test]
821    fn test_to_f64_vec_identity() {
822        let v = vec![1.0_f64, 2.5, 3.7];
823        let converted = to_f64_vec(&v);
824        assert_eq!(v, converted);
825    }
826
827    #[test]
828    fn test_from_f64_vec_identity() {
829        let v = vec![1.0, 2.5, 3.7];
830        let converted: Vec<f64> = from_f64_vec(&v);
831        assert_eq!(v, converted);
832    }
833
834    #[test]
835    fn test_f32_roundtrip() {
836        let orig = vec![1.0_f32, 2.5, 3.7];
837        let f64_vec = to_f64_vec(&orig);
838        let back: Vec<f32> = from_f64_vec(&f64_vec);
839        for (a, b) in orig.iter().zip(back.iter()) {
840            assert!((a - b).abs() < 1e-6);
841        }
842    }
843}