1use crate::{curve::*, range_iter};
2use core::Scalar;
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use std::{convert::TryFrom, fmt};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum SplinePointDirection<T>
8where
9 T: Curved,
10{
11 Single(T),
12 InOut(T, T),
13}
14
15impl<T> Default for SplinePointDirection<T>
16where
17 T: Curved,
18{
19 fn default() -> Self {
20 Self::Single(T::zero())
21 }
22}
23
24#[derive(Debug, Default, Clone, Serialize, Deserialize)]
25pub struct SplinePoint<T>
26where
27 T: Curved,
28{
29 pub point: T,
30 #[serde(default)]
31 pub direction: SplinePointDirection<T>,
32}
33
34impl<T> SplinePoint<T>
35where
36 T: Curved,
37{
38 pub fn point(point: T) -> Self {
39 Self {
40 point,
41 direction: Default::default(),
42 }
43 }
44
45 pub fn new(point: T, direction: SplinePointDirection<T>) -> Self {
46 Self { point, direction }
47 }
48}
49
50impl<T> From<T> for SplinePoint<T>
51where
52 T: Curved,
53{
54 fn from(value: T) -> Self {
55 Self::point(value)
56 }
57}
58
59impl<T> From<(T, T)> for SplinePoint<T>
60where
61 T: Curved,
62{
63 fn from(value: (T, T)) -> Self {
64 Self::new(value.0, SplinePointDirection::Single(value.1))
65 }
66}
67
68impl<T> From<(T, T, T)> for SplinePoint<T>
69where
70 T: Curved,
71{
72 fn from(value: (T, T, T)) -> Self {
73 Self::new(value.0, SplinePointDirection::InOut(value.1, value.2))
74 }
75}
76
77impl<T> From<[T; 2]> for SplinePoint<T>
78where
79 T: Curved,
80{
81 fn from(value: [T; 2]) -> Self {
82 let [a, b] = value;
83 Self::new(a, SplinePointDirection::Single(b))
84 }
85}
86
87impl<T> From<[T; 3]> for SplinePoint<T>
88where
89 T: Curved,
90{
91 fn from(value: [T; 3]) -> Self {
92 let [a, b, c] = value;
93 Self::new(a, SplinePointDirection::InOut(b, c))
94 }
95}
96
97#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
98pub enum SplineError {
99 EmptyPointsList,
100 Curve(usize, CurveError),
102}
103
104impl fmt::Display for SplineError {
105 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106 write!(f, "{:?}", self)
107 }
108}
109
110pub type SplineDef<T> = Vec<SplinePoint<T>>;
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(try_from = "SplineDef<T>")]
114#[serde(into = "SplineDef<T>")]
115#[serde(bound = "T: Serialize + DeserializeOwned")]
116pub struct Spline<T>
117where
118 T: Default + Clone + Curved + CurvedChange,
119{
120 points: Vec<SplinePoint<T>>,
121 cached: Vec<Curve<T>>,
122 length: Scalar,
123 parts_times_values: Vec<(Scalar, T)>,
124}
125
126impl<T> Default for Spline<T>
127where
128 T: Default + Clone + Curved + CurvedChange,
129{
130 fn default() -> Self {
131 Self::point(T::zero()).unwrap()
132 }
133}
134
135impl<T> Spline<T>
136where
137 T: Default + Clone + Curved + CurvedChange,
138{
139 pub fn new(mut points: Vec<SplinePoint<T>>) -> Result<Self, SplineError> {
140 if points.is_empty() {
141 return Err(SplineError::EmptyPointsList);
142 }
143 if points.len() == 1 {
144 points.push(points[0].clone())
145 }
146 let cached = points
147 .windows(2)
148 .enumerate()
149 .map(|(index, pair)| {
150 let from_direction = match &pair[0].direction {
151 SplinePointDirection::Single(dir) => dir.clone(),
152 SplinePointDirection::InOut(_, dir) => dir.negate(),
153 };
154 let to_direction = match &pair[1].direction {
155 SplinePointDirection::Single(dir) => dir.negate(),
156 SplinePointDirection::InOut(dir, _) => dir.clone(),
157 };
158 let from_param = pair[0].point.offset(&from_direction);
159 let to_param = pair[1].point.offset(&to_direction);
160 Curve::bezier(
161 pair[0].point.clone(),
162 from_param,
163 to_param,
164 pair[1].point.clone(),
165 )
166 .map_err(|error| SplineError::Curve(index, error))
167 })
168 .collect::<Result<Vec<_>, _>>()?;
169 let lengths = cached
170 .iter()
171 .map(|curve| curve.length())
172 .collect::<Vec<_>>();
173 let mut time = 0.0;
174 let mut parts_times_values = Vec::with_capacity(points.len());
175 parts_times_values.push((0.0, points[0].point.clone()));
176 for (length, point) in lengths.iter().zip(points.iter().skip(1)) {
177 time += length;
178 parts_times_values.push((time, point.point.clone()));
179 }
180 Ok(Self {
181 points,
182 cached,
183 length: time,
184 parts_times_values,
185 })
186 }
187
188 pub fn linear(from: T, to: T) -> Result<Self, SplineError> {
189 Self::new(vec![SplinePoint::point(from), SplinePoint::point(to)])
190 }
191
192 pub fn point(point: T) -> Result<Self, SplineError> {
193 Self::linear(point.clone(), point)
194 }
195
196 pub fn value_along_axis_iter(
197 &self,
198 steps: usize,
199 axis_index: usize,
200 ) -> Option<impl Iterator<Item = Scalar>> {
201 let from = self.points.first()?.point.get_axis(axis_index)?;
202 let to = self.points.last()?.point.get_axis(axis_index)?;
203 Some(range_iter(steps, from, to))
204 }
205
206 pub fn sample(&self, factor: Scalar) -> T {
207 let (index, factor) = self.find_curve_index_factor(factor);
208 self.cached[index].sample(factor)
209 }
210
211 pub fn sample_along_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<T> {
212 let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
213 self.cached[index].sample_along_axis(axis_value, axis_index)
214 }
215
216 pub fn sample_first_derivative(&self, factor: Scalar) -> T {
218 let (index, factor) = self.find_curve_index_factor(factor);
219 self.cached[index].sample_first_derivative(factor)
220 }
221
222 pub fn sample_first_derivative_along_axis(
224 &self,
225 axis_value: Scalar,
226 axis_index: usize,
227 ) -> Option<T> {
228 let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
229 self.cached[index].sample_first_derivative_along_axis(axis_value, axis_index)
230 }
231
232 pub fn sample_second_derivative(&self, factor: Scalar) -> T {
234 let (index, factor) = self.find_curve_index_factor(factor);
235 self.cached[index].sample_second_derivative(factor)
236 }
237
238 pub fn sample_second_derivative_along_axis(
240 &self,
241 axis_value: Scalar,
242 axis_index: usize,
243 ) -> Option<T> {
244 let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
245 self.cached[index].sample_second_derivative_along_axis(axis_value, axis_index)
246 }
247
248 pub fn sample_k(&self, factor: Scalar) -> Scalar {
249 let (index, factor) = self.find_curve_index_factor(factor);
250 self.cached[index].sample_k(factor)
251 }
252
253 pub fn sample_curvature_radius(&self, factor: Scalar) -> Scalar {
254 let (index, factor) = self.find_curve_index_factor(factor);
255 self.cached[index].sample_curvature_radius(factor)
256 }
257
258 pub fn sample_tangent(&self, factor: Scalar) -> T {
259 let (index, factor) = self.find_curve_index_factor(factor);
260 self.cached[index].sample_tangent(factor)
261 }
262
263 pub fn sample_tangent_along_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<T> {
264 let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
265 self.cached[index].sample_tangent_along_axis(axis_value, axis_index)
266 }
267
268 pub fn length(&self) -> Scalar {
269 self.length
270 }
271
272 pub fn points(&self) -> &[SplinePoint<T>] {
273 &self.points
274 }
275
276 pub fn set_points(&mut self, points: Vec<SplinePoint<T>>) {
277 if let Ok(result) = Self::new(points) {
278 *self = result;
279 }
280 }
281
282 pub fn curves(&self) -> &[Curve<T>] {
283 &self.cached
284 }
285
286 pub fn find_curve_index_factor(&self, mut factor: Scalar) -> (usize, Scalar) {
287 factor = factor.max(0.0).min(1.0);
288 let t = factor * self.length;
289 let index = match self
290 .parts_times_values
291 .binary_search_by(|(time, _)| time.partial_cmp(&t).unwrap())
292 {
293 Ok(index) => index,
294 Err(index) => index.saturating_sub(1),
295 };
296 let index = index.min(self.cached.len().saturating_sub(1));
297 let start = self.parts_times_values[index].0;
298 let length = self.parts_times_values[index + 1].0 - start;
299 let factor = if length > 0.0 {
300 (t - start) / length
301 } else {
302 1.0
303 };
304 (index, factor)
305 }
306
307 pub fn find_curve_index_by_axis_value(
308 &self,
309 mut axis_value: Scalar,
310 axis_index: usize,
311 ) -> Option<usize> {
312 let min = self.points.first().unwrap().point.get_axis(axis_index)?;
313 let max = self.points.last().unwrap().point.get_axis(axis_index)?;
314 axis_value = axis_value.max(min).min(max);
315 let index = match self.parts_times_values.binary_search_by(|(_, value)| {
316 value
317 .get_axis(axis_index)
318 .unwrap()
319 .partial_cmp(&axis_value)
320 .unwrap()
321 }) {
322 Ok(index) => index,
323 Err(index) => index.saturating_sub(1),
324 };
325 Some(index.min(self.cached.len().saturating_sub(1)))
326 }
327
328 pub fn find_time_for_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<Scalar> {
329 let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
330 self.cached[index].find_time_for_axis(axis_value, axis_index)
331 }
332
333 }
335
336impl<T> TryFrom<SplineDef<T>> for Spline<T>
337where
338 T: Default + Clone + Curved + CurvedChange,
339{
340 type Error = SplineError;
341
342 fn try_from(value: SplineDef<T>) -> Result<Self, Self::Error> {
343 Self::new(value)
344 }
345}
346
347impl<T> From<Spline<T>> for SplineDef<T>
348where
349 T: Default + Clone + Curved + CurvedChange,
350{
351 fn from(v: Spline<T>) -> Self {
352 v.points
353 }
354}