Skip to main content

ferray_ufunc/ops/
explog.rs

1// ferray-ufunc: Exponential and logarithmic functions
2//
3// exp, exp2, expm1, log, log2, log10, log1p, logaddexp, logaddexp2
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::FerrayResult;
9use num_traits::Float;
10
11use crate::cr_math::CrMath;
12use crate::helpers::{binary_float_op, unary_float_op};
13
14/// Elementwise exponential (e^x).
15pub fn exp<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
16where
17    T: Element + Float + CrMath,
18    D: Dimension,
19{
20    unary_float_op(input, T::cr_exp)
21}
22
23/// Fast elementwise exponential (e^x) with ≤1 ULP accuracy.
24///
25/// Uses an Even/Odd Remez decomposition that is ~30% faster than `exp()` (CORE-MATH)
26/// while achieving faithful rounding (≤1 ULP). The default `exp()` is correctly
27/// rounded (≤0.5 ULP) via CORE-MATH.
28///
29/// This function auto-vectorizes for SSE/AVX2/AVX-512/NEON with no lookup tables.
30/// Subnormal outputs (x < -708.4) are flushed to zero.
31///
32/// For f64 arrays, uses the optimized batch kernel directly.
33/// For f32 arrays, promotes to f64 internally (f32 has only 24 mantissa bits,
34/// so the result is correctly rounded for all finite f32 inputs).
35pub fn exp_fast<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
36where
37    T: Element + Float,
38    D: Dimension,
39{
40    use std::any::TypeId;
41    if TypeId::of::<T>() == TypeId::of::<f64>() {
42        // SAFETY: T is f64 — reinterpret the array reference
43        let f64_input = unsafe { &*(input as *const Array<T, D> as *const Array<f64, D>) };
44        let n = f64_input.size();
45        let result = if let Some(slice) = f64_input.as_slice() {
46            let mut data = Vec::with_capacity(n);
47            #[allow(clippy::uninit_vec)]
48            unsafe {
49                data.set_len(n);
50            }
51            crate::dispatch::dispatch_exp_fast_f64(slice, &mut data);
52            Array::from_vec(f64_input.dim().clone(), data)?
53        } else {
54            let data: Vec<f64> = f64_input
55                .iter()
56                .map(|&x| crate::fast_exp::exp_fast_f64(x))
57                .collect();
58            Array::from_vec(f64_input.dim().clone(), data)?
59        };
60        // SAFETY: T is f64, reinterpret back
61        Ok(unsafe { std::mem::transmute_copy(&std::mem::ManuallyDrop::new(result)) })
62    } else if TypeId::of::<T>() == TypeId::of::<f32>() {
63        let f32_input = unsafe { &*(input as *const Array<T, D> as *const Array<f32, D>) };
64        let n = f32_input.size();
65        let result = if let Some(slice) = f32_input.as_slice() {
66            let mut data = Vec::with_capacity(n);
67            #[allow(clippy::uninit_vec)]
68            unsafe {
69                data.set_len(n);
70            }
71            crate::dispatch::dispatch_exp_fast_f32(slice, &mut data);
72            Array::from_vec(f32_input.dim().clone(), data)?
73        } else {
74            let data: Vec<f32> = f32_input
75                .iter()
76                .map(|&x| crate::fast_exp::exp_fast_f32(x))
77                .collect();
78            Array::from_vec(f32_input.dim().clone(), data)?
79        };
80        Ok(unsafe { std::mem::transmute_copy(&std::mem::ManuallyDrop::new(result)) })
81    } else {
82        // Fallback for other float types: use libm exp
83        unary_float_op(input, |x| x.exp())
84    }
85}
86
87/// Elementwise 2^x.
88pub fn exp2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
89where
90    T: Element + Float + CrMath,
91    D: Dimension,
92{
93    unary_float_op(input, T::cr_exp2)
94}
95
96/// Elementwise exp(x) - 1, accurate near zero.
97pub fn expm1<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
98where
99    T: Element + Float + CrMath,
100    D: Dimension,
101{
102    unary_float_op(input, T::cr_exp_m1)
103}
104
105/// Elementwise natural logarithm.
106pub fn log<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
107where
108    T: Element + Float + CrMath,
109    D: Dimension,
110{
111    unary_float_op(input, T::cr_ln)
112}
113
114/// Elementwise base-2 logarithm.
115pub fn log2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
116where
117    T: Element + Float + CrMath,
118    D: Dimension,
119{
120    unary_float_op(input, T::cr_log2)
121}
122
123/// Elementwise base-10 logarithm.
124pub fn log10<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
125where
126    T: Element + Float + CrMath,
127    D: Dimension,
128{
129    unary_float_op(input, T::cr_log10)
130}
131
132/// Elementwise ln(1 + x), accurate near zero.
133pub fn log1p<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
134where
135    T: Element + Float + CrMath,
136    D: Dimension,
137{
138    unary_float_op(input, T::cr_ln_1p)
139}
140
141/// log(exp(a) + exp(b)), computed in a numerically stable way.
142pub fn logaddexp<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
143where
144    T: Element + Float + CrMath,
145    D: Dimension,
146{
147    binary_float_op(a, b, |x, y| {
148        let max = if x > y { x } else { y };
149        let min = if x > y { y } else { x };
150        max + (min - max).cr_exp().cr_ln_1p()
151    })
152}
153
154/// log2(2^a + 2^b), computed in a numerically stable way.
155pub fn logaddexp2<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
156where
157    T: Element + Float + CrMath,
158    D: Dimension,
159{
160    let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(|| <T as Element>::one());
161    binary_float_op(a, b, |x, y| {
162        let max = if x > y { x } else { y };
163        let min = if x > y { y } else { x };
164        max + ((min - max) * ln2).cr_exp().cr_ln_1p() / ln2
165    })
166}
167
168// ---------------------------------------------------------------------------
169// f16 variants (f32-promoted)
170// ---------------------------------------------------------------------------
171
172/// Elementwise exponential for f16 arrays via f32 promotion.
173#[cfg(feature = "f16")]
174pub fn exp_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
175where
176    D: Dimension,
177{
178    crate::helpers::unary_f16_op(input, f32::exp)
179}
180
181/// Elementwise 2^x for f16 arrays via f32 promotion.
182#[cfg(feature = "f16")]
183pub fn exp2_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
184where
185    D: Dimension,
186{
187    crate::helpers::unary_f16_op(input, f32::exp2)
188}
189
190/// Elementwise exp(x)-1 for f16 arrays via f32 promotion.
191#[cfg(feature = "f16")]
192pub fn expm1_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
193where
194    D: Dimension,
195{
196    crate::helpers::unary_f16_op(input, |x| x.exp_m1())
197}
198
199/// Elementwise natural logarithm for f16 arrays via f32 promotion.
200#[cfg(feature = "f16")]
201pub fn log_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
202where
203    D: Dimension,
204{
205    crate::helpers::unary_f16_op(input, f32::ln)
206}
207
208/// Elementwise base-2 logarithm for f16 arrays via f32 promotion.
209#[cfg(feature = "f16")]
210pub fn log2_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
211where
212    D: Dimension,
213{
214    crate::helpers::unary_f16_op(input, f32::log2)
215}
216
217/// Elementwise base-10 logarithm for f16 arrays via f32 promotion.
218#[cfg(feature = "f16")]
219pub fn log10_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
220where
221    D: Dimension,
222{
223    crate::helpers::unary_f16_op(input, f32::log10)
224}
225
226/// Elementwise ln(1+x) for f16 arrays via f32 promotion.
227#[cfg(feature = "f16")]
228pub fn log1p_f16<D>(input: &Array<half::f16, D>) -> FerrayResult<Array<half::f16, D>>
229where
230    D: Dimension,
231{
232    crate::helpers::unary_f16_op(input, |x| x.ln_1p())
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use ferray_core::dimension::Ix1;
239
240    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
241        let n = data.len();
242        Array::from_vec(Ix1::new([n]), data).unwrap()
243    }
244
245    #[test]
246    fn test_exp() {
247        let a = arr1(vec![0.0, 1.0]);
248        let r = exp(&a).unwrap();
249        let s = r.as_slice().unwrap();
250        assert!((s[0] - 1.0).abs() < 1e-12);
251        assert!((s[1] - std::f64::consts::E).abs() < 1e-12);
252    }
253
254    #[test]
255    fn test_exp_fast() {
256        let a = arr1(vec![0.0, 1.0, -1.0, 10.0, -10.0]);
257        let r = exp_fast(&a).unwrap();
258        let s = r.as_slice().unwrap();
259        assert!((s[0] - 1.0).abs() < 1e-15);
260        assert!((s[1] - std::f64::consts::E).abs() < 1e-14);
261        assert!((s[2] - 1.0 / std::f64::consts::E).abs() < 1e-15);
262        // Check ≤1.5 ULP vs libm
263        for (i, &x) in [0.0, 1.0, -1.0, 10.0, -10.0].iter().enumerate() {
264            let reference = x.exp();
265            let ulp = (s[i] - reference).abs() / (reference.abs() * f64::EPSILON);
266            assert!(ulp <= 1.5, "exp_fast({x}) ulp = {ulp}");
267        }
268    }
269
270    #[test]
271    fn test_exp2() {
272        let a = arr1(vec![0.0, 3.0, 10.0]);
273        let r = exp2(&a).unwrap();
274        let s = r.as_slice().unwrap();
275        assert!((s[0] - 1.0).abs() < 1e-12);
276        assert!((s[1] - 8.0).abs() < 1e-12);
277        assert!((s[2] - 1024.0).abs() < 1e-9);
278    }
279
280    #[test]
281    fn test_expm1() {
282        let a = arr1(vec![0.0, 1e-15]);
283        let r = expm1(&a).unwrap();
284        let s = r.as_slice().unwrap();
285        assert!((s[0]).abs() < 1e-12);
286        // expm1 should be accurate near zero
287        assert!((s[1] - 1e-15).abs() < 1e-25);
288    }
289
290    #[test]
291    fn test_log() {
292        let a = arr1(vec![1.0, std::f64::consts::E]);
293        let r = log(&a).unwrap();
294        let s = r.as_slice().unwrap();
295        assert!((s[0]).abs() < 1e-12);
296        assert!((s[1] - 1.0).abs() < 1e-12);
297    }
298
299    #[test]
300    fn test_log2() {
301        let a = arr1(vec![1.0, 8.0, 1024.0]);
302        let r = log2(&a).unwrap();
303        let s = r.as_slice().unwrap();
304        assert!((s[0]).abs() < 1e-12);
305        assert!((s[1] - 3.0).abs() < 1e-12);
306        assert!((s[2] - 10.0).abs() < 1e-10);
307    }
308
309    #[test]
310    fn test_log10() {
311        let a = arr1(vec![1.0, 100.0, 1000.0]);
312        let r = log10(&a).unwrap();
313        let s = r.as_slice().unwrap();
314        assert!((s[0]).abs() < 1e-12);
315        assert!((s[1] - 2.0).abs() < 1e-12);
316        assert!((s[2] - 3.0).abs() < 1e-12);
317    }
318
319    #[test]
320    fn test_log1p() {
321        let a = arr1(vec![0.0, 1e-15]);
322        let r = log1p(&a).unwrap();
323        let s = r.as_slice().unwrap();
324        assert!((s[0]).abs() < 1e-12);
325        assert!((s[1] - 1e-15).abs() < 1e-25);
326    }
327
328    #[test]
329    fn test_logaddexp() {
330        let a = arr1(vec![0.0]);
331        let b = arr1(vec![0.0]);
332        let r = logaddexp(&a, &b).unwrap();
333        let s = r.as_slice().unwrap();
334        // log(e^0 + e^0) = log(2) ~ 0.693
335        assert!((s[0] - std::f64::consts::LN_2).abs() < 1e-12);
336    }
337
338    #[test]
339    fn test_logaddexp2() {
340        let a = arr1(vec![0.0]);
341        let b = arr1(vec![0.0]);
342        let r = logaddexp2(&a, &b).unwrap();
343        let s = r.as_slice().unwrap();
344        // log2(2^0 + 2^0) = log2(2) = 1
345        assert!((s[0] - 1.0).abs() < 1e-12);
346    }
347
348    #[cfg(feature = "f16")]
349    mod f16_tests {
350        use super::*;
351
352        fn arr1_f16(data: &[f32]) -> Array<half::f16, Ix1> {
353            let n = data.len();
354            let vals: Vec<half::f16> = data.iter().map(|&x| half::f16::from_f32(x)).collect();
355            Array::from_vec(Ix1::new([n]), vals).unwrap()
356        }
357
358        #[test]
359        fn test_exp_f16() {
360            let a = arr1_f16(&[0.0, 1.0]);
361            let r = exp_f16(&a).unwrap();
362            let s = r.as_slice().unwrap();
363            assert!((s[0].to_f32() - 1.0).abs() < 0.01);
364            assert!((s[1].to_f32() - std::f32::consts::E).abs() < 0.02);
365        }
366
367        #[test]
368        fn test_log_f16() {
369            let a = arr1_f16(&[1.0, std::f32::consts::E]);
370            let r = log_f16(&a).unwrap();
371            let s = r.as_slice().unwrap();
372            assert!(s[0].to_f32().abs() < 0.01);
373            assert!((s[1].to_f32() - 1.0).abs() < 0.01);
374        }
375
376        #[test]
377        fn test_log2_f16() {
378            let a = arr1_f16(&[1.0, 8.0]);
379            let r = log2_f16(&a).unwrap();
380            let s = r.as_slice().unwrap();
381            assert!(s[0].to_f32().abs() < 0.01);
382            assert!((s[1].to_f32() - 3.0).abs() < 0.01);
383        }
384    }
385}