1#![allow(unsafe_code)]
46
47use crate::error::{AlgorithmError, Result};
48
49fn validate_unary(data: &[f32], out: &[f32]) -> Result<()> {
54 if data.len() != out.len() {
55 return Err(AlgorithmError::InvalidParameter {
56 parameter: "input",
57 message: format!(
58 "Slice length mismatch: data={}, out={}",
59 data.len(),
60 out.len()
61 ),
62 });
63 }
64 Ok(())
65}
66
67#[cfg(target_arch = "aarch64")]
72mod neon_impl {
73 use std::arch::aarch64::*;
74
75 #[target_feature(enable = "neon")]
77 pub(crate) unsafe fn sqrt_f32(data: &[f32], out: &mut [f32]) {
78 unsafe {
79 let len = data.len();
80 let chunks = len / 4;
81 let d_ptr = data.as_ptr();
82 let o_ptr = out.as_mut_ptr();
83
84 for i in 0..chunks {
85 let off = i * 4;
86 let vd = vld1q_f32(d_ptr.add(off));
87 let vr = vsqrtq_f32(vd);
88 vst1q_f32(o_ptr.add(off), vr);
89 }
90 let rem = chunks * 4;
91 for i in rem..len {
92 *o_ptr.add(i) = (*d_ptr.add(i)).sqrt();
93 }
94 }
95 }
96
97 #[target_feature(enable = "neon")]
99 pub(crate) unsafe fn abs_f32(data: &[f32], out: &mut [f32]) {
100 unsafe {
101 let len = data.len();
102 let chunks = len / 4;
103 let d_ptr = data.as_ptr();
104 let o_ptr = out.as_mut_ptr();
105
106 for i in 0..chunks {
107 let off = i * 4;
108 let vd = vld1q_f32(d_ptr.add(off));
109 let vr = vabsq_f32(vd);
110 vst1q_f32(o_ptr.add(off), vr);
111 }
112 let rem = chunks * 4;
113 for i in rem..len {
114 *o_ptr.add(i) = (*d_ptr.add(i)).abs();
115 }
116 }
117 }
118
119 #[target_feature(enable = "neon")]
121 pub(crate) unsafe fn floor_f32(data: &[f32], out: &mut [f32]) {
122 unsafe {
123 let len = data.len();
124 let chunks = len / 4;
125 let d_ptr = data.as_ptr();
126 let o_ptr = out.as_mut_ptr();
127
128 for i in 0..chunks {
129 let off = i * 4;
130 let vd = vld1q_f32(d_ptr.add(off));
131 let vr = vrndmq_f32(vd);
132 vst1q_f32(o_ptr.add(off), vr);
133 }
134 let rem = chunks * 4;
135 for i in rem..len {
136 *o_ptr.add(i) = (*d_ptr.add(i)).floor();
137 }
138 }
139 }
140
141 #[target_feature(enable = "neon")]
143 pub(crate) unsafe fn ceil_f32(data: &[f32], out: &mut [f32]) {
144 unsafe {
145 let len = data.len();
146 let chunks = len / 4;
147 let d_ptr = data.as_ptr();
148 let o_ptr = out.as_mut_ptr();
149
150 for i in 0..chunks {
151 let off = i * 4;
152 let vd = vld1q_f32(d_ptr.add(off));
153 let vr = vrndpq_f32(vd);
154 vst1q_f32(o_ptr.add(off), vr);
155 }
156 let rem = chunks * 4;
157 for i in rem..len {
158 *o_ptr.add(i) = (*d_ptr.add(i)).ceil();
159 }
160 }
161 }
162
163 #[target_feature(enable = "neon")]
165 pub(crate) unsafe fn round_f32(data: &[f32], out: &mut [f32]) {
166 unsafe {
167 let len = data.len();
168 let chunks = len / 4;
169 let d_ptr = data.as_ptr();
170 let o_ptr = out.as_mut_ptr();
171
172 for i in 0..chunks {
173 let off = i * 4;
174 let vd = vld1q_f32(d_ptr.add(off));
175 let vr = vrndaq_f32(vd);
177 vst1q_f32(o_ptr.add(off), vr);
178 }
179 let rem = chunks * 4;
180 for i in rem..len {
181 *o_ptr.add(i) = (*d_ptr.add(i)).round();
182 }
183 }
184 }
185
186 #[target_feature(enable = "neon")]
190 pub(crate) unsafe fn exp_f32(data: &[f32], out: &mut [f32]) {
191 for i in 0..data.len() {
192 out[i] = data[i].exp();
193 }
194 }
195
196 #[target_feature(enable = "neon")]
198 pub(crate) unsafe fn ln_f32(data: &[f32], out: &mut [f32]) {
199 for i in 0..data.len() {
200 out[i] = data[i].ln();
201 }
202 }
203}
204
205mod scalar_impl {
207 pub(crate) fn apply_unary(data: &[f32], out: &mut [f32], f: fn(f32) -> f32) {
208 const LANES: usize = 8;
209 let chunks = data.len() / LANES;
210
211 for i in 0..chunks {
212 let start = i * LANES;
213 let end = start + LANES;
214 for j in start..end {
215 out[j] = f(data[j]);
216 }
217 }
218
219 let remainder_start = chunks * LANES;
220 for i in remainder_start..data.len() {
221 out[i] = f(data[i]);
222 }
223 }
224
225 pub(crate) fn apply_binary(a: &[f32], b: &[f32], out: &mut [f32], f: fn(f32, f32) -> f32) {
226 const LANES: usize = 8;
227 let chunks = a.len() / LANES;
228
229 for i in 0..chunks {
230 let start = i * LANES;
231 let end = start + LANES;
232 for j in start..end {
233 out[j] = f(a[j], b[j]);
234 }
235 }
236
237 let remainder_start = chunks * LANES;
238 for i in remainder_start..a.len() {
239 out[i] = f(a[i], b[i]);
240 }
241 }
242}
243
244pub fn sqrt_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
256 validate_unary(data, out)?;
257
258 #[cfg(target_arch = "aarch64")]
259 {
260 unsafe {
262 neon_impl::sqrt_f32(data, out);
263 }
264 }
265
266 #[cfg(not(target_arch = "aarch64"))]
267 {
268 scalar_impl::apply_unary(data, out, f32::sqrt);
269 }
270
271 Ok(())
272}
273
274pub fn ln_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
282 validate_unary(data, out)?;
283
284 #[cfg(target_arch = "aarch64")]
285 {
286 unsafe {
288 neon_impl::ln_f32(data, out);
289 }
290 }
291
292 #[cfg(not(target_arch = "aarch64"))]
293 {
294 scalar_impl::apply_unary(data, out, f32::ln);
295 }
296
297 Ok(())
298}
299
300pub fn log10_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
306 validate_unary(data, out)?;
307
308 #[cfg(target_arch = "aarch64")]
309 {
310 unsafe {
313 neon_impl::ln_f32(data, out);
314 }
315 let log10e = std::f32::consts::LOG10_E;
316 for val in out.iter_mut() {
317 *val *= log10e;
318 }
319 }
320
321 #[cfg(not(target_arch = "aarch64"))]
322 {
323 scalar_impl::apply_unary(data, out, f32::log10);
324 }
325
326 Ok(())
327}
328
329pub fn log2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
335 validate_unary(data, out)?;
336
337 #[cfg(target_arch = "aarch64")]
338 {
339 unsafe {
342 neon_impl::ln_f32(data, out);
343 }
344 let log2e = std::f32::consts::LOG2_E;
345 for val in out.iter_mut() {
346 *val *= log2e;
347 }
348 }
349
350 #[cfg(not(target_arch = "aarch64"))]
351 {
352 scalar_impl::apply_unary(data, out, f32::log2);
353 }
354
355 Ok(())
356}
357
358pub fn exp_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
366 validate_unary(data, out)?;
367
368 #[cfg(target_arch = "aarch64")]
369 {
370 unsafe {
372 neon_impl::exp_f32(data, out);
373 }
374 }
375
376 #[cfg(not(target_arch = "aarch64"))]
377 {
378 scalar_impl::apply_unary(data, out, f32::exp);
379 }
380
381 Ok(())
382}
383
384pub fn exp2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
390 validate_unary(data, out)?;
391 scalar_impl::apply_unary(data, out, f32::exp2);
392 Ok(())
393}
394
395pub fn pow_f32(base: &[f32], exponent: &[f32], out: &mut [f32]) -> Result<()> {
401 if base.len() != exponent.len() || base.len() != out.len() {
402 return Err(AlgorithmError::InvalidParameter {
403 parameter: "input",
404 message: "Slice length mismatch".to_string(),
405 });
406 }
407
408 scalar_impl::apply_binary(base, exponent, out, f32::powf);
409 Ok(())
410}
411
412pub fn sin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
418 validate_unary(data, out)?;
419 scalar_impl::apply_unary(data, out, f32::sin);
420 Ok(())
421}
422
423pub fn cos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
429 validate_unary(data, out)?;
430 scalar_impl::apply_unary(data, out, f32::cos);
431 Ok(())
432}
433
434pub fn tan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
440 validate_unary(data, out)?;
441 scalar_impl::apply_unary(data, out, f32::tan);
442 Ok(())
443}
444
445pub fn asin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
451 validate_unary(data, out)?;
452 scalar_impl::apply_unary(data, out, f32::asin);
453 Ok(())
454}
455
456pub fn acos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
462 validate_unary(data, out)?;
463 scalar_impl::apply_unary(data, out, f32::acos);
464 Ok(())
465}
466
467pub fn atan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
473 validate_unary(data, out)?;
474 scalar_impl::apply_unary(data, out, f32::atan);
475 Ok(())
476}
477
478pub fn atan2_f32(y: &[f32], x: &[f32], out: &mut [f32]) -> Result<()> {
484 if y.len() != x.len() || y.len() != out.len() {
485 return Err(AlgorithmError::InvalidParameter {
486 parameter: "input",
487 message: "Slice length mismatch".to_string(),
488 });
489 }
490 scalar_impl::apply_binary(y, x, out, f32::atan2);
491 Ok(())
492}
493
494pub fn sinh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
500 validate_unary(data, out)?;
501 scalar_impl::apply_unary(data, out, f32::sinh);
502 Ok(())
503}
504
505pub fn cosh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
511 validate_unary(data, out)?;
512 scalar_impl::apply_unary(data, out, f32::cosh);
513 Ok(())
514}
515
516pub fn tanh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
522 validate_unary(data, out)?;
523 scalar_impl::apply_unary(data, out, f32::tanh);
524 Ok(())
525}
526
527pub fn abs_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
535 validate_unary(data, out)?;
536
537 #[cfg(target_arch = "aarch64")]
538 {
539 unsafe {
541 neon_impl::abs_f32(data, out);
542 }
543 }
544
545 #[cfg(not(target_arch = "aarch64"))]
546 {
547 scalar_impl::apply_unary(data, out, f32::abs);
548 }
549
550 Ok(())
551}
552
553pub fn floor_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
561 validate_unary(data, out)?;
562
563 #[cfg(target_arch = "aarch64")]
564 {
565 unsafe {
567 neon_impl::floor_f32(data, out);
568 }
569 }
570
571 #[cfg(not(target_arch = "aarch64"))]
572 {
573 scalar_impl::apply_unary(data, out, f32::floor);
574 }
575
576 Ok(())
577}
578
579pub fn ceil_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
587 validate_unary(data, out)?;
588
589 #[cfg(target_arch = "aarch64")]
590 {
591 unsafe {
593 neon_impl::ceil_f32(data, out);
594 }
595 }
596
597 #[cfg(not(target_arch = "aarch64"))]
598 {
599 scalar_impl::apply_unary(data, out, f32::ceil);
600 }
601
602 Ok(())
603}
604
605pub fn round_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
613 validate_unary(data, out)?;
614
615 #[cfg(target_arch = "aarch64")]
616 {
617 unsafe {
619 neon_impl::round_f32(data, out);
620 }
621 }
622
623 #[cfg(not(target_arch = "aarch64"))]
624 {
625 scalar_impl::apply_unary(data, out, f32::round);
626 }
627
628 Ok(())
629}
630
631pub fn fract_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
637 validate_unary(data, out)?;
638 floor_f32(data, out)?;
640 for i in 0..data.len() {
641 out[i] = data[i] - out[i];
642 }
643 Ok(())
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use approx::assert_relative_eq;
650 use std::f32::consts::PI;
651
652 #[test]
653 fn test_sqrt_f32() {
654 let data = vec![1.0, 4.0, 9.0, 16.0, 25.0];
655 let mut out = vec![0.0; 5];
656
657 sqrt_f32(&data, &mut out).expect("sqrt_f32 failed");
658
659 assert_relative_eq!(out[0], 1.0);
660 assert_relative_eq!(out[1], 2.0);
661 assert_relative_eq!(out[2], 3.0);
662 assert_relative_eq!(out[3], 4.0);
663 assert_relative_eq!(out[4], 5.0);
664 }
665
666 #[test]
667 fn test_sqrt_large() {
668 let data = vec![4.0; 1000];
669 let mut out = vec![0.0; 1000];
670
671 sqrt_f32(&data, &mut out).expect("sqrt_f32 large failed");
672
673 for &val in &out {
674 assert_relative_eq!(val, 2.0);
675 }
676 }
677
678 #[test]
679 fn test_exp_ln() {
680 let data = vec![0.0, 1.0, 2.0, 3.0];
681 let mut exp_out = vec![0.0; 4];
682 let mut ln_out = vec![0.0; 4];
683
684 exp_f32(&data, &mut exp_out).expect("exp_f32 failed");
685 ln_f32(&exp_out, &mut ln_out).expect("ln_f32 failed");
686
687 for i in 0..4 {
688 assert_relative_eq!(ln_out[i], data[i], epsilon = 1e-5);
689 }
690 }
691
692 #[test]
693 fn test_exp_large() {
694 let data: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
696 let mut out = vec![0.0; 100];
697
698 exp_f32(&data, &mut out).expect("exp_f32 large failed");
699
700 for i in 0..100 {
701 assert_relative_eq!(out[i], data[i].exp(), epsilon = 1e-4);
702 }
703 }
704
705 #[test]
706 fn test_ln_large() {
707 let data: Vec<f32> = (1..=100).map(|i| i as f32).collect();
708 let mut out = vec![0.0; 100];
709
710 ln_f32(&data, &mut out).expect("ln_f32 large failed");
711
712 for i in 0..100 {
713 assert_relative_eq!(out[i], data[i].ln(), epsilon = 1e-4);
714 }
715 }
716
717 #[test]
718 fn test_log10() {
719 let data = vec![1.0, 10.0, 100.0, 1000.0];
720 let mut out = vec![0.0; 4];
721
722 log10_f32(&data, &mut out).expect("log10_f32 failed");
723
724 assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
725 assert_relative_eq!(out[1], 1.0, epsilon = 1e-5);
726 assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
727 assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
728 }
729
730 #[test]
731 fn test_log2() {
732 let data = vec![1.0, 2.0, 4.0, 8.0, 16.0];
733 let mut out = vec![0.0; 5];
734
735 log2_f32(&data, &mut out).expect("log2_f32 failed");
736
737 assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
738 assert_relative_eq!(out[1], 1.0, epsilon = 1e-4);
739 assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
740 assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
741 assert_relative_eq!(out[4], 4.0, epsilon = 1e-4);
742 }
743
744 #[test]
745 fn test_pow() {
746 let base = vec![2.0, 3.0, 4.0, 5.0];
747 let exp = vec![2.0, 2.0, 2.0, 2.0];
748 let mut out = vec![0.0; 4];
749
750 pow_f32(&base, &exp, &mut out).expect("pow_f32 failed");
751
752 assert_relative_eq!(out[0], 4.0);
753 assert_relative_eq!(out[1], 9.0);
754 assert_relative_eq!(out[2], 16.0);
755 assert_relative_eq!(out[3], 25.0);
756 }
757
758 #[test]
759 fn test_sin_cos() {
760 let data = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0];
761 let mut sin_out = vec![0.0; 5];
762 let mut cos_out = vec![0.0; 5];
763
764 sin_f32(&data, &mut sin_out).expect("sin_f32 failed");
765 cos_f32(&data, &mut cos_out).expect("cos_f32 failed");
766
767 assert_relative_eq!(sin_out[0], 0.0, epsilon = 1e-6);
768 assert_relative_eq!(sin_out[4], 1.0, epsilon = 1e-6);
769 assert_relative_eq!(cos_out[0], 1.0, epsilon = 1e-6);
770 assert_relative_eq!(cos_out[4], 0.0, epsilon = 1e-6);
771
772 for i in 0..5 {
774 let sum = sin_out[i] * sin_out[i] + cos_out[i] * cos_out[i];
775 assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
776 }
777 }
778
779 #[test]
780 fn test_tan() {
781 let data = vec![0.0, PI / 4.0];
782 let mut out = vec![0.0; 2];
783
784 tan_f32(&data, &mut out).expect("tan_f32 failed");
785
786 assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
787 assert_relative_eq!(out[1], 1.0, epsilon = 1e-6);
788 }
789
790 #[test]
791 fn test_asin_acos() {
792 let data = vec![0.0, 0.5, 1.0];
793 let mut asin_out = vec![0.0; 3];
794 let mut acos_out = vec![0.0; 3];
795
796 asin_f32(&data, &mut asin_out).expect("asin_f32 failed");
797 acos_f32(&data, &mut acos_out).expect("acos_f32 failed");
798
799 assert_relative_eq!(asin_out[0], 0.0, epsilon = 1e-6);
800 assert_relative_eq!(asin_out[2], PI / 2.0, epsilon = 1e-6);
801 assert_relative_eq!(acos_out[0], PI / 2.0, epsilon = 1e-6);
802 assert_relative_eq!(acos_out[2], 0.0, epsilon = 1e-6);
803 }
804
805 #[test]
806 fn test_atan2() {
807 let y = vec![0.0, 1.0, 0.0, -1.0];
808 let x = vec![1.0, 0.0, -1.0, 0.0];
809 let mut out = vec![0.0; 4];
810
811 atan2_f32(&y, &x, &mut out).expect("atan2_f32 failed");
812
813 assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
814 assert_relative_eq!(out[1], PI / 2.0, epsilon = 1e-6);
815 assert_relative_eq!(out[2], PI, epsilon = 1e-6);
816 assert_relative_eq!(out[3], -PI / 2.0, epsilon = 1e-6);
817 }
818
819 #[test]
820 fn test_hyperbolic() {
821 let data = vec![0.0, 1.0];
822 let mut sinh_out = vec![0.0; 2];
823 let mut cosh_out = vec![0.0; 2];
824 let mut tanh_out = vec![0.0; 2];
825
826 sinh_f32(&data, &mut sinh_out).expect("sinh_f32 failed");
827 cosh_f32(&data, &mut cosh_out).expect("cosh_f32 failed");
828 tanh_f32(&data, &mut tanh_out).expect("tanh_f32 failed");
829
830 assert_relative_eq!(sinh_out[0], 0.0, epsilon = 1e-6);
831 assert_relative_eq!(cosh_out[0], 1.0, epsilon = 1e-6);
832 assert_relative_eq!(tanh_out[0], 0.0, epsilon = 1e-6);
833 }
834
835 #[test]
836 fn test_abs() {
837 let data = vec![-1.0, -2.0, 3.0, -4.0, 5.0];
838 let mut out = vec![0.0; 5];
839
840 abs_f32(&data, &mut out).expect("abs_f32 failed");
841
842 assert_relative_eq!(out[0], 1.0);
843 assert_relative_eq!(out[1], 2.0);
844 assert_relative_eq!(out[2], 3.0);
845 assert_relative_eq!(out[3], 4.0);
846 assert_relative_eq!(out[4], 5.0);
847 }
848
849 #[test]
850 fn test_abs_large() {
851 let data: Vec<f32> = (-500..500).map(|i| i as f32).collect();
852 let mut out = vec![0.0; 1000];
853
854 abs_f32(&data, &mut out).expect("abs_f32 large failed");
855
856 for i in 0..1000 {
857 assert_relative_eq!(out[i], (data[i]).abs());
858 }
859 }
860
861 #[test]
862 fn test_floor_ceil_round() {
863 let data = vec![1.2, 1.7, -1.2, -1.7];
864 let mut floor_out = vec![0.0; 4];
865 let mut ceil_out = vec![0.0; 4];
866 let mut round_out = vec![0.0; 4];
867
868 floor_f32(&data, &mut floor_out).expect("floor_f32 failed");
869 ceil_f32(&data, &mut ceil_out).expect("ceil_f32 failed");
870 round_f32(&data, &mut round_out).expect("round_f32 failed");
871
872 assert_relative_eq!(floor_out[0], 1.0);
873 assert_relative_eq!(floor_out[1], 1.0);
874 assert_relative_eq!(floor_out[2], -2.0);
875 assert_relative_eq!(floor_out[3], -2.0);
876 assert_relative_eq!(ceil_out[0], 2.0);
877 assert_relative_eq!(ceil_out[1], 2.0);
878 assert_relative_eq!(ceil_out[2], -1.0);
879 assert_relative_eq!(ceil_out[3], -1.0);
880 assert_relative_eq!(round_out[0], 1.0);
881 assert_relative_eq!(round_out[1], 2.0);
882 assert_relative_eq!(round_out[2], -1.0);
883 assert_relative_eq!(round_out[3], -2.0);
884 }
885
886 #[test]
887 fn test_fract() {
888 let data = vec![1.3, 2.7, -1.3, -2.7];
889 let mut out = vec![0.0; 4];
890
891 fract_f32(&data, &mut out).expect("fract_f32 failed");
892
893 assert_relative_eq!(out[0], 0.3, epsilon = 1e-6);
894 assert_relative_eq!(out[1], 0.7, epsilon = 1e-6);
895 assert_relative_eq!(out[2], 0.7, epsilon = 1e-6);
897 assert_relative_eq!(out[3], 0.3, epsilon = 1e-6);
898 }
899
900 #[test]
901 fn test_length_mismatch() {
902 let data = vec![1.0; 10];
903 let mut out = vec![0.0; 5];
904
905 assert!(sqrt_f32(&data, &mut out).is_err());
906 }
907}