use std::ops::{Add, Sub};
use nalgebra::Scalar;
use num_dual::{DualNum, DualNumFloat};
use num_traits::{float::TotalOrder, Zero};
use crate::Contour;
#[derive(Debug, Clone)]
pub struct Chain<C, D, F>
where
C: Contour<D, F>,
D: DualNum<F>,
F: DualNumFloat,
{
contours: Vec<C>,
cached_length: Vec<D>,
_g: std::marker::PhantomData<F>,
}
impl<C, D, F> Chain<C, D, F>
where
C: Contour<D, F>,
D: DualNum<F>,
F: DualNumFloat,
{
pub fn new(contours: Vec<C>) -> Self {
let cached_length = cum_sum(contours.iter().map(|c| c.length()));
Self {
contours,
cached_length,
_g: std::marker::PhantomData,
}
}
}
impl<C, D, F> Contour<D, F> for Chain<C, D, F>
where
C: Contour<D, F>,
D: Scalar + DualNum<F>,
F: DualNumFloat + TotalOrder,
for<'a> &'a D: Sub<&'a D, Output = D>,
{
fn position(&self, s: &D) -> nalgebra::Point2<D> {
match self
.cached_length
.binary_search_by(|x| x.re().total_cmp(&s.re()))
{
Ok(i) if i < self.contours.len() => self.contours[i].position(&D::zero()),
Ok(_) => self
.contours
.last()
.unwrap()
.position(&self.contours.last().unwrap().s_interval().1),
Err(i) => {
dbg!(i);
let ds = s - &self.cached_length[i - 1];
dbg!(&ds);
self.contours[i - 1].position(&ds)
}
}
}
fn s_interval(&self) -> (D, D) {
(D::zero(), self.cached_length.last().unwrap().clone())
}
}
pub(crate) fn cum_sum<I, T>(v: I) -> Vec<T>
where
I: IntoIterator<Item = T>,
T: Zero + Clone + Add<Output = T>,
{
v.into_iter().fold(vec![T::zero()], |mut acc, x| {
acc.push(acc.last().cloned().unwrap() + x.clone());
acc
})
}
#[cfg(test)]
mod tests {
use nalgebra::ComplexField;
use num_dual::{Dual, Dual64};
use super::*;
struct Line<D, F>
where
D: DualNum<F>,
F: DualNumFloat,
{
p0: nalgebra::Point2<D>,
p1: nalgebra::Point2<D>,
_f: std::marker::PhantomData<F>,
}
impl<D, F> Line<D, F>
where
D: DualNum<F> + ComplexField<RealField = D>,
F: DualNumFloat,
{
fn new(p0: nalgebra::Point2<D>, p1: nalgebra::Point2<D>) -> Self {
Self {
p0,
p1,
_f: std::marker::PhantomData,
}
}
fn length(&self) -> D {
(&self.p1 - &self.p0).norm()
}
}
impl<D, F> Contour<D, F> for Line<D, F>
where
D: DualNum<F> + ComplexField<RealField = D>,
F: DualNumFloat,
{
fn position(&self, s: &D) -> nalgebra::Point2<D> {
&self.p0 + (&self.p1 - &self.p0) * s.clone() / self.length()
}
fn s_interval(&self) -> (D, D) {
(D::zero(), self.length())
}
}
#[test]
fn test_chain_f64() {
let p0 = nalgebra::Point2::new(0.0, 0.0);
let p1 = nalgebra::Point2::new(2.0, 0.0);
let p2 = p1 + nalgebra::Vector2::new(3.0, 4.0);
let l0 = Line::new(p0, p1);
let l1 = Line::new(p1, p2);
let chain: Chain<Line<f64, f64>, f64, f64> = Chain::new(vec![l0, l1]);
assert_eq!(chain.position(&0.0), p0);
assert_eq!(chain.position(&2.0), p1);
assert_eq!(chain.position(&1.5), (p0 + p1.coords) * 0.75);
assert_eq!(chain.position(&(2.0 + 5.0)), p2);
}
#[test]
fn test_chain_dual() {
let p0 = nalgebra::Point2::new(Dual64::new(0.0, 1.0), Dual64::new(0.0, 0.0));
let p1 = nalgebra::Point2::new(Dual64::new(2.0, 1.0), Dual64::new(0.0, 0.0));
let p2 = p1 + nalgebra::Vector2::new(Dual::new(3.0, 1.0), Dual::new(4.0, 0.0));
let l0 = Line::new(p0, p1);
let l1 = Line::new(p1, p2);
let chain = Chain::new(vec![l0, l1]);
assert_eq!(chain.position(&Dual::new(0.0, 1.0)), p0);
assert_eq!(chain.position(&Dual::new(2.0, 1.0)), p1);
assert_eq!(
chain.position(&Dual::new(1.5, 1.0)),
(p0 + p1.coords).map(|x| x * 0.75)
);
assert_eq!(
chain.position(&(Dual::new(2.0, 1.0) + Dual::new(5.0, 1.0))),
p2
);
}
}