1use ferray_core::Array;
36use ferray_core::dimension::Dimension;
37use ferray_core::dtype::Element;
38use ferray_core::error::FerrayResult;
39use num_traits::Float;
40
41use crate::cr_math::CrMath;
42use crate::helpers::{
43 binary_elementwise_op, unary_float_op, unary_float_op_compute, unary_float_op_into_compute,
44};
45
46pub fn exp<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
48where
49 T: Element + Float + CrMath,
50 D: Dimension,
51{
52 unary_float_op_compute(input, T::cr_exp)
53}
54
55pub fn exp_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
58where
59 T: Element + Float + CrMath,
60 D: Dimension,
61{
62 unary_float_op_into_compute(input, out, "exp", T::cr_exp)
63}
64
65pub fn exp_fast<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
78where
79 T: Element + Float,
80 D: Dimension,
81{
82 use std::any::TypeId;
83 if TypeId::of::<T>() == TypeId::of::<f64>() {
84 let f64_input =
86 unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f64, D>>() };
87 let n = f64_input.size();
88 let result = if let Some(slice) = f64_input.as_slice() {
89 let mut data = Vec::with_capacity(n);
90 #[allow(clippy::uninit_vec)]
91 unsafe {
92 data.set_len(n);
93 }
94 crate::dispatch::dispatch_exp_fast_f64(slice, &mut data);
95 Array::from_vec(f64_input.dim().clone(), data)?
96 } else {
97 let data: Vec<f64> = f64_input
98 .iter()
99 .map(|&x| crate::fast_exp::exp_fast_f64(x))
100 .collect();
101 Array::from_vec(f64_input.dim().clone(), data)?
102 };
103 Ok(unsafe { crate::helpers::reinterpret_array::<f64, T, D>(result) })
105 } else if TypeId::of::<T>() == TypeId::of::<f32>() {
106 let f32_input =
107 unsafe { &*std::ptr::from_ref::<Array<T, D>>(input).cast::<Array<f32, D>>() };
108 let n = f32_input.size();
109 let result = if let Some(slice) = f32_input.as_slice() {
110 let mut data = Vec::with_capacity(n);
111 #[allow(clippy::uninit_vec)]
112 unsafe {
113 data.set_len(n);
114 }
115 crate::dispatch::dispatch_exp_fast_f32(slice, &mut data);
116 Array::from_vec(f32_input.dim().clone(), data)?
117 } else {
118 let data: Vec<f32> = f32_input
119 .iter()
120 .map(|&x| crate::fast_exp::exp_fast_f32(x))
121 .collect();
122 Array::from_vec(f32_input.dim().clone(), data)?
123 };
124 Ok(unsafe { crate::helpers::reinterpret_array::<f32, T, D>(result) })
126 } else {
127 unary_float_op(input, num_traits::Float::exp)
129 }
130}
131
132pub fn exp2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
134where
135 T: Element + Float + CrMath,
136 D: Dimension,
137{
138 unary_float_op_compute(input, T::cr_exp2)
139}
140
141pub fn expm1<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
143where
144 T: Element + Float + CrMath,
145 D: Dimension,
146{
147 unary_float_op_compute(input, T::cr_exp_m1)
148}
149
150pub fn log<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
152where
153 T: Element + Float + CrMath,
154 D: Dimension,
155{
156 unary_float_op_compute(input, T::cr_ln)
157}
158
159pub fn log_into<T, D>(input: &Array<T, D>, out: &mut Array<T, D>) -> FerrayResult<()>
161where
162 T: Element + Float + CrMath,
163 D: Dimension,
164{
165 unary_float_op_into_compute(input, out, "log", T::cr_ln)
166}
167
168pub fn log2<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
170where
171 T: Element + Float + CrMath,
172 D: Dimension,
173{
174 unary_float_op_compute(input, T::cr_log2)
175}
176
177pub fn log10<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
179where
180 T: Element + Float + CrMath,
181 D: Dimension,
182{
183 unary_float_op_compute(input, T::cr_log10)
184}
185
186pub fn log1p<T, D>(input: &Array<T, D>) -> FerrayResult<Array<T, D>>
188where
189 T: Element + Float + CrMath,
190 D: Dimension,
191{
192 unary_float_op_compute(input, T::cr_ln_1p)
193}
194
195pub fn logaddexp<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
197where
198 T: Element + Float + CrMath,
199 D: Dimension,
200{
201 binary_elementwise_op(a, b, |x, y| {
202 if x.is_nan() || y.is_nan() {
203 return T::nan();
204 }
205 let max = if x > y { x } else { y };
206 if max.is_infinite() {
212 return max;
213 }
214 let min = if x > y { y } else { x };
215 max + (min - max).cr_exp().cr_ln_1p()
216 })
217}
218
219pub fn logaddexp2<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
221where
222 T: Element + Float + CrMath,
223 D: Dimension,
224{
225 let ln2 = T::from(std::f64::consts::LN_2).unwrap_or_else(|| <T as Element>::one());
226 binary_elementwise_op(a, b, |x, y| {
227 if x.is_nan() || y.is_nan() {
228 return T::nan();
229 }
230 let max = if x > y { x } else { y };
231 if max.is_infinite() {
237 return max;
238 }
239 let min = if x > y { y } else { x };
240 max + ((min - max) * ln2).cr_exp().cr_ln_1p() / ln2
241 })
242}
243
244use crate::helpers::unary_f16_fn;
250
251unary_f16_fn!(
252 #[cfg(feature = "f16")]
254 exp_f16,
255 f32::exp
256);
257unary_f16_fn!(
258 #[cfg(feature = "f16")]
260 exp2_f16,
261 f32::exp2
262);
263unary_f16_fn!(
264 #[cfg(feature = "f16")]
266 expm1_f16,
267 f32::exp_m1
268);
269unary_f16_fn!(
270 #[cfg(feature = "f16")]
272 log_f16,
273 f32::ln
274);
275unary_f16_fn!(
276 #[cfg(feature = "f16")]
278 log2_f16,
279 f32::log2
280);
281unary_f16_fn!(
282 #[cfg(feature = "f16")]
284 log10_f16,
285 f32::log10
286);
287unary_f16_fn!(
288 #[cfg(feature = "f16")]
290 log1p_f16,
291 f32::ln_1p
292);
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 use crate::test_util::arr1;
299
300 #[test]
301 fn test_exp() {
302 let a = arr1(vec![0.0, 1.0]);
303 let r = exp(&a).unwrap();
304 let s = r.as_slice().unwrap();
305 assert!((s[0] - 1.0).abs() < 1e-12);
306 assert!((s[1] - std::f64::consts::E).abs() < 1e-12);
307 }
308
309 #[test]
310 fn test_exp_fast() {
311 let a = arr1(vec![0.0, 1.0, -1.0, 10.0, -10.0]);
312 let r = exp_fast(&a).unwrap();
313 let s = r.as_slice().unwrap();
314 assert!((s[0] - 1.0).abs() < 1e-15);
315 assert!((s[1] - std::f64::consts::E).abs() < 1e-14);
316 assert!((s[2] - 1.0 / std::f64::consts::E).abs() < 1e-15);
317 for (i, &x) in [0.0, 1.0, -1.0, 10.0, -10.0].iter().enumerate() {
319 let reference = x.exp();
320 let ulp = (s[i] - reference).abs() / (reference.abs() * f64::EPSILON);
321 assert!(ulp <= 1.5, "exp_fast({x}) ulp = {ulp}");
322 }
323 }
324
325 #[test]
326 fn test_exp2() {
327 let a = arr1(vec![0.0, 3.0, 10.0]);
328 let r = exp2(&a).unwrap();
329 let s = r.as_slice().unwrap();
330 assert!((s[0] - 1.0).abs() < 1e-12);
331 assert!((s[1] - 8.0).abs() < 1e-12);
332 assert!((s[2] - 1024.0).abs() < 1e-9);
333 }
334
335 #[test]
336 fn test_expm1() {
337 let a = arr1(vec![0.0, 1e-15]);
338 let r = expm1(&a).unwrap();
339 let s = r.as_slice().unwrap();
340 assert!((s[0]).abs() < 1e-12);
341 assert!((s[1] - 1e-15).abs() < 1e-25);
343 }
344
345 #[test]
346 fn test_log() {
347 let a = arr1(vec![1.0, std::f64::consts::E]);
348 let r = log(&a).unwrap();
349 let s = r.as_slice().unwrap();
350 assert!((s[0]).abs() < 1e-12);
351 assert!((s[1] - 1.0).abs() < 1e-12);
352 }
353
354 #[test]
355 fn test_log2() {
356 let a = arr1(vec![1.0, 8.0, 1024.0]);
357 let r = log2(&a).unwrap();
358 let s = r.as_slice().unwrap();
359 assert!((s[0]).abs() < 1e-12);
360 assert!((s[1] - 3.0).abs() < 1e-12);
361 assert!((s[2] - 10.0).abs() < 1e-10);
362 }
363
364 #[test]
365 fn test_log10() {
366 let a = arr1(vec![1.0, 100.0, 1000.0]);
367 let r = log10(&a).unwrap();
368 let s = r.as_slice().unwrap();
369 assert!((s[0]).abs() < 1e-12);
370 assert!((s[1] - 2.0).abs() < 1e-12);
371 assert!((s[2] - 3.0).abs() < 1e-12);
372 }
373
374 #[test]
375 fn test_log1p() {
376 let a = arr1(vec![0.0, 1e-15]);
377 let r = log1p(&a).unwrap();
378 let s = r.as_slice().unwrap();
379 assert!((s[0]).abs() < 1e-12);
380 assert!((s[1] - 1e-15).abs() < 1e-25);
381 }
382
383 #[test]
384 fn test_logaddexp() {
385 let a = arr1(vec![0.0]);
386 let b = arr1(vec![0.0]);
387 let r = logaddexp(&a, &b).unwrap();
388 let s = r.as_slice().unwrap();
389 assert!((s[0] - std::f64::consts::LN_2).abs() < 1e-12);
391 }
392
393 #[test]
394 fn test_logaddexp2() {
395 let a = arr1(vec![0.0]);
396 let b = arr1(vec![0.0]);
397 let r = logaddexp2(&a, &b).unwrap();
398 let s = r.as_slice().unwrap();
399 assert!((s[0] - 1.0).abs() < 1e-12);
401 }
402
403 fn eval_logaddexp(
404 f: impl Fn(
405 &Array<f64, ferray_core::dimension::Ix1>,
406 &Array<f64, ferray_core::dimension::Ix1>,
407 ) -> FerrayResult<Array<f64, ferray_core::dimension::Ix1>>,
408 a: Vec<f64>,
409 b: Vec<f64>,
410 ) -> Vec<f64> {
411 let r = f(&arr1(a), &arr1(b));
412 assert!(r.is_ok(), "logaddexp kernel returned error: {:?}", r.err());
413 match r {
414 Ok(arr) => match arr.as_slice() {
415 Some(s) => s.to_vec(),
416 None => arr.iter().copied().collect(),
417 },
418 Err(_) => Vec::new(),
419 }
420 }
421
422 #[test]
423 fn test_logaddexp_infinities() {
424 let inf = f64::INFINITY;
429 let s = eval_logaddexp(logaddexp, vec![inf, -inf, inf], vec![inf, -inf, -inf]);
430 assert_eq!(s[0], inf, "logaddexp(inf, inf)");
431 assert_eq!(s[1], -inf, "logaddexp(-inf, -inf)");
432 assert_eq!(s[2], inf, "logaddexp(inf, -inf)");
433 let f = eval_logaddexp(logaddexp, vec![1.0], vec![1.0]);
435 assert!((f[0] - (1.0 + std::f64::consts::LN_2)).abs() < 1e-12);
436 }
437
438 #[test]
439 fn test_logaddexp2_infinities() {
440 let inf = f64::INFINITY;
445 let s = eval_logaddexp(logaddexp2, vec![inf, -inf, inf], vec![inf, -inf, -inf]);
446 assert_eq!(s[0], inf, "logaddexp2(inf, inf)");
447 assert_eq!(s[1], -inf, "logaddexp2(-inf, -inf)");
448 assert_eq!(s[2], inf, "logaddexp2(inf, -inf)");
449 let f = eval_logaddexp(logaddexp2, vec![0.0], vec![0.0]);
451 assert!((f[0] - 1.0).abs() < 1e-12);
452 }
453
454 #[cfg(feature = "f16")]
455 mod f16_tests {
456 use super::*;
457 use ferray_core::dimension::Ix1;
458
459 fn arr1_f16(data: &[f32]) -> Array<half::f16, Ix1> {
460 let n = data.len();
461 let vals: Vec<half::f16> = data.iter().map(|&x| half::f16::from_f32(x)).collect();
462 Array::from_vec(Ix1::new([n]), vals).unwrap()
463 }
464
465 #[test]
466 fn test_exp_f16() {
467 let a = arr1_f16(&[0.0, 1.0]);
468 let r = exp_f16(&a).unwrap();
469 let s = r.as_slice().unwrap();
470 assert!((s[0].to_f32() - 1.0).abs() < 0.01);
471 assert!((s[1].to_f32() - std::f32::consts::E).abs() < 0.02);
472 }
473
474 #[test]
475 fn test_log_f16() {
476 let a = arr1_f16(&[1.0, std::f32::consts::E]);
477 let r = log_f16(&a).unwrap();
478 let s = r.as_slice().unwrap();
479 assert!(s[0].to_f32().abs() < 0.01);
480 assert!((s[1].to_f32() - 1.0).abs() < 0.01);
481 }
482
483 #[test]
484 fn test_log2_f16() {
485 let a = arr1_f16(&[1.0, 8.0]);
486 let r = log2_f16(&a).unwrap();
487 let s = r.as_slice().unwrap();
488 assert!(s[0].to_f32().abs() < 0.01);
489 assert!((s[1].to_f32() - 3.0).abs() < 0.01);
490 }
491 }
492}