enterpolation/base/
adaptors.rs

1use crate::{ConstDiscreteGenerator, Curve, DiscreteGenerator, Generator};
2use core::ops::{Add, Bound, Mul, RangeBounds};
3use num_traits::clamp;
4use num_traits::real::Real;
5
6/// Wrapper for curves to clamp input to their domain.
7///
8/// This struct in constructued through the [`clamp()`] method of curves.
9/// Please look their for more information.
10///
11/// [`clamp()`]: crate::Curve::clamp()
12#[derive(Debug, Copy, Clone, PartialEq)]
13#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
14pub struct Clamp<G>(G);
15
16impl<G> Clamp<G> {
17    /// Create a new `Clamp` struct.
18    pub fn new(gen: G) -> Self {
19        Clamp(gen)
20    }
21}
22
23impl<G, R> Generator<R> for Clamp<G>
24where
25    G: Curve<R>,
26    R: Real,
27{
28    type Output = G::Output;
29    fn gen(&self, input: R) -> Self::Output {
30        let [min, max] = self.domain();
31        let clamped = clamp(input, min, max);
32        self.0.gen(clamped)
33    }
34}
35
36impl<G, R> Curve<R> for Clamp<G>
37where
38    G: Curve<R>,
39    R: Real,
40{
41    fn domain(&self) -> [R; 2] {
42        self.0.domain()
43    }
44}
45
46/// Acts like a slice of a curve.
47///
48/// That is, a slice of a curve has the same domain as the curve itself but maps the domain onto the range given.
49///
50/// This struct is created by the [`slice()`] method. Please look their for more information.
51///
52/// [`slice()`]: crate::Curve::slice()
53#[derive(Debug, Copy, Clone, PartialEq)]
54#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
55pub struct Slice<G, R>(TransformInput<G, R, R>);
56
57impl<G, R> Slice<G, R>
58where
59    G: Curve<R>,
60    R: Real,
61{
62    /// Create a new slice of the given generator.
63    ///
64    /// It does not matter if the bounds itself are included or excluded as we assume a continuous curve.
65    pub fn new<B>(gen: G, bound: B) -> Self
66    where
67        B: RangeBounds<R>,
68    {
69        let [gen_start, gen_end] = gen.domain();
70        let bound_start = match bound.start_bound() {
71            Bound::Included(x) | Bound::Excluded(x) => *x,
72            Bound::Unbounded => gen_start,
73        };
74        let bound_end = match bound.end_bound() {
75            Bound::Included(x) | Bound::Excluded(x) => *x,
76            Bound::Unbounded => gen_end,
77        };
78        let scale = (bound_end - bound_start) / (gen_end - gen_start);
79        Slice(TransformInput::new(gen, bound_start - gen_start, scale))
80    }
81}
82
83impl<G, R> Generator<R> for Slice<G, R>
84where
85    G: Generator<R>,
86    R: Real,
87{
88    type Output = G::Output;
89    fn gen(&self, input: R) -> Self::Output {
90        self.0.gen(input)
91    }
92}
93
94impl<G, R> Curve<R> for Slice<G, R>
95where
96    G: Curve<R>,
97    R: Real,
98{
99    fn domain(&self) -> [R; 2] {
100        self.0.inner.domain()
101    }
102}
103
104/// Struct which transforms the input before sending it to the underlying generator.
105///
106/// Both addition and multiplication is done. In regards to math operation priorities, multiplication is done first.
107#[derive(Debug, Copy, Clone, PartialEq)]
108#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
109pub struct TransformInput<G, A, M> {
110    addition: A,
111    multiplication: M,
112    inner: G,
113}
114
115impl<G, A, M> TransformInput<G, A, M> {
116    /// Create a generic `TransformInput`.
117    pub fn new(generator: G, addition: A, multiplication: M) -> Self {
118        TransformInput {
119            inner: generator,
120            addition,
121            multiplication,
122        }
123    }
124}
125
126impl<G, R> TransformInput<G, R, R>
127where
128    G: Curve<R>,
129    R: Real,
130{
131    /// Transfrom an input such that the wrapped generator changes its domain from [0.0,1.0] to
132    /// the domain wished for.
133    pub fn normalized_to_domain(generator: G, start: R, end: R) -> Self {
134        Self::new(generator, -start, (end - start).recip())
135    }
136}
137
138impl<G, A, M, I> Generator<I> for TransformInput<G, A, M>
139where
140    I: Mul<M>,
141    I::Output: Add<A>,
142    A: Copy,
143    M: Copy,
144    G: Generator<<<I as Mul<M>>::Output as Add<A>>::Output>,
145{
146    type Output = G::Output;
147    fn gen(&self, input: I) -> Self::Output {
148        self.inner.gen(input * self.multiplication + self.addition)
149    }
150}
151
152impl<G, R> Curve<R> for TransformInput<G, R, R>
153where
154    G: Curve<R>,
155    R: Real,
156{
157    fn domain(&self) -> [R; 2] {
158        let orig = self.inner.domain();
159        let start = (orig[0] - self.addition) / self.multiplication;
160        let end = (orig[1] - self.addition) / self.multiplication;
161        [start, end]
162    }
163}
164
165/// Struct which composite two generator together to act as one generator.
166///
167/// This `struct` is created by [`Generator::composite`]. See its documentation for more.
168#[derive(Debug, Copy, Clone, PartialEq)]
169#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
170pub struct Composite<A, B>(A, B);
171
172impl<A, B> Composite<A, B> {
173    /// Creates a composite generator.
174    pub fn new(first: A, second: B) -> Self {
175        Composite(first, second)
176    }
177}
178
179impl<A, B, T> Generator<T> for Composite<A, B>
180where
181    A: Generator<T>,
182    B: Generator<A::Output>,
183{
184    type Output = B::Output;
185    fn gen(&self, scalar: T) -> Self::Output {
186        self.1.gen(self.0.gen(scalar))
187    }
188}
189
190impl<A, B, R> Curve<R> for Composite<A, B>
191where
192    A: Curve<R>,
193    B: Generator<A::Output>,
194    R: Real,
195{
196    fn domain(&self) -> [R; 2] {
197        self.0.domain()
198    }
199}
200
201/// DiscreteGenerator adaptor which stacks two generators.
202///
203/// That it, the struct holds two generators with output S and T and outputs (S,T).
204///
205/// This `struct` is created by [`Generator::stack]. See its documentation for more.
206#[derive(Debug, Copy, Clone, PartialEq)]
207#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
208pub struct Stack<G, H>(G, H);
209
210impl<G, H> Stack<G, H> {
211    /// Creates a stacked generator, working similar like the `zip` method of iterators.
212    pub fn new(first: G, second: H) -> Self {
213        Stack(first, second)
214    }
215}
216
217impl<G, H, Input> Generator<Input> for Stack<G, H>
218where
219    G: Generator<Input>,
220    H: Generator<Input>,
221    Input: Copy,
222{
223    type Output = (G::Output, H::Output);
224    fn gen(&self, input: Input) -> Self::Output {
225        (self.0.gen(input), self.1.gen(input))
226    }
227}
228
229impl<G, H> DiscreteGenerator for Stack<G, H>
230where
231    G: DiscreteGenerator,
232    H: DiscreteGenerator,
233{
234    fn len(&self) -> usize {
235        self.0.len().min(self.1.len())
236    }
237}
238
239impl<G, H, const N: usize> ConstDiscreteGenerator<N> for Stack<G, H>
240where
241    G: ConstDiscreteGenerator<N>,
242    H: ConstDiscreteGenerator<N>,
243{
244}
245
246impl<G, H, R> Curve<R> for Stack<G, H>
247where
248    G: Curve<R>,
249    H: Curve<R>,
250    R: Real,
251{
252    fn domain(&self) -> [R; 2] {
253        let first = self.0.domain();
254        let second = self.1.domain();
255        [first[0].max(second[0]), first[1].min(second[1])]
256    }
257}
258
259/// DiscreteGenerator Adaptor which repeats the underlying elements.
260#[derive(Debug, Copy, Clone, PartialEq)]
261#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
262pub struct Repeat<G>(G);
263
264impl<G> Repeat<G> {
265    /// Repeat a given DiscreteGenerator pseudo-endlessly.
266    ///
267    /// In reality this adaptpor repeats the underlying elements until `usize::MAX` is reached.
268    pub fn new(gen: G) -> Self {
269        Repeat(gen)
270    }
271}
272
273impl<G> Generator<usize> for Repeat<G>
274where
275    G: DiscreteGenerator,
276{
277    type Output = G::Output;
278    fn gen(&self, input: usize) -> Self::Output {
279        self.0.gen(input % self.0.len())
280    }
281}
282
283impl<G> DiscreteGenerator for Repeat<G>
284where
285    G: DiscreteGenerator,
286{
287    fn len(&self) -> usize {
288        usize::MAX
289    }
290}
291
292impl<G> ConstDiscreteGenerator<{ usize::MAX }> for Repeat<G> where G: DiscreteGenerator {}
293
294/// Generator adaptor which repeats a fixed amount of first elements.
295#[derive(Debug, Copy, Clone, PartialEq)]
296#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
297pub struct Wrap<G> {
298    inner: G,
299    n: usize,
300}
301
302impl<G> Wrap<G> {
303    /// Wrap the first `n` elements to the end.
304    pub fn new(gen: G, n: usize) -> Self {
305        Wrap { inner: gen, n }
306    }
307}
308
309impl<G> Generator<usize> for Wrap<G>
310where
311    G: DiscreteGenerator,
312{
313    type Output = G::Output;
314    fn gen(&self, input: usize) -> Self::Output {
315        self.inner.gen(input % self.inner.len())
316    }
317}
318
319impl<G> DiscreteGenerator for Wrap<G>
320where
321    G: DiscreteGenerator,
322{
323    fn len(&self) -> usize {
324        self.inner.len() + self.n
325    }
326}
327
328#[cfg(test)]
329mod test {
330    use super::*;
331    use crate::easing::Identity;
332
333    #[test]
334    fn input_transform() {
335        let identity = Identity {};
336        let transformed = TransformInput::new(identity, 0.0, 2.0);
337        assert_f64_near!(transformed.gen(1.0), 2.0);
338        let results = [0.0, 1.0, 2.0];
339        // try to extract
340        let extractor = transformed.extract([0.0, 0.5, 1.0]);
341        for (val, res) in extractor.zip(results.iter()) {
342            assert_f64_near!(val, res);
343        }
344        // try to take - should be the same as before as the domain should have changed accordingly
345        let transformed = TransformInput::new(identity, 0.0, 2.0);
346        for (val, res) in transformed
347            .take(results.len())
348            .zip(<Identity as Curve<f64>>::take(identity, results.len()))
349        {
350            assert_f64_near!(val, res);
351        }
352    }
353
354    #[test]
355    fn slice() {
356        let identity = Identity {};
357        let slice = Slice::new(identity, 10.0..100.0);
358        let results = [10.0, 100.0];
359        assert_f64_near!(slice.gen(0.0), 10.0);
360        assert_f64_near!(slice.gen(1.0), 100.0);
361        for (val, res) in slice.take(results.len()).zip(results.iter()) {
362            assert_f64_near!(val, res);
363        }
364    }
365}