strobe/
ops.rs

1//! Implementations of operations that can be applied at expression nodes.
2//!
3//! Broadly speaking, we have three categories of operations:
4//! * Identity:     `array`, `slice`, and `constant` are inputs from outside the expression
5//! * N-ary:        Operations like `mul`, `add`, etc
6//! * Accumulators: Operations that aggregate a full array to a scalar (like `sum`)
7//!
8//! Custom operations can be assembled using the `unary`, `binary`, `ternary`, and
9//! `accumulator` functions, along with a reference to a matching function or closure.
10//!
11//! For example, you might want to use powi, which isn't treated directly here because
12//! it requires configuration using a different type from the element (the integer power):
13//! ```rust
14//! use strobe::{Expr, array, unary};
15//!
16//! let x = [0.0_f64, 1.0, 2.0];
17//! let mut xn: Expr<'_, _, 64> = array(&x);  // Input expression node for x
18//!
19//! let sq_func = |a: &[f64], out: &mut [f64]| { (0..a.len()).for_each(|i| {out[i] = x[i].powi(2)}); Ok(()) };
20//! let xsq = unary(&mut xn, &sq_func).eval().unwrap();
21//!
22//! (0..x.len()).for_each(|i| {assert_eq!(x[i] * x[i], xsq[i])});
23//! ```
24use crate::expr::{Accumulator, AccumulatorFn, BinaryFn, Expr, Op, TernaryFn, UnaryFn};
25use crate::{Array, Elem};
26use num_traits::{Float, MulAdd};
27
28/// Array identity operation. This allows the use of Vec, etc. as inputs.
29pub fn array<'a, T: Elem, const N: usize>(v: &'a Array<T>) -> Expr<'a, T, N> {
30    use Op::Array;
31    Expr::new(T::zero(), Array { v: v.as_ref() }, v.as_ref().len())
32}
33
34/// Array identity operation via iterator.
35///
36/// This allows interoperation with non-contiguous array formats.
37/// Note this method is significantly slower than consuming an array or slice,
38/// even if the iterator is over a contiguous array or slice. Whenever possible,
39/// it is best to provide the contiguous data directly.
40///
41/// ## Panics
42/// * If the length of the iterator is not known exactly
43pub fn iterator<'a, T: Elem, const N: usize>(
44    v: &'a mut dyn Iterator<Item = &'a T>,
45) -> Result<Expr<'a, T, N>, &'static str> {
46    use Op::Iterator;
47    // In order to use the iterator in a calculation that requires concrete
48    // size bounds, it must have an upper bound on size, and that upper bound
49    // must be equal to the lower bound
50    let n = match v.size_hint() {
51        (lower, Some(upper)) if lower == upper => lower,
52        _ => return Err("Iterator has unbounded size"),
53    };
54    Ok(Expr::new(T::zero(), Iterator { v }, n))
55}
56
57/// Slice identity operation. This allows the use of a slice as an input.
58pub fn slice<T: Elem, const N: usize>(v: &[T]) -> Expr<'_, T, N> {
59    use Op::Array;
60    Expr::new(T::zero(), Array { v }, v.len())
61}
62
63/// A scalar identity operation is always either a constant or the
64/// output of an accumulator to be used in a downstream expression.
65fn scalar<T: Elem, const N: usize>(v: T, acc: Option<Accumulator<'_, T, N>>) -> Expr<'_, T, N> {
66    use Op::Scalar;
67    Expr::new(v, Scalar { acc }, usize::MAX)
68}
69
70/// Constant identity operation. This allows use of a constant value as input.
71pub fn constant<'a, T: Elem, const N: usize>(v: T) -> Expr<'a, T, N> {
72    scalar(v, None)
73}
74
75/// Assemble an arbitrary (1xN)-to-(1x1) operation.
76pub fn accumulator<'a, T: Elem, const N: usize>(
77    start: T,
78    a: &'a mut Expr<'a, T, N>,
79    f: &'a dyn AccumulatorFn<T>,
80) -> Accumulator<'a, T, N> {
81    Accumulator {
82        v: None,
83        start,
84        a,
85        f,
86    }
87}
88
89/// Assemble an arbitrary (1xN)-to-(1xN) operation.
90pub fn unary<'a, T: Elem, const N: usize>(
91    a: &'a mut Expr<'a, T, N>,
92    f: &'a dyn UnaryFn<T>,
93) -> Expr<'a, T, N> {
94    use Op::Unary;
95    let n = a.len();
96    Expr::new(T::zero(), Unary { a, f }, n)
97}
98
99/// Assemble an arbitrary (2xN)-to-(1xN) operation.
100pub fn binary<'a, T: Elem, const N: usize>(
101    a: &'a mut Expr<'a, T, N>,
102    b: &'a mut Expr<'a, T, N>,
103    f: &'a dyn BinaryFn<T>,
104) -> Expr<'a, T, N> {
105    use Op::Binary;
106    let n = a.len().min(b.len());
107    Expr::new(T::zero(), Binary { a, b, f }, n)
108}
109
110/// Assemble an arbitrary (3xN)-to-(1xN) operation.
111pub fn ternary<'a, T: Elem, const N: usize>(
112    a: &'a mut Expr<'a, T, N>,
113    b: &'a mut Expr<'a, T, N>,
114    c: &'a mut Expr<'a, T, N>,
115    f: &'a dyn TernaryFn<T>,
116) -> Expr<'a, T, N> {
117    use Op::Ternary;
118    let n = a.len().min(b.len().min(c.len()));
119    Expr::new(T::zero(), Ternary { a, b, c, f }, n)
120}
121
122fn lt_inner<T: Elem + PartialOrd>(
123    left: &[T],
124    right: &[T],
125    out: &mut [T],
126) -> Result<(), &'static str> {
127    // Check sizes
128    let n = out.len();
129    if left.len() != n || right.len() != n {
130        return Err("Size mismatch");
131    };
132
133    // Execute
134    for i in 0..n {
135        let res = left[i] < right[i];
136        if res {
137            out[i] = T::one();
138        } else {
139            out[i] = T::zero();
140        }
141    }
142    Ok(())
143}
144
145/// Elementwise less-than, returning T::one() for true and T::zero() for false.
146pub fn lt<'a, T: Elem + PartialOrd, const N: usize>(
147    left: &'a mut Expr<'a, T, N>,
148    right: &'a mut Expr<'a, T, N>,
149) -> Expr<'a, T, N> {
150    binary(left, right, &lt_inner)
151}
152
153fn gt_inner<T: Elem + PartialOrd>(
154    left: &[T],
155    right: &[T],
156    out: &mut [T],
157) -> Result<(), &'static str> {
158    // Check sizes
159    let n = out.len();
160    if left.len() != n || right.len() != n {
161        return Err("Size mismatch");
162    };
163
164    // Execute
165    for i in 0..n {
166        let res = left[i] > right[i];
167        if res {
168            out[i] = T::one();
169        } else {
170            out[i] = T::zero();
171        }
172    }
173    Ok(())
174}
175
176/// Elementwise greater-than, returning T::one() for true and T::zero() for false.
177pub fn gt<'a, T: Elem + PartialOrd, const N: usize>(
178    left: &'a mut Expr<'a, T, N>,
179    right: &'a mut Expr<'a, T, N>,
180) -> Expr<'a, T, N> {
181    binary(left, right, &gt_inner)
182}
183
184fn le_inner<T: Elem + PartialOrd>(
185    left: &[T],
186    right: &[T],
187    out: &mut [T],
188) -> Result<(), &'static str> {
189    // Check sizes
190    let n = out.len();
191    if left.len() != n || right.len() != n {
192        return Err("Size mismatch");
193    };
194
195    // Execute
196    for i in 0..n {
197        let res = left[i] <= right[i];
198        if res {
199            out[i] = T::one();
200        } else {
201            out[i] = T::zero();
202        }
203    }
204    Ok(())
205}
206
207/// Elementwise less-than-or-equal, returning T::one() for true and T::zero() for false.
208pub fn le<'a, T: Elem + PartialOrd, const N: usize>(
209    left: &'a mut Expr<'a, T, N>,
210    right: &'a mut Expr<'a, T, N>,
211) -> Expr<'a, T, N> {
212    binary(left, right, &le_inner)
213}
214
215fn ge_inner<T: Elem + PartialOrd>(
216    left: &[T],
217    right: &[T],
218    out: &mut [T],
219) -> Result<(), &'static str> {
220    // Check sizes
221    let n = out.len();
222    if left.len() != n || right.len() != n {
223        return Err("Size mismatch");
224    };
225
226    // Execute
227    for i in 0..n {
228        let res = left[i] >= right[i];
229        if res {
230            out[i] = T::one();
231        } else {
232            out[i] = T::zero();
233        }
234    }
235    Ok(())
236}
237
238/// Elementwise greater-than-or-equal, returning T::one() for true and T::zero() for false.
239pub fn ge<'a, T: Elem + PartialOrd, const N: usize>(
240    left: &'a mut Expr<'a, T, N>,
241    right: &'a mut Expr<'a, T, N>,
242) -> Expr<'a, T, N> {
243    binary(left, right, &ge_inner)
244}
245
246fn eq_inner<T: Elem + PartialOrd>(
247    left: &[T],
248    right: &[T],
249    out: &mut [T],
250) -> Result<(), &'static str> {
251    // Check sizes
252    let n = out.len();
253    if left.len() != n || right.len() != n {
254        return Err("Size mismatch");
255    };
256
257    // Execute
258    for i in 0..n {
259        let res = left[i] == right[i];
260        if res {
261            out[i] = T::one();
262        } else {
263            out[i] = T::zero();
264        }
265    }
266    Ok(())
267}
268
269/// Elementwise equals, returning T::one() for true and T::zero() for false.
270pub fn eq<'a, T: Elem + PartialOrd, const N: usize>(
271    left: &'a mut Expr<'a, T, N>,
272    right: &'a mut Expr<'a, T, N>,
273) -> Expr<'a, T, N> {
274    binary(left, right, &eq_inner)
275}
276
277fn ne_inner<T: Elem + PartialOrd>(
278    left: &[T],
279    right: &[T],
280    out: &mut [T],
281) -> Result<(), &'static str> {
282    // Check sizes
283    let n = out.len();
284    if left.len() != n || right.len() != n {
285        return Err("Size mismatch");
286    };
287
288    // Execute
289    for i in 0..n {
290        let res = left[i] != right[i];
291        if res {
292            out[i] = T::one();
293        } else {
294            out[i] = T::zero();
295        }
296    }
297    Ok(())
298}
299
300/// Elementwise not-equal, returning T::one() for true and T::zero() for false.
301pub fn ne<'a, T: Elem + PartialOrd, const N: usize>(
302    left: &'a mut Expr<'a, T, N>,
303    right: &'a mut Expr<'a, T, N>,
304) -> Expr<'a, T, N> {
305    binary(left, right, &ne_inner)
306}
307
308fn min_inner<T: Elem + Ord>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
309    // Check sizes
310    let n = out.len();
311    if left.len() != n || right.len() != n {
312        return Err("Size mismatch");
313    };
314
315    // Execute
316    (0..n).for_each(|i| out[i] = left[i].min(right[i]));
317    Ok(())
318}
319
320/// Elementwise minimum for strictly ordered number types.
321/// For floating-point version with NaN handling, see `fmin`.
322pub fn min<'a, T: Elem + Ord, const N: usize>(
323    left: &'a mut Expr<'a, T, N>,
324    right: &'a mut Expr<'a, T, N>,
325) -> Expr<'a, T, N> {
326    binary(left, right, &min_inner)
327}
328
329fn max_inner<T: Elem + Ord>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
330    // Check sizes
331    let n = out.len();
332    if left.len() != n || right.len() != n {
333        return Err("Size mismatch");
334    };
335
336    // Execute
337    (0..n).for_each(|i| out[i] = left[i].max(right[i]));
338    Ok(())
339}
340
341/// Elementwise maximum for strictly ordered number types.
342/// For floating-point version with NaN handling, see `fmax`.
343pub fn max<'a, T: Elem + Ord, const N: usize>(
344    left: &'a mut Expr<'a, T, N>,
345    right: &'a mut Expr<'a, T, N>,
346) -> Expr<'a, T, N> {
347    binary(left, right, &max_inner)
348}
349
350fn add_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
351    // Check sizes
352    let n = out.len();
353    if left.len() != n || right.len() != n {
354        return Err("Size mismatch");
355    };
356
357    // Execute
358    (0..n).for_each(|i| out[i] = left[i] + right[i]);
359    Ok(())
360}
361
362/// Elementwise addition
363pub fn add<'a, T: Elem, const N: usize>(
364    left: &'a mut Expr<'a, T, N>,
365    right: &'a mut Expr<'a, T, N>,
366) -> Expr<'a, T, N> {
367    binary(left, right, &add_inner)
368}
369
370fn sub_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
371    // Check sizes
372    let n = out.len();
373    if left.len() != n || right.len() != n {
374        return Err("Size mismatch");
375    };
376
377    // Execute
378    (0..n).for_each(|i| out[i] = left[i] - right[i]);
379    Ok(())
380}
381
382/// Elementwise subtraction
383pub fn sub<'a, T: Elem, const N: usize>(
384    left: &'a mut Expr<'a, T, N>,
385    right: &'a mut Expr<'a, T, N>,
386) -> Expr<'a, T, N> {
387    binary(left, right, &sub_inner)
388}
389
390fn mul_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
391    // Check sizes
392    let n = out.len();
393    if left.len() != n || right.len() != n {
394        return Err("Size mismatch");
395    };
396
397    // Execute
398    (0..n).for_each(|i| out[i] = left[i] * right[i]);
399    Ok(())
400}
401
402/// Elementwise multiplication
403pub fn mul<'a, T: Elem, const N: usize>(
404    left: &'a mut Expr<'a, T, N>,
405    right: &'a mut Expr<'a, T, N>,
406) -> Expr<'a, T, N> {
407    binary(left, right, &mul_inner)
408}
409
410fn div_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
411    // Check sizes
412    let n = out.len();
413    if left.len() != n || right.len() != n {
414        return Err("Size mismatch");
415    };
416
417    // Execute
418    (0..n).for_each(|i| out[i] = left[i] / right[i]);
419    Ok(())
420}
421
422/// Elementwise division
423pub fn div<'a, T: Elem, const N: usize>(
424    numer: &'a mut Expr<'a, T, N>,
425    denom: &'a mut Expr<'a, T, N>,
426) -> Expr<'a, T, N> {
427    binary(numer, denom, &div_inner)
428}
429
430fn fmin_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
431    // Check sizes
432    let n = out.len();
433    if left.len() != n || right.len() != n {
434        return Err("Size mismatch");
435    };
436
437    // Execute
438    (0..n).for_each(|i| out[i] = left[i].min(right[i]));
439    Ok(())
440}
441
442/// Elementwise floating-point minimum.
443/// Ignores NaN values if either value is a number.
444pub fn fmin<'a, T: Float, const N: usize>(
445    left: &'a mut Expr<'a, T, N>,
446    right: &'a mut Expr<'a, T, N>,
447) -> Expr<'a, T, N> {
448    binary(left, right, &fmin_inner)
449}
450
451fn fmax_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
452    // Check sizes
453    let n = out.len();
454    if left.len() != n || right.len() != n {
455        return Err("Size mismatch");
456    };
457
458    // Execute
459    (0..n).for_each(|i| out[i] = left[i].max(right[i]));
460    Ok(())
461}
462
463/// Elementwise floating-point maximum.
464/// Ignores NaN values if either value is a number.
465pub fn fmax<'a, T: Float, const N: usize>(
466    left: &'a mut Expr<'a, T, N>,
467    right: &'a mut Expr<'a, T, N>,
468) -> Expr<'a, T, N> {
469    binary(left, right, &fmax_inner)
470}
471
472fn powf_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
473    // Check sizes
474    let n = out.len();
475    if left.len() != n || right.len() != n {
476        return Err("Size mismatch");
477    };
478
479    // Execute
480    (0..n).for_each(|i| out[i] = left[i].powf(right[i]));
481    Ok(())
482}
483
484/// Elementwise float exponent for float types
485pub fn powf<'a, T: Float, const N: usize>(
486    a: &'a mut Expr<'a, T, N>,
487    b: &'a mut Expr<'a, T, N>,
488) -> Expr<'a, T, N> {
489    binary(a, b, &powf_inner)
490}
491
492fn flog2_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
493    // Check sizes
494    let n = out.len();
495    if x.len() != n {
496        return Err("Size mismatch");
497    };
498
499    // Execute
500    (0..n).for_each(|i| out[i] = x[i].log2());
501    Ok(())
502}
503
504/// Elementwise log base 2 for float types
505pub fn flog2<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
506    unary(a, &flog2_inner)
507}
508
509fn flog10_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
510    // Check sizes
511    let n = out.len();
512    if x.len() != n {
513        return Err("Size mismatch");
514    };
515
516    // Execute
517    (0..n).for_each(|i| out[i] = x[i].log10());
518    Ok(())
519}
520
521/// Elementwise log base 10 for float types
522pub fn flog10<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
523    unary(a, &flog10_inner)
524}
525
526fn exp_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
527    // Check sizes
528    let n = out.len();
529    if x.len() != n {
530        return Err("Size mismatch");
531    };
532
533    // Execute
534    (0..n).for_each(|i| out[i] = x[i].exp());
535    Ok(())
536}
537
538/// Elementwise e^x for float types
539pub fn exp<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
540    unary(a, &exp_inner)
541}
542
543fn atan2_inner<T: Float>(y: &[T], x: &[T], out: &mut [T]) -> Result<(), &'static str> {
544    // Check sizes
545    let n = out.len();
546    if x.len() != n || y.len() != n {
547        return Err("Size mismatch");
548    };
549
550    // Execute
551    (0..n).for_each(|i| out[i] = y[i].atan2(x[i]));
552    Ok(())
553}
554
555/// Elementwise atan2(y, x) for float types. Produces correct results where atan
556/// would produce errors due to the singularity in the tangent function.
557///
558/// In accordance with tradition, the inputs are taken in (`y`, `x`) order
559/// and evaluated like `y.atan2(x)`.
560pub fn atan2<'a, T: Float, const N: usize>(
561    y: &'a mut Expr<'a, T, N>,
562    x: &'a mut Expr<'a, T, N>,
563) -> Expr<'a, T, N> {
564    binary(y, x, &atan2_inner)
565}
566
567fn sin_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
568    // Check sizes
569    let n = out.len();
570    if x.len() != n {
571        return Err("Size mismatch");
572    };
573
574    // Execute
575    (0..n).for_each(|i| out[i] = x[i].sin());
576    Ok(())
577}
578
579/// Elementwise sin(x) for float types
580pub fn sin<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
581    unary(a, &sin_inner)
582}
583
584fn tan_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
585    // Check sizes
586    let n = out.len();
587    if x.len() != n {
588        return Err("Size mismatch");
589    };
590
591    // Execute
592    (0..n).for_each(|i| out[i] = x[i].tan());
593    Ok(())
594}
595
596/// Elementwise tan(x) for float types
597pub fn tan<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
598    unary(a, &tan_inner)
599}
600
601fn cos_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
602    // Check sizes
603    let n = out.len();
604    if x.len() != n {
605        return Err("Size mismatch");
606    };
607
608    // Execute
609    (0..n).for_each(|i| out[i] = x[i].cos());
610    Ok(())
611}
612
613/// Elementwise cos(x) for float types
614pub fn cos<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
615    unary(a, &cos_inner)
616}
617
618fn asin_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
619    // Check sizes
620    let n = out.len();
621    if x.len() != n {
622        return Err("Size mismatch");
623    };
624
625    // Execute
626    (0..n).for_each(|i| out[i] = x[i].asin());
627    Ok(())
628}
629
630/// Elementwise asin(x) for float types
631pub fn asin<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
632    unary(a, &asin_inner)
633}
634
635fn acos_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
636    // Check sizes
637    let n = out.len();
638    if x.len() != n {
639        return Err("Size mismatch");
640    };
641
642    // Execute
643    (0..n).for_each(|i| out[i] = x[i].acos());
644    Ok(())
645}
646
647/// Elementwise acos(x) for float types
648pub fn acos<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
649    unary(a, &acos_inner)
650}
651
652fn atan_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
653    // Check sizes
654    let n = out.len();
655    if x.len() != n {
656        return Err("Size mismatch");
657    };
658
659    // Execute
660    (0..n).for_each(|i| out[i] = x[i].atan());
661    Ok(())
662}
663
664/// Elementwise atan(x) for float types
665///
666/// This function will produce erroneous results near multiple of pi/2.
667/// For a version that maintains correctness near singularities in tan(x),
668/// see `atan2`.
669pub fn atan<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
670    unary(a, &atan_inner)
671}
672
673fn sinh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
674    // Check sizes
675    let n = out.len();
676    if x.len() != n {
677        return Err("Size mismatch");
678    };
679
680    // Execute
681    (0..n).for_each(|i| out[i] = x[i].sinh());
682    Ok(())
683}
684
685/// Elementwise sinh(x) for float types
686pub fn sinh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
687    unary(a, &sinh_inner)
688}
689
690fn cosh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
691    // Check sizes
692    let n = out.len();
693    if x.len() != n {
694        return Err("Size mismatch");
695    };
696
697    // Execute
698    (0..n).for_each(|i| out[i] = x[i].cosh());
699    Ok(())
700}
701
702/// Elementwise cosh(x) for float types
703pub fn cosh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
704    unary(a, &cosh_inner)
705}
706
707fn tanh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
708    // Check sizes
709    let n = out.len();
710    if x.len() != n {
711        return Err("Size mismatch");
712    };
713
714    // Execute
715    (0..n).for_each(|i| out[i] = x[i].tanh());
716    Ok(())
717}
718
719/// Elementwise tanh(x) for float types
720pub fn tanh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
721    unary(a, &tanh_inner)
722}
723
724fn asinh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
725    // Check sizes
726    let n = out.len();
727    if x.len() != n {
728        return Err("Size mismatch");
729    };
730
731    // Execute
732    (0..n).for_each(|i| out[i] = x[i].asinh());
733    Ok(())
734}
735
736/// Elementwise asinh(x) for float types
737pub fn asinh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
738    unary(a, &asinh_inner)
739}
740
741fn acosh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
742    // Check sizes
743    let n = out.len();
744    if x.len() != n {
745        return Err("Size mismatch");
746    };
747
748    // Execute
749    (0..n).for_each(|i| out[i] = x[i].acosh());
750    Ok(())
751}
752
753/// Elementwise acosh(x) for float types
754pub fn acosh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
755    unary(a, &acosh_inner)
756}
757
758fn atanh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
759    // Check sizes
760    let n = out.len();
761    if x.len() != n {
762        return Err("Size mismatch");
763    };
764
765    // Execute
766    (0..n).for_each(|i| out[i] = x[i].atanh());
767    Ok(())
768}
769
770/// Elementwise atanh(x) for float types
771pub fn atanh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
772    unary(a, &atanh_inner)
773}
774
775fn abs_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
776    // Check sizes
777    let n = out.len();
778    if x.len() != n {
779        return Err("Size mismatch");
780    };
781
782    // Execute
783    (0..n).for_each(|i| out[i] = x[i].abs());
784    Ok(())
785}
786
787/// Elementwise abs(x) for float types
788pub fn abs<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
789    unary(a, &abs_inner)
790}
791
792fn mul_add_inner<T: Elem + MulAdd<T, Output = T>>(
793    a: &[T],
794    b: &[T],
795    c: &[T],
796    out: &mut [T],
797) -> Result<(), &'static str> {
798    // Check sizes
799    let n = out.len();
800    if a.len() != n || b.len() != n || c.len() != n {
801        return Err("Size mismatch");
802    };
803
804    // Execute
805    (0..n).for_each(|i| out[i] = a[i].mul_add(b[i], c[i]));
806    Ok(())
807}
808
809/// Elementwise fused multiply-add
810///
811/// If the compilation target supports FMA (fused multiply-add)
812/// and `-Ctarget-feature=fma` is given to rustc, this
813/// performs the multiplication and addition in a single operation with
814/// a single roundoff error, and can provide a significant improvement
815/// in either or both of speed and float error.
816///
817/// However, if the compilation target does _not_ support FMA
818/// or if FMA is not enabled, this will be much slower than a
819/// separate multiply and add, because it will not vectorize.
820pub fn mul_add<'a, T: Elem + MulAdd<T, Output = T>, const N: usize>(
821    a: &'a mut Expr<'a, T, N>,
822    b: &'a mut Expr<'a, T, N>,
823    c: &'a mut Expr<'a, T, N>,
824) -> Expr<'a, T, N> {
825    ternary(a, b, c, &mul_add_inner)
826}
827
828fn sum_inner<T: Elem>(x: &[T], v: &mut T) -> Result<(), &'static str> {
829    (0..x.len()).for_each(|i| *v = *v + x[i]);
830    Ok(())
831}
832
833/// Cumulative sum of array elements.
834///
835/// Note that while it is allowed, applying this to an expression with
836/// a scalar operation will produce meaningless results.
837pub fn sum<'a, T: Elem, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
838    let acc = Some(accumulator(T::zero(), a, &sum_inner));
839    scalar(T::zero(), acc)
840}