retrofire_core/math/
spline.rs

1//! Bézier curves and splines.
2
3use alloc::vec::Vec;
4use core::{array, fmt::Debug};
5
6use crate::geom::{Polyline, Ray};
7use crate::math::{Affine, Lerp, Linear, Parametric};
8
9/// A cubic Bézier curve, defined by four control points.
10///
11/// TODO More info about Béziers
12///
13/// ```text
14///
15///  p1
16///   \         ____
17///    \     _-´    `--_                      p3
18///     \   /           `-_                    \
19///      \ |               `-_                 |\
20///       \|                  `-__            /  \
21///        \                      `---_____--´    \
22///        p0                                      \
23///                                                 p2
24/// ```
25#[derive(Debug, Clone, Eq, PartialEq)]
26pub struct CubicBezier<T>(pub [T; 4]);
27
28/// Interpolates smoothly from 0.0 to 1.0 as `t` goes from 0.0 to 1.0.
29///
30/// Returns 0 for all `t` <= 0 and 1 for all `t` >= 1. Has a continuous
31/// first derivative.
32pub fn smoothstep(t: f32) -> f32 {
33    step(t, &0.0, &1.0, |t| t * t * (3.0 - 2.0 * t))
34}
35
36/// Even smoother version of [`smoothstep`].
37///
38/// Has continuous first and second derivatives.
39pub fn smootherstep(t: f32) -> f32 {
40    step(t, &0.0, &1.0, |t| t * t * t * (10.0 + t * (6.0 * t - 15.0)))
41}
42
43/// Helper for defining step functions.
44///
45/// Returns `min` if t ≤ 0, `max` if t ≥ 1, and `f(t)` if 0 < t < 1.
46#[inline]
47pub fn step<T: Clone, F>(t: f32, min: &T, max: &T, f: F) -> T
48where
49    F: FnOnce(f32) -> T,
50{
51    if t <= 0.0 {
52        min.clone()
53    } else if t >= 1.0 {
54        max.clone()
55    } else {
56        f(t)
57    }
58}
59
60impl<T> CubicBezier<T>
61where
62    T: Affine<Diff: Linear<Scalar = f32>> + Clone,
63{
64    /// Evaluates the value of `self` at `t`.
65    ///
66    /// For t < 0, returns the first control point. For t > 1, returns the last
67    /// control point. Uses [De Casteljau's algorithm][1].
68    ///
69    /// [1]: https://en.wikipedia.org/wiki/De_Casteljau%27s_algorithm
70    pub fn eval(&self, t: f32) -> T {
71        let [p0, p1, p2, p3] = &self.0;
72        step(t, p0, p3, |t| {
73            let p01 = p0.lerp(p1, t);
74            let p12 = p1.lerp(p2, t);
75            let p23 = p2.lerp(p3, t);
76            p01.lerp(&p12, t).lerp(&p12.lerp(&p23, t), t)
77        })
78    }
79
80    /// Evaluates the value of `self` at `t`.
81    ///
82    /// For t < 0, returns the first control point. For t > 1, returns the last
83    /// control point.
84    ///
85    /// Directly evaluates the cubic. Faster but possibly less numerically
86    /// stable than [`Self::eval`].
87    pub fn fast_eval(&self, t: f32) -> T {
88        let [p0, .., p3] = &self.0;
89        step(t, p0, p3, |t| {
90            // Add a linear combination of the three coefficients
91            // to `p0` to get the result
92            let [co3, co2, co1] = self.coefficients();
93            p0.add(&co3.mul(t).add(&co2).mul(t).add(&co1).mul(t))
94        })
95    }
96
97    /// Returns the tangent, or direction vector, of `self` at `t`.
98    ///
99    /// Clamps `t` to the range [0, 1].
100    pub fn tangent(&self, t: f32) -> T::Diff {
101        let [p0, p1, p2, p3] = &self.0;
102        let t = t.clamp(0.0, 1.0);
103
104        //   3 (3 (p1 - p2) + (p3 - p0)) * t^2
105        // + 6 ((p0 - p1 + p2 - p1) * t
106        // + 3 (p1 - p0)
107
108        let co2: T::Diff = p1.sub(p2).mul(3.0).add(&p3.sub(p0));
109        let co1: T::Diff = p0.sub(p1).add(&p2.sub(p1)).mul(2.0);
110        let co0: T::Diff = p1.sub(p0);
111
112        co2.mul(t).add(&co1).mul(t).add(&co0).mul(3.0)
113    }
114
115    /// Returns the coefficients used to evaluate the spline.
116    ///
117    /// These are constant as long as the control points do not change,
118    /// so they can be precomputed when the spline is evaluated several times,
119    /// for example by an iterator.
120    ///
121    /// The coefficient values are, from the first to the last:
122    /// ```text
123    /// co3 = (p3 - p0) + 3 * (p1 - p2)
124    /// co2 = 3 * (p0 - p1) + 3 * (p2 - p1)
125    /// co1 = 3 * (p1 - p0)
126    /// ```
127    /// The value of the spline at *t* is then computed as:
128    /// ```text
129    /// co3 * t^3 + co2 * t^2 + co1 * t + p0
130    ///
131    /// = (((co3 * t) + co2 * t) + co1 * t) + p0.
132    /// ```
133    fn coefficients(&self) -> [T::Diff; 3] {
134        let [p0, p1, p2, p3] = &self.0;
135
136        // Rewrite the parametric equation into a form where three of the
137        // coefficients are vectors, their linear combination added to `p0`
138        // so the equation can be expressed for affine types:
139        //
140        //   (p3 - p0) * t^3 + (p1 - p2) * 3t^3
141        // + (p0 + p2) * 3t^2 - p1 * 6t^2
142        // + (p1 - p0) * 3t
143        // + p0
144        // = ((p3 - p0 + 3(p1 - p2)) * t^3
145        // + 3(p0 - p1 + p2 - p1) * t^2
146        // + 3(p1 - p0) * t
147        // + p0
148        // = ((((p3 - p0 + 3(p1 - p2))) * t
149        //      + 3(p0 - p1 + p2 - p1)) * t)
150        //          + 3(p1 - p0)) * t)
151        //              + p0
152        let p3_p0 = p3.sub(p0);
153        let p1_p0_3 = p1.sub(p0).mul(3.0);
154        let p1_p2_3 = p1.sub(p2).mul(3.0);
155        [p3_p0.add(&p1_p2_3), p1_p0_3.add(&p1_p2_3).neg(), p1_p0_3]
156    }
157}
158
159/// A curve composed of one or more concatenated
160/// [cubic Bézier curves][CubicBezier].
161#[derive(Debug, Clone, Eq, PartialEq)]
162pub struct BezierSpline<T>(Vec<T>);
163
164impl<T> BezierSpline<T>
165where
166    T: Affine<Diff: Linear<Scalar = f32> + Clone> + Clone,
167{
168    /// Creates a Bézier spline from the given control points. The number of
169    /// elements in `pts` must be 3n + 1 for some positive integer n.
170    ///
171    /// Consecutive points in `pts` make up Bézier curves such that:
172    /// * `pts[0..=3]` define the first curve,
173    /// * `pts[3..=6]` define the second curve,
174    ///
175    /// and so on.
176    ///
177    /// # Panics
178    /// If `pts.len() < 4` or if `pts.len() % 3 != 1`.
179    pub fn new(pts: &[T]) -> Self {
180        assert!(
181            pts.len() >= 4 && pts.len() % 3 == 1,
182            "length must be 3n+1 for some integer n > 0, was {}",
183            pts.len()
184        );
185        Self(pts.to_vec())
186    }
187
188    /// Constructs a Bézier spline
189    pub fn from_rays<I>(rays: I) -> Self
190    where
191        I: IntoIterator<Item = Ray<T>>,
192    {
193        let mut rays = rays.into_iter().peekable();
194        let mut first = true;
195        let mut pts = Vec::with_capacity(2 * rays.size_hint().0);
196        while let Some(ray) = rays.next() {
197            if !first {
198                pts.push(ray.eval(-1.0));
199            }
200            first = false;
201            pts.push(ray.0.clone());
202            if rays.peek().is_some() {
203                pts.push(ray.eval(1.0));
204            }
205        }
206        Self::new(&pts)
207    }
208
209    /// Evaluates `self` at position `t`.
210    ///
211    /// Returns the first point if `t < 0` and the last point if `t > 1`.
212    pub fn eval(&self, t: f32) -> T {
213        // invariant self.0.len() != 0 -> last always exists
214        step(t, &self.0[0], self.0.last().unwrap(), |t| {
215            let (t, seg) = self.segment(t);
216            CubicBezier(seg).fast_eval(t)
217        })
218    }
219
220    /// Returns the tangent of `self` at `t`.
221    ///
222    /// Clamps `t` to the range [0, 1].
223    pub fn tangent(&self, t: f32) -> T::Diff {
224        let (t, seg) = self.segment(t);
225        CubicBezier(seg).tangent(t)
226    }
227
228    fn segment(&self, t: f32) -> (f32, [T; 4]) {
229        let segs = ((self.0.len() - 1) / 3) as f32;
230        // TODO use floor and make the code cleaner
231        let seg = ((t * segs) as u32 as f32).min(segs - 1.0);
232        let t2 = t * segs - seg;
233        let idx = 3 * (seg as usize);
234        (t2, array::from_fn(|k| self.0[idx + k].clone()))
235    }
236
237    /// Approximates `self` as a sequence of line segments.
238    ///
239    /// Recursively subdivides the curve into two half-curves, stopping once
240    /// the approximation error is small enough, as determined by the `halt`
241    /// function.
242    ///
243    /// Given a curve segment between some points `p` and `r`, the parameter
244    /// passed to `halt` is the distance to the real midpoint `q` from its
245    /// linear approximation `q'`. If `halt` returns `true`, the line segment
246    /// `pr` is returned as the approximation of this curve segment, otherwise
247    /// the bisection continues.
248    ///
249    /// Note that this heuristic does not work well in certain edge cases
250    /// (consider, for example, an S-shaped curve where `q'` is very close
251    /// to `q`, yet a straight line would be a poor approximation). However,
252    /// in practice it tends to give reasonable results.
253    ///
254    /// ```text
255    ///                 ___--- q ---___
256    ///             _--´       |       `--_
257    ///         _--´           |           `--_
258    ///      _-p ------------- q' ------------ r-_
259    ///   _-´                                     `-_
260    /// ```
261    ///
262    /// # Examples
263    /// ```
264    /// use retrofire_core::math::{BezierSpline, vec2, Vec2};
265    ///
266    /// let curve = BezierSpline::<Vec2>::new(
267    ///     &[vec2(0.0, 0.0), vec2(0.0, 1.0), vec2(1.0, 1.0), vec2(1.0, 0.0)]
268    /// );
269    /// let approx = curve.approximate(|err| err.len_sqr() < 0.01*0.01);
270    /// assert_eq!(approx.0.len(), 17);
271    /// ```
272    pub fn approximate(&self, halt: impl Fn(&T::Diff) -> bool) -> Polyline<T> {
273        let len = self.0.len();
274        let mut res = Vec::with_capacity(3 * len);
275        self.do_approx(0.0, 1.0, 10 + len.ilog2(), &halt, &mut res);
276        res.push(self.0[len - 1].clone());
277        Polyline(res)
278    }
279
280    fn do_approx(
281        &self,
282        a: f32,
283        b: f32,
284        max_dep: u32,
285        halt: &impl Fn(&T::Diff) -> bool,
286        accum: &mut Vec<T>,
287    ) {
288        let mid = a.lerp(&b, 0.5);
289
290        let ap = self.eval(a);
291        let bp = self.eval(b);
292
293        let real = self.eval(mid);
294        let approx = ap.lerp(&bp, 0.5);
295
296        if max_dep == 0 || halt(&real.sub(&approx)) {
297            accum.push(ap);
298        } else {
299            self.do_approx(a, mid, max_dep - 1, halt, accum);
300            self.do_approx(mid, b, max_dep - 1, halt, accum);
301        }
302    }
303}
304
305impl<T> Parametric<T> for CubicBezier<T>
306where
307    T: Affine<Diff: Linear<Scalar = f32> + Clone> + Clone,
308{
309    fn eval(&self, t: f32) -> T {
310        self.fast_eval(t)
311    }
312}
313
314impl<T> Parametric<T> for BezierSpline<T>
315where
316    T: Affine<Diff: Linear<Scalar = f32> + Clone> + Clone,
317{
318    fn eval(&self, t: f32) -> T {
319        self.eval(t)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use alloc::vec;
326
327    use crate::assert_approx_eq;
328    use crate::math::{Parametric, Point2, Vec2, pt2, vec2};
329
330    use super::*;
331
332    #[test]
333    fn smoothstep_test() {
334        assert_eq!(0.0, smoothstep(-10.0));
335        assert_eq!(0.0, smoothstep(0.0));
336
337        assert_eq!(0.15625, smoothstep(0.25));
338        assert_eq!(0.50000, smoothstep(0.5));
339        assert_eq!(0.84375, smoothstep(0.75));
340
341        assert_eq!(1.0, smoothstep(1.0));
342        assert_eq!(1.0, smoothstep(10.0));
343
344        assert_eq!(0.15625, smoothstep.eval(0.25));
345    }
346
347    #[test]
348    fn smootherstep_test() {
349        assert_eq!(0.0, smootherstep(-10.0));
350        assert_eq!(0.0, smootherstep(0.0));
351
352        assert_eq!(0.103515625, smootherstep(0.25));
353        assert_eq!(0.5, smootherstep(0.5));
354        assert_eq!(0.8964844, smootherstep(0.75));
355
356        assert_eq!(1.0, smootherstep(1.0));
357        assert_eq!(1.0, smootherstep(10.0));
358
359        assert_eq!(0.103515625, smootherstep.eval(0.25));
360    }
361
362    #[test]
363    fn bezier_spline_eval_eq_fast_eval() {
364        let b: CubicBezier<Vec2> = CubicBezier(
365            [[0.0, 0.0], [0.0, 2.0], [1.0, -1.0], [1.0, 1.0]].map(Vec2::from),
366        );
367        for i in 0..11 {
368            let t = i as f32 / 10.0;
369            let (v, u) = (b.eval(t), b.fast_eval(t));
370            assert_approx_eq!(v.x(), u.x(), eps = 1e-5);
371            assert_approx_eq!(v.y(), u.y(), eps = 1e-5);
372        }
373    }
374
375    #[test]
376    fn bezier_spline_eval_1d() {
377        let b = CubicBezier([0.0, 2.0, -1.0, 1.0]);
378
379        assert_eq!(b.eval(-1.0), 0.0);
380        assert_eq!(b.eval(0.00), 0.0);
381        assert_eq!(b.eval(0.25), 0.71875);
382        assert_eq!(b.eval(0.50), 0.5);
383        assert_eq!(b.eval(0.75), 0.28125);
384        assert_eq!(b.eval(1.00), 1.0);
385        assert_eq!(b.eval(2.00), 1.0);
386    }
387
388    #[test]
389    fn bezier_spline_eval_2d_vec() {
390        let b = CubicBezier(
391            [[0.0, 0.0], [0.0, 2.0], [1.0, -1.0], [1.0, 1.0]]
392                .map(Vec2::<()>::from),
393        );
394
395        assert_eq!(b.eval(-1.0), vec2(0.0, 0.0));
396        assert_eq!(b.eval(0.00), vec2(0.0, 0.0));
397        assert_eq!(b.eval(0.25), vec2(0.15625, 0.71875));
398        assert_eq!(b.eval(0.50), vec2(0.5, 0.5));
399        assert_eq!(b.eval(0.75), vec2(0.84375, 0.281250));
400        assert_eq!(b.eval(1.00), vec2(1.0, 1.0));
401        assert_eq!(b.eval(2.00), vec2(1.0, 1.0));
402    }
403
404    #[test]
405    fn bezier_spline_eval_2d_point() {
406        let b = CubicBezier(
407            [[0.0, 0.0], [0.0, 2.0], [1.0, -1.0], [1.0, 1.0]]
408                .map(Point2::<()>::from),
409        );
410
411        assert_eq!(b.eval(-1.0), pt2(0.0, 0.0));
412        assert_eq!(b.eval(0.00), pt2(0.0, 0.0));
413        assert_eq!(b.eval(0.25), pt2(0.15625, 0.71875));
414        assert_eq!(b.eval(0.50), pt2(0.5, 0.5));
415        assert_eq!(b.eval(0.75), pt2(0.84375, 0.281250));
416        assert_eq!(b.eval(1.00), pt2(1.0, 1.0));
417        assert_eq!(b.eval(2.00), pt2(1.0, 1.0));
418    }
419
420    #[test]
421    fn bezier_spline_tangent_1d() {
422        let b = CubicBezier([0.0, 2.0, -1.0, 1.0]);
423
424        assert_eq!(b.tangent(-1.0), 6.0);
425        assert_eq!(b.tangent(0.00), 6.0);
426        assert_eq!(b.tangent(0.25), 0.375);
427        assert_eq!(b.tangent(0.50), -1.5);
428        assert_eq!(b.tangent(0.75), 0.375);
429        assert_eq!(b.tangent(1.00), 6.0);
430        assert_eq!(b.tangent(2.00), 6.0);
431    }
432
433    #[test]
434    fn bezier_spline_tangent_2d() {
435        let b = CubicBezier(
436            [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]
437                .map(Point2::<()>::from),
438        );
439
440        assert_eq!(b.tangent(-1.0), vec2(0.0, 3.0),);
441        assert_eq!(b.tangent(0.0), vec2(0.0, 3.0),);
442        assert_eq!(b.tangent(0.25), vec2(1.125, 0.75),);
443        assert_eq!(b.tangent(0.5), vec2(1.5, 0.0),);
444        assert_eq!(b.tangent(0.75), vec2(1.125, 0.75),);
445        assert_eq!(b.tangent(1.0), vec2(0.0, 3.0),);
446        assert_eq!(b.tangent(2.0), vec2(0.0, 3.0),);
447    }
448
449    #[test]
450    fn bezier_spline_eval() {
451        let c = BezierSpline(vec![0.0, 0.8, 0.9, 1.0, 0.6, 0.5, 0.5]);
452        assert_eq!(c.eval(-1.0), 0.0);
453        assert_eq!(c.eval(0.0), 0.0);
454        assert_approx_eq!(c.eval(0.25), 0.7625);
455        assert_eq!(c.eval(0.5), 1.0);
456        assert_eq!(c.eval(0.75), 0.6);
457        assert_eq!(c.eval(1.0), 0.5);
458        assert_eq!(c.eval(2.0), 0.5);
459    }
460}