baiser/
bezier.rs

1use crate::{Curve, CurvePoint, Distance};
2use num_traits::Float;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6/// Single point
7#[derive(Clone, PartialEq)]
8pub struct Bezier0<F: Float, P: CurvePoint<F>> {
9    pub point: P,
10    phantom_data: PhantomData<F>,
11}
12
13impl<F: Float, P: CurvePoint<F>> Bezier0<F, P> {
14    pub fn new(point: P) -> Self {
15        Self {
16            point,
17            phantom_data: Default::default(),
18        }
19    }
20}
21
22/// Line
23#[derive(Clone, PartialEq)]
24pub struct Bezier1<F: Float, P: CurvePoint<F>> {
25    pub p0: P,
26    pub p1: P,
27    phantom_data: PhantomData<F>,
28}
29
30impl<F: Float, P: CurvePoint<F>> Bezier1<F, P> {
31    pub fn new(p0: P, p1: P) -> Self {
32        Self {
33            p0,
34            p1,
35            phantom_data: Default::default(),
36        }
37    }
38}
39
40/// Quadratic bezier curve
41#[derive(Clone, PartialEq)]
42pub struct Bezier2<F: Float, P: CurvePoint<F>> {
43    pub p0: P,
44    pub p1: P,
45    pub p2: P,
46    phantom_data: PhantomData<F>,
47}
48
49impl<F: Float, P: CurvePoint<F>> Bezier2<F, P> {
50    pub fn new(p0: P, p1: P, p2: P) -> Self {
51        Self {
52            p0,
53            p1,
54            p2,
55            phantom_data: Default::default(),
56        }
57    }
58}
59
60/// Cubic bezier curve
61#[derive(Clone, PartialEq)]
62pub struct Bezier3<F: Float, P: CurvePoint<F>> {
63    pub p0: P,
64    pub p1: P,
65    pub p2: P,
66    pub p3: P,
67    phantom_data: PhantomData<F>,
68}
69
70impl<F: Float, P: CurvePoint<F>> Bezier3<F, P> {
71    pub fn new(p0: P, p1: P, p2: P, p3: P) -> Self {
72        Self {
73            p0,
74            p1,
75            p2,
76            p3,
77            phantom_data: Default::default(),
78        }
79    }
80}
81
82#[derive(Clone, PartialEq)]
83pub enum Bezier<F: Float, P: CurvePoint<F>> {
84    C0(Bezier0<F, P>),
85    C1(Bezier1<F, P>),
86    C2(Bezier2<F, P>),
87    C3(Bezier3<F, P>),
88}
89
90impl<F: Float, P: CurvePoint<F>> Copy for Bezier<F, P> where P: Copy {}
91impl<F: Float, P: CurvePoint<F>> Copy for Bezier0<F, P> where P: Copy {}
92impl<F: Float, P: CurvePoint<F>> Copy for Bezier1<F, P> where P: Copy {}
93impl<F: Float, P: CurvePoint<F>> Copy for Bezier2<F, P> where P: Copy {}
94impl<F: Float, P: CurvePoint<F>> Copy for Bezier3<F, P> where P: Copy {}
95
96macro_rules! for_every_level {
97    ($curve:ident, $name:ident, $block:block) => {
98        match $curve {
99            Bezier::C0($name) => $block,
100            Bezier::C1($name) => $block,
101            Bezier::C2($name) => $block,
102            Bezier::C3($name) => $block,
103        }
104    };
105}
106
107impl<F: Float, P: CurvePoint<F>> Debug for Bezier<F, P>
108where
109    P: Debug,
110{
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        f.debug_tuple("Bezier")
113            .field(for_every_level!(self, c, { c }))
114            .finish()
115    }
116}
117impl<F: Float, P: CurvePoint<F>> Debug for Bezier0<F, P>
118where
119    P: Debug,
120{
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_tuple("Bezier0").field(&self.point).finish()
123    }
124}
125impl<F: Float, P: CurvePoint<F>> Debug for Bezier1<F, P>
126where
127    P: Debug,
128{
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_tuple("Bezier1")
131            .field(&self.p0)
132            .field(&self.p1)
133            .finish()
134    }
135}
136
137impl<F: Float, P: CurvePoint<F>> Debug for Bezier2<F, P>
138where
139    P: Debug,
140{
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        f.debug_tuple("Bezier2")
143            .field(&self.p0)
144            .field(&self.p1)
145            .field(&self.p2)
146            .finish()
147    }
148}
149
150impl<F: Float, P: CurvePoint<F>> Debug for Bezier3<F, P>
151where
152    P: Debug,
153{
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        f.debug_tuple("Bezier3")
156            .field(&self.p0)
157            .field(&self.p1)
158            .field(&self.p2)
159            .field(&self.p3)
160            .finish()
161    }
162}
163
164impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier0<F, P> {
165    fn value_at(&self, _t: F) -> P {
166        self.point.clone()
167    }
168
169    fn tangent_at(&self, _t: F) -> P {
170        self.point.scale(F::zero())
171    }
172
173    fn start_point(&self) -> P {
174        self.point.clone()
175    }
176
177    fn end_point(&self) -> P {
178        self.point.clone()
179    }
180
181    fn estimate_length(&self, _precision: F) -> F
182    where
183        P: Distance<F>,
184    {
185        F::zero()
186    }
187}
188
189impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier1<F, P> {
190    fn value_at(&self, t: F) -> P {
191        self.p0.add(&self.p1.sub(&self.p0).scale(t))
192    }
193
194    fn tangent_at(&self, _t: F) -> P {
195        self.p1.sub(&self.p0)
196    }
197
198    fn start_point(&self) -> P {
199        self.p0.clone()
200    }
201
202    fn end_point(&self) -> P {
203        self.p1.clone()
204    }
205
206    fn estimate_length(&self, _precision: F) -> F
207    where
208        P: Distance<F>,
209    {
210        self.p0.distance(&self.p1)
211    }
212}
213
214impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier2<F, P> {
215    fn value_at(&self, t: F) -> P {
216        let t2 = t * t;
217        let t1 = F::one() - t;
218        let t12 = t1 * t1;
219
220        let two = F::one() + F::one();
221
222        self.p0
223            .scale(t12)
224            .add(&self.p1.scale(two * t1 * t))
225            .add(&self.p2.scale(t2))
226    }
227
228    fn tangent_at(&self, t: F) -> P {
229        let p0 = &self.p0;
230        let p1 = &self.p1;
231        let p2 = &self.p2;
232
233        let two = F::one() + F::one();
234
235        let t2 = t + t;
236        let nt2 = two - t2;
237
238        let v1 = p1.sub(p0).scale(nt2);
239        let v2 = p2.sub(p1).scale(t2);
240
241        v1.add(&v2)
242    }
243
244    fn start_point(&self) -> P {
245        self.p0.clone()
246    }
247
248    fn end_point(&self) -> P {
249        self.p2.clone()
250    }
251
252    fn estimate_length(&self, precision: F) -> F
253    where
254        P: Distance<F>,
255    {
256        let p0 = &self.p0;
257        let p1 = &self.p1;
258        let p2 = &self.p2;
259
260        let min = p0.distance(p1);
261        let max = p0.distance(p1) + p1.distance(p2);
262
263        let half = F::one() / (F::one() + F::one());
264
265        if max == F::zero() {
266            F::zero()
267        } else if (max - min) / max < precision {
268            (min + max) * half
269        } else {
270            let m01 = p0.add(p1).scale(half);
271            let m12 = p1.add(p2).scale(half);
272            let m = m01.add(&m12).scale(half);
273
274            let b1 = Bezier2::new(p0.clone(), m01, m.clone());
275            let b2 = Bezier2::new(m, m12, p2.clone());
276
277            b1.estimate_length(precision) + b2.estimate_length(precision)
278        }
279    }
280}
281
282impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier3<F, P> {
283    fn value_at(&self, t: F) -> P {
284        let three = F::one() + F::one() + F::one();
285
286        let t2 = t * t;
287        let t3 = t2 * t;
288
289        let nt = F::one() - t;
290        let nt2 = nt * nt;
291        let nt3 = nt2 * nt;
292
293        self.p0
294            .scale(nt3)
295            .add(&self.p1.scale(three * nt2 * t))
296            .add(&self.p2.scale(three * nt * t2).add(&self.p3.scale(t3)))
297    }
298
299    fn tangent_at(&self, t: F) -> P {
300        let p0 = &self.p0;
301        let p1 = &self.p1;
302        let p2 = &self.p2;
303        let p3 = &self.p3;
304
305        let three = F::one() + F::one() + F::one();
306        let six = three + three;
307
308        let t2 = t * t;
309
310        let nt = F::one() - t;
311        let nt2 = nt * nt;
312
313        let v1 = p1.sub(p0).scale(three * nt2);
314        let v2 = p2.sub(p1).scale(six * nt * t);
315        let v3 = p3.sub(p2).scale(three * t2);
316
317        v1.add(&v2).add(&v3)
318    }
319
320    fn start_point(&self) -> P {
321        self.p0.clone()
322    }
323
324    fn end_point(&self) -> P {
325        self.p3.clone()
326    }
327
328    fn estimate_length(&self, precision: F) -> F
329    where
330        P: Distance<F>,
331    {
332        let p0 = &self.p0;
333        let p1 = &self.p1;
334        let p2 = &self.p2;
335        let p3 = &self.p3;
336
337        let min = p0.distance(p3);
338        let max = p0.distance(p1) + p1.distance(p2) + p2.distance(p3);
339
340        let half = F::one() / (F::one() + F::one());
341
342        if max == F::zero() {
343            F::zero()
344        } else if (max - min) / max < precision {
345            (min + max) * half
346        } else {
347            let m01 = p0.add(p1).scale(half);
348            let m12 = p1.add(p2).scale(half);
349            let m23 = p2.add(p3).scale(half);
350            let m012 = m01.add(&m12).scale(half);
351            let m123 = m12.add(&m23).scale(half);
352            let m = m012.add(&m123).scale(half);
353
354            let b1 = Bezier3::new(p0.clone(), m01, m012, m.clone());
355            let b2 = Bezier3::new(m, m123, m23, p3.clone());
356
357            b1.estimate_length(precision) + b2.estimate_length(precision)
358        }
359    }
360}
361
362impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier<F, P> {
363    fn value_at(&self, t: F) -> P {
364        for_every_level!(self, c, { c.value_at(t) })
365    }
366
367    fn tangent_at(&self, t: F) -> P {
368        for_every_level!(self, c, { c.tangent_at(t) })
369    }
370
371    fn start_point(&self) -> P {
372        for_every_level!(self, c, { c.start_point() })
373    }
374
375    fn end_point(&self) -> P {
376        for_every_level!(self, c, { c.end_point() })
377    }
378
379    fn estimate_length(&self, precision: F) -> F
380    where
381        P: Distance<F>,
382    {
383        for_every_level!(self, c, { c.estimate_length(precision) })
384    }
385}
386
387#[cfg(test)]
388mod test {
389    use super::*;
390
391    #[test]
392    fn bezier_0() {
393        let curve = Bezier0::new(2.0);
394        assert_eq!(curve.value_at(0.0), 2.0);
395        assert_eq!(curve.value_at(0.5), 2.0);
396        assert_eq!(curve.value_at(1.0), 2.0);
397    }
398
399    #[test]
400    fn bezier_1() {
401        let curve = Bezier1::new(1.0, 3.0);
402        assert_eq!(curve.value_at(0.0), 1.0);
403        assert_eq!(curve.value_at(0.5), 2.0);
404        assert_eq!(curve.value_at(1.0), 3.0);
405    }
406
407    #[test]
408    fn bezier_2() {
409        let curve = Bezier2::new(1.0, 3.0, 2.0);
410        assert_eq!(curve.value_at(0.0), 1.0);
411        assert_eq!(curve.value_at(0.5), 2.25);
412        assert_eq!(curve.value_at(1.0), 2.0);
413    }
414
415    #[test]
416    fn bezier_3() {
417        let curve = Bezier3::new(1.0, 4.0, 2.0, 4.0);
418        assert_eq!(curve.value_at(0.0), 1.0);
419        assert_eq!(curve.value_at(0.5), 2.875);
420        assert_eq!(curve.value_at(1.0), 4.0);
421    }
422
423    #[derive(Clone, PartialEq, Debug)]
424    struct Point {
425        x: f64,
426        y: f64,
427    }
428    impl CurvePoint<f64> for Point {
429        fn add(&self, other: &Self) -> Self {
430            Point {
431                x: self.x + other.x,
432                y: self.y + other.y,
433            }
434        }
435
436        fn sub(&self, other: &Self) -> Self {
437            Point {
438                x: self.x - other.x,
439                y: self.y - other.y,
440            }
441        }
442
443        fn multiply(&self, other: &Self) -> Self {
444            Point {
445                x: self.x * other.x,
446                y: self.y * other.y,
447            }
448        }
449
450        fn scale(&self, s: f64) -> Self {
451            Point {
452                x: self.x * s,
453                y: self.y * s,
454            }
455        }
456    }
457
458    #[test]
459    fn cubic_bezier_2d() {
460        let curve = Bezier3::new(
461            Point { x: 0.0, y: 0.0 },
462            Point { x: 0.0, y: 1.0 },
463            Point { x: 2.0, y: -1.0 },
464            Point { x: 2.0, y: 0.0 },
465        );
466
467        assert_eq!(curve.value_at(0.0), Point { x: 0.0, y: 0.0 });
468        assert_eq!(curve.value_at(0.5), Point { x: 1.0, y: 0.0 });
469        assert_eq!(curve.value_at(1.0), Point { x: 2.0, y: 0.0 });
470
471        assert_eq!(curve.tangent_at(0.0), Point { x: 0.0, y: 3.0 });
472        assert_eq!(curve.tangent_at(0.5), Point { x: 3.0, y: -1.5 });
473        assert_eq!(curve.tangent_at(1.0), Point { x: 0.0, y: 3.0 });
474    }
475}