contourable 0.8.0

A library for differentiable functions
Documentation
use std::ops::Sub;

use nalgebra::Scalar;
use num_dual::{DualNum, DualNumFloat};
use num_traits::float::TotalOrder;

use crate::{chain::cum_sum, Contour};

pub struct MixedChain<D, F>
where
    D: DualNum<F>,
    F: DualNumFloat,
{
    contours: Vec<Box<dyn Contour<D, F>>>,

    cached_length: Vec<D>,
    _g: std::marker::PhantomData<F>,
}

impl<D, F> MixedChain<D, F>
where
    D: DualNum<F>,
    F: DualNumFloat,
{
    pub fn new(contours: Vec<Box<dyn Contour<D, F>>>) -> Self {
        let cached_length = cum_sum(contours.iter().map(|c| c.length()));
        Self {
            contours,
            cached_length,
            _g: std::marker::PhantomData,
        }
    }
}

impl<D, F> Contour<D, F> for MixedChain<D, F>
where
    D: Scalar + DualNum<F>,
    F: DualNumFloat + TotalOrder,
    for<'a> &'a D: Sub<&'a D, Output = D>,
{
    fn position(&self, s: &D) -> nalgebra::Point2<D> {
        match self
            .cached_length
            .binary_search_by(|x| x.re().total_cmp(&s.re()))
        {
            Ok(i) if i < self.contours.len() => self.contours[i].position(&D::zero()),
            Ok(_) => self
                .contours
                .last()
                .unwrap()
                .position(&self.contours.last().unwrap().s_interval().1),
            Err(i) => {
                dbg!(i);
                let ds = s - &self.cached_length[i - 1];
                dbg!(&ds);
                self.contours[i - 1].position(&ds)
            }
        }
    }

    fn s_interval(&self) -> (D, D) {
        (D::zero(), self.cached_length.last().unwrap().clone())
    }
}

#[cfg(test)]
mod tests {

    use nalgebra::ComplexField;
    use num_dual::{Dual, Dual64};

    use super::*;

    struct Line<D, F>
    where
        D: DualNum<F>,
        F: DualNumFloat,
    {
        p0: nalgebra::Point2<D>,
        p1: nalgebra::Point2<D>,

        _f: std::marker::PhantomData<F>,
    }

    impl<D, F> Line<D, F>
    where
        D: DualNum<F> + ComplexField<RealField = D>,
        F: DualNumFloat,
    {
        fn new(p0: nalgebra::Point2<D>, p1: nalgebra::Point2<D>) -> Self {
            Self {
                p0,
                p1,
                _f: std::marker::PhantomData,
            }
        }
        fn length(&self) -> D {
            (&self.p1 - &self.p0).norm()
        }
    }

    impl<D, F> Contour<D, F> for Line<D, F>
    where
        D: DualNum<F> + ComplexField<RealField = D>,
        F: DualNumFloat,
    {
        fn position(&self, s: &D) -> nalgebra::Point2<D> {
            &self.p0 + (&self.p1 - &self.p0) * s.clone() / self.length()
        }

        fn s_interval(&self) -> (D, D) {
            (D::zero(), self.length())
        }
    }

    #[test]
    fn test_chain_f64() {
        let p0 = nalgebra::Point2::new(0.0, 0.0);
        let p1 = nalgebra::Point2::new(2.0, 0.0);
        let p2 = p1 + nalgebra::Vector2::new(3.0, 4.0);
        let l0 = Line::new(p0, p1);
        let l1 = Line::new(p1, p2);
        let chain = MixedChain::new(vec![Box::new(l0), Box::new(l1)]);
        assert_eq!(chain.position(&0.0), p0);
        assert_eq!(chain.position(&2.0), p1);
        assert_eq!(chain.position(&1.5), (p0 + p1.coords) * 0.75);
        assert_eq!(chain.position(&(2.0 + 5.0)), p2);
    }
    #[test]
    fn test_chain_dual() {
        let p0 = nalgebra::Point2::new(Dual64::new(0.0, 1.0), Dual64::new(0.0, 0.0));
        let p1 = nalgebra::Point2::new(Dual64::new(2.0, 1.0), Dual64::new(0.0, 0.0));
        let p2 = p1 + nalgebra::Vector2::new(Dual::new(3.0, 1.0), Dual::new(4.0, 0.0));
        let l0 = Line::new(p0, p1);
        let l1 = Line::new(p1, p2);
        let chain = MixedChain::new(vec![Box::new(l0), Box::new(l1)]);
        assert_eq!(chain.position(&Dual::new(0.0, 1.0)), p0);
        assert_eq!(chain.position(&Dual::new(2.0, 1.0)), p1);
        assert_eq!(
            chain.position(&Dual::new(1.5, 1.0)),
            (p0 + p1.coords).map(|x| x * 0.75)
        );
        assert_eq!(
            chain.position(&(Dual::new(2.0, 1.0) + Dual::new(5.0, 1.0))),
            p2
        );
    }
}