1#![allow(clippy::excessive_precision)]
3#![allow(clippy::approx_constant)]
4
5use rten_base::hint::unlikely;
6use rten_simd::ops::{FloatOps, MaskOps, NumOps};
7use rten_simd::{Isa, Simd, SimdUnaryOp};
8
9const PI: f32 = 3.1415927;
12const INV_2_PI: f32 = 0.15915494;
13const HALF_PI: f32 = 1.5707964;
14
15const LARGE_THRESHOLD: f32 = 48_000.0;
18
19#[derive(Default)]
23pub struct Sin(SinCos<false>);
24
25impl Sin {
26 pub fn new() -> Self {
27 Self::default()
28 }
29}
30
31impl SimdUnaryOp<f32> for Sin {
32 #[inline(always)]
33 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
34 self.0.eval(isa, x)
35 }
36}
37
38#[derive(Default)]
42pub struct Cos(SinCos<true>);
43
44impl Cos {
45 pub fn new() -> Self {
46 Self::default()
47 }
48}
49
50impl SimdUnaryOp<f32> for Cos {
51 #[inline(always)]
52 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
53 self.0.eval(isa, x)
54 }
55}
56
57#[derive(Default)]
59struct SinCos<const COS: bool> {}
60
61impl<const COS: bool> SimdUnaryOp<f32> for SinCos<COS> {
62 #[inline(always)]
63 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
64 let ops = isa.f32();
65 let mask_ops = isa.m32();
66
67 let large = ops.ge(ops.abs(x), ops.splat(LARGE_THRESHOLD));
71 if unlikely(mask_ops.any(large)) {
72 let mut y = x.to_array();
73 for i in 0..ops.len() {
74 y[i] = if COS { y[i].cos() } else { y[i].sin() };
75 }
76 return ops.load(y.as_ref());
77 }
78
79 let inv_2_pi = ops.splat(INV_2_PI);
84 let two_pi_hi = ops.splat(6.28125);
85 let two_pi_lo = ops.splat(1.9353072e-3);
86 let pi = ops.splat(PI);
87
88 let a3 = ops.splat(-1.3314664364e-01);
90 let a5 = ops.splat(3.2340581529e-03);
91 let one = ops.splat(1.0);
92
93 let b2 = ops.splat(3.3519912511e-02);
95 let b4 = ops.splat(4.8770775902e-04);
96
97 let k = ops.round_ties_even(ops.mul(x, inv_2_pi));
99 let x_rr = ops.mul_sub_from(k, two_pi_hi, x);
100 let mut x_rr = ops.mul_sub_from(k, two_pi_lo, x_rr);
101
102 if COS {
103 let pi_half = ops.splat(HALF_PI);
104 x_rr = ops.sub(pi_half, x_rr);
105 }
106
107 let x_rr = ops.min(x_rr, ops.sub(pi, x_rr));
109 let x_rr = ops.max(x_rr, ops.sub(ops.neg(pi), x_rr));
110 let x_rr = ops.min(x_rr, ops.sub(pi, x_rr));
111
112 let x_rr_sq = ops.mul(x_rr, x_rr);
114
115 let p = ops.mul_add(x_rr_sq, a5, a3);
117 let p = ops.mul_add(x_rr_sq, p, one);
118 let p = ops.mul(p, x_rr);
119
120 let q = ops.mul_add(x_rr_sq, b4, b2);
122 let q = ops.mul_add(x_rr_sq, q, one);
123
124 ops.div(p, q)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::LARGE_THRESHOLD;
131 use crate::testing::{ARange, Tolerance, UnaryOpTester, arange};
132 use crate::{Cos, Sin};
133
134 const MAX_ERROR_FOR_SMALL_X: f32 = 2.0 * std::f32::EPSILON; const SMALL_X: ARange<f32> = arange(-10., 10., 0.1f32);
140
141 fn multiples_of_pi() -> impl Iterator<Item = f32> + Clone {
143 (-5..5).map(|n| (n as f32) * super::PI)
144 }
145
146 fn all_floats_in_range(min: f32, max: f32) -> impl Iterator<Item = f32> + Clone {
148 std::iter::successors(Some(min), |f| Some(f.next_up())).take_while(move |x| *x <= max)
149 }
150
151 #[test]
152 fn test_sin() {
153 let test = UnaryOpTester {
154 reference: f32::sin,
155 simd: Sin::new(),
156 range: SMALL_X.chain(multiples_of_pi()),
157 tolerance: Tolerance::Absolute(MAX_ERROR_FOR_SMALL_X),
158 };
159 test.run();
160 }
161
162 #[test]
163 #[ignore] fn test_sin_exhaustive() {
165 let test = UnaryOpTester {
166 reference: f32::sin,
167 simd: Sin::new(),
168 range: all_floats_in_range(-LARGE_THRESHOLD, LARGE_THRESHOLD),
169 tolerance: Tolerance::Absolute(3e-7),
170 };
171 test.run_with_progress();
172 }
173
174 #[test]
175 fn test_cos() {
176 let test = UnaryOpTester {
177 reference: f32::cos,
178 simd: Cos::new(),
179 range: SMALL_X.chain(multiples_of_pi()),
180 tolerance: Tolerance::Absolute(MAX_ERROR_FOR_SMALL_X),
181 };
182 test.run();
183 }
184
185 #[test]
186 #[ignore] fn test_cos_exhaustive() {
188 let test = UnaryOpTester {
189 reference: f32::cos,
190 simd: Cos::new(),
191 range: all_floats_in_range(-LARGE_THRESHOLD, LARGE_THRESHOLD),
192 tolerance: Tolerance::Absolute(5e-7),
195 };
196 test.run_with_progress();
197 }
198}