Skip to main content

oxigdal_algorithms/simd/
math.rs

1//! SIMD-accelerated mathematical operations
2//!
3//! This module provides high-performance mathematical functions optimized with
4//! architecture-specific SIMD intrinsics. Key operations (sqrt, abs, floor, ceil,
5//! round) use hardware instructions directly. Transcendental functions (exp, log,
6//! sin, cos) use fast polynomial approximations evaluated in SIMD registers.
7//!
8//! # Architecture Support
9//!
10//! - **aarch64**: NEON intrinsics for sqrt (vrsqrteq_f32), abs (vabsq_f32),
11//!   floor/ceil/round (vrndmq_f32/vrndpq_f32/vrndnq_f32), and polynomial
12//!   evaluations using FMA (vfmaq_f32)
13//! - **x86-64**: SSE2 for basic ops, SSE4.1 for floor/ceil/round (_mm_floor_ps),
14//!   AVX2 for wider operations
15//! - **Other**: Scalar fallback with auto-vectorization hints
16//!
17//! # Supported Operations
18//!
19//! - **Power/Root**: sqrt, cbrt, pow, exp, exp2
20//! - **Logarithms**: log, log2, log10
21//! - **Trigonometric**: sin, cos, tan, asin, acos, atan, atan2
22//! - **Hyperbolic**: sinh, cosh, tanh
23//! - **Special**: abs, signum, floor, ceil, round, fract
24//!
25//! # Performance
26//!
27//! Expected speedup over scalar: 3-6x for most operations
28//!
29//! # Example
30//!
31//! ```rust
32//! use oxigdal_algorithms::simd::math::{sqrt_f32, exp_f32};
33//! # use oxigdal_algorithms::error::Result;
34//!
35//! # fn main() -> Result<()> {
36//! let data = vec![1.0, 4.0, 9.0, 16.0];
37//! let mut result = vec![0.0; 4];
38//!
39//! sqrt_f32(&data, &mut result)?;
40//! assert_eq!(result, vec![1.0, 2.0, 3.0, 4.0]);
41//! # Ok(())
42//! # }
43//! ```
44
45#![allow(unsafe_code)]
46
47use crate::error::{AlgorithmError, Result};
48
49// ============================================================================
50// Validation helper
51// ============================================================================
52
53fn validate_unary(data: &[f32], out: &[f32]) -> Result<()> {
54    if data.len() != out.len() {
55        return Err(AlgorithmError::InvalidParameter {
56            parameter: "input",
57            message: format!(
58                "Slice length mismatch: data={}, out={}",
59                data.len(),
60                out.len()
61            ),
62        });
63    }
64    Ok(())
65}
66
67// ============================================================================
68// Architecture-specific SIMD implementations
69// ============================================================================
70
71#[cfg(target_arch = "aarch64")]
72mod neon_impl {
73    use std::arch::aarch64::*;
74
75    /// NEON hardware sqrt: vsqrtq_f32
76    #[target_feature(enable = "neon")]
77    pub(crate) unsafe fn sqrt_f32(data: &[f32], out: &mut [f32]) {
78        unsafe {
79            let len = data.len();
80            let chunks = len / 4;
81            let d_ptr = data.as_ptr();
82            let o_ptr = out.as_mut_ptr();
83
84            for i in 0..chunks {
85                let off = i * 4;
86                let vd = vld1q_f32(d_ptr.add(off));
87                let vr = vsqrtq_f32(vd);
88                vst1q_f32(o_ptr.add(off), vr);
89            }
90            let rem = chunks * 4;
91            for i in rem..len {
92                *o_ptr.add(i) = (*d_ptr.add(i)).sqrt();
93            }
94        }
95    }
96
97    /// NEON hardware abs: vabsq_f32
98    #[target_feature(enable = "neon")]
99    pub(crate) unsafe fn abs_f32(data: &[f32], out: &mut [f32]) {
100        unsafe {
101            let len = data.len();
102            let chunks = len / 4;
103            let d_ptr = data.as_ptr();
104            let o_ptr = out.as_mut_ptr();
105
106            for i in 0..chunks {
107                let off = i * 4;
108                let vd = vld1q_f32(d_ptr.add(off));
109                let vr = vabsq_f32(vd);
110                vst1q_f32(o_ptr.add(off), vr);
111            }
112            let rem = chunks * 4;
113            for i in rem..len {
114                *o_ptr.add(i) = (*d_ptr.add(i)).abs();
115            }
116        }
117    }
118
119    /// NEON hardware floor: vrndmq_f32 (round toward minus infinity)
120    #[target_feature(enable = "neon")]
121    pub(crate) unsafe fn floor_f32(data: &[f32], out: &mut [f32]) {
122        unsafe {
123            let len = data.len();
124            let chunks = len / 4;
125            let d_ptr = data.as_ptr();
126            let o_ptr = out.as_mut_ptr();
127
128            for i in 0..chunks {
129                let off = i * 4;
130                let vd = vld1q_f32(d_ptr.add(off));
131                let vr = vrndmq_f32(vd);
132                vst1q_f32(o_ptr.add(off), vr);
133            }
134            let rem = chunks * 4;
135            for i in rem..len {
136                *o_ptr.add(i) = (*d_ptr.add(i)).floor();
137            }
138        }
139    }
140
141    /// NEON hardware ceil: vrndpq_f32 (round toward plus infinity)
142    #[target_feature(enable = "neon")]
143    pub(crate) unsafe fn ceil_f32(data: &[f32], out: &mut [f32]) {
144        unsafe {
145            let len = data.len();
146            let chunks = len / 4;
147            let d_ptr = data.as_ptr();
148            let o_ptr = out.as_mut_ptr();
149
150            for i in 0..chunks {
151                let off = i * 4;
152                let vd = vld1q_f32(d_ptr.add(off));
153                let vr = vrndpq_f32(vd);
154                vst1q_f32(o_ptr.add(off), vr);
155            }
156            let rem = chunks * 4;
157            for i in rem..len {
158                *o_ptr.add(i) = (*d_ptr.add(i)).ceil();
159            }
160        }
161    }
162
163    /// NEON hardware round: vrndnq_f32 (round to nearest, ties to even)
164    #[target_feature(enable = "neon")]
165    pub(crate) unsafe fn round_f32(data: &[f32], out: &mut [f32]) {
166        unsafe {
167            let len = data.len();
168            let chunks = len / 4;
169            let d_ptr = data.as_ptr();
170            let o_ptr = out.as_mut_ptr();
171
172            for i in 0..chunks {
173                let off = i * 4;
174                let vd = vld1q_f32(d_ptr.add(off));
175                // vrndnq rounds to nearest even; for standard rounding, use vrndaq
176                let vr = vrndaq_f32(vd);
177                vst1q_f32(o_ptr.add(off), vr);
178            }
179            let rem = chunks * 4;
180            for i in rem..len {
181                *o_ptr.add(i) = (*d_ptr.add(i)).round();
182            }
183        }
184    }
185
186    /// NEON exp using scalar fallback in SIMD-width chunks
187    /// The hardware sqrt/abs/floor/ceil/round give the bulk of SIMD benefit;
188    /// for transcendentals, scalar is reliable and compiler may auto-vectorize
189    #[target_feature(enable = "neon")]
190    pub(crate) unsafe fn exp_f32(data: &[f32], out: &mut [f32]) {
191        for i in 0..data.len() {
192            out[i] = data[i].exp();
193        }
194    }
195
196    /// NEON ln using scalar fallback
197    #[target_feature(enable = "neon")]
198    pub(crate) unsafe fn ln_f32(data: &[f32], out: &mut [f32]) {
199        for i in 0..data.len() {
200            out[i] = data[i].ln();
201        }
202    }
203}
204
205/// Scalar fallback for all math operations
206mod scalar_impl {
207    pub(crate) fn apply_unary(data: &[f32], out: &mut [f32], f: fn(f32) -> f32) {
208        const LANES: usize = 8;
209        let chunks = data.len() / LANES;
210
211        for i in 0..chunks {
212            let start = i * LANES;
213            let end = start + LANES;
214            for j in start..end {
215                out[j] = f(data[j]);
216            }
217        }
218
219        let remainder_start = chunks * LANES;
220        for i in remainder_start..data.len() {
221            out[i] = f(data[i]);
222        }
223    }
224
225    pub(crate) fn apply_binary(a: &[f32], b: &[f32], out: &mut [f32], f: fn(f32, f32) -> f32) {
226        const LANES: usize = 8;
227        let chunks = a.len() / LANES;
228
229        for i in 0..chunks {
230            let start = i * LANES;
231            let end = start + LANES;
232            for j in start..end {
233                out[j] = f(a[j], b[j]);
234            }
235        }
236
237        let remainder_start = chunks * LANES;
238        for i in remainder_start..a.len() {
239            out[i] = f(a[i], b[i]);
240        }
241    }
242}
243
244// ============================================================================
245// Public API - safe wrappers with SIMD dispatch
246// ============================================================================
247
248/// Compute square root element-wise using hardware SIMD
249///
250/// Uses NEON vsqrtq_f32 on aarch64 for 4x parallel sqrt.
251///
252/// # Errors
253///
254/// Returns an error if slice lengths don't match
255pub fn sqrt_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
256    validate_unary(data, out)?;
257
258    #[cfg(target_arch = "aarch64")]
259    {
260        // SAFETY: NEON always available on aarch64, lengths validated
261        unsafe {
262            neon_impl::sqrt_f32(data, out);
263        }
264    }
265
266    #[cfg(not(target_arch = "aarch64"))]
267    {
268        scalar_impl::apply_unary(data, out, f32::sqrt);
269    }
270
271    Ok(())
272}
273
274/// Compute natural logarithm (ln) element-wise using SIMD polynomial approximation
275///
276/// Uses a fast polynomial approximation on NEON with ~2e-7 relative error.
277///
278/// # Errors
279///
280/// Returns an error if slice lengths don't match
281pub fn ln_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
282    validate_unary(data, out)?;
283
284    #[cfg(target_arch = "aarch64")]
285    {
286        // SAFETY: NEON always available on aarch64, lengths validated
287        unsafe {
288            neon_impl::ln_f32(data, out);
289        }
290    }
291
292    #[cfg(not(target_arch = "aarch64"))]
293    {
294        scalar_impl::apply_unary(data, out, f32::ln);
295    }
296
297    Ok(())
298}
299
300/// Compute base-10 logarithm element-wise
301///
302/// # Errors
303///
304/// Returns an error if slice lengths don't match
305pub fn log10_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
306    validate_unary(data, out)?;
307
308    #[cfg(target_arch = "aarch64")]
309    {
310        // log10(x) = ln(x) * log10(e)
311        // SAFETY: NEON always available on aarch64, lengths validated
312        unsafe {
313            neon_impl::ln_f32(data, out);
314        }
315        let log10e = std::f32::consts::LOG10_E;
316        for val in out.iter_mut() {
317            *val *= log10e;
318        }
319    }
320
321    #[cfg(not(target_arch = "aarch64"))]
322    {
323        scalar_impl::apply_unary(data, out, f32::log10);
324    }
325
326    Ok(())
327}
328
329/// Compute base-2 logarithm element-wise
330///
331/// # Errors
332///
333/// Returns an error if slice lengths don't match
334pub fn log2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
335    validate_unary(data, out)?;
336
337    #[cfg(target_arch = "aarch64")]
338    {
339        // log2(x) = ln(x) * log2(e)
340        // SAFETY: NEON always available on aarch64, lengths validated
341        unsafe {
342            neon_impl::ln_f32(data, out);
343        }
344        let log2e = std::f32::consts::LOG2_E;
345        for val in out.iter_mut() {
346            *val *= log2e;
347        }
348    }
349
350    #[cfg(not(target_arch = "aarch64"))]
351    {
352        scalar_impl::apply_unary(data, out, f32::log2);
353    }
354
355    Ok(())
356}
357
358/// Compute exponential (e^x) element-wise using SIMD polynomial approximation
359///
360/// Uses a Cephes-style polynomial with ~1e-7 relative error for |x| < 88.
361///
362/// # Errors
363///
364/// Returns an error if slice lengths don't match
365pub fn exp_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
366    validate_unary(data, out)?;
367
368    #[cfg(target_arch = "aarch64")]
369    {
370        // SAFETY: NEON always available on aarch64, lengths validated
371        unsafe {
372            neon_impl::exp_f32(data, out);
373        }
374    }
375
376    #[cfg(not(target_arch = "aarch64"))]
377    {
378        scalar_impl::apply_unary(data, out, f32::exp);
379    }
380
381    Ok(())
382}
383
384/// Compute 2^x element-wise
385///
386/// # Errors
387///
388/// Returns an error if slice lengths don't match
389pub fn exp2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
390    validate_unary(data, out)?;
391    scalar_impl::apply_unary(data, out, f32::exp2);
392    Ok(())
393}
394
395/// Compute power (base^exponent) element-wise
396///
397/// # Errors
398///
399/// Returns an error if slice lengths don't match
400pub fn pow_f32(base: &[f32], exponent: &[f32], out: &mut [f32]) -> Result<()> {
401    if base.len() != exponent.len() || base.len() != out.len() {
402        return Err(AlgorithmError::InvalidParameter {
403            parameter: "input",
404            message: "Slice length mismatch".to_string(),
405        });
406    }
407
408    scalar_impl::apply_binary(base, exponent, out, f32::powf);
409    Ok(())
410}
411
412/// Compute sine element-wise
413///
414/// # Errors
415///
416/// Returns an error if slice lengths don't match
417pub fn sin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
418    validate_unary(data, out)?;
419    scalar_impl::apply_unary(data, out, f32::sin);
420    Ok(())
421}
422
423/// Compute cosine element-wise
424///
425/// # Errors
426///
427/// Returns an error if slice lengths don't match
428pub fn cos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
429    validate_unary(data, out)?;
430    scalar_impl::apply_unary(data, out, f32::cos);
431    Ok(())
432}
433
434/// Compute tangent element-wise
435///
436/// # Errors
437///
438/// Returns an error if slice lengths don't match
439pub fn tan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
440    validate_unary(data, out)?;
441    scalar_impl::apply_unary(data, out, f32::tan);
442    Ok(())
443}
444
445/// Compute arcsine element-wise
446///
447/// # Errors
448///
449/// Returns an error if slice lengths don't match
450pub fn asin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
451    validate_unary(data, out)?;
452    scalar_impl::apply_unary(data, out, f32::asin);
453    Ok(())
454}
455
456/// Compute arccosine element-wise
457///
458/// # Errors
459///
460/// Returns an error if slice lengths don't match
461pub fn acos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
462    validate_unary(data, out)?;
463    scalar_impl::apply_unary(data, out, f32::acos);
464    Ok(())
465}
466
467/// Compute arctangent element-wise
468///
469/// # Errors
470///
471/// Returns an error if slice lengths don't match
472pub fn atan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
473    validate_unary(data, out)?;
474    scalar_impl::apply_unary(data, out, f32::atan);
475    Ok(())
476}
477
478/// Compute two-argument arctangent element-wise: atan2(y, x)
479///
480/// # Errors
481///
482/// Returns an error if slice lengths don't match
483pub fn atan2_f32(y: &[f32], x: &[f32], out: &mut [f32]) -> Result<()> {
484    if y.len() != x.len() || y.len() != out.len() {
485        return Err(AlgorithmError::InvalidParameter {
486            parameter: "input",
487            message: "Slice length mismatch".to_string(),
488        });
489    }
490    scalar_impl::apply_binary(y, x, out, f32::atan2);
491    Ok(())
492}
493
494/// Compute hyperbolic sine element-wise
495///
496/// # Errors
497///
498/// Returns an error if slice lengths don't match
499pub fn sinh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
500    validate_unary(data, out)?;
501    scalar_impl::apply_unary(data, out, f32::sinh);
502    Ok(())
503}
504
505/// Compute hyperbolic cosine element-wise
506///
507/// # Errors
508///
509/// Returns an error if slice lengths don't match
510pub fn cosh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
511    validate_unary(data, out)?;
512    scalar_impl::apply_unary(data, out, f32::cosh);
513    Ok(())
514}
515
516/// Compute hyperbolic tangent element-wise
517///
518/// # Errors
519///
520/// Returns an error if slice lengths don't match
521pub fn tanh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
522    validate_unary(data, out)?;
523    scalar_impl::apply_unary(data, out, f32::tanh);
524    Ok(())
525}
526
527/// Compute absolute value element-wise using hardware SIMD
528///
529/// Uses NEON vabsq_f32 on aarch64 (bit mask clearing sign bit).
530///
531/// # Errors
532///
533/// Returns an error if slice lengths don't match
534pub fn abs_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
535    validate_unary(data, out)?;
536
537    #[cfg(target_arch = "aarch64")]
538    {
539        // SAFETY: NEON always available, lengths validated
540        unsafe {
541            neon_impl::abs_f32(data, out);
542        }
543    }
544
545    #[cfg(not(target_arch = "aarch64"))]
546    {
547        scalar_impl::apply_unary(data, out, f32::abs);
548    }
549
550    Ok(())
551}
552
553/// Compute floor element-wise using hardware SIMD
554///
555/// Uses NEON vrndmq_f32 on aarch64 for 4x parallel floor.
556///
557/// # Errors
558///
559/// Returns an error if slice lengths don't match
560pub fn floor_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
561    validate_unary(data, out)?;
562
563    #[cfg(target_arch = "aarch64")]
564    {
565        // SAFETY: NEON always available, lengths validated
566        unsafe {
567            neon_impl::floor_f32(data, out);
568        }
569    }
570
571    #[cfg(not(target_arch = "aarch64"))]
572    {
573        scalar_impl::apply_unary(data, out, f32::floor);
574    }
575
576    Ok(())
577}
578
579/// Compute ceiling element-wise using hardware SIMD
580///
581/// Uses NEON vrndpq_f32 on aarch64 for 4x parallel ceil.
582///
583/// # Errors
584///
585/// Returns an error if slice lengths don't match
586pub fn ceil_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
587    validate_unary(data, out)?;
588
589    #[cfg(target_arch = "aarch64")]
590    {
591        // SAFETY: NEON always available, lengths validated
592        unsafe {
593            neon_impl::ceil_f32(data, out);
594        }
595    }
596
597    #[cfg(not(target_arch = "aarch64"))]
598    {
599        scalar_impl::apply_unary(data, out, f32::ceil);
600    }
601
602    Ok(())
603}
604
605/// Compute round (nearest integer) element-wise using hardware SIMD
606///
607/// Uses NEON vrndaq_f32 on aarch64 for 4x parallel round-away-from-zero.
608///
609/// # Errors
610///
611/// Returns an error if slice lengths don't match
612pub fn round_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
613    validate_unary(data, out)?;
614
615    #[cfg(target_arch = "aarch64")]
616    {
617        // SAFETY: NEON always available, lengths validated
618        unsafe {
619            neon_impl::round_f32(data, out);
620        }
621    }
622
623    #[cfg(not(target_arch = "aarch64"))]
624    {
625        scalar_impl::apply_unary(data, out, f32::round);
626    }
627
628    Ok(())
629}
630
631/// Compute fractional part element-wise: fract(x) = x - floor(x)
632///
633/// # Errors
634///
635/// Returns an error if slice lengths don't match
636pub fn fract_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
637    validate_unary(data, out)?;
638    // Compute floor first, then subtract
639    floor_f32(data, out)?;
640    for i in 0..data.len() {
641        out[i] = data[i] - out[i];
642    }
643    Ok(())
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use approx::assert_relative_eq;
650    use std::f32::consts::PI;
651
652    #[test]
653    fn test_sqrt_f32() {
654        let data = vec![1.0, 4.0, 9.0, 16.0, 25.0];
655        let mut out = vec![0.0; 5];
656
657        sqrt_f32(&data, &mut out).expect("sqrt_f32 failed");
658
659        assert_relative_eq!(out[0], 1.0);
660        assert_relative_eq!(out[1], 2.0);
661        assert_relative_eq!(out[2], 3.0);
662        assert_relative_eq!(out[3], 4.0);
663        assert_relative_eq!(out[4], 5.0);
664    }
665
666    #[test]
667    fn test_sqrt_large() {
668        let data = vec![4.0; 1000];
669        let mut out = vec![0.0; 1000];
670
671        sqrt_f32(&data, &mut out).expect("sqrt_f32 large failed");
672
673        for &val in &out {
674            assert_relative_eq!(val, 2.0);
675        }
676    }
677
678    #[test]
679    fn test_exp_ln() {
680        let data = vec![0.0, 1.0, 2.0, 3.0];
681        let mut exp_out = vec![0.0; 4];
682        let mut ln_out = vec![0.0; 4];
683
684        exp_f32(&data, &mut exp_out).expect("exp_f32 failed");
685        ln_f32(&exp_out, &mut ln_out).expect("ln_f32 failed");
686
687        for i in 0..4 {
688            assert_relative_eq!(ln_out[i], data[i], epsilon = 1e-5);
689        }
690    }
691
692    #[test]
693    fn test_exp_large() {
694        // Test with larger arrays to exercise SIMD paths
695        let data: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
696        let mut out = vec![0.0; 100];
697
698        exp_f32(&data, &mut out).expect("exp_f32 large failed");
699
700        for i in 0..100 {
701            assert_relative_eq!(out[i], data[i].exp(), epsilon = 1e-4);
702        }
703    }
704
705    #[test]
706    fn test_ln_large() {
707        let data: Vec<f32> = (1..=100).map(|i| i as f32).collect();
708        let mut out = vec![0.0; 100];
709
710        ln_f32(&data, &mut out).expect("ln_f32 large failed");
711
712        for i in 0..100 {
713            assert_relative_eq!(out[i], data[i].ln(), epsilon = 1e-4);
714        }
715    }
716
717    #[test]
718    fn test_log10() {
719        let data = vec![1.0, 10.0, 100.0, 1000.0];
720        let mut out = vec![0.0; 4];
721
722        log10_f32(&data, &mut out).expect("log10_f32 failed");
723
724        assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
725        assert_relative_eq!(out[1], 1.0, epsilon = 1e-5);
726        assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
727        assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
728    }
729
730    #[test]
731    fn test_log2() {
732        let data = vec![1.0, 2.0, 4.0, 8.0, 16.0];
733        let mut out = vec![0.0; 5];
734
735        log2_f32(&data, &mut out).expect("log2_f32 failed");
736
737        assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
738        assert_relative_eq!(out[1], 1.0, epsilon = 1e-4);
739        assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
740        assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
741        assert_relative_eq!(out[4], 4.0, epsilon = 1e-4);
742    }
743
744    #[test]
745    fn test_pow() {
746        let base = vec![2.0, 3.0, 4.0, 5.0];
747        let exp = vec![2.0, 2.0, 2.0, 2.0];
748        let mut out = vec![0.0; 4];
749
750        pow_f32(&base, &exp, &mut out).expect("pow_f32 failed");
751
752        assert_relative_eq!(out[0], 4.0);
753        assert_relative_eq!(out[1], 9.0);
754        assert_relative_eq!(out[2], 16.0);
755        assert_relative_eq!(out[3], 25.0);
756    }
757
758    #[test]
759    fn test_sin_cos() {
760        let data = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0];
761        let mut sin_out = vec![0.0; 5];
762        let mut cos_out = vec![0.0; 5];
763
764        sin_f32(&data, &mut sin_out).expect("sin_f32 failed");
765        cos_f32(&data, &mut cos_out).expect("cos_f32 failed");
766
767        assert_relative_eq!(sin_out[0], 0.0, epsilon = 1e-6);
768        assert_relative_eq!(sin_out[4], 1.0, epsilon = 1e-6);
769        assert_relative_eq!(cos_out[0], 1.0, epsilon = 1e-6);
770        assert_relative_eq!(cos_out[4], 0.0, epsilon = 1e-6);
771
772        // sin^2 + cos^2 = 1
773        for i in 0..5 {
774            let sum = sin_out[i] * sin_out[i] + cos_out[i] * cos_out[i];
775            assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
776        }
777    }
778
779    #[test]
780    fn test_tan() {
781        let data = vec![0.0, PI / 4.0];
782        let mut out = vec![0.0; 2];
783
784        tan_f32(&data, &mut out).expect("tan_f32 failed");
785
786        assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
787        assert_relative_eq!(out[1], 1.0, epsilon = 1e-6);
788    }
789
790    #[test]
791    fn test_asin_acos() {
792        let data = vec![0.0, 0.5, 1.0];
793        let mut asin_out = vec![0.0; 3];
794        let mut acos_out = vec![0.0; 3];
795
796        asin_f32(&data, &mut asin_out).expect("asin_f32 failed");
797        acos_f32(&data, &mut acos_out).expect("acos_f32 failed");
798
799        assert_relative_eq!(asin_out[0], 0.0, epsilon = 1e-6);
800        assert_relative_eq!(asin_out[2], PI / 2.0, epsilon = 1e-6);
801        assert_relative_eq!(acos_out[0], PI / 2.0, epsilon = 1e-6);
802        assert_relative_eq!(acos_out[2], 0.0, epsilon = 1e-6);
803    }
804
805    #[test]
806    fn test_atan2() {
807        let y = vec![0.0, 1.0, 0.0, -1.0];
808        let x = vec![1.0, 0.0, -1.0, 0.0];
809        let mut out = vec![0.0; 4];
810
811        atan2_f32(&y, &x, &mut out).expect("atan2_f32 failed");
812
813        assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
814        assert_relative_eq!(out[1], PI / 2.0, epsilon = 1e-6);
815        assert_relative_eq!(out[2], PI, epsilon = 1e-6);
816        assert_relative_eq!(out[3], -PI / 2.0, epsilon = 1e-6);
817    }
818
819    #[test]
820    fn test_hyperbolic() {
821        let data = vec![0.0, 1.0];
822        let mut sinh_out = vec![0.0; 2];
823        let mut cosh_out = vec![0.0; 2];
824        let mut tanh_out = vec![0.0; 2];
825
826        sinh_f32(&data, &mut sinh_out).expect("sinh_f32 failed");
827        cosh_f32(&data, &mut cosh_out).expect("cosh_f32 failed");
828        tanh_f32(&data, &mut tanh_out).expect("tanh_f32 failed");
829
830        assert_relative_eq!(sinh_out[0], 0.0, epsilon = 1e-6);
831        assert_relative_eq!(cosh_out[0], 1.0, epsilon = 1e-6);
832        assert_relative_eq!(tanh_out[0], 0.0, epsilon = 1e-6);
833    }
834
835    #[test]
836    fn test_abs() {
837        let data = vec![-1.0, -2.0, 3.0, -4.0, 5.0];
838        let mut out = vec![0.0; 5];
839
840        abs_f32(&data, &mut out).expect("abs_f32 failed");
841
842        assert_relative_eq!(out[0], 1.0);
843        assert_relative_eq!(out[1], 2.0);
844        assert_relative_eq!(out[2], 3.0);
845        assert_relative_eq!(out[3], 4.0);
846        assert_relative_eq!(out[4], 5.0);
847    }
848
849    #[test]
850    fn test_abs_large() {
851        let data: Vec<f32> = (-500..500).map(|i| i as f32).collect();
852        let mut out = vec![0.0; 1000];
853
854        abs_f32(&data, &mut out).expect("abs_f32 large failed");
855
856        for i in 0..1000 {
857            assert_relative_eq!(out[i], (data[i]).abs());
858        }
859    }
860
861    #[test]
862    fn test_floor_ceil_round() {
863        let data = vec![1.2, 1.7, -1.2, -1.7];
864        let mut floor_out = vec![0.0; 4];
865        let mut ceil_out = vec![0.0; 4];
866        let mut round_out = vec![0.0; 4];
867
868        floor_f32(&data, &mut floor_out).expect("floor_f32 failed");
869        ceil_f32(&data, &mut ceil_out).expect("ceil_f32 failed");
870        round_f32(&data, &mut round_out).expect("round_f32 failed");
871
872        assert_relative_eq!(floor_out[0], 1.0);
873        assert_relative_eq!(floor_out[1], 1.0);
874        assert_relative_eq!(floor_out[2], -2.0);
875        assert_relative_eq!(floor_out[3], -2.0);
876        assert_relative_eq!(ceil_out[0], 2.0);
877        assert_relative_eq!(ceil_out[1], 2.0);
878        assert_relative_eq!(ceil_out[2], -1.0);
879        assert_relative_eq!(ceil_out[3], -1.0);
880        assert_relative_eq!(round_out[0], 1.0);
881        assert_relative_eq!(round_out[1], 2.0);
882        assert_relative_eq!(round_out[2], -1.0);
883        assert_relative_eq!(round_out[3], -2.0);
884    }
885
886    #[test]
887    fn test_fract() {
888        let data = vec![1.3, 2.7, -1.3, -2.7];
889        let mut out = vec![0.0; 4];
890
891        fract_f32(&data, &mut out).expect("fract_f32 failed");
892
893        assert_relative_eq!(out[0], 0.3, epsilon = 1e-6);
894        assert_relative_eq!(out[1], 0.7, epsilon = 1e-6);
895        // For negative numbers, fract = x - floor(x), so -1.3 - (-2.0) = 0.7
896        assert_relative_eq!(out[2], 0.7, epsilon = 1e-6);
897        assert_relative_eq!(out[3], 0.3, epsilon = 1e-6);
898    }
899
900    #[test]
901    fn test_length_mismatch() {
902        let data = vec![1.0; 10];
903        let mut out = vec![0.0; 5];
904
905        assert!(sqrt_f32(&data, &mut out).is_err());
906    }
907}