1pub use super::cordic_number::CordicNumber;
3use fixed::types::U0F64;
4use std::convert::TryInto;
5
6const ATAN_TABLE: &[u8] = include_bytes!("tables/cordic_atan.table");
7
8fn lookup_table(table: &[u8], index: u8) -> U0F64 {
9 let i = index as usize * 8;
10 U0F64::from_bits(u64::from_le_bytes(table[i..(i + 8)].try_into().unwrap()))
11}
12
13pub fn cordic_circular<T: CordicNumber>(mut x: T, mut y: T, mut z: T, vecmode: T) -> (T, T, T) {
14 let _0 = T::zero();
15 let _2 = T::one() + T::one();
16
17 for i in 0..T::num_fract_bits() {
18 if vecmode >= _0 && y < vecmode || vecmode < _0 && z >= _0 {
19 let x1 = x - (y >> i);
20 y = y + (x >> i);
21 x = x1;
22 z = z - T::from_u0f64(lookup_table(ATAN_TABLE, i));
23 } else {
24 let x1 = x + (y >> i);
25 y = y - (x >> i);
26 x = x1;
27 z = z + T::from_u0f64(lookup_table(ATAN_TABLE, i));
28 }
29 }
30
31 (x, y, z)
32}
33
34pub fn gain_cordic<T: CordicNumber>() -> T {
35 cordic_circular(T::one(), T::zero(), T::zero(), -T::one()).0
36}
37
38pub fn sin_cos<T: CordicNumber>(mut angle: T) -> (T, T) {
39 let mut negative = false;
40
41 while angle > T::frac_pi_2() {
42 angle -= T::pi();
43 negative = !negative;
44 }
45
46 while angle < -T::frac_pi_2() {
47 angle += T::pi();
48 negative = !negative;
49 }
50
51 let inv_gain = T::one() / gain_cordic(); let res = cordic_circular(inv_gain, T::zero(), angle, -T::one());
53
54 if negative {
55 (-res.1, -res.0)
56 } else {
57 (res.1, res.0)
58 }
59}
60
61pub fn sin<T: CordicNumber>(angle: T) -> T {
62 sin_cos(angle).0
63}
64
65pub fn cos<T: CordicNumber>(angle: T) -> T {
66 sin_cos(angle).1
67}
68
69pub fn tan<T: CordicNumber>(angle: T) -> T {
70 let (sin, cos) = sin_cos(angle);
71 sin / cos
72}
73
74pub fn asin<T: CordicNumber>(mut val: T) -> T {
75 let mut theta = T::zero();
80 let mut z = (T::one(), T::zero());
81 let niter = T::num_fract_bits();
82
83 for j in 0..niter {
84 let sign_x = if z.0 < T::zero() { -T::one() } else { T::one() };
85 let sigma = if z.1 <= val { sign_x } else { -sign_x };
86 let rotate = |(x, y)| (x - ((y >> j) * sigma), y + ((x >> j) * sigma));
87 z = rotate(rotate(z));
88
89 let angle = T::from_u0f64(lookup_table(ATAN_TABLE, j));
90 theta = theta + ((angle + angle) * sigma);
91 val = val + (val >> (j + j));
92 }
93
94 theta
95}
96
97pub fn acos<T: CordicNumber>(val: T) -> T {
98 T::frac_pi_2() - asin(val)
99}
100
101pub fn atan<T: CordicNumber>(val: T) -> T {
102 cordic_circular(T::one(), val, T::zero(), T::zero()).2
103}
104
105pub fn atan2<T: CordicNumber>(y: T, x: T) -> T {
106 if x == T::zero() {
107 if y < T::zero() {
108 return -T::frac_pi_2();
109 } else {
110 return T::frac_pi_2();
111 }
112 }
113
114 if y == T::zero() {
115 if x >= T::zero() {
116 return T::zero();
117 } else {
118 return T::pi();
119 }
120 }
121
122 match (x < T::zero(), y < T::zero()) {
123 (false, false) => atan(y / x),
124 (false, true) => -atan(-y / x),
125 (true, false) => T::pi() - atan(y / -x),
126 (true, true) => atan(y / x) - T::pi(),
127 }
128}
129
130pub fn sqrt<T: CordicNumber>(x: T, niters: usize) -> T {
132 if x == T::zero() || x == T::one() {
133 return x;
134 }
135
136 let mut pow2 = T::one();
138 let mut result = T::zero();
139
140 if x < T::one() {
141 while x <= pow2 * pow2 {
142 pow2 = pow2 >> 1;
143 }
144
145 result = pow2;
146 } else {
147 while pow2 * pow2 <= x {
149 pow2 = pow2 << 1;
150 }
151
152 result = pow2 >> 1;
153 }
154
155 for _ in 0..niters {
156 pow2 = pow2 >> 1;
157 let next_result = result + pow2;
158 if next_result * next_result <= x {
159 result = next_result;
160 }
161 }
162
163 result
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use fixed::types::I48F16;
170
171 fn assert_approx_eq<T: std::fmt::Display>(
172 input: T,
173 computed: f64,
174 expected: f64,
175 max_err: f64,
176 ) {
177 let err = (computed - expected).abs();
178 if err > max_err {
179 panic!(
180 "mismatch for input {}: computed {}, expected {}",
181 input, computed, expected
182 );
183 }
184 }
185
186 macro_rules! test_trig(
187 ($test: ident, $test_comprehensive: ident, $trigf: ident, $max_err: expr) => {
188 #[test]
189 fn $test() {
190 for i in -100..100 {
191 let fx = f64::from(i) * 0.1_f64;
192 let x: I48F16 = I48F16::from_num(fx);
193 assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
194 }
195 }
196
197 #[test]
198 fn $test_comprehensive() {
199 for i in 0..(1 << 20) {
200 let x = I48F16::from_bits(i);
201 let fx: f64 = x.to_num();
202 assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
203
204 let x = -I48F16::from_bits(i);
206 let fx: f64 = x.to_num();
207 assert_approx_eq(x, $trigf(x).to_num(), fx.$trigf(), $max_err);
208 }
209 }
210 }
211 );
212
213 test_trig!(test_sin, test_sin_comprehensive, sin, 0.001);
214 test_trig!(test_cos, test_cos_comprehensive, cos, 0.001);
215 test_trig!(test_atan, test_atan_comprehensive, atan, 0.001);
216
217 #[test]
218 fn test_asin() {
219 for i in 0..(1 << 17) {
220 let x = I48F16::from_bits(i);
221 let fx: f64 = x.to_num();
222 assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
223
224 let x = I48F16::from_bits(i);
226 let fx: f64 = x.to_num();
227 assert_approx_eq(x, asin(x).to_num(), fx.asin(), 0.01);
228 }
229 }
230
231 #[test]
232 fn test_acos() {
233 for i in 0..(1 << 17) {
234 let x = I48F16::from_bits(i);
235 let fx: f64 = x.to_num();
236 assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
237
238 let x = I48F16::from_bits(i);
240 let fx: f64 = x.to_num();
241 assert_approx_eq(x, acos(x).to_num(), fx.acos(), 0.01);
242 }
243 }
244
245 #[test]
246 fn test_sqrt() {
247 for i in 0..(1 << 20) {
248 let x = I48F16::from_bits(i);
249 let fx: f64 = x.to_num();
250 assert_approx_eq(x, sqrt(x, 40).to_num(), fx.sqrt(), 0.01);
251 }
252 }
253}