Skip to main content

ferray_ufunc/ops/
explog.rs

1// ferray-ufunc: Exponential and logarithmic functions
2//
3// exp, exp2, expm1, log, log2, log10, log1p, logaddexp, logaddexp2
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::FerrayResult;
9use num_traits::Float;
10
11use crate::cr_math::CrMath;
12use crate::helpers::{
13    binary_elementwise_op, unary_float_op, unary_float_op_compute, unary_float_op_into_compute,
14};
15
16/// Elementwise exponential (e^x).
17pub fn exp<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
18where
19    T: Element + Float + CrMath,
20    D: Dimension,
21{
22    unary_float_op_compute(input, T::cr_exp)
23}
24
25/// In-place `e^x` — `_into` counterpart of [`exp`]. Parallelizes along
26/// the compute-bound threshold for transcendentals (100k elements).
27pub fn exp_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
28where
29    T: Element + Float + CrMath,
30    D: Dimension,
31{
32    unary_float_op_into_compute(input, out, "exp", T::cr_exp)
33}
34
35/// Fast elementwise exponential (e^x) with ≤1 ULP accuracy.
36///
37/// Uses an Even/Odd Remez decomposition that is ~30% faster than `exp()` (CORE-MATH)
38/// while achieving faithful rounding (≤1 ULP). The default `exp()` is correctly
39/// rounded (≤0.5 ULP) via CORE-MATH.
40///
41/// This function auto-vectorizes for SSE/AVX2/AVX-512/NEON with no lookup tables.
42/// Subnormal outputs (x < -708.4) are flushed to zero.
43///
44/// For f64 arrays, uses the optimized batch kernel directly.
45/// For f32 arrays, promotes to f64 internally (f32 has only 24 mantissa bits,
46/// so the result is correctly rounded for all finite f32 inputs).
47pub fn exp_fast<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
48where
49    T: Element + Float,
50    D: Dimension,
51{
52    use std::any::TypeId;
53    if TypeId::of::<T>() == TypeId::of::<f64>() {
54        // SAFETY: T is f64 — reinterpret the array reference
55        let f64_input =
56            unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f64, D>>() };
57        let n = f64_input.size();
58        let result = if let Some(slice) = f64_input.as_slice() {
59            let mut data = Vec::with_capacity(n);
60            #[allow(clippy::uninit_vec)]
61            unsafe {
62                data.set_len(n);
63            }
64            crate::dispatch::dispatch_exp_fast_f64(slice, &mut data);
65            Array::from_vec(f64_input.dim().clone(), data)?
66        } else {
67            let data: Vec<f64> = f64_input
68                .iter()
69                .map(|&x| crate::fast_exp::exp_fast_f64(x))
70                .collect();
71            Array::from_vec(f64_input.dim().clone(), data)?
72        };
73        // SAFETY: T was verified to be f64 at the top of this branch.
74        Ok(unsafe { crate::helpers::reinterpret_array::<f64, T, D>(result) })
75    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
76        let f32_input =
77            unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f32, D>>() };
78        let n = f32_input.size();
79        let result = if let Some(slice) = f32_input.as_slice() {
80            let mut data = Vec::with_capacity(n);
81            #[allow(clippy::uninit_vec)]
82            unsafe {
83                data.set_len(n);
84            }
85            crate::dispatch::dispatch_exp_fast_f32(slice, &mut data);
86            Array::from_vec(f32_input.dim().clone(), data)?
87        } else {
88            let data: Vec<f32> = f32_input
89                .iter()
90                .map(|&x| crate::fast_exp::exp_fast_f32(x))
91                .collect();
92            Array::from_vec(f32_input.dim().clone(), data)?
93        };
94        // SAFETY: T was verified to be f32 at the top of this branch.
95        Ok(unsafe { crate::helpers::reinterpret_array::<f32, T, D>(result) })
96    } else {
97        // Fallback for other float types: use libm exp
98        unary_float_op(input, num_traits::Float::exp)
99    }
100}
101
102/// Elementwise 2^x.
103pub fn exp2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
104where
105    T: Element + Float + CrMath,
106    D: Dimension,
107{
108    unary_float_op_compute(input, T::cr_exp2)
109}
110
111/// Elementwise exp(x) - 1, accurate near zero.
112pub fn expm1<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
113where
114    T: Element + Float + CrMath,
115    D: Dimension,
116{
117    unary_float_op_compute(input, T::cr_exp_m1)
118}
119
120/// Elementwise natural logarithm.
121pub fn log<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
122where
123    T: Element + Float + CrMath,
124    D: Dimension,
125{
126    unary_float_op_compute(input, T::cr_ln)
127}
128
129/// In-place natural logarithm — `_into` counterpart of [`log`].
130pub fn log_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
131where
132    T: Element + Float + CrMath,
133    D: Dimension,
134{
135    unary_float_op_into_compute(input, out, "log", T::cr_ln)
136}
137
138/// Elementwise base-2 logarithm.
139pub fn log2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
140where
141    T: Element + Float + CrMath,
142    D: Dimension,
143{
144    unary_float_op_compute(input, T::cr_log2)
145}
146
147/// Elementwise base-10 logarithm.
148pub fn log10<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
149where
150    T: Element + Float + CrMath,
151    D: Dimension,
152{
153    unary_float_op_compute(input, T::cr_log10)
154}
155
156/// Elementwise ln(1 + x), accurate near zero.
157pub fn log1p<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
158where
159    T: Element + Float + CrMath,
160    D: Dimension,
161{
162    unary_float_op_compute(input, T::cr_ln_1p)
163}
164
165/// log(exp(a) + exp(b)), computed in a numerically stable way.
166pub fn logaddexp<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
167where
168    T: Element + Float + CrMath,
169    D: Dimension,
170{
171    binary_elementwise_op(a, b, |x, y| {
172        if x.is_nan() || y.is_nan() {
173            return T::nan();
174        }
175        let max = if x > y { x } else { y };
176        let min = if x > y { y } else { x };
177        max + (min - max).cr_exp().cr_ln_1p()
178    })
179}
180
181/// log2(2^a + 2^b), computed in a numerically stable way.
182pub fn logaddexp2<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
183where
184    T: Element + Float + CrMath,
185    D: Dimension,
186{
187    let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(|| <T as Element>::one());
188    binary_elementwise_op(a, b, |x, y| {
189        if x.is_nan() || y.is_nan() {
190            return T::nan();
191        }
192        let max = if x > y { x } else { y };
193        let min = if x > y { y } else { x };
194        max + ((min - max) * ln2).cr_exp().cr_ln_1p() / ln2
195    })
196}
197
198// ---------------------------------------------------------------------------
199// f16 variants (f32-promoted) — generated via the shared unary_f16_fn!
200// macro (#142).
201// ---------------------------------------------------------------------------
202
203use crate::helpers::unary_f16_fn;
204
205unary_f16_fn!(
206    /// Elementwise exponential for f16 arrays via f32 promotion.
207    #[cfg(feature = "f16")]
208    exp_f16,
209    f32::exp
210);
211unary_f16_fn!(
212    /// Elementwise 2^x for f16 arrays via f32 promotion.
213    #[cfg(feature = "f16")]
214    exp2_f16,
215    f32::exp2
216);
217unary_f16_fn!(
218    /// Elementwise exp(x)-1 for f16 arrays via f32 promotion.
219    #[cfg(feature = "f16")]
220    expm1_f16,
221    f32::exp_m1
222);
223unary_f16_fn!(
224    /// Elementwise natural logarithm for f16 arrays via f32 promotion.
225    #[cfg(feature = "f16")]
226    log_f16,
227    f32::ln
228);
229unary_f16_fn!(
230    /// Elementwise base-2 logarithm for f16 arrays via f32 promotion.
231    #[cfg(feature = "f16")]
232    log2_f16,
233    f32::log2
234);
235unary_f16_fn!(
236    /// Elementwise base-10 logarithm for f16 arrays via f32 promotion.
237    #[cfg(feature = "f16")]
238    log10_f16,
239    f32::log10
240);
241unary_f16_fn!(
242    /// Elementwise ln(1+x) for f16 arrays via f32 promotion.
243    #[cfg(feature = "f16")]
244    log1p_f16,
245    f32::ln_1p
246);
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    use crate::test_util::arr1;
253
254    #[test]
255    fn test_exp() {
256        let a = arr1(vec![0.0, 1.0]);
257        let r = exp(&a).unwrap();
258        let s = r.as_slice().unwrap();
259        assert!((s[0] - 1.0).abs() < 1e-12);
260        assert!((s[1] - std::f64::consts::E).abs() < 1e-12);
261    }
262
263    #[test]
264    fn test_exp_fast() {
265        let a = arr1(vec![0.0, 1.0, -1.0, 10.0, -10.0]);
266        let r = exp_fast(&a).unwrap();
267        let s = r.as_slice().unwrap();
268        assert!((s[0] - 1.0).abs() < 1e-15);
269        assert!((s[1] - std::f64::consts::E).abs() < 1e-14);
270        assert!((s[2] - 1.0 / std::f64::consts::E).abs() < 1e-15);
271        // Check ≤1.5 ULP vs libm
272        for (i, &x) in [0.0, 1.0, -1.0, 10.0, -10.0].iter().enumerate() {
273            let reference = x.exp();
274            let ulp = (s[i] - reference).abs() / (reference.abs() * f64::EPSILON);
275            assert!(ulp <= 1.5, "exp_fast({x}) ulp = {ulp}");
276        }
277    }
278
279    #[test]
280    fn test_exp2() {
281        let a = arr1(vec![0.0, 3.0, 10.0]);
282        let r = exp2(&a).unwrap();
283        let s = r.as_slice().unwrap();
284        assert!((s[0] - 1.0).abs() < 1e-12);
285        assert!((s[1] - 8.0).abs() < 1e-12);
286        assert!((s[2] - 1024.0).abs() < 1e-9);
287    }
288
289    #[test]
290    fn test_expm1() {
291        let a = arr1(vec![0.0, 1e-15]);
292        let r = expm1(&a).unwrap();
293        let s = r.as_slice().unwrap();
294        assert!((s[0]).abs() < 1e-12);
295        // expm1 should be accurate near zero
296        assert!((s[1] - 1e-15).abs() < 1e-25);
297    }
298
299    #[test]
300    fn test_log() {
301        let a = arr1(vec![1.0, std::f64::consts::E]);
302        let r = log(&a).unwrap();
303        let s = r.as_slice().unwrap();
304        assert!((s[0]).abs() < 1e-12);
305        assert!((s[1] - 1.0).abs() < 1e-12);
306    }
307
308    #[test]
309    fn test_log2() {
310        let a = arr1(vec![1.0, 8.0, 1024.0]);
311        let r = log2(&a).unwrap();
312        let s = r.as_slice().unwrap();
313        assert!((s[0]).abs() < 1e-12);
314        assert!((s[1] - 3.0).abs() < 1e-12);
315        assert!((s[2] - 10.0).abs() < 1e-10);
316    }
317
318    #[test]
319    fn test_log10() {
320        let a = arr1(vec![1.0, 100.0, 1000.0]);
321        let r = log10(&a).unwrap();
322        let s = r.as_slice().unwrap();
323        assert!((s[0]).abs() < 1e-12);
324        assert!((s[1] - 2.0).abs() < 1e-12);
325        assert!((s[2] - 3.0).abs() < 1e-12);
326    }
327
328    #[test]
329    fn test_log1p() {
330        let a = arr1(vec![0.0, 1e-15]);
331        let r = log1p(&a).unwrap();
332        let s = r.as_slice().unwrap();
333        assert!((s[0]).abs() < 1e-12);
334        assert!((s[1] - 1e-15).abs() < 1e-25);
335    }
336
337    #[test]
338    fn test_logaddexp() {
339        let a = arr1(vec![0.0]);
340        let b = arr1(vec![0.0]);
341        let r = logaddexp(&a, &b).unwrap();
342        let s = r.as_slice().unwrap();
343        // log(e^0 + e^0) = log(2) ~ 0.693
344        assert!((s[0] - std::f64::consts::LN_2).abs() < 1e-12);
345    }
346
347    #[test]
348    fn test_logaddexp2() {
349        let a = arr1(vec![0.0]);
350        let b = arr1(vec![0.0]);
351        let r = logaddexp2(&a, &b).unwrap();
352        let s = r.as_slice().unwrap();
353        // log2(2^0 + 2^0) = log2(2) = 1
354        assert!((s[0] - 1.0).abs() < 1e-12);
355    }
356
357    #[cfg(feature = "f16")]
358    mod f16_tests {
359        use super::*;
360        use ferray_core::dimension::Ix1;
361
362        fn arr1_f16(data: &[f32]) -> Array<half::f16, Ix1> {
363            let n = data.len();
364            let vals: Vec<half::f16> = data.iter().map(|&x| half::f16::from_f32(x)).collect();
365            Array::from_vec(Ix1::new([n]), vals).unwrap()
366        }
367
368        #[test]
369        fn test_exp_f16() {
370            let a = arr1_f16(&[0.0, 1.0]);
371            let r = exp_f16(&a).unwrap();
372            let s = r.as_slice().unwrap();
373            assert!((s[0].to_f32() - 1.0).abs() < 0.01);
374            assert!((s[1].to_f32() - std::f32::consts::E).abs() < 0.02);
375        }
376
377        #[test]
378        fn test_log_f16() {
379            let a = arr1_f16(&[1.0, std::f32::consts::E]);
380            let r = log_f16(&a).unwrap();
381            let s = r.as_slice().unwrap();
382            assert!(s[0].to_f32().abs() < 0.01);
383            assert!((s[1].to_f32() - 1.0).abs() < 0.01);
384        }
385
386        #[test]
387        fn test_log2_f16() {
388            let a = arr1_f16(&[1.0, 8.0]);
389            let r = log2_f16(&a).unwrap();
390            let s = r.as_slice().unwrap();
391            assert!(s[0].to_f32().abs() < 0.01);
392            assert!((s[1].to_f32() - 3.0).abs() < 0.01);
393        }
394    }
395}