1use arrayvec::ArrayVec;
5
6#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
26pub struct Poly<const N: usize> {
27 pub(crate) coeffs: [f64; N],
28}
29
30pub type Quadratic = Poly<3>;
32
33pub type Cubic = Poly<4>;
35
36pub type Quartic = Poly<5>;
38
39pub type Quintic = Poly<6>;
41
42impl<const N: usize> Poly<N> {
43 pub const fn new(coeffs: [f64; N]) -> Poly<N> {
49 Poly { coeffs }
50 }
51
52 pub fn coeffs(&self) -> &[f64; N] {
56 &self.coeffs
57 }
58
59 pub fn eval(&self, x: f64) -> f64 {
61 let mut acc = 0.0;
62 for c in self.coeffs.iter().rev() {
63 acc = acc * x + c;
66 }
67 acc
68 }
69
70 pub fn max_abs_coefficient(&self) -> f64 {
74 let mut max = 0.0f64;
75 for c in &self.coeffs {
76 max = max.max(c.abs());
77 }
78 max
79 }
80
81 pub fn is_finite(&self) -> bool {
83 self.coeffs.iter().all(|c| c.is_finite())
84 }
85}
86
87macro_rules! impl_deriv_and_deflate {
88 ($N:literal, $N_MINUS_ONE:literal) => {
89 impl Poly<$N> {
90 pub fn deriv(&self) -> Poly<$N_MINUS_ONE> {
93 let mut coeffs = [0.0; $N_MINUS_ONE];
94 for (i, (d, c)) in coeffs.iter_mut().zip(&self.coeffs[1..]).enumerate() {
95 *d = (i + 1) as f64 * c;
96 }
97 Poly::new(coeffs)
98 }
99
100 pub fn deflate(&self, root: f64) -> Poly<$N_MINUS_ONE> {
109 let mut acc = 0.0;
110 let mut coeffs = [0.0; $N_MINUS_ONE];
111 for (d, c) in coeffs.iter_mut().zip(&self.coeffs[1..]).rev() {
112 acc = acc * root + c;
113 *d = acc;
114 }
115 Poly::new(coeffs)
116 }
117 }
118 };
119}
120
121macro_rules! impl_roots_between_recursive {
122 ($N:literal, $N_MINUS_ONE:literal) => {
123 impl Poly<$N> {
124 pub fn roots_between(
133 self,
134 lower: f64,
135 upper: f64,
136 x_error: f64,
137 ) -> ArrayVec<f64, $N_MINUS_ONE> {
138 let mut ret = ArrayVec::new();
139 let mut scratch = ArrayVec::new();
140 self.roots_between_with_buffer(lower, upper, x_error, &mut ret, &mut scratch);
141 ret
142 }
143
144 fn roots_between_with_buffer<const M: usize>(
148 self,
149 lower: f64,
150 upper: f64,
151 x_error: f64,
152 out: &mut ArrayVec<f64, M>,
153 scratch: &mut ArrayVec<f64, M>,
154 ) {
155 let deriv = self.deriv();
156 if !deriv.is_finite() {
157 return;
158 }
159 deriv.roots_between_with_buffer(lower, upper, x_error, scratch, out);
160 scratch.push(upper);
161 out.clear();
162 let mut last = lower;
163 let mut last_val = self.eval(last);
164
165 for &mut x in scratch {
169 let val = self.eval(x);
170 if $crate::different_signs(last_val, val) {
171 out.push($crate::yuksel::find_root(
172 |x| self.eval(x),
173 |x| deriv.eval(x),
174 last,
175 x,
176 last_val,
177 val,
178 x_error,
179 ));
180 }
181
182 last = x;
183 last_val = val;
184 }
185 }
186 }
187 };
188}
189
190impl_deriv_and_deflate!(3, 2);
191impl_deriv_and_deflate!(4, 3);
192impl_deriv_and_deflate!(5, 4);
193impl_deriv_and_deflate!(6, 5);
194impl_deriv_and_deflate!(7, 6);
195impl_deriv_and_deflate!(8, 7);
196impl_deriv_and_deflate!(9, 8);
197impl_deriv_and_deflate!(10, 9);
198
199impl_roots_between_recursive!(5, 4);
200impl_roots_between_recursive!(6, 5);
201impl_roots_between_recursive!(7, 6);
202impl_roots_between_recursive!(8, 7);
203impl_roots_between_recursive!(9, 8);
204impl_roots_between_recursive!(10, 9);
205
206impl<const N: usize> core::ops::Mul<f64> for Poly<N> {
207 type Output = Poly<N>;
208
209 fn mul(mut self, scale: f64) -> Poly<N> {
210 self *= scale;
211 self
212 }
213}
214
215impl<const N: usize> core::ops::MulAssign<f64> for Poly<N> {
216 fn mul_assign(&mut self, scale: f64) {
217 for c in &mut self.coeffs {
218 *c *= scale;
219 }
220 }
221}
222
223impl<const N: usize> core::ops::Mul<f64> for &Poly<N> {
224 type Output = Poly<N>;
225
226 fn mul(self, scale: f64) -> Poly<N> {
227 (*self) * scale
228 }
229}
230
231impl<const N: usize> core::ops::Div<f64> for Poly<N> {
232 type Output = Poly<N>;
233
234 fn div(mut self, scale: f64) -> Poly<N> {
235 self /= scale;
236 self
237 }
238}
239
240impl<const N: usize> core::ops::DivAssign<f64> for Poly<N> {
241 fn div_assign(&mut self, scale: f64) {
242 for c in &mut self.coeffs {
243 *c /= scale;
244 }
245 }
246}
247
248impl<const N: usize> core::ops::Div<f64> for &Poly<N> {
249 type Output = Poly<N>;
250
251 fn div(self, scale: f64) -> Poly<N> {
252 (*self) / scale
253 }
254}
255
256impl<const N: usize> core::ops::AddAssign<&Poly<N>> for Poly<N> {
257 fn add_assign(&mut self, rhs: &Poly<N>) {
258 for (c, d) in self.coeffs.iter_mut().zip(rhs.coeffs) {
259 *c += d;
260 }
261 }
262}
263
264impl<const N: usize> core::ops::AddAssign<Poly<N>> for Poly<N> {
265 fn add_assign(&mut self, rhs: Poly<N>) {
266 *self += &rhs;
267 }
268}
269
270impl<const N: usize> core::ops::Add<Poly<N>> for Poly<N> {
271 type Output = Poly<N>;
272
273 fn add(mut self, rhs: Poly<N>) -> Poly<N> {
274 self += rhs;
275 self
276 }
277}
278
279impl<const N: usize> core::ops::Add<&Poly<N>> for Poly<N> {
280 type Output = Poly<N>;
281
282 fn add(mut self, rhs: &Poly<N>) -> Poly<N> {
283 self += rhs;
284 self
285 }
286}
287
288impl<const N: usize> core::ops::Add<Poly<N>> for &Poly<N> {
289 type Output = Poly<N>;
290
291 fn add(self, mut rhs: Poly<N>) -> Poly<N> {
292 rhs += self;
293 rhs
294 }
295}
296
297impl<const N: usize> core::ops::SubAssign<&Poly<N>> for Poly<N> {
298 fn sub_assign(&mut self, rhs: &Poly<N>) {
299 for (c, d) in self.coeffs.iter_mut().zip(rhs.coeffs) {
300 *c -= d;
301 }
302 }
303}
304
305impl<const N: usize> core::ops::SubAssign<Poly<N>> for Poly<N> {
306 fn sub_assign(&mut self, rhs: Poly<N>) {
307 *self -= &rhs;
308 }
309}
310
311impl<const N: usize> core::ops::Sub<Poly<N>> for Poly<N> {
312 type Output = Poly<N>;
313
314 fn sub(mut self, rhs: Poly<N>) -> Poly<N> {
315 self -= rhs;
316 self
317 }
318}
319
320impl<const N: usize> core::ops::Sub<&Poly<N>> for Poly<N> {
321 type Output = Poly<N>;
322
323 fn sub(mut self, rhs: &Poly<N>) -> Poly<N> {
324 self -= rhs;
325 self
326 }
327}
328
329impl<const N: usize> core::ops::Sub<Poly<N>> for &Poly<N> {
330 type Output = Poly<N>;
331
332 fn sub(self, mut rhs: Poly<N>) -> Poly<N> {
333 rhs -= self;
334 rhs
335 }
336}
337
338#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn smoke() {
350 let p = Poly::new([-6.0, 11.0, -6.0, 1.0]);
351
352 let roots = p.roots_between(0.0, 5.0, 1e-6);
353 assert_eq!(roots.len(), 3);
354 assert!((roots[0] - 1.0).abs() <= 1e-6);
355 assert!((roots[1] - 2.0).abs() <= 1e-6);
356 assert!((roots[2] - 3.0).abs() <= 1e-6);
357
358 let p = Poly::new([24.0, -50.0, 35.0, -10.0, 1.0]);
359
360 let roots = p.roots_between(0.0, 5.0, 1e-6);
361 assert_eq!(roots.len(), 4);
362 assert!((roots[0] - 1.0).abs() <= 1e-6);
363 assert!((roots[1] - 2.0).abs() <= 1e-6);
364 assert!((roots[2] - 3.0).abs() <= 1e-6);
365 assert!((roots[3] - 4.0).abs() <= 1e-6);
366 }
367
368 fn check_root_values<const N: usize>(p: &Poly<N>, roots: &[f64]) {
372 let magnitude = p.max_abs_coefficient().max(1.0);
375 let accuracy = magnitude * 1e-12;
376
377 for r in roots {
378 let accuracy = accuracy * r.abs().powi(N as i32 - 1).max(1.0);
382 let y = p.eval(*r);
383 assert!(
384 y.abs() <= accuracy,
385 "poly {p:?} had root {r} evaluate to {y:?}, but expected {accuracy:?}"
386 );
387 }
388 }
389
390 #[test]
391 fn root_evaluation_deg3() {
392 arbtest::arbtest(|u| {
393 let poly: Poly<4> = crate::arbitrary::poly(u)?;
394 if (poly.max_abs_coefficient() * 10.0f64.powi(4)).is_infinite() {
396 return Err(arbitrary::Error::IncorrectFormat);
397 }
398 let roots = poly.roots_between(-10.0, 10.0, 1e-13);
399
400 check_root_values(&poly, &roots);
401 Ok(())
402 })
403 .budget_ms(5_000);
404 }
405
406 #[test]
407 fn root_evaluation_deg4() {
408 arbtest::arbtest(|u| {
409 let poly: Poly<5> = crate::arbitrary::poly(u)?;
410 if (poly.max_abs_coefficient() * 10.0f64.powi(5)).is_infinite() {
411 return Err(arbitrary::Error::IncorrectFormat);
412 }
413 let roots = poly.roots_between(-10.0, 10.0, 1e-13);
414 check_root_values(&poly, &roots);
415 Ok(())
416 })
417 .budget_ms(5_000);
418 }
419
420 #[test]
421 fn root_evaluation_deg9() {
422 arbtest::arbtest(|u| {
423 let poly: Poly<10> = crate::arbitrary::poly(u)?;
424 if (poly.max_abs_coefficient() * 10.0f64.powi(11)).is_infinite() {
425 return Err(arbitrary::Error::IncorrectFormat);
426 }
427 let roots = poly.roots_between(-10.0, 10.0, 1e-13);
428 check_root_values(&poly, &roots);
429 Ok(())
430 })
431 .budget_ms(5_000);
432 }
433
434 #[test]
435 fn planted_root_deg5() {
436 arbtest::arbtest(|u| {
437 let planted_root = crate::arbitrary::float_in_unit_interval(u)?;
438 let poly: Poly<6> = crate::arbitrary::poly_with_planted_root(u, planted_root, 1e-6)?;
439
440 if (poly.max_abs_coefficient() * 1024.0).is_infinite() {
444 return Err(arbitrary::Error::IncorrectFormat);
445 }
446 let roots = poly.roots_between(-2.0, 2.0, 1e-13);
447
448 assert!(roots.iter().all(|r| r.is_finite()));
449
450 assert!(roots.is_sorted());
452 assert!(roots.iter().all(|r| (-2.0..=2.0).contains(r)));
453
454 let error = poly.max_abs_coefficient().max(1.0) * 1e-12;
457 assert!(roots.iter().any(|r| (r - planted_root).abs() <= error));
458 Ok(())
459 })
460 .budget_ms(5_000);
461 }
462}