1use crate::expr::{Accumulator, AccumulatorFn, BinaryFn, Expr, Op, TernaryFn, UnaryFn};
25use crate::{Array, Elem};
26use num_traits::{Float, MulAdd};
27
28pub fn array<'a, T: Elem, const N: usize>(v: &'a Array<T>) -> Expr<'a, T, N> {
30 use Op::Array;
31 Expr::new(T::zero(), Array { v: v.as_ref() }, v.as_ref().len())
32}
33
34pub fn iterator<'a, T: Elem, const N: usize>(
44 v: &'a mut dyn Iterator<Item = &'a T>,
45) -> Result<Expr<'a, T, N>, &'static str> {
46 use Op::Iterator;
47 let n = match v.size_hint() {
51 (lower, Some(upper)) if lower == upper => lower,
52 _ => return Err("Iterator has unbounded size"),
53 };
54 Ok(Expr::new(T::zero(), Iterator { v }, n))
55}
56
57pub fn slice<T: Elem, const N: usize>(v: &[T]) -> Expr<'_, T, N> {
59 use Op::Array;
60 Expr::new(T::zero(), Array { v }, v.len())
61}
62
63fn scalar<T: Elem, const N: usize>(v: T, acc: Option<Accumulator<'_, T, N>>) -> Expr<'_, T, N> {
66 use Op::Scalar;
67 Expr::new(v, Scalar { acc }, usize::MAX)
68}
69
70pub fn constant<'a, T: Elem, const N: usize>(v: T) -> Expr<'a, T, N> {
72 scalar(v, None)
73}
74
75pub fn accumulator<'a, T: Elem, const N: usize>(
77 start: T,
78 a: &'a mut Expr<'a, T, N>,
79 f: &'a dyn AccumulatorFn<T>,
80) -> Accumulator<'a, T, N> {
81 Accumulator {
82 v: None,
83 start,
84 a,
85 f,
86 }
87}
88
89pub fn unary<'a, T: Elem, const N: usize>(
91 a: &'a mut Expr<'a, T, N>,
92 f: &'a dyn UnaryFn<T>,
93) -> Expr<'a, T, N> {
94 use Op::Unary;
95 let n = a.len();
96 Expr::new(T::zero(), Unary { a, f }, n)
97}
98
99pub fn binary<'a, T: Elem, const N: usize>(
101 a: &'a mut Expr<'a, T, N>,
102 b: &'a mut Expr<'a, T, N>,
103 f: &'a dyn BinaryFn<T>,
104) -> Expr<'a, T, N> {
105 use Op::Binary;
106 let n = a.len().min(b.len());
107 Expr::new(T::zero(), Binary { a, b, f }, n)
108}
109
110pub fn ternary<'a, T: Elem, const N: usize>(
112 a: &'a mut Expr<'a, T, N>,
113 b: &'a mut Expr<'a, T, N>,
114 c: &'a mut Expr<'a, T, N>,
115 f: &'a dyn TernaryFn<T>,
116) -> Expr<'a, T, N> {
117 use Op::Ternary;
118 let n = a.len().min(b.len().min(c.len()));
119 Expr::new(T::zero(), Ternary { a, b, c, f }, n)
120}
121
122fn lt_inner<T: Elem + PartialOrd>(
123 left: &[T],
124 right: &[T],
125 out: &mut [T],
126) -> Result<(), &'static str> {
127 let n = out.len();
129 if left.len() != n || right.len() != n {
130 return Err("Size mismatch");
131 };
132
133 for i in 0..n {
135 let res = left[i] < right[i];
136 if res {
137 out[i] = T::one();
138 } else {
139 out[i] = T::zero();
140 }
141 }
142 Ok(())
143}
144
145pub fn lt<'a, T: Elem + PartialOrd, const N: usize>(
147 left: &'a mut Expr<'a, T, N>,
148 right: &'a mut Expr<'a, T, N>,
149) -> Expr<'a, T, N> {
150 binary(left, right, <_inner)
151}
152
153fn gt_inner<T: Elem + PartialOrd>(
154 left: &[T],
155 right: &[T],
156 out: &mut [T],
157) -> Result<(), &'static str> {
158 let n = out.len();
160 if left.len() != n || right.len() != n {
161 return Err("Size mismatch");
162 };
163
164 for i in 0..n {
166 let res = left[i] > right[i];
167 if res {
168 out[i] = T::one();
169 } else {
170 out[i] = T::zero();
171 }
172 }
173 Ok(())
174}
175
176pub fn gt<'a, T: Elem + PartialOrd, const N: usize>(
178 left: &'a mut Expr<'a, T, N>,
179 right: &'a mut Expr<'a, T, N>,
180) -> Expr<'a, T, N> {
181 binary(left, right, >_inner)
182}
183
184fn le_inner<T: Elem + PartialOrd>(
185 left: &[T],
186 right: &[T],
187 out: &mut [T],
188) -> Result<(), &'static str> {
189 let n = out.len();
191 if left.len() != n || right.len() != n {
192 return Err("Size mismatch");
193 };
194
195 for i in 0..n {
197 let res = left[i] <= right[i];
198 if res {
199 out[i] = T::one();
200 } else {
201 out[i] = T::zero();
202 }
203 }
204 Ok(())
205}
206
207pub fn le<'a, T: Elem + PartialOrd, const N: usize>(
209 left: &'a mut Expr<'a, T, N>,
210 right: &'a mut Expr<'a, T, N>,
211) -> Expr<'a, T, N> {
212 binary(left, right, &le_inner)
213}
214
215fn ge_inner<T: Elem + PartialOrd>(
216 left: &[T],
217 right: &[T],
218 out: &mut [T],
219) -> Result<(), &'static str> {
220 let n = out.len();
222 if left.len() != n || right.len() != n {
223 return Err("Size mismatch");
224 };
225
226 for i in 0..n {
228 let res = left[i] >= right[i];
229 if res {
230 out[i] = T::one();
231 } else {
232 out[i] = T::zero();
233 }
234 }
235 Ok(())
236}
237
238pub fn ge<'a, T: Elem + PartialOrd, const N: usize>(
240 left: &'a mut Expr<'a, T, N>,
241 right: &'a mut Expr<'a, T, N>,
242) -> Expr<'a, T, N> {
243 binary(left, right, &ge_inner)
244}
245
246fn eq_inner<T: Elem + PartialOrd>(
247 left: &[T],
248 right: &[T],
249 out: &mut [T],
250) -> Result<(), &'static str> {
251 let n = out.len();
253 if left.len() != n || right.len() != n {
254 return Err("Size mismatch");
255 };
256
257 for i in 0..n {
259 let res = left[i] == right[i];
260 if res {
261 out[i] = T::one();
262 } else {
263 out[i] = T::zero();
264 }
265 }
266 Ok(())
267}
268
269pub fn eq<'a, T: Elem + PartialOrd, const N: usize>(
271 left: &'a mut Expr<'a, T, N>,
272 right: &'a mut Expr<'a, T, N>,
273) -> Expr<'a, T, N> {
274 binary(left, right, &eq_inner)
275}
276
277fn ne_inner<T: Elem + PartialOrd>(
278 left: &[T],
279 right: &[T],
280 out: &mut [T],
281) -> Result<(), &'static str> {
282 let n = out.len();
284 if left.len() != n || right.len() != n {
285 return Err("Size mismatch");
286 };
287
288 for i in 0..n {
290 let res = left[i] != right[i];
291 if res {
292 out[i] = T::one();
293 } else {
294 out[i] = T::zero();
295 }
296 }
297 Ok(())
298}
299
300pub fn ne<'a, T: Elem + PartialOrd, const N: usize>(
302 left: &'a mut Expr<'a, T, N>,
303 right: &'a mut Expr<'a, T, N>,
304) -> Expr<'a, T, N> {
305 binary(left, right, &ne_inner)
306}
307
308fn min_inner<T: Elem + Ord>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
309 let n = out.len();
311 if left.len() != n || right.len() != n {
312 return Err("Size mismatch");
313 };
314
315 (0..n).for_each(|i| out[i] = left[i].min(right[i]));
317 Ok(())
318}
319
320pub fn min<'a, T: Elem + Ord, const N: usize>(
323 left: &'a mut Expr<'a, T, N>,
324 right: &'a mut Expr<'a, T, N>,
325) -> Expr<'a, T, N> {
326 binary(left, right, &min_inner)
327}
328
329fn max_inner<T: Elem + Ord>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
330 let n = out.len();
332 if left.len() != n || right.len() != n {
333 return Err("Size mismatch");
334 };
335
336 (0..n).for_each(|i| out[i] = left[i].max(right[i]));
338 Ok(())
339}
340
341pub fn max<'a, T: Elem + Ord, const N: usize>(
344 left: &'a mut Expr<'a, T, N>,
345 right: &'a mut Expr<'a, T, N>,
346) -> Expr<'a, T, N> {
347 binary(left, right, &max_inner)
348}
349
350fn add_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
351 let n = out.len();
353 if left.len() != n || right.len() != n {
354 return Err("Size mismatch");
355 };
356
357 (0..n).for_each(|i| out[i] = left[i] + right[i]);
359 Ok(())
360}
361
362pub fn add<'a, T: Elem, const N: usize>(
364 left: &'a mut Expr<'a, T, N>,
365 right: &'a mut Expr<'a, T, N>,
366) -> Expr<'a, T, N> {
367 binary(left, right, &add_inner)
368}
369
370fn sub_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
371 let n = out.len();
373 if left.len() != n || right.len() != n {
374 return Err("Size mismatch");
375 };
376
377 (0..n).for_each(|i| out[i] = left[i] - right[i]);
379 Ok(())
380}
381
382pub fn sub<'a, T: Elem, const N: usize>(
384 left: &'a mut Expr<'a, T, N>,
385 right: &'a mut Expr<'a, T, N>,
386) -> Expr<'a, T, N> {
387 binary(left, right, &sub_inner)
388}
389
390fn mul_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
391 let n = out.len();
393 if left.len() != n || right.len() != n {
394 return Err("Size mismatch");
395 };
396
397 (0..n).for_each(|i| out[i] = left[i] * right[i]);
399 Ok(())
400}
401
402pub fn mul<'a, T: Elem, const N: usize>(
404 left: &'a mut Expr<'a, T, N>,
405 right: &'a mut Expr<'a, T, N>,
406) -> Expr<'a, T, N> {
407 binary(left, right, &mul_inner)
408}
409
410fn div_inner<T: Elem>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
411 let n = out.len();
413 if left.len() != n || right.len() != n {
414 return Err("Size mismatch");
415 };
416
417 (0..n).for_each(|i| out[i] = left[i] / right[i]);
419 Ok(())
420}
421
422pub fn div<'a, T: Elem, const N: usize>(
424 numer: &'a mut Expr<'a, T, N>,
425 denom: &'a mut Expr<'a, T, N>,
426) -> Expr<'a, T, N> {
427 binary(numer, denom, &div_inner)
428}
429
430fn fmin_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
431 let n = out.len();
433 if left.len() != n || right.len() != n {
434 return Err("Size mismatch");
435 };
436
437 (0..n).for_each(|i| out[i] = left[i].min(right[i]));
439 Ok(())
440}
441
442pub fn fmin<'a, T: Float, const N: usize>(
445 left: &'a mut Expr<'a, T, N>,
446 right: &'a mut Expr<'a, T, N>,
447) -> Expr<'a, T, N> {
448 binary(left, right, &fmin_inner)
449}
450
451fn fmax_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
452 let n = out.len();
454 if left.len() != n || right.len() != n {
455 return Err("Size mismatch");
456 };
457
458 (0..n).for_each(|i| out[i] = left[i].max(right[i]));
460 Ok(())
461}
462
463pub fn fmax<'a, T: Float, const N: usize>(
466 left: &'a mut Expr<'a, T, N>,
467 right: &'a mut Expr<'a, T, N>,
468) -> Expr<'a, T, N> {
469 binary(left, right, &fmax_inner)
470}
471
472fn powf_inner<T: Float>(left: &[T], right: &[T], out: &mut [T]) -> Result<(), &'static str> {
473 let n = out.len();
475 if left.len() != n || right.len() != n {
476 return Err("Size mismatch");
477 };
478
479 (0..n).for_each(|i| out[i] = left[i].powf(right[i]));
481 Ok(())
482}
483
484pub fn powf<'a, T: Float, const N: usize>(
486 a: &'a mut Expr<'a, T, N>,
487 b: &'a mut Expr<'a, T, N>,
488) -> Expr<'a, T, N> {
489 binary(a, b, &powf_inner)
490}
491
492fn flog2_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
493 let n = out.len();
495 if x.len() != n {
496 return Err("Size mismatch");
497 };
498
499 (0..n).for_each(|i| out[i] = x[i].log2());
501 Ok(())
502}
503
504pub fn flog2<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
506 unary(a, &flog2_inner)
507}
508
509fn flog10_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
510 let n = out.len();
512 if x.len() != n {
513 return Err("Size mismatch");
514 };
515
516 (0..n).for_each(|i| out[i] = x[i].log10());
518 Ok(())
519}
520
521pub fn flog10<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
523 unary(a, &flog10_inner)
524}
525
526fn exp_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
527 let n = out.len();
529 if x.len() != n {
530 return Err("Size mismatch");
531 };
532
533 (0..n).for_each(|i| out[i] = x[i].exp());
535 Ok(())
536}
537
538pub fn exp<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
540 unary(a, &exp_inner)
541}
542
543fn atan2_inner<T: Float>(y: &[T], x: &[T], out: &mut [T]) -> Result<(), &'static str> {
544 let n = out.len();
546 if x.len() != n || y.len() != n {
547 return Err("Size mismatch");
548 };
549
550 (0..n).for_each(|i| out[i] = y[i].atan2(x[i]));
552 Ok(())
553}
554
555pub fn atan2<'a, T: Float, const N: usize>(
561 y: &'a mut Expr<'a, T, N>,
562 x: &'a mut Expr<'a, T, N>,
563) -> Expr<'a, T, N> {
564 binary(y, x, &atan2_inner)
565}
566
567fn sin_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
568 let n = out.len();
570 if x.len() != n {
571 return Err("Size mismatch");
572 };
573
574 (0..n).for_each(|i| out[i] = x[i].sin());
576 Ok(())
577}
578
579pub fn sin<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
581 unary(a, &sin_inner)
582}
583
584fn tan_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
585 let n = out.len();
587 if x.len() != n {
588 return Err("Size mismatch");
589 };
590
591 (0..n).for_each(|i| out[i] = x[i].tan());
593 Ok(())
594}
595
596pub fn tan<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
598 unary(a, &tan_inner)
599}
600
601fn cos_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
602 let n = out.len();
604 if x.len() != n {
605 return Err("Size mismatch");
606 };
607
608 (0..n).for_each(|i| out[i] = x[i].cos());
610 Ok(())
611}
612
613pub fn cos<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
615 unary(a, &cos_inner)
616}
617
618fn asin_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
619 let n = out.len();
621 if x.len() != n {
622 return Err("Size mismatch");
623 };
624
625 (0..n).for_each(|i| out[i] = x[i].asin());
627 Ok(())
628}
629
630pub fn asin<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
632 unary(a, &asin_inner)
633}
634
635fn acos_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
636 let n = out.len();
638 if x.len() != n {
639 return Err("Size mismatch");
640 };
641
642 (0..n).for_each(|i| out[i] = x[i].acos());
644 Ok(())
645}
646
647pub fn acos<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
649 unary(a, &acos_inner)
650}
651
652fn atan_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
653 let n = out.len();
655 if x.len() != n {
656 return Err("Size mismatch");
657 };
658
659 (0..n).for_each(|i| out[i] = x[i].atan());
661 Ok(())
662}
663
664pub fn atan<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
670 unary(a, &atan_inner)
671}
672
673fn sinh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
674 let n = out.len();
676 if x.len() != n {
677 return Err("Size mismatch");
678 };
679
680 (0..n).for_each(|i| out[i] = x[i].sinh());
682 Ok(())
683}
684
685pub fn sinh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
687 unary(a, &sinh_inner)
688}
689
690fn cosh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
691 let n = out.len();
693 if x.len() != n {
694 return Err("Size mismatch");
695 };
696
697 (0..n).for_each(|i| out[i] = x[i].cosh());
699 Ok(())
700}
701
702pub fn cosh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
704 unary(a, &cosh_inner)
705}
706
707fn tanh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
708 let n = out.len();
710 if x.len() != n {
711 return Err("Size mismatch");
712 };
713
714 (0..n).for_each(|i| out[i] = x[i].tanh());
716 Ok(())
717}
718
719pub fn tanh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
721 unary(a, &tanh_inner)
722}
723
724fn asinh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
725 let n = out.len();
727 if x.len() != n {
728 return Err("Size mismatch");
729 };
730
731 (0..n).for_each(|i| out[i] = x[i].asinh());
733 Ok(())
734}
735
736pub fn asinh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
738 unary(a, &asinh_inner)
739}
740
741fn acosh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
742 let n = out.len();
744 if x.len() != n {
745 return Err("Size mismatch");
746 };
747
748 (0..n).for_each(|i| out[i] = x[i].acosh());
750 Ok(())
751}
752
753pub fn acosh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
755 unary(a, &acosh_inner)
756}
757
758fn atanh_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
759 let n = out.len();
761 if x.len() != n {
762 return Err("Size mismatch");
763 };
764
765 (0..n).for_each(|i| out[i] = x[i].atanh());
767 Ok(())
768}
769
770pub fn atanh<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
772 unary(a, &atanh_inner)
773}
774
775fn abs_inner<T: Float>(x: &[T], out: &mut [T]) -> Result<(), &'static str> {
776 let n = out.len();
778 if x.len() != n {
779 return Err("Size mismatch");
780 };
781
782 (0..n).for_each(|i| out[i] = x[i].abs());
784 Ok(())
785}
786
787pub fn abs<'a, T: Float, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
789 unary(a, &abs_inner)
790}
791
792fn mul_add_inner<T: Elem + MulAdd<T, Output = T>>(
793 a: &[T],
794 b: &[T],
795 c: &[T],
796 out: &mut [T],
797) -> Result<(), &'static str> {
798 let n = out.len();
800 if a.len() != n || b.len() != n || c.len() != n {
801 return Err("Size mismatch");
802 };
803
804 (0..n).for_each(|i| out[i] = a[i].mul_add(b[i], c[i]));
806 Ok(())
807}
808
809pub fn mul_add<'a, T: Elem + MulAdd<T, Output = T>, const N: usize>(
821 a: &'a mut Expr<'a, T, N>,
822 b: &'a mut Expr<'a, T, N>,
823 c: &'a mut Expr<'a, T, N>,
824) -> Expr<'a, T, N> {
825 ternary(a, b, c, &mul_add_inner)
826}
827
828fn sum_inner<T: Elem>(x: &[T], v: &mut T) -> Result<(), &'static str> {
829 (0..x.len()).for_each(|i| *v = *v + x[i]);
830 Ok(())
831}
832
833pub fn sum<'a, T: Elem, const N: usize>(a: &'a mut Expr<'a, T, N>) -> Expr<'a, T, N> {
838 let acc = Some(accumulator(T::zero(), a, &sum_inner));
839 scalar(T::zero(), acc)
840}