Skip to main content

ferray_ufunc/ops/
explog.rs

1// ferray-ufunc: Exponential and logarithmic functions
2//
3// exp, exp2, expm1, log, log2, log10, log1p, logaddexp, logaddexp2
4//
5// ## REQ status — REQ-6 (exp/log family) + binary-promote tie-ins
6//
7// SHIPPED:
8//   - REQ-6 (`exp`/`exp2`/`expm1`/`log`/`log2`/`log10`/`log1p`/`logaddexp`/
9//     `logaddexp2`): the full NumPy exp/log ufunc family as generic free
10//     functions preserving input dimensionality (REQ-1). Anchors:
11//     `pub fn exp`/`pub fn exp2`/`pub fn expm1`, `pub fn log`/`pub fn log2`/
12//     `pub fn log10`/`pub fn log1p`, `pub fn logaddexp`/`pub fn logaddexp2`
13//     (binary). `T: Element + Float` (the `exp`/`log` defaults route through
14//     CORE-MATH `CrMath` for correctly-rounded results; faithful-rounding
15//     `_fast` kernels `pub fn exp_fast` and in-place `_into` counterparts
16//     `pub fn exp_into`/`pub fn log_into` are also provided). Special-value
17//     edges (`log(0)` -> -inf, `log(-x)` -> NaN, `exp` overflow -> inf,
18//     `expm1`/`log1p` near-zero accuracy) are audited against numpy 2.4.x and
19//     green. Non-test production consumer: re-exported verbatim from the crate
20//     root (`lib.rs` `pub use ops::explog::{exp, exp_fast, exp_into, exp2,
21//     expm1, log, log_into, log1p, log2, log10, logaddexp, logaddexp2}`), the
22//     public ufunc surface and the ferray-python exp/log binding target. (f16
23//     variants `exp_f16`/`log_f16`/… are feature-gated re-exports.)
24//   - REQ-23 tie-in (integer/bool input promotion): `exp_promote`/`log_promote`/
25//     `exp2_promote`/`expm1_promote`/`log2_promote`/`log10_promote`/
26//     `log1p_promote` (in `promoted.rs`) call THESE generic `T: Float` kernels
27//     monomorphised at the compute float — no separate int kernel here.
28//   - REQ-25 tie-in (binary int/bool promotion): `logaddexp_promote`/
29//     `logaddexp2_promote` (in `promoted.rs`) wrap `pub fn logaddexp`/
30//     `pub fn logaddexp2` here for integer/bool operand pairs; f32/f64 callers
31//     stay byte-identical.
32//
33// NOT-STARTED: none — REQ-6 is fully shipped for this module.
34
35use ferray_core::Array;
36use ferray_core::dimension::Dimension;
37use ferray_core::dtype::Element;
38use ferray_core::error::FerrayResult;
39use num_traits::Float;
40
41use crate::cr_math::CrMath;
42use crate::helpers::{
43    binary_elementwise_op, unary_float_op, unary_float_op_compute, unary_float_op_into_compute,
44};
45
46/// Elementwise exponential (e^x).
47pub fn exp<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
48where
49    T: Element + Float + CrMath,
50    D: Dimension,
51{
52    unary_float_op_compute(input, T::cr_exp)
53}
54
55/// In-place `e^x` — `_into` counterpart of [`exp`]. Parallelizes along
56/// the compute-bound threshold for transcendentals (100k elements).
57pub fn exp_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
58where
59    T: Element + Float + CrMath,
60    D: Dimension,
61{
62    unary_float_op_into_compute(input, out, "exp", T::cr_exp)
63}
64
65/// Fast elementwise exponential (e^x) with ≤1 ULP accuracy.
66///
67/// Uses an Even/Odd Remez decomposition that is ~30% faster than `exp()` (CORE-MATH)
68/// while achieving faithful rounding (≤1 ULP). The default `exp()` is correctly
69/// rounded (≤0.5 ULP) via CORE-MATH.
70///
71/// This function auto-vectorizes for SSE/AVX2/AVX-512/NEON with no lookup tables.
72/// Subnormal outputs (x < -708.4) are flushed to zero.
73///
74/// For f64 arrays, uses the optimized batch kernel directly.
75/// For f32 arrays, promotes to f64 internally (f32 has only 24 mantissa bits,
76/// so the result is correctly rounded for all finite f32 inputs).
77pub fn exp_fast<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
78where
79    T: Element + Float,
80    D: Dimension,
81{
82    use std::any::TypeId;
83    if TypeId::of::<T>() == TypeId::of::<f64>() {
84        // SAFETY: T is f64 — reinterpret the array reference
85        let f64_input =
86            unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f64, D>>() };
87        let n = f64_input.size();
88        let result = if let Some(slice) = f64_input.as_slice() {
89            let mut data = Vec::with_capacity(n);
90            #[allow(clippy::uninit_vec)]
91            unsafe {
92                data.set_len(n);
93            }
94            crate::dispatch::dispatch_exp_fast_f64(slice, &mut data);
95            Array::from_vec(f64_input.dim().clone(), data)?
96        } else {
97            let data: Vec<f64> = f64_input
98                .iter()
99                .map(|&x| crate::fast_exp::exp_fast_f64(x))
100                .collect();
101            Array::from_vec(f64_input.dim().clone(), data)?
102        };
103        // SAFETY: T was verified to be f64 at the top of this branch.
104        Ok(unsafe { crate::helpers::reinterpret_array::<f64, T, D>(result) })
105    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
106        let f32_input =
107            unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f32, D>>() };
108        let n = f32_input.size();
109        let result = if let Some(slice) = f32_input.as_slice() {
110            let mut data = Vec::with_capacity(n);
111            #[allow(clippy::uninit_vec)]
112            unsafe {
113                data.set_len(n);
114            }
115            crate::dispatch::dispatch_exp_fast_f32(slice, &mut data);
116            Array::from_vec(f32_input.dim().clone(), data)?
117        } else {
118            let data: Vec<f32> = f32_input
119                .iter()
120                .map(|&x| crate::fast_exp::exp_fast_f32(x))
121                .collect();
122            Array::from_vec(f32_input.dim().clone(), data)?
123        };
124        // SAFETY: T was verified to be f32 at the top of this branch.
125        Ok(unsafe { crate::helpers::reinterpret_array::<f32, T, D>(result) })
126    } else {
127        // Fallback for other float types: use libm exp
128        unary_float_op(input, num_traits::Float::exp)
129    }
130}
131
132/// Elementwise 2^x.
133pub fn exp2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
134where
135    T: Element + Float + CrMath,
136    D: Dimension,
137{
138    unary_float_op_compute(input, T::cr_exp2)
139}
140
141/// Elementwise exp(x) - 1, accurate near zero.
142pub fn expm1<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
143where
144    T: Element + Float + CrMath,
145    D: Dimension,
146{
147    unary_float_op_compute(input, T::cr_exp_m1)
148}
149
150/// Elementwise natural logarithm.
151pub fn log<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
152where
153    T: Element + Float + CrMath,
154    D: Dimension,
155{
156    unary_float_op_compute(input, T::cr_ln)
157}
158
159/// In-place natural logarithm — `_into` counterpart of [`log`].
160pub fn log_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
161where
162    T: Element + Float + CrMath,
163    D: Dimension,
164{
165    unary_float_op_into_compute(input, out, "log", T::cr_ln)
166}
167
168/// Elementwise base-2 logarithm.
169pub fn log2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
170where
171    T: Element + Float + CrMath,
172    D: Dimension,
173{
174    unary_float_op_compute(input, T::cr_log2)
175}
176
177/// Elementwise base-10 logarithm.
178pub fn log10<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
179where
180    T: Element + Float + CrMath,
181    D: Dimension,
182{
183    unary_float_op_compute(input, T::cr_log10)
184}
185
186/// Elementwise ln(1 + x), accurate near zero.
187pub fn log1p<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
188where
189    T: Element + Float + CrMath,
190    D: Dimension,
191{
192    unary_float_op_compute(input, T::cr_ln_1p)
193}
194
195/// log(exp(a) + exp(b)), computed in a numerically stable way.
196pub fn logaddexp<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
197where
198    T: Element + Float + CrMath,
199    D: Dimension,
200{
201    binary_elementwise_op(a, b, |x, y| {
202        if x.is_nan() || y.is_nan() {
203            return T::nan();
204        }
205        let max = if x > y { x } else { y };
206        // Equal-infinity / infinite-max guard (numpy: logaddexp(inf,inf)=inf,
207        // logaddexp(-inf,-inf)=-inf). The stable form below computes
208        // inf - inf = NaN when both inputs are +inf; returning `max` directly
209        // mirrors numpy. (inf,-inf)/(inf,finite) → max=inf → inf;
210        // (-inf,finite) keeps max=finite and falls through to the finite path.
211        if max.is_infinite() {
212            return max;
213        }
214        let min = if x > y { y } else { x };
215        max + (min - max).cr_exp().cr_ln_1p()
216    })
217}
218
219/// log2(2^a + 2^b), computed in a numerically stable way.
220pub fn logaddexp2<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
221where
222    T: Element + Float + CrMath,
223    D: Dimension,
224{
225    let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(|| <T as Element>::one());
226    binary_elementwise_op(a, b, |x, y| {
227        if x.is_nan() || y.is_nan() {
228            return T::nan();
229        }
230        let max = if x > y { x } else { y };
231        // Equal-infinity / infinite-max guard (numpy: logaddexp2(inf,inf)=inf,
232        // logaddexp2(-inf,-inf)=-inf). The stable form below computes
233        // inf - inf = NaN when both inputs are +inf; returning `max` directly
234        // mirrors numpy. (inf,-inf)/(inf,finite) → max=inf → inf;
235        // (-inf,finite) keeps max=finite and falls through to the finite path.
236        if max.is_infinite() {
237            return max;
238        }
239        let min = if x > y { y } else { x };
240        max + ((min - max) * ln2).cr_exp().cr_ln_1p() / ln2
241    })
242}
243
244// ---------------------------------------------------------------------------
245// f16 variants (f32-promoted) — generated via the shared unary_f16_fn!
246// macro (#142).
247// ---------------------------------------------------------------------------
248
249use crate::helpers::unary_f16_fn;
250
251unary_f16_fn!(
252    /// Elementwise exponential for f16 arrays via f32 promotion.
253    #[cfg(feature = "f16")]
254    exp_f16,
255    f32::exp
256);
257unary_f16_fn!(
258    /// Elementwise 2^x for f16 arrays via f32 promotion.
259    #[cfg(feature = "f16")]
260    exp2_f16,
261    f32::exp2
262);
263unary_f16_fn!(
264    /// Elementwise exp(x)-1 for f16 arrays via f32 promotion.
265    #[cfg(feature = "f16")]
266    expm1_f16,
267    f32::exp_m1
268);
269unary_f16_fn!(
270    /// Elementwise natural logarithm for f16 arrays via f32 promotion.
271    #[cfg(feature = "f16")]
272    log_f16,
273    f32::ln
274);
275unary_f16_fn!(
276    /// Elementwise base-2 logarithm for f16 arrays via f32 promotion.
277    #[cfg(feature = "f16")]
278    log2_f16,
279    f32::log2
280);
281unary_f16_fn!(
282    /// Elementwise base-10 logarithm for f16 arrays via f32 promotion.
283    #[cfg(feature = "f16")]
284    log10_f16,
285    f32::log10
286);
287unary_f16_fn!(
288    /// Elementwise ln(1+x) for f16 arrays via f32 promotion.
289    #[cfg(feature = "f16")]
290    log1p_f16,
291    f32::ln_1p
292);
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    use crate::test_util::arr1;
299
300    #[test]
301    fn test_exp() {
302        let a = arr1(vec![0.0, 1.0]);
303        let r = exp(&a).unwrap();
304        let s = r.as_slice().unwrap();
305        assert!((s[0] - 1.0).abs() < 1e-12);
306        assert!((s[1] - std::f64::consts::E).abs() < 1e-12);
307    }
308
309    #[test]
310    fn test_exp_fast() {
311        let a = arr1(vec![0.0, 1.0, -1.0, 10.0, -10.0]);
312        let r = exp_fast(&a).unwrap();
313        let s = r.as_slice().unwrap();
314        assert!((s[0] - 1.0).abs() < 1e-15);
315        assert!((s[1] - std::f64::consts::E).abs() < 1e-14);
316        assert!((s[2] - 1.0 / std::f64::consts::E).abs() < 1e-15);
317        // Check ≤1.5 ULP vs libm
318        for (i, &x) in [0.0, 1.0, -1.0, 10.0, -10.0].iter().enumerate() {
319            let reference = x.exp();
320            let ulp = (s[i] - reference).abs() / (reference.abs() * f64::EPSILON);
321            assert!(ulp <= 1.5, "exp_fast({x}) ulp = {ulp}");
322        }
323    }
324
325    #[test]
326    fn test_exp2() {
327        let a = arr1(vec![0.0, 3.0, 10.0]);
328        let r = exp2(&a).unwrap();
329        let s = r.as_slice().unwrap();
330        assert!((s[0] - 1.0).abs() < 1e-12);
331        assert!((s[1] - 8.0).abs() < 1e-12);
332        assert!((s[2] - 1024.0).abs() < 1e-9);
333    }
334
335    #[test]
336    fn test_expm1() {
337        let a = arr1(vec![0.0, 1e-15]);
338        let r = expm1(&a).unwrap();
339        let s = r.as_slice().unwrap();
340        assert!((s[0]).abs() < 1e-12);
341        // expm1 should be accurate near zero
342        assert!((s[1] - 1e-15).abs() < 1e-25);
343    }
344
345    #[test]
346    fn test_log() {
347        let a = arr1(vec![1.0, std::f64::consts::E]);
348        let r = log(&a).unwrap();
349        let s = r.as_slice().unwrap();
350        assert!((s[0]).abs() < 1e-12);
351        assert!((s[1] - 1.0).abs() < 1e-12);
352    }
353
354    #[test]
355    fn test_log2() {
356        let a = arr1(vec![1.0, 8.0, 1024.0]);
357        let r = log2(&a).unwrap();
358        let s = r.as_slice().unwrap();
359        assert!((s[0]).abs() < 1e-12);
360        assert!((s[1] - 3.0).abs() < 1e-12);
361        assert!((s[2] - 10.0).abs() < 1e-10);
362    }
363
364    #[test]
365    fn test_log10() {
366        let a = arr1(vec![1.0, 100.0, 1000.0]);
367        let r = log10(&a).unwrap();
368        let s = r.as_slice().unwrap();
369        assert!((s[0]).abs() < 1e-12);
370        assert!((s[1] - 2.0).abs() < 1e-12);
371        assert!((s[2] - 3.0).abs() < 1e-12);
372    }
373
374    #[test]
375    fn test_log1p() {
376        let a = arr1(vec![0.0, 1e-15]);
377        let r = log1p(&a).unwrap();
378        let s = r.as_slice().unwrap();
379        assert!((s[0]).abs() < 1e-12);
380        assert!((s[1] - 1e-15).abs() < 1e-25);
381    }
382
383    #[test]
384    fn test_logaddexp() {
385        let a = arr1(vec![0.0]);
386        let b = arr1(vec![0.0]);
387        let r = logaddexp(&a, &b).unwrap();
388        let s = r.as_slice().unwrap();
389        // log(e^0 + e^0) = log(2) ~ 0.693
390        assert!((s[0] - std::f64::consts::LN_2).abs() < 1e-12);
391    }
392
393    #[test]
394    fn test_logaddexp2() {
395        let a = arr1(vec![0.0]);
396        let b = arr1(vec![0.0]);
397        let r = logaddexp2(&a, &b).unwrap();
398        let s = r.as_slice().unwrap();
399        // log2(2^0 + 2^0) = log2(2) = 1
400        assert!((s[0] - 1.0).abs() < 1e-12);
401    }
402
403    fn eval_logaddexp(
404        f: impl Fn(
405            &Array<f64, ferray_core::dimension::Ix1>,
406            &Array<f64, ferray_core::dimension::Ix1>,
407        ) -> FerrayResult<Array<f64, ferray_core::dimension::Ix1>>,
408        a: Vec<f64>,
409        b: Vec<f64>,
410    ) -> Vec<f64> {
411        let r = f(&arr1(a), &arr1(b));
412        assert!(r.is_ok(), "logaddexp kernel returned error: {:?}", r.err());
413        match r {
414            Ok(arr) => match arr.as_slice() {
415                Some(s) => s.to_vec(),
416                None => arr.iter().copied().collect(),
417            },
418            Err(_) => Vec::new(),
419        }
420    }
421
422    #[test]
423    fn test_logaddexp_infinities() {
424        // Live numpy 2.4.5 oracle (generate_umath.py:710 `logaddexp`):
425        //   np.logaddexp(inf, inf)   == inf
426        //   np.logaddexp(-inf, -inf) == -inf
427        //   np.logaddexp(inf, -inf)  == inf
428        let inf = f64::INFINITY;
429        let s = eval_logaddexp(logaddexp, vec![inf, -inf, inf], vec![inf, -inf, -inf]);
430        assert_eq!(s[0], inf, "logaddexp(inf, inf)");
431        assert_eq!(s[1], -inf, "logaddexp(-inf, -inf)");
432        assert_eq!(s[2], inf, "logaddexp(inf, -inf)");
433        // Finite case stays byte-identical: logaddexp(1,1) == 1 + ln(2).
434        let f = eval_logaddexp(logaddexp, vec![1.0], vec![1.0]);
435        assert!((f[0] - (1.0 + std::f64::consts::LN_2)).abs() < 1e-12);
436    }
437
438    #[test]
439    fn test_logaddexp2_infinities() {
440        // Live numpy 2.4.5 oracle (generate_umath.py:716 `logaddexp2`):
441        //   np.logaddexp2(inf, inf)   == inf
442        //   np.logaddexp2(-inf, -inf) == -inf
443        //   np.logaddexp2(inf, -inf)  == inf
444        let inf = f64::INFINITY;
445        let s = eval_logaddexp(logaddexp2, vec![inf, -inf, inf], vec![inf, -inf, -inf]);
446        assert_eq!(s[0], inf, "logaddexp2(inf, inf)");
447        assert_eq!(s[1], -inf, "logaddexp2(-inf, -inf)");
448        assert_eq!(s[2], inf, "logaddexp2(inf, -inf)");
449        // Finite case stays byte-identical: logaddexp2(0,0) == 1.
450        let f = eval_logaddexp(logaddexp2, vec![0.0], vec![0.0]);
451        assert!((f[0] - 1.0).abs() < 1e-12);
452    }
453
454    #[cfg(feature = "f16")]
455    mod f16_tests {
456        use super::*;
457        use ferray_core::dimension::Ix1;
458
459        fn arr1_f16(data: &[f32]) -> Array<half::f16, Ix1> {
460            let n = data.len();
461            let vals: Vec<half::f16> = data.iter().map(|&x| half::f16::from_f32(x)).collect();
462            Array::from_vec(Ix1::new([n]), vals).unwrap()
463        }
464
465        #[test]
466        fn test_exp_f16() {
467            let a = arr1_f16(&[0.0, 1.0]);
468            let r = exp_f16(&a).unwrap();
469            let s = r.as_slice().unwrap();
470            assert!((s[0].to_f32() - 1.0).abs() < 0.01);
471            assert!((s[1].to_f32() - std::f32::consts::E).abs() < 0.02);
472        }
473
474        #[test]
475        fn test_log_f16() {
476            let a = arr1_f16(&[1.0, std::f32::consts::E]);
477            let r = log_f16(&a).unwrap();
478            let s = r.as_slice().unwrap();
479            assert!(s[0].to_f32().abs() < 0.01);
480            assert!((s[1].to_f32() - 1.0).abs() < 0.01);
481        }
482
483        #[test]
484        fn test_log2_f16() {
485            let a = arr1_f16(&[1.0, 8.0]);
486            let r = log2_f16(&a).unwrap();
487            let s = r.as_slice().unwrap();
488            assert!(s[0].to_f32().abs() < 0.01);
489            assert!((s[1].to_f32() - 3.0).abs() < 0.01);
490        }
491    }
492}