use crate::builder::Unknown;
use crate::{Chain, ConstEquidistant, Curve, Identity, Signal, SortedChain};
use num_traits::real::Real;
use topology_traits::Merge;
use core::fmt::Debug;
mod builder;
pub use builder::{LinearBuilder, LinearDirector};
pub mod error;
pub use error::{KnotElementInequality, LinearError, NotSorted, TooFewElements};
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Linear<K, E, F> {
elements: E,
knots: K,
easing: F,
}
impl Linear<Unknown, Unknown, Unknown> {
pub fn builder() -> LinearBuilder<Unknown, Unknown, Identity, Unknown> {
LinearBuilder::new()
}
}
impl<R, K, E, F> Signal<R> for Linear<K, E, F>
where
K: SortedChain<Output = R>,
E: Chain,
E::Output: Merge<R> + Debug,
F: Curve<R, Output = R>,
R: Real + Debug,
{
type Output = E::Output;
fn eval(&self, scalar: K::Output) -> Self::Output {
let (min_index, max_index, factor) = self.knots.upper_border(scalar);
let min_point = self.elements.eval(min_index);
let max_point = self.elements.eval(max_index);
min_point.merge(max_point, self.easing.eval(factor))
}
}
impl<R, K, E, F> Curve<R> for Linear<K, E, F>
where
K: SortedChain<Output = R>,
E: Chain,
E::Output: Merge<R> + Debug,
F: Curve<R, Output = R>,
R: Real + Debug,
{
fn domain(&self) -> [R; 2] {
[self.knots.first().unwrap(), self.knots.last().unwrap()]
}
}
impl<K, E, F> Linear<K, E, F>
where
K: SortedChain,
K::Output: Real,
E: Chain,
E::Output: Merge<K::Output>,
{
pub fn new(elements: E, knots: K, easing: F) -> Result<Self, LinearError> {
if elements.len() < 2 {
return Err(TooFewElements::new(elements.len()).into());
}
if knots.len() != elements.len() {
return Err(KnotElementInequality::new(elements.len(), knots.len()).into());
}
Ok(Linear {
elements,
knots,
easing,
})
}
}
impl<K, E, F> Linear<K, E, F>
where
E: Chain,
K: SortedChain,
E::Output: Merge<K::Output>,
K::Output: Real,
{
pub fn new_unchecked(elements: E, knots: K, easing: F) -> Self {
Linear {
elements,
knots,
easing,
}
}
}
impl<R, T, const N: usize> Linear<ConstEquidistant<R, N>, [T; N], Identity> {
pub const fn equidistant_unchecked(elements: [T; N]) -> Self {
Linear {
elements,
knots: ConstEquidistant::new(),
easing: Identity::new(),
}
}
}
pub type ConstEquidistantLinear<R, T, const N: usize> =
Linear<ConstEquidistant<R, N>, [T; N], Identity>;
#[cfg(test)]
mod test {
use super::*;
use crate::Curve;
#[test]
fn linear_equidistant() {
let lin = Linear::builder()
.elements([20.0, 100.0, 0.0, 200.0])
.equidistant::<f64>()
.normalized()
.build()
.unwrap();
let expected = [20.0, 60.0, 100.0, 50.0, 0.0, 100.0, 200.0];
let mut iter = lin.take(expected.len());
for val in expected {
assert_f64_near!(val, iter.next().unwrap());
}
}
#[test]
fn linear() {
let lin = Linear::builder()
.elements([20.0, 100.0, 0.0, 200.0])
.knots([0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0])
.build()
.unwrap();
let expected = [20.0, 60.0, 100.0, 50.0, 0.0, 100.0, 200.0];
let mut iter = lin.take(expected.len());
for val in expected {
assert_f64_near!(val, iter.next().unwrap());
}
}
#[test]
fn extrapolation() {
let lin = Linear::builder()
.elements([20.0, 100.0, 0.0, 200.0])
.knots([1.0, 2.0, 3.0, 4.0])
.build()
.unwrap();
assert_f64_near!(lin.eval(1.5), 60.0);
assert_f64_near!(lin.eval(2.5), 50.0);
assert_f64_near!(lin.eval(-1.0), -140.0);
assert_f64_near!(lin.eval(5.0), 400.0);
}
#[test]
fn weights() {
let lin = Linear::builder()
.elements_with_weights([(0.0, 9.0), (1.0, 1.0)])
.equidistant::<f64>()
.normalized()
.build()
.unwrap();
assert_f64_near!(lin.eval(0.5), 0.1);
}
#[test]
fn const_creation() {
const LIN: ConstEquidistantLinear<f64, f64, 4> =
ConstEquidistantLinear::equidistant_unchecked([20.0, 100.0, 0.0, 200.0]);
let expected = [20.0, 60.0, 100.0, 50.0, 0.0, 100.0, 200.0];
let mut iter = LIN.take(expected.len());
for val in expected {
assert_f64_near!(val, iter.next().unwrap());
}
}
#[test]
fn borrow_creation() {
let elements = [20.0, 100.0, 0.0, 200.0];
let knots = [0.0, 1.0, 2.0, 3.0];
let samples = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
let linear = Linear::builder()
.elements(&elements)
.knots(&knots)
.build()
.unwrap();
let expected = [20.0, 60.0, 100.0, 50.0, 0.0, 100.0, 200.0];
let mut iter = linear.sample(samples);
for val in expected {
assert_f64_near!(val, iter.next().unwrap());
}
}
#[test]
fn partial_eq() {
let linear = Linear::builder()
.elements([20.0, 100.0, 0.0, 200.0])
.knots([0.0, 1.0, 2.0, 3.0])
.build()
.unwrap();
let linear2 = Linear::builder()
.elements([20.0, 100.0, 0.0, 200.0])
.knots([0.0, 1.0, 2.0, 3.0])
.build()
.unwrap();
assert_eq!(linear, linear2);
}
}