Skip to main content

irithyll_core/
math.rs

1//! Platform-agnostic f64 math operations.
2//!
3//! In `std` mode, these delegate to inherent f64 methods (zero overhead).
4//! In `no_std` mode, these use `libm` (pure Rust software implementations).
5//!
6//! This module exists because f64 inherent methods (.sqrt(), .exp(), etc.)
7//! are not available in no_std on MSRV 1.75.
8
9// NOTE: We always use libm functions regardless of std feature.
10// On std targets, LLVM will optimize these to the same native instructions.
11// This avoids conditional compilation complexity and ensures identical behavior.
12
13/// Absolute value.
14#[inline]
15pub fn abs(x: f64) -> f64 {
16    libm::fabs(x)
17}
18
19/// Square root.
20#[inline]
21pub fn sqrt(x: f64) -> f64 {
22    libm::sqrt(x)
23}
24
25/// Natural exponential (e^x).
26#[inline]
27pub fn exp(x: f64) -> f64 {
28    libm::exp(x)
29}
30
31/// Natural logarithm (ln).
32#[inline]
33pub fn ln(x: f64) -> f64 {
34    libm::log(x)
35}
36
37/// Base-2 logarithm.
38#[inline]
39pub fn log2(x: f64) -> f64 {
40    libm::log2(x)
41}
42
43/// Base-10 logarithm.
44#[inline]
45pub fn log10(x: f64) -> f64 {
46    libm::log10(x)
47}
48
49/// Power: x^n (floating point exponent).
50#[inline]
51pub fn powf(x: f64, n: f64) -> f64 {
52    libm::pow(x, n)
53}
54
55/// Power: x^n (integer exponent).
56#[inline]
57pub fn powi(x: f64, n: i32) -> f64 {
58    libm::pow(x, n as f64)
59}
60
61/// Sine.
62#[inline]
63pub fn sin(x: f64) -> f64 {
64    libm::sin(x)
65}
66
67/// Cosine.
68#[inline]
69pub fn cos(x: f64) -> f64 {
70    libm::cos(x)
71}
72
73/// Floor.
74#[inline]
75pub fn floor(x: f64) -> f64 {
76    libm::floor(x)
77}
78
79/// Ceil.
80#[inline]
81pub fn ceil(x: f64) -> f64 {
82    libm::ceil(x)
83}
84
85/// Round to nearest integer.
86#[inline]
87pub fn round(x: f64) -> f64 {
88    libm::round(x)
89}
90
91/// Hyperbolic tangent.
92#[inline]
93pub fn tanh(x: f64) -> f64 {
94    libm::tanh(x)
95}
96
97/// Softplus: ln(1 + exp(x)), numerically stable.
98#[inline]
99pub fn softplus(x: f64) -> f64 {
100    if x > 20.0 {
101        x
102    } else if x < -20.0 {
103        libm::exp(x)
104    } else {
105        libm::log(1.0 + libm::exp(x))
106    }
107}
108
109/// Logistic sigmoid: 1 / (1 + exp(-x)), numerically stable.
110#[inline]
111pub fn sigmoid(x: f64) -> f64 {
112    if x >= 0.0 {
113        let e = libm::exp(-x);
114        1.0 / (1.0 + e)
115    } else {
116        let e = libm::exp(x);
117        e / (1.0 + e)
118    }
119}
120
121/// Minimum of two f64 values (handles NaN: returns the non-NaN value).
122#[inline]
123pub fn fmin(x: f64, y: f64) -> f64 {
124    libm::fmin(x, y)
125}
126
127/// Maximum of two f64 values (handles NaN: returns the non-NaN value).
128#[inline]
129pub fn fmax(x: f64, y: f64) -> f64 {
130    libm::fmax(x, y)
131}
132
133/// Error function.
134#[inline]
135pub fn erf(x: f64) -> f64 {
136    libm::erf(x)
137}
138
139/// f32 absolute value.
140#[inline]
141pub fn abs_f32(x: f32) -> f32 {
142    libm::fabsf(x)
143}
144
145/// f32 square root.
146#[inline]
147pub fn sqrt_f32(x: f32) -> f32 {
148    libm::sqrtf(x)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn sqrt_of_4() {
157        assert!((sqrt(4.0) - 2.0).abs() < 1e-15);
158    }
159
160    #[test]
161    fn exp_of_0() {
162        assert!((exp(0.0) - 1.0).abs() < 1e-15);
163    }
164
165    #[test]
166    fn ln_of_e() {
167        assert!((ln(core::f64::consts::E) - 1.0).abs() < 1e-15);
168    }
169
170    #[test]
171    fn abs_negative() {
172        assert_eq!(abs(-5.0), 5.0);
173        assert_eq!(abs(5.0), 5.0);
174        assert_eq!(abs(0.0), 0.0);
175    }
176
177    #[test]
178    fn powf_squares() {
179        assert!((powf(3.0, 2.0) - 9.0).abs() < 1e-15);
180    }
181
182    #[test]
183    fn powi_cubes() {
184        assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
185    }
186
187    #[test]
188    fn sin_cos_identity() {
189        let x = 1.0;
190        let s = sin(x);
191        let c = cos(x);
192        assert!((s * s + c * c - 1.0).abs() < 1e-15);
193    }
194
195    #[test]
196    fn floor_ceil_round() {
197        assert_eq!(floor(2.7), 2.0);
198        assert_eq!(ceil(2.3), 3.0);
199        assert_eq!(round(2.5), 3.0);
200        assert_eq!(round(2.4), 2.0);
201    }
202
203    #[test]
204    fn log2_of_8() {
205        assert!((log2(8.0) - 3.0).abs() < 1e-15);
206    }
207
208    #[test]
209    fn tanh_of_0() {
210        assert!((tanh(0.0)).abs() < 1e-15);
211    }
212
213    #[test]
214    fn fmin_fmax() {
215        assert_eq!(fmin(1.0, 2.0), 1.0);
216        assert_eq!(fmax(1.0, 2.0), 2.0);
217    }
218
219    #[test]
220    fn softplus_large_positive() {
221        // For x >> 0, softplus(x) ~ x
222        assert!((softplus(50.0) - 50.0).abs() < 1e-10);
223    }
224
225    #[test]
226    fn softplus_large_negative() {
227        // For x << 0, softplus(x) ~ 0
228        let result = softplus(-50.0);
229        assert!(result >= 0.0 && result < 1e-20);
230    }
231
232    #[test]
233    fn softplus_zero() {
234        let expected = ln(2.0);
235        assert!((softplus(0.0) - expected).abs() < 1e-12);
236    }
237
238    #[test]
239    fn softplus_always_positive() {
240        for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
241            assert!(softplus(x) > 0.0, "softplus({}) should be > 0", x);
242        }
243    }
244
245    #[test]
246    fn sigmoid_at_zero() {
247        assert!((sigmoid(0.0) - 0.5).abs() < 1e-12);
248    }
249
250    #[test]
251    fn sigmoid_range() {
252        for &x in &[-10.0, -1.0, 0.0, 1.0, 10.0] {
253            let s = sigmoid(x);
254            assert!(
255                s > 0.0 && s < 1.0,
256                "sigmoid({}) = {} should be in (0, 1)",
257                x,
258                s
259            );
260        }
261    }
262
263    #[test]
264    fn sigmoid_symmetry() {
265        let x = 3.0;
266        assert!((sigmoid(x) + sigmoid(-x) - 1.0).abs() < 1e-12);
267    }
268
269    #[test]
270    fn sigmoid_extreme_values() {
271        let s_pos = sigmoid(100.0);
272        let s_neg = sigmoid(-100.0);
273        assert!(s_pos >= 0.0 && s_pos <= 1.0);
274        assert!(s_neg >= 0.0 && s_neg <= 1.0);
275    }
276}