rten_vecmath/
sin_cos.rs

1// We reuse constants exactly from the XNNPACK implementation.
2#![allow(clippy::excessive_precision)]
3#![allow(clippy::approx_constant)]
4
5use rten_base::hint::unlikely;
6use rten_simd::ops::{FloatOps, MaskOps, NumOps};
7use rten_simd::{Isa, Simd, SimdUnaryOp};
8
9// Values taken from the XNNPACK vsin implementation that was used as a
10// reference.
11const PI: f32 = 3.1415927;
12const INV_2_PI: f32 = 0.15915494;
13const HALF_PI: f32 = 1.5707964;
14
15// Threshold for large inputs. If abs(x) exceeds this value, the implementation
16// falls back to the standard library.
17const LARGE_THRESHOLD: f32 = 48_000.0;
18
19/// Computes the sine function.
20///
21/// The implementation has a maximum absolute error of 2.98e-7 (2.5 * f32::EPSILON).
22#[derive(Default)]
23pub struct Sin(SinCos<false>);
24
25impl Sin {
26    pub fn new() -> Self {
27        Self::default()
28    }
29}
30
31impl SimdUnaryOp<f32> for Sin {
32    #[inline(always)]
33    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
34        self.0.eval(isa, x)
35    }
36}
37
38/// Computes the cosine function.
39///
40/// The implementation has a maximum absolute error of 4.17e-7 (3.5 * f32::EPSILON).
41#[derive(Default)]
42pub struct Cos(SinCos<true>);
43
44impl Cos {
45    pub fn new() -> Self {
46        Self::default()
47    }
48}
49
50impl SimdUnaryOp<f32> for Cos {
51    #[inline(always)]
52    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
53        self.0.eval(isa, x)
54    }
55}
56
57/// Computes the sine or cosine function.
58#[derive(Default)]
59struct SinCos<const COS: bool> {}
60
61impl<const COS: bool> SimdUnaryOp<f32> for SinCos<COS> {
62    #[inline(always)]
63    fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
64        let ops = isa.f32();
65        let mask_ops = isa.m32();
66
67        // For large inputs the vectorized algorithm can produce results
68        // outside the [-1, 1] range. Fall back to scalar evaluation for such
69        // inputs to avoid this.
70        let large = ops.ge(ops.abs(x), ops.splat(LARGE_THRESHOLD));
71        if unlikely(mask_ops.any(large)) {
72            let mut y = x.to_array();
73            for i in 0..ops.len() {
74                y[i] = if COS { y[i].cos() } else { y[i].sin() };
75            }
76            return ops.load(y.as_ref());
77        }
78
79        // The implementation here is based on XNNPACK.
80        // See https://github.com/google/XNNPACK/blob/master/src/f32-vsin/rational-5-4.c.in.
81
82        // Range reduction constants.
83        let inv_2_pi = ops.splat(INV_2_PI);
84        let two_pi_hi = ops.splat(6.28125);
85        let two_pi_lo = ops.splat(1.9353072e-3);
86        let pi = ops.splat(PI);
87
88        // Rational approximation numerator constants.
89        let a3 = ops.splat(-1.3314664364e-01);
90        let a5 = ops.splat(3.2340581529e-03);
91        let one = ops.splat(1.0);
92
93        // Rational approximation denominator constants.
94        let b2 = ops.splat(3.3519912511e-02);
95        let b4 = ops.splat(4.8770775902e-04);
96
97        // Compute range-reduced `x_rr` such that `x_rr ∈ [−π, π]`.
98        let k = ops.round_ties_even(ops.mul(x, inv_2_pi));
99        let x_rr = ops.mul_sub_from(k, two_pi_hi, x);
100        let mut x_rr = ops.mul_sub_from(k, two_pi_lo, x_rr);
101
102        if COS {
103            let pi_half = ops.splat(HALF_PI);
104            x_rr = ops.sub(pi_half, x_rr);
105        }
106
107        // Further reduce range to [-π/2, π/2].
108        let x_rr = ops.min(x_rr, ops.sub(pi, x_rr));
109        let x_rr = ops.max(x_rr, ops.sub(ops.neg(pi), x_rr));
110        let x_rr = ops.min(x_rr, ops.sub(pi, x_rr));
111
112        // Approximate sin via a rational approximation.
113        let x_rr_sq = ops.mul(x_rr, x_rr);
114
115        // Numerator polynomial
116        let p = ops.mul_add(x_rr_sq, a5, a3);
117        let p = ops.mul_add(x_rr_sq, p, one);
118        let p = ops.mul(p, x_rr);
119
120        // Denominator polynomial
121        let q = ops.mul_add(x_rr_sq, b4, b2);
122        let q = ops.mul_add(x_rr_sq, q, one);
123
124        ops.div(p, q)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::LARGE_THRESHOLD;
131    use crate::testing::{ARange, Tolerance, UnaryOpTester, arange};
132    use crate::{Cos, Sin};
133
134    // Maximum error of `SinCos` compared to `f32::sin` and `f32::cos` in the
135    // `SMALL_X` range.
136    const MAX_ERROR_FOR_SMALL_X: f32 = 2.0 * std::f32::EPSILON; // 2.38e-7
137
138    // Range of small/medium X values which we expect most inputs will be in.
139    const SMALL_X: ARange<f32> = arange(-10., 10., 0.1f32);
140
141    // Multiples of π are the worst case for range reduction.
142    fn multiples_of_pi() -> impl Iterator<Item = f32> + Clone {
143        (-5..5).map(|n| (n as f32) * super::PI)
144    }
145
146    // Generate all float values in the range [min, max].
147    fn all_floats_in_range(min: f32, max: f32) -> impl Iterator<Item = f32> + Clone {
148        std::iter::successors(Some(min), |f| Some(f.next_up())).take_while(move |x| *x <= max)
149    }
150
151    #[test]
152    fn test_sin() {
153        let test = UnaryOpTester {
154            reference: f32::sin,
155            simd: Sin::new(),
156            range: SMALL_X.chain(multiples_of_pi()),
157            tolerance: Tolerance::Absolute(MAX_ERROR_FOR_SMALL_X),
158        };
159        test.run();
160    }
161
162    #[test]
163    #[ignore] // Ignored by default due to long runtime
164    fn test_sin_exhaustive() {
165        let test = UnaryOpTester {
166            reference: f32::sin,
167            simd: Sin::new(),
168            range: all_floats_in_range(-LARGE_THRESHOLD, LARGE_THRESHOLD),
169            tolerance: Tolerance::Absolute(3e-7),
170        };
171        test.run_with_progress();
172    }
173
174    #[test]
175    fn test_cos() {
176        let test = UnaryOpTester {
177            reference: f32::cos,
178            simd: Cos::new(),
179            range: SMALL_X.chain(multiples_of_pi()),
180            tolerance: Tolerance::Absolute(MAX_ERROR_FOR_SMALL_X),
181        };
182        test.run();
183    }
184
185    #[test]
186    #[ignore] // Ignored by default due to long runtime
187    fn test_cos_exhaustive() {
188        let test = UnaryOpTester {
189            reference: f32::cos,
190            simd: Cos::new(),
191            range: all_floats_in_range(-LARGE_THRESHOLD, LARGE_THRESHOLD),
192            // Maximum error for cos is larger than for sin because cos has an
193            // extra subtraction.
194            tolerance: Tolerance::Absolute(5e-7),
195        };
196        test.run_with_progress();
197    }
198}