use std::ops::Sub;
use nalgebra::Scalar;
use num_dual::{DualNum, DualNumFloat};
use num_traits::float::TotalOrder;
use crate::{chain::cum_sum, Contour};
pub struct MixedChain<D, F>
where
D: DualNum<F>,
F: DualNumFloat,
{
contours: Vec<Box<dyn Contour<D, F>>>,
cached_length: Vec<D>,
_g: std::marker::PhantomData<F>,
}
impl<D, F> MixedChain<D, F>
where
D: DualNum<F>,
F: DualNumFloat,
{
pub fn new(contours: Vec<Box<dyn Contour<D, F>>>) -> Self {
let cached_length = cum_sum(contours.iter().map(|c| c.length()));
Self {
contours,
cached_length,
_g: std::marker::PhantomData,
}
}
}
impl<D, F> Contour<D, F> for MixedChain<D, F>
where
D: Scalar + DualNum<F>,
F: DualNumFloat + TotalOrder,
for<'a> &'a D: Sub<&'a D, Output = D>,
{
fn position(&self, s: &D) -> nalgebra::Point2<D> {
match self
.cached_length
.binary_search_by(|x| x.re().total_cmp(&s.re()))
{
Ok(i) if i < self.contours.len() => self.contours[i].position(&D::zero()),
Ok(_) => self
.contours
.last()
.unwrap()
.position(&self.contours.last().unwrap().s_interval().1),
Err(i) => {
dbg!(i);
let ds = s - &self.cached_length[i - 1];
dbg!(&ds);
self.contours[i - 1].position(&ds)
}
}
}
fn s_interval(&self) -> (D, D) {
(D::zero(), self.cached_length.last().unwrap().clone())
}
}
#[cfg(test)]
mod tests {
use nalgebra::ComplexField;
use num_dual::{Dual, Dual64};
use super::*;
struct Line<D, F>
where
D: DualNum<F>,
F: DualNumFloat,
{
p0: nalgebra::Point2<D>,
p1: nalgebra::Point2<D>,
_f: std::marker::PhantomData<F>,
}
impl<D, F> Line<D, F>
where
D: DualNum<F> + ComplexField<RealField = D>,
F: DualNumFloat,
{
fn new(p0: nalgebra::Point2<D>, p1: nalgebra::Point2<D>) -> Self {
Self {
p0,
p1,
_f: std::marker::PhantomData,
}
}
fn length(&self) -> D {
(&self.p1 - &self.p0).norm()
}
}
impl<D, F> Contour<D, F> for Line<D, F>
where
D: DualNum<F> + ComplexField<RealField = D>,
F: DualNumFloat,
{
fn position(&self, s: &D) -> nalgebra::Point2<D> {
&self.p0 + (&self.p1 - &self.p0) * s.clone() / self.length()
}
fn s_interval(&self) -> (D, D) {
(D::zero(), self.length())
}
}
#[test]
fn test_chain_f64() {
let p0 = nalgebra::Point2::new(0.0, 0.0);
let p1 = nalgebra::Point2::new(2.0, 0.0);
let p2 = p1 + nalgebra::Vector2::new(3.0, 4.0);
let l0 = Line::new(p0, p1);
let l1 = Line::new(p1, p2);
let chain = MixedChain::new(vec![Box::new(l0), Box::new(l1)]);
assert_eq!(chain.position(&0.0), p0);
assert_eq!(chain.position(&2.0), p1);
assert_eq!(chain.position(&1.5), (p0 + p1.coords) * 0.75);
assert_eq!(chain.position(&(2.0 + 5.0)), p2);
}
#[test]
fn test_chain_dual() {
let p0 = nalgebra::Point2::new(Dual64::new(0.0, 1.0), Dual64::new(0.0, 0.0));
let p1 = nalgebra::Point2::new(Dual64::new(2.0, 1.0), Dual64::new(0.0, 0.0));
let p2 = p1 + nalgebra::Vector2::new(Dual::new(3.0, 1.0), Dual::new(4.0, 0.0));
let l0 = Line::new(p0, p1);
let l1 = Line::new(p1, p2);
let chain = MixedChain::new(vec![Box::new(l0), Box::new(l1)]);
assert_eq!(chain.position(&Dual::new(0.0, 1.0)), p0);
assert_eq!(chain.position(&Dual::new(2.0, 1.0)), p1);
assert_eq!(
chain.position(&Dual::new(1.5, 1.0)),
(p0 + p1.coords).map(|x| x * 0.75)
);
assert_eq!(
chain.position(&(Dual::new(2.0, 1.0) + Dual::new(5.0, 1.0))),
p2
);
}
}