fixed_trig/
cordic.rs

1//! Implementations based on the CORDIC algorithm.
2pub use super::cordic_number::CordicNumber;
3use fixed::types::U0F64;
4use std::convert::TryInto;
5
6const ATAN_TABLE: &[u8] = include_bytes!("tables/cordic_atan.table");
7
8fn lookup_table(table: &[u8], index: u8) -> U0F64 {
9    let i = index as usize * 8;
10    U0F64::from_bits(u64::from_le_bytes(table[i..(i + 8)].try_into().unwrap()))
11}
12
13pub fn cordic_circular<T: CordicNumber>(mut x: T, mut y: T, mut z: T, vecmode: T) -> (T, T, T) {
14    let _0 = T::zero();
15    let _2 = T::one() + T::one();
16
17    for i in 0..T::num_fract_bits() {
18        if vecmode >= _0 && y < vecmode || vecmode < _0 && z >= _0 {
19            let x1 = x - (y >> i);
20            y = y + (x >> i);
21            x = x1;
22            z = z - T::from_u0f64(lookup_table(ATAN_TABLE, i));
23        } else {
24            let x1 = x + (y >> i);
25            y = y - (x >> i);
26            x = x1;
27            z = z + T::from_u0f64(lookup_table(ATAN_TABLE, i));
28        }
29    }
30
31    (x, y, z)
32}
33
34pub fn gain_cordic<T: CordicNumber>() -> T {
35    cordic_circular(T::one(), T::zero(), T::zero(), -T::one()).0
36}
37
38pub fn sin_cos<T: CordicNumber>(mut angle: T) -> (T, T) {
39    let mut negative = false;
40
41    while angle > T::frac_pi_2() {
42        angle -= T::pi();
43        negative = !negative;
44    }
45
46    while angle < -T::frac_pi_2() {
47        angle += T::pi();
48        negative = !negative;
49    }
50
51    let inv_gain = T::one() / gain_cordic(); // FIXME: precompute this.
52    let res = cordic_circular(inv_gain, T::zero(), angle, -T::one());
53
54    if negative {
55        (-res.1, -res.0)
56    } else {
57        (res.1, res.0)
58    }
59}
60
61pub fn sin<T: CordicNumber>(angle: T) -> T {
62    sin_cos(angle).0
63}
64
65pub fn cos<T: CordicNumber>(angle: T) -> T {
66    sin_cos(angle).1
67}
68
69pub fn tan<T: CordicNumber>(angle: T) -> T {
70    let (sin, cos) = sin_cos(angle);
71    sin / cos
72}
73
74pub fn asin<T: CordicNumber>(mut val: T) -> T {
75    // For asin, we use a double-rotation approach to reduce errors.
76    // NOTE: see https://stackoverflow.com/questions/25976656/cordic-arcsine-implementation-fails
77    // for details about the innacuracy of CORDIC for asin.
78
79    let mut theta = T::zero();
80    let mut z = (T::one(), T::zero());
81    let niter = T::num_fract_bits();
82
83    for j in 0..niter {
84        let sign_x = if z.0 < T::zero() { -T::one() } else { T::one() };
85        let sigma = if z.1 <= val { sign_x } else { -sign_x };
86        let rotate = |(x, y)| (x - ((y >> j) * sigma), y + ((x >> j) * sigma));
87        z = rotate(rotate(z));
88
89        let angle = T::from_u0f64(lookup_table(ATAN_TABLE, j));
90        theta = theta + ((angle + angle) * sigma);
91        val = val + (val >> (j + j));
92    }
93
94    theta
95}
96
97pub fn acos<T: CordicNumber>(val: T) -> T {
98    T::frac_pi_2() - asin(val)
99}
100
101pub fn atan<T: CordicNumber>(val: T) -> T {
102    cordic_circular(T::one(), val, T::zero(), T::zero()).2
103}
104
105pub fn atan2<T: CordicNumber>(y: T, x: T) -> T {
106    if x == T::zero() {
107        if y < T::zero() {
108            return -T::frac_pi_2();
109        } else {
110            return T::frac_pi_2();
111        }
112    }
113
114    if y == T::zero() {
115        if x >= T::zero() {
116            return T::zero();
117        } else {
118            return T::pi();
119        }
120    }
121
122    match (x < T::zero(), y < T::zero()) {
123        (false, false) => atan(y / x),
124        (false, true) => -atan(-y / x),
125        (true, false) => T::pi() - atan(y / -x),
126        (true, true) => atan(y / x) - T::pi(),
127    }
128}
129
130// FIXME: this should be contributed to the fixed-sqrt crate instead of remaining here.
131pub fn sqrt<T: CordicNumber>(x: T, niters: usize) -> T {
132    if x == T::zero() || x == T::one() {
133        return x;
134    }
135
136    // FIXME: optimize with bitshifts
137    let mut pow2 = T::one();
138    let mut result = T::zero();
139
140    if x < T::one() {
141        while x <= pow2 * pow2 {
142            pow2 = pow2 >> 1;
143        }
144
145        result = pow2;
146    } else {
147        // x >= T::one()
148        while pow2 * pow2 <= x {
149            pow2 = pow2 << 1;
150        }
151
152        result = pow2 >> 1;
153    }
154
155    for _ in 0..niters {
156        pow2 = pow2 >> 1;
157        let next_result = result + pow2;
158        if next_result * next_result <= x {
159            result = next_result;
160        }
161    }
162
163    result
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use fixed::types::I48F16;
170
171    fn assert_approx_eq<T: std::fmt::Display>(
172        input: T,
173        computed: f64,
174        expected: f64,
175        max_err: f64,
176    ) {
177        let err = (computed - expected).abs();
178        if err > max_err {
179            panic!(
180                "mismatch for input {}: computed {}, expected {}",
181                input, computed, expected
182            );
183        }
184    }
185
186    macro_rules! test_trig(
187        ($test: ident, $test_comprehensive: ident, $trigf: ident, $max_err: expr) => {
188            #[test]
189            fn $test() {
190                for i in -100..100 {
191                    let fx = f64::from(i) * 0.1_f64;
192                    let x: I48F16 = I48F16::from_num(fx);
193                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
194                }
195            }
196
197            #[test]
198            fn $test_comprehensive() {
199                for i in 0..(1 << 20) {
200                    let x = I48F16::from_bits(i);
201                    let fx: f64 = x.to_num();
202                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
203
204                    // Test negative numbers too.
205                    let x = -I48F16::from_bits(i);
206                    let fx: f64 = x.to_num();
207                    assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
208                }
209            }
210        }
211    );
212
213    test_trig!(test_sin, test_sin_comprehensive, sin, 0.001);
214    test_trig!(test_cos, test_cos_comprehensive, cos, 0.001);
215    test_trig!(test_atan, test_atan_comprehensive, atan, 0.001);
216
217    #[test]
218    fn test_asin() {
219        for i in 0..(1 << 17) {
220            let x = I48F16::from_bits(i);
221            let fx: f64 = x.to_num();
222            assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
223
224            // Test negative numbers too.
225            let x = I48F16::from_bits(i);
226            let fx: f64 = x.to_num();
227            assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
228        }
229    }
230
231    #[test]
232    fn test_acos() {
233        for i in 0..(1 << 17) {
234            let x = I48F16::from_bits(i);
235            let fx: f64 = x.to_num();
236            assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
237
238            // Test negative numbers too.
239            let x = I48F16::from_bits(i);
240            let fx: f64 = x.to_num();
241            assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
242        }
243    }
244
245    #[test]
246    fn test_sqrt() {
247        for i in 0..(1 << 20) {
248            let x = I48F16::from_bits(i);
249            let fx: f64 = x.to_num();
250            assert_approx_eq(x, sqrt(x, 40).to_num(), fx.sqrt(), 0.01);
251        }
252    }
253}