oxygengine_animation/
spline.rs

1use crate::{curve::*, range_iter};
2use core::Scalar;
3use serde::{de::DeserializeOwned, Deserialize, Serialize};
4use std::{convert::TryFrom, fmt};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub enum SplinePointDirection<T>
8where
9    T: Curved,
10{
11    Single(T),
12    InOut(T, T),
13}
14
15impl<T> Default for SplinePointDirection<T>
16where
17    T: Curved,
18{
19    fn default() -> Self {
20        Self::Single(T::zero())
21    }
22}
23
24#[derive(Debug, Default, Clone, Serialize, Deserialize)]
25pub struct SplinePoint<T>
26where
27    T: Curved,
28{
29    pub point: T,
30    #[serde(default)]
31    pub direction: SplinePointDirection<T>,
32}
33
34impl<T> SplinePoint<T>
35where
36    T: Curved,
37{
38    pub fn point(point: T) -> Self {
39        Self {
40            point,
41            direction: Default::default(),
42        }
43    }
44
45    pub fn new(point: T, direction: SplinePointDirection<T>) -> Self {
46        Self { point, direction }
47    }
48}
49
50impl<T> From<T> for SplinePoint<T>
51where
52    T: Curved,
53{
54    fn from(value: T) -> Self {
55        Self::point(value)
56    }
57}
58
59impl<T> From<(T, T)> for SplinePoint<T>
60where
61    T: Curved,
62{
63    fn from(value: (T, T)) -> Self {
64        Self::new(value.0, SplinePointDirection::Single(value.1))
65    }
66}
67
68impl<T> From<(T, T, T)> for SplinePoint<T>
69where
70    T: Curved,
71{
72    fn from(value: (T, T, T)) -> Self {
73        Self::new(value.0, SplinePointDirection::InOut(value.1, value.2))
74    }
75}
76
77impl<T> From<[T; 2]> for SplinePoint<T>
78where
79    T: Curved,
80{
81    fn from(value: [T; 2]) -> Self {
82        let [a, b] = value;
83        Self::new(a, SplinePointDirection::Single(b))
84    }
85}
86
87impl<T> From<[T; 3]> for SplinePoint<T>
88where
89    T: Curved,
90{
91    fn from(value: [T; 3]) -> Self {
92        let [a, b, c] = value;
93        Self::new(a, SplinePointDirection::InOut(b, c))
94    }
95}
96
97#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
98pub enum SplineError {
99    EmptyPointsList,
100    /// (points pair index, curve error)
101    Curve(usize, CurveError),
102}
103
104impl fmt::Display for SplineError {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        write!(f, "{:?}", self)
107    }
108}
109
110pub type SplineDef<T> = Vec<SplinePoint<T>>;
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(try_from = "SplineDef<T>")]
114#[serde(into = "SplineDef<T>")]
115#[serde(bound = "T: Serialize + DeserializeOwned")]
116pub struct Spline<T>
117where
118    T: Default + Clone + Curved + CurvedChange,
119{
120    points: Vec<SplinePoint<T>>,
121    cached: Vec<Curve<T>>,
122    length: Scalar,
123    parts_times_values: Vec<(Scalar, T)>,
124}
125
126impl<T> Default for Spline<T>
127where
128    T: Default + Clone + Curved + CurvedChange,
129{
130    fn default() -> Self {
131        Self::point(T::zero()).unwrap()
132    }
133}
134
135impl<T> Spline<T>
136where
137    T: Default + Clone + Curved + CurvedChange,
138{
139    pub fn new(mut points: Vec<SplinePoint<T>>) -> Result<Self, SplineError> {
140        if points.is_empty() {
141            return Err(SplineError::EmptyPointsList);
142        }
143        if points.len() == 1 {
144            points.push(points[0].clone())
145        }
146        let cached = points
147            .windows(2)
148            .enumerate()
149            .map(|(index, pair)| {
150                let from_direction = match &pair[0].direction {
151                    SplinePointDirection::Single(dir) => dir.clone(),
152                    SplinePointDirection::InOut(_, dir) => dir.negate(),
153                };
154                let to_direction = match &pair[1].direction {
155                    SplinePointDirection::Single(dir) => dir.negate(),
156                    SplinePointDirection::InOut(dir, _) => dir.clone(),
157                };
158                let from_param = pair[0].point.offset(&from_direction);
159                let to_param = pair[1].point.offset(&to_direction);
160                Curve::bezier(
161                    pair[0].point.clone(),
162                    from_param,
163                    to_param,
164                    pair[1].point.clone(),
165                )
166                .map_err(|error| SplineError::Curve(index, error))
167            })
168            .collect::<Result<Vec<_>, _>>()?;
169        let lengths = cached
170            .iter()
171            .map(|curve| curve.length())
172            .collect::<Vec<_>>();
173        let mut time = 0.0;
174        let mut parts_times_values = Vec::with_capacity(points.len());
175        parts_times_values.push((0.0, points[0].point.clone()));
176        for (length, point) in lengths.iter().zip(points.iter().skip(1)) {
177            time += length;
178            parts_times_values.push((time, point.point.clone()));
179        }
180        Ok(Self {
181            points,
182            cached,
183            length: time,
184            parts_times_values,
185        })
186    }
187
188    pub fn linear(from: T, to: T) -> Result<Self, SplineError> {
189        Self::new(vec![SplinePoint::point(from), SplinePoint::point(to)])
190    }
191
192    pub fn point(point: T) -> Result<Self, SplineError> {
193        Self::linear(point.clone(), point)
194    }
195
196    pub fn value_along_axis_iter(
197        &self,
198        steps: usize,
199        axis_index: usize,
200    ) -> Option<impl Iterator<Item = Scalar>> {
201        let from = self.points.first()?.point.get_axis(axis_index)?;
202        let to = self.points.last()?.point.get_axis(axis_index)?;
203        Some(range_iter(steps, from, to))
204    }
205
206    pub fn sample(&self, factor: Scalar) -> T {
207        let (index, factor) = self.find_curve_index_factor(factor);
208        self.cached[index].sample(factor)
209    }
210
211    pub fn sample_along_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<T> {
212        let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
213        self.cached[index].sample_along_axis(axis_value, axis_index)
214    }
215
216    /// Velocity of change along the curve axis.
217    pub fn sample_first_derivative(&self, factor: Scalar) -> T {
218        let (index, factor) = self.find_curve_index_factor(factor);
219        self.cached[index].sample_first_derivative(factor)
220    }
221
222    /// Velocity of change along the curve axis.
223    pub fn sample_first_derivative_along_axis(
224        &self,
225        axis_value: Scalar,
226        axis_index: usize,
227    ) -> Option<T> {
228        let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
229        self.cached[index].sample_first_derivative_along_axis(axis_value, axis_index)
230    }
231
232    /// Acceleration of change along the curve axis.
233    pub fn sample_second_derivative(&self, factor: Scalar) -> T {
234        let (index, factor) = self.find_curve_index_factor(factor);
235        self.cached[index].sample_second_derivative(factor)
236    }
237
238    /// Acceleration of change along the curve axis.
239    pub fn sample_second_derivative_along_axis(
240        &self,
241        axis_value: Scalar,
242        axis_index: usize,
243    ) -> Option<T> {
244        let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
245        self.cached[index].sample_second_derivative_along_axis(axis_value, axis_index)
246    }
247
248    pub fn sample_k(&self, factor: Scalar) -> Scalar {
249        let (index, factor) = self.find_curve_index_factor(factor);
250        self.cached[index].sample_k(factor)
251    }
252
253    pub fn sample_curvature_radius(&self, factor: Scalar) -> Scalar {
254        let (index, factor) = self.find_curve_index_factor(factor);
255        self.cached[index].sample_curvature_radius(factor)
256    }
257
258    pub fn sample_tangent(&self, factor: Scalar) -> T {
259        let (index, factor) = self.find_curve_index_factor(factor);
260        self.cached[index].sample_tangent(factor)
261    }
262
263    pub fn sample_tangent_along_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<T> {
264        let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
265        self.cached[index].sample_tangent_along_axis(axis_value, axis_index)
266    }
267
268    pub fn length(&self) -> Scalar {
269        self.length
270    }
271
272    pub fn points(&self) -> &[SplinePoint<T>] {
273        &self.points
274    }
275
276    pub fn set_points(&mut self, points: Vec<SplinePoint<T>>) {
277        if let Ok(result) = Self::new(points) {
278            *self = result;
279        }
280    }
281
282    pub fn curves(&self) -> &[Curve<T>] {
283        &self.cached
284    }
285
286    pub fn find_curve_index_factor(&self, mut factor: Scalar) -> (usize, Scalar) {
287        factor = factor.max(0.0).min(1.0);
288        let t = factor * self.length;
289        let index = match self
290            .parts_times_values
291            .binary_search_by(|(time, _)| time.partial_cmp(&t).unwrap())
292        {
293            Ok(index) => index,
294            Err(index) => index.saturating_sub(1),
295        };
296        let index = index.min(self.cached.len().saturating_sub(1));
297        let start = self.parts_times_values[index].0;
298        let length = self.parts_times_values[index + 1].0 - start;
299        let factor = if length > 0.0 {
300            (t - start) / length
301        } else {
302            1.0
303        };
304        (index, factor)
305    }
306
307    pub fn find_curve_index_by_axis_value(
308        &self,
309        mut axis_value: Scalar,
310        axis_index: usize,
311    ) -> Option<usize> {
312        let min = self.points.first().unwrap().point.get_axis(axis_index)?;
313        let max = self.points.last().unwrap().point.get_axis(axis_index)?;
314        axis_value = axis_value.max(min).min(max);
315        let index = match self.parts_times_values.binary_search_by(|(_, value)| {
316            value
317                .get_axis(axis_index)
318                .unwrap()
319                .partial_cmp(&axis_value)
320                .unwrap()
321        }) {
322            Ok(index) => index,
323            Err(index) => index.saturating_sub(1),
324        };
325        Some(index.min(self.cached.len().saturating_sub(1)))
326    }
327
328    pub fn find_time_for_axis(&self, axis_value: Scalar, axis_index: usize) -> Option<Scalar> {
329        let index = self.find_curve_index_by_axis_value(axis_value, axis_index)?;
330        self.cached[index].find_time_for_axis(axis_value, axis_index)
331    }
332
333    // TODO: find_time_for()
334}
335
336impl<T> TryFrom<SplineDef<T>> for Spline<T>
337where
338    T: Default + Clone + Curved + CurvedChange,
339{
340    type Error = SplineError;
341
342    fn try_from(value: SplineDef<T>) -> Result<Self, Self::Error> {
343        Self::new(value)
344    }
345}
346
347impl<T> From<Spline<T>> for SplineDef<T>
348where
349    T: Default + Clone + Curved + CurvedChange,
350{
351    fn from(v: Spline<T>) -> Self {
352        v.points
353    }
354}