1use crate::{Curve, CurvePoint, Distance};
2use num_traits::Float;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5
6#[derive(Clone, PartialEq)]
8pub struct Bezier0<F: Float, P: CurvePoint<F>> {
9 pub point: P,
10 phantom_data: PhantomData<F>,
11}
12
13impl<F: Float, P: CurvePoint<F>> Bezier0<F, P> {
14 pub fn new(point: P) -> Self {
15 Self {
16 point,
17 phantom_data: Default::default(),
18 }
19 }
20}
21
22#[derive(Clone, PartialEq)]
24pub struct Bezier1<F: Float, P: CurvePoint<F>> {
25 pub p0: P,
26 pub p1: P,
27 phantom_data: PhantomData<F>,
28}
29
30impl<F: Float, P: CurvePoint<F>> Bezier1<F, P> {
31 pub fn new(p0: P, p1: P) -> Self {
32 Self {
33 p0,
34 p1,
35 phantom_data: Default::default(),
36 }
37 }
38}
39
40#[derive(Clone, PartialEq)]
42pub struct Bezier2<F: Float, P: CurvePoint<F>> {
43 pub p0: P,
44 pub p1: P,
45 pub p2: P,
46 phantom_data: PhantomData<F>,
47}
48
49impl<F: Float, P: CurvePoint<F>> Bezier2<F, P> {
50 pub fn new(p0: P, p1: P, p2: P) -> Self {
51 Self {
52 p0,
53 p1,
54 p2,
55 phantom_data: Default::default(),
56 }
57 }
58}
59
60#[derive(Clone, PartialEq)]
62pub struct Bezier3<F: Float, P: CurvePoint<F>> {
63 pub p0: P,
64 pub p1: P,
65 pub p2: P,
66 pub p3: P,
67 phantom_data: PhantomData<F>,
68}
69
70impl<F: Float, P: CurvePoint<F>> Bezier3<F, P> {
71 pub fn new(p0: P, p1: P, p2: P, p3: P) -> Self {
72 Self {
73 p0,
74 p1,
75 p2,
76 p3,
77 phantom_data: Default::default(),
78 }
79 }
80}
81
82#[derive(Clone, PartialEq)]
83pub enum Bezier<F: Float, P: CurvePoint<F>> {
84 C0(Bezier0<F, P>),
85 C1(Bezier1<F, P>),
86 C2(Bezier2<F, P>),
87 C3(Bezier3<F, P>),
88}
89
90impl<F: Float, P: CurvePoint<F>> Copy for Bezier<F, P> where P: Copy {}
91impl<F: Float, P: CurvePoint<F>> Copy for Bezier0<F, P> where P: Copy {}
92impl<F: Float, P: CurvePoint<F>> Copy for Bezier1<F, P> where P: Copy {}
93impl<F: Float, P: CurvePoint<F>> Copy for Bezier2<F, P> where P: Copy {}
94impl<F: Float, P: CurvePoint<F>> Copy for Bezier3<F, P> where P: Copy {}
95
96macro_rules! for_every_level {
97 ($curve:ident, $name:ident, $block:block) => {
98 match $curve {
99 Bezier::C0($name) => $block,
100 Bezier::C1($name) => $block,
101 Bezier::C2($name) => $block,
102 Bezier::C3($name) => $block,
103 }
104 };
105}
106
107impl<F: Float, P: CurvePoint<F>> Debug for Bezier<F, P>
108where
109 P: Debug,
110{
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_tuple("Bezier")
113 .field(for_every_level!(self, c, { c }))
114 .finish()
115 }
116}
117impl<F: Float, P: CurvePoint<F>> Debug for Bezier0<F, P>
118where
119 P: Debug,
120{
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_tuple("Bezier0").field(&self.point).finish()
123 }
124}
125impl<F: Float, P: CurvePoint<F>> Debug for Bezier1<F, P>
126where
127 P: Debug,
128{
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_tuple("Bezier1")
131 .field(&self.p0)
132 .field(&self.p1)
133 .finish()
134 }
135}
136
137impl<F: Float, P: CurvePoint<F>> Debug for Bezier2<F, P>
138where
139 P: Debug,
140{
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 f.debug_tuple("Bezier2")
143 .field(&self.p0)
144 .field(&self.p1)
145 .field(&self.p2)
146 .finish()
147 }
148}
149
150impl<F: Float, P: CurvePoint<F>> Debug for Bezier3<F, P>
151where
152 P: Debug,
153{
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_tuple("Bezier3")
156 .field(&self.p0)
157 .field(&self.p1)
158 .field(&self.p2)
159 .field(&self.p3)
160 .finish()
161 }
162}
163
164impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier0<F, P> {
165 fn value_at(&self, _t: F) -> P {
166 self.point.clone()
167 }
168
169 fn tangent_at(&self, _t: F) -> P {
170 self.point.scale(F::zero())
171 }
172
173 fn start_point(&self) -> P {
174 self.point.clone()
175 }
176
177 fn end_point(&self) -> P {
178 self.point.clone()
179 }
180
181 fn estimate_length(&self, _precision: F) -> F
182 where
183 P: Distance<F>,
184 {
185 F::zero()
186 }
187}
188
189impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier1<F, P> {
190 fn value_at(&self, t: F) -> P {
191 self.p0.add(&self.p1.sub(&self.p0).scale(t))
192 }
193
194 fn tangent_at(&self, _t: F) -> P {
195 self.p1.sub(&self.p0)
196 }
197
198 fn start_point(&self) -> P {
199 self.p0.clone()
200 }
201
202 fn end_point(&self) -> P {
203 self.p1.clone()
204 }
205
206 fn estimate_length(&self, _precision: F) -> F
207 where
208 P: Distance<F>,
209 {
210 self.p0.distance(&self.p1)
211 }
212}
213
214impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier2<F, P> {
215 fn value_at(&self, t: F) -> P {
216 let t2 = t * t;
217 let t1 = F::one() - t;
218 let t12 = t1 * t1;
219
220 let two = F::one() + F::one();
221
222 self.p0
223 .scale(t12)
224 .add(&self.p1.scale(two * t1 * t))
225 .add(&self.p2.scale(t2))
226 }
227
228 fn tangent_at(&self, t: F) -> P {
229 let p0 = &self.p0;
230 let p1 = &self.p1;
231 let p2 = &self.p2;
232
233 let two = F::one() + F::one();
234
235 let t2 = t + t;
236 let nt2 = two - t2;
237
238 let v1 = p1.sub(p0).scale(nt2);
239 let v2 = p2.sub(p1).scale(t2);
240
241 v1.add(&v2)
242 }
243
244 fn start_point(&self) -> P {
245 self.p0.clone()
246 }
247
248 fn end_point(&self) -> P {
249 self.p2.clone()
250 }
251
252 fn estimate_length(&self, precision: F) -> F
253 where
254 P: Distance<F>,
255 {
256 let p0 = &self.p0;
257 let p1 = &self.p1;
258 let p2 = &self.p2;
259
260 let min = p0.distance(p1);
261 let max = p0.distance(p1) + p1.distance(p2);
262
263 let half = F::one() / (F::one() + F::one());
264
265 if max == F::zero() {
266 F::zero()
267 } else if (max - min) / max < precision {
268 (min + max) * half
269 } else {
270 let m01 = p0.add(p1).scale(half);
271 let m12 = p1.add(p2).scale(half);
272 let m = m01.add(&m12).scale(half);
273
274 let b1 = Bezier2::new(p0.clone(), m01, m.clone());
275 let b2 = Bezier2::new(m, m12, p2.clone());
276
277 b1.estimate_length(precision) + b2.estimate_length(precision)
278 }
279 }
280}
281
282impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier3<F, P> {
283 fn value_at(&self, t: F) -> P {
284 let three = F::one() + F::one() + F::one();
285
286 let t2 = t * t;
287 let t3 = t2 * t;
288
289 let nt = F::one() - t;
290 let nt2 = nt * nt;
291 let nt3 = nt2 * nt;
292
293 self.p0
294 .scale(nt3)
295 .add(&self.p1.scale(three * nt2 * t))
296 .add(&self.p2.scale(three * nt * t2).add(&self.p3.scale(t3)))
297 }
298
299 fn tangent_at(&self, t: F) -> P {
300 let p0 = &self.p0;
301 let p1 = &self.p1;
302 let p2 = &self.p2;
303 let p3 = &self.p3;
304
305 let three = F::one() + F::one() + F::one();
306 let six = three + three;
307
308 let t2 = t * t;
309
310 let nt = F::one() - t;
311 let nt2 = nt * nt;
312
313 let v1 = p1.sub(p0).scale(three * nt2);
314 let v2 = p2.sub(p1).scale(six * nt * t);
315 let v3 = p3.sub(p2).scale(three * t2);
316
317 v1.add(&v2).add(&v3)
318 }
319
320 fn start_point(&self) -> P {
321 self.p0.clone()
322 }
323
324 fn end_point(&self) -> P {
325 self.p3.clone()
326 }
327
328 fn estimate_length(&self, precision: F) -> F
329 where
330 P: Distance<F>,
331 {
332 let p0 = &self.p0;
333 let p1 = &self.p1;
334 let p2 = &self.p2;
335 let p3 = &self.p3;
336
337 let min = p0.distance(p3);
338 let max = p0.distance(p1) + p1.distance(p2) + p2.distance(p3);
339
340 let half = F::one() / (F::one() + F::one());
341
342 if max == F::zero() {
343 F::zero()
344 } else if (max - min) / max < precision {
345 (min + max) * half
346 } else {
347 let m01 = p0.add(p1).scale(half);
348 let m12 = p1.add(p2).scale(half);
349 let m23 = p2.add(p3).scale(half);
350 let m012 = m01.add(&m12).scale(half);
351 let m123 = m12.add(&m23).scale(half);
352 let m = m012.add(&m123).scale(half);
353
354 let b1 = Bezier3::new(p0.clone(), m01, m012, m.clone());
355 let b2 = Bezier3::new(m, m123, m23, p3.clone());
356
357 b1.estimate_length(precision) + b2.estimate_length(precision)
358 }
359 }
360}
361
362impl<F: Float, P: CurvePoint<F>> Curve<F, P> for Bezier<F, P> {
363 fn value_at(&self, t: F) -> P {
364 for_every_level!(self, c, { c.value_at(t) })
365 }
366
367 fn tangent_at(&self, t: F) -> P {
368 for_every_level!(self, c, { c.tangent_at(t) })
369 }
370
371 fn start_point(&self) -> P {
372 for_every_level!(self, c, { c.start_point() })
373 }
374
375 fn end_point(&self) -> P {
376 for_every_level!(self, c, { c.end_point() })
377 }
378
379 fn estimate_length(&self, precision: F) -> F
380 where
381 P: Distance<F>,
382 {
383 for_every_level!(self, c, { c.estimate_length(precision) })
384 }
385}
386
387#[cfg(test)]
388mod test {
389 use super::*;
390
391 #[test]
392 fn bezier_0() {
393 let curve = Bezier0::new(2.0);
394 assert_eq!(curve.value_at(0.0), 2.0);
395 assert_eq!(curve.value_at(0.5), 2.0);
396 assert_eq!(curve.value_at(1.0), 2.0);
397 }
398
399 #[test]
400 fn bezier_1() {
401 let curve = Bezier1::new(1.0, 3.0);
402 assert_eq!(curve.value_at(0.0), 1.0);
403 assert_eq!(curve.value_at(0.5), 2.0);
404 assert_eq!(curve.value_at(1.0), 3.0);
405 }
406
407 #[test]
408 fn bezier_2() {
409 let curve = Bezier2::new(1.0, 3.0, 2.0);
410 assert_eq!(curve.value_at(0.0), 1.0);
411 assert_eq!(curve.value_at(0.5), 2.25);
412 assert_eq!(curve.value_at(1.0), 2.0);
413 }
414
415 #[test]
416 fn bezier_3() {
417 let curve = Bezier3::new(1.0, 4.0, 2.0, 4.0);
418 assert_eq!(curve.value_at(0.0), 1.0);
419 assert_eq!(curve.value_at(0.5), 2.875);
420 assert_eq!(curve.value_at(1.0), 4.0);
421 }
422
423 #[derive(Clone, PartialEq, Debug)]
424 struct Point {
425 x: f64,
426 y: f64,
427 }
428 impl CurvePoint<f64> for Point {
429 fn add(&self, other: &Self) -> Self {
430 Point {
431 x: self.x + other.x,
432 y: self.y + other.y,
433 }
434 }
435
436 fn sub(&self, other: &Self) -> Self {
437 Point {
438 x: self.x - other.x,
439 y: self.y - other.y,
440 }
441 }
442
443 fn multiply(&self, other: &Self) -> Self {
444 Point {
445 x: self.x * other.x,
446 y: self.y * other.y,
447 }
448 }
449
450 fn scale(&self, s: f64) -> Self {
451 Point {
452 x: self.x * s,
453 y: self.y * s,
454 }
455 }
456 }
457
458 #[test]
459 fn cubic_bezier_2d() {
460 let curve = Bezier3::new(
461 Point { x: 0.0, y: 0.0 },
462 Point { x: 0.0, y: 1.0 },
463 Point { x: 2.0, y: -1.0 },
464 Point { x: 2.0, y: 0.0 },
465 );
466
467 assert_eq!(curve.value_at(0.0), Point { x: 0.0, y: 0.0 });
468 assert_eq!(curve.value_at(0.5), Point { x: 1.0, y: 0.0 });
469 assert_eq!(curve.value_at(1.0), Point { x: 2.0, y: 0.0 });
470
471 assert_eq!(curve.tangent_at(0.0), Point { x: 0.0, y: 3.0 });
472 assert_eq!(curve.tangent_at(0.5), Point { x: 3.0, y: -1.5 });
473 assert_eq!(curve.tangent_at(1.0), Point { x: 0.0, y: 3.0 });
474 }
475}