use std::{
any::type_name,
fs,
fs::File,
io::{Read, Write},
path::Path,
};
use crate::catenary::{Catenary, P2};
use anyhow::anyhow;
use bincode::config;
use nalgebra::{ComplexField, Point2, RealField};
use ndarray::{Array1, Array2};
use num_dual::{DualNum, DualNumFloat};
use num_traits::{AsPrimitive, FloatConst};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct Table<F>
where
F: DualNum<F> + DualNumFloat, {
table: Array2<Catenary<F, F>>,
theta: Array1<F>,
d_theta: F,
frac_dist_length: Array1<F>,
d_frac_dist_length: F,
}
pub type Table32 = Table<f32>;
pub type Table64 = Table<f64>;
impl<F> Table<F>
where
F: DualNum<F>
+ DualNumFloat
+ std::default::Default
+ std::convert::From<f32>
+ RealField
+ Serialize
+ for<'a> Deserialize<'a>
+ AsPrimitive<usize>
+ FloatConst,
{
pub fn generate(n_theta: usize, n_length: usize) -> anyhow::Result<Self> {
let mut table = Array2::<Catenary<F, F>>::default((n_theta, n_length));
let theta = Array1::linspace(F::zero(), F::FRAC_PI_2(), n_theta);
let d_theta = theta[1] - theta[0];
let frac_dist_length = Array1::linspace(F::zero(), F::one(), n_length);
let d_frac_dist_length = frac_dist_length[1] - frac_dist_length[0];
let p0 = P2::origin();
let mid_id = n_length / 2;
theta.iter().enumerate().try_for_each(|(i_t, &t)| {
println!("{}/{n_theta} ({}%)", i_t + 1, (i_t + 1) * 100 / n_theta);
let p1 = P2::new(ComplexField::cos(t), ComplexField::sin(t));
table[[i_t, mid_id]] = Catenary::<F, F>::from_points_length(
&p0,
&(p1 * frac_dist_length[mid_id]),
1.0.into(),
)
.ok_or(anyhow::anyhow!(
"Failed to generate mid catenary for theta index {i_t}"
))?;
frac_dist_length
.iter()
.enumerate()
.skip(mid_id + 1)
.map(|(i_d, d)| (i_d, d, i_d - 1))
.chain(
frac_dist_length
.iter()
.enumerate()
.take(mid_id)
.rev()
.map(|(i_d, d)| (i_d, d, i_d + 1)),
)
.try_for_each(|(i_d, d, i_d0)| {
let cat0 = table[[i_t, i_d0]];
let cat = Catenary::<F, F>::from_points_length_init(
&p0,
&(p1 * *d),
1.0.into(),
&cat0,
)
.ok_or(anyhow::anyhow!(
"Failed to generate catenary for theta index {i_t} and length index {i_d}"
))?;
table[[i_t, i_d]] = cat;
Ok::<(), anyhow::Error>(())
})
})?;
Ok(Self {
table,
theta,
d_theta,
frac_dist_length,
d_frac_dist_length,
})
}
pub fn save(&self, path: &str) -> anyhow::Result<()> {
let buffer = bincode::serde::encode_to_vec(self, config::standard())?;
let mut file = File::create(path)?;
file.write_all(&buffer)?;
Ok(())
}
pub fn save_to_cache(self) -> anyhow::Result<Self> {
let cache_dir = dirs::cache_dir()
.ok_or(anyhow!("Could not determine cache directory"))?
.join("catenary");
fs::create_dir_all(&cache_dir)?;
let table_t = self.theta.len();
let table_d = self.frac_dist_length.len();
let file_path = cache_dir.join(make_dataset_name::<F>(table_t, table_d));
let bytes = bincode::serde::encode_to_vec(&self, config::standard())?;
fs::write(file_path, bytes)?;
Ok(self)
}
pub fn load(path: &str) -> anyhow::Result<Self> {
let mut file = File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
Ok(bincode::serde::decode_from_slice(&buffer[..], config::standard())?.0)
}
#[must_use]
pub fn interpolate<D: DualNum<F>>(&self, theta: D, frac_dist_length: D) -> Catenary<D, F> {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let i_t = ComplexField::floor(theta.re() / self.d_theta).as_();
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let i_d = ComplexField::floor(frac_dist_length.re() / self.d_frac_dist_length).as_();
let table_t0d0 = &self.table[[i_t, i_d]];
let table_t1d0 = &self.table[[i_t + 1, i_d]];
let table_t0d1 = &self.table[[i_t, i_d + 1]];
let table_t1d1 = &self.table[[i_t + 1, i_d + 1]];
let p = (theta - self.theta[i_t]) / self.d_theta;
let table_d0_diff = table_t1d0 - table_t0d0;
let table_ti_d0 = &(table_d0_diff * &p) + table_t0d0;
let table_d1_diff = table_t1d1 - table_t0d1;
let table_ti_d1 = &(table_d1_diff * &p) + table_t0d1;
let q = (frac_dist_length - self.frac_dist_length[i_d]) / self.d_frac_dist_length;
&table_ti_d0 + &((&table_ti_d1 - &table_ti_d0) * &q)
}
#[must_use]
pub fn get_catenary<D: DualNum<F> + RealField>(
&self,
p0: &Point2<D>,
p1: &Point2<D>,
l: &D,
) -> Option<Catenary<D, F>> {
let p01 = p1 - p0;
let dist = p01.norm();
if l < &dist {
return None;
}
if dist < D::from(1e-4.into()) {
return Some(Catenary {
a: D::zero(),
c: p1.x.clone(),
h: p1.y.clone() - l.clone() / D::from(2.0.into()),
s_0: -l.clone() / D::from(2.0.into()),
s_1: l.clone() / D::from(2.0.into()),
_f: std::marker::PhantomData,
});
}
let theta = (p01.y.clone() / p01.x.clone()).atan();
let frac_dist_length = dist / l;
match (p0.x.re() <= p1.x.re(), p0.y.re() <= p1.y.re()) {
(true, true) => {
let mut cat = self.interpolate::<D>(theta, frac_dist_length);
cat.a *= l.clone();
cat.c = cat.c * l + p0.x.clone();
cat.h = cat.h * l + p0.y.clone();
cat.s_0 = cat.s_0 * l;
cat.s_1 = cat.s_1 * l;
Some(cat)
}
(true, false) => {
let mut cat = self.interpolate::<D>(-theta, frac_dist_length);
cat.a *= l.clone();
cat.c = -cat.c * l + p1.x.clone();
cat.h = cat.h * l + p1.y.clone();
std::mem::swap(&mut cat.s_0, &mut cat.s_1);
cat.s_0 = -cat.s_0 * l;
cat.s_1 = -cat.s_1 * l;
Some(cat)
}
(false, true) => {
let mut cat = self.interpolate::<D>(-theta, frac_dist_length);
cat.a *= l.clone();
cat.c = -cat.c * l + p0.x.clone();
cat.h = cat.h * l + p0.y.clone();
cat.s_0 = -cat.s_0 * l;
cat.s_1 = -cat.s_1 * l;
Some(cat)
}
(false, false) => {
let mut cat = self.interpolate::<D>(theta, frac_dist_length);
cat.a *= l.clone();
cat.c = cat.c * l + p1.x.clone();
cat.h = cat.h * l + p1.y.clone();
std::mem::swap(&mut cat.s_0, &mut cat.s_1);
cat.s_0 = cat.s_0 * l;
cat.s_1 = cat.s_1 * l;
Some(cat)
}
}
}
pub async fn download(table_t: usize, table_d: usize) -> anyhow::Result<Self> {
let file_path = Path::new(DATASET_URL).join(make_dataset_name::<F>(table_t, table_d));
println!("Downloading dataset [{}]...", file_path.display());
let bytes = reqwest::get(file_path.to_str().unwrap())
.await?
.bytes()
.await?;
Ok(bincode::serde::decode_from_slice(&bytes, config::standard())?.0)
}
pub fn load_from_cache(table_t: usize, table_d: usize) -> anyhow::Result<Self> {
let cache_dir = dirs::cache_dir()
.ok_or(anyhow!("Could not determine cache directory"))?
.join("catenary");
let file_path = cache_dir.join(make_dataset_name::<F>(table_t, table_d));
if !file_path.exists() {
return Err(anyhow!("Dataset not found in cache"));
}
let bytes = fs::read(file_path)?;
Ok(bincode::serde::decode_from_slice(&bytes, config::standard())?.0)
}
pub async fn load_from_cache_or_download_or_generate_then_cache(
table_t: usize,
table_d: usize,
) -> anyhow::Result<Self> {
match Self::load_from_cache(table_t, table_d) {
Ok(table) => Ok(table),
Err(_e) => {
match Self::download(table_t, table_d).await {
Ok(table) => Ok(table),
Err(_e) => {
Self::generate(table_t, table_d)
}
}?
.save_to_cache()
}
}
}
}
const DATASET_URL: &str = "https://gitlab.com/youbihub/catenary/-/raw/main/output/";
pub fn make_dataset_name<T>(table_t: usize, table_d: usize) -> String {
format!(
"{}_catenary_table_t{}_d{}.bin",
type_name::<T>(),
table_t,
table_d
)
}
#[cfg(test)]
mod tests {
use core::panic;
use std::fs;
use approx::{assert_abs_diff_eq, assert_relative_eq};
use nalgebra::{Const, Point2, Vector3};
use num_dual::{Derivative, Dual, DualNum, DualVec64};
use crate::{CatMaker, Catenary};
use super::Table;
fn assert_abs_diff_cat_re<D, E, F>(cat1: &Catenary<D, F>, cat2: &Catenary<E, F>, epsilon: F)
where
D: DualNum<F>,
E: DualNum<F>,
F: approx::AbsDiffEq<Epsilon = F> + std::fmt::Debug + std::marker::Copy,
{
assert_abs_diff_eq!(cat1.a.re(), cat2.a.re(), epsilon = epsilon);
assert_abs_diff_eq!(cat1.c.re(), cat2.c.re(), epsilon = epsilon);
assert_abs_diff_eq!(cat1.h.re(), cat2.h.re(), epsilon = epsilon);
assert_abs_diff_eq!(cat1.s_0.re(), cat2.s_0.re(), epsilon = epsilon);
assert_abs_diff_eq!(cat1.s_1.re(), cat2.s_1.re(), epsilon = epsilon);
}
fn assert_rel_cat_re<D, E, F>(cat1: &Catenary<D, F>, cat2: &Catenary<E, F>, max_relative: F)
where
D: DualNum<F>,
E: DualNum<F>,
F: approx::RelativeEq<Epsilon = F> + std::fmt::Debug + std::marker::Copy,
{
assert_relative_eq!(cat1.a.re(), cat2.a.re(), max_relative = max_relative);
assert_relative_eq!(cat1.c.re(), cat2.c.re(), max_relative = max_relative);
assert_relative_eq!(cat1.h.re(), cat2.h.re(), max_relative = max_relative);
assert_relative_eq!(cat1.s_0.re(), cat2.s_0.re(), max_relative = max_relative);
assert_relative_eq!(cat1.s_1.re(), cat2.s_1.re(), max_relative = max_relative);
}
#[test]
fn generate() {
let n_theta = 10;
let n_length = 10;
assert!(Table::<f64>::generate(n_theta, n_length).is_ok());
}
#[test]
fn interpolate() {
let n_theta = 10;
let n_length = 12;
let mut table = Table::generate(n_theta, n_length).unwrap();
let theta_r = std::f64::consts::FRAC_PI_2 / 9.0 * 3.25;
let frac_dist_length_r = 1.0 / 11.0 * 2.1;
let theta = DualVec64::new(theta_r, Derivative::new(Some(Vector3::new(1.0, 0.0, 0.0))));
let frac_dist_length = DualVec64::new(
frac_dist_length_r,
Derivative::new(Some(Vector3::new(0.0, 1.0, 0.0))),
);
table.table[[3, 2]] = CatMaker::a(10.0)
.c(100.0)
.h(1000.0)
.s_0(10_000.0)
.s_1(100_000.0);
table.table[[4, 2]] = CatMaker::a(20.0)
.c(200.0)
.h(2000.0)
.s_0(20_000.0)
.s_1(200_000.0);
table.table[[3, 3]] = CatMaker::a(30.0)
.c(300.0)
.h(3000.0)
.s_0(30_000.0)
.s_1(300_000.0);
table.table[[4, 3]] = CatMaker::a(40.0)
.c(400.0)
.h(4000.0)
.s_0(40_000.0)
.s_1(400_000.0);
let expected = CatMaker::a(14.5).c(145.).h(1450.).s_0(14500.).s_1(145_000.);
let expected_dtheta = 2.0 * 9.0 / std::f64::consts::PI;
let expected_dlength = 11.0;
let cat = table.interpolate(theta, frac_dist_length);
println!("cat: {cat:#?}",);
assert_abs_diff_cat_re(&cat, &expected, 1e-6);
let v_a = cat.a.eps.unwrap_generic(Const::<3>, Const::<1>);
let v_c = cat.c.eps.unwrap_generic(Const::<3>, Const::<1>);
let v_h = cat.h.eps.unwrap_generic(Const::<3>, Const::<1>);
let v_s_0 = cat.s_0.eps.unwrap_generic(Const::<3>, Const::<1>);
let v_s_1 = cat.s_1.eps.unwrap_generic(Const::<3>, Const::<1>);
assert_abs_diff_eq!(v_a[(0, 0)], expected_dtheta * 10.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_c[(0, 0)], expected_dtheta * 100.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_h[(0, 0)], expected_dtheta * 1000.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_s_0[(0, 0)], expected_dtheta * 10_000.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_s_1[(0, 0)], expected_dtheta * 100_000.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_a[(1, 0)], expected_dlength * 20.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_c[(1, 0)], expected_dlength * 200.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_h[(1, 0)], expected_dlength * 2000.0, epsilon = 1e-6);
assert_abs_diff_eq!(v_s_0[(1, 0)], expected_dlength * 20_000., epsilon = 1e-6);
assert_abs_diff_eq!(v_s_1[(1, 0)], expected_dlength * 200_000.0, epsilon = 1e-6);
}
#[test]
fn save_and_load() {
let n_theta = 10;
let n_length = 10;
let table = Table::<f64>::generate(n_theta, n_length).unwrap();
let path = "test_identity.bin";
table.save(path).unwrap();
let table2 = Table::load(path).unwrap();
fs::remove_file(path).unwrap();
assert_eq!(table.table, table2.table);
}
#[test]
fn get_catenary_north_east() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(11.0)
.c(22.0)
.h(16.0)
.s_0_from_x(5.0)
.s_1_from_x(40.0);
let (p0, p1) = cat_ref.end_points();
let l = cat_ref.length();
let cat = table.get_catenary::<f64>(&p0, &p1, &l).unwrap();
assert_rel_cat_re(&cat, &cat_ref, 1e-5);
}
#[test]
fn get_catenary_north_west() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(11.0)
.c(22.0)
.h(16.0)
.s_0_from_x(33.0)
.s_1_from_x(-2.0);
let (p0, p1) = cat_ref.end_points();
let l = cat_ref.length();
let cat = table.get_catenary::<f64>(&p0, &p1, &l).unwrap();
assert_rel_cat_re(&cat, &cat_ref, 1e-5);
}
#[test]
fn get_catenary_south_east() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(11.0).c(22.0).h(16.0).s_0_from_x(5.0).s_1(40.0);
let (p0, p1) = cat_ref.end_points();
let l = cat_ref.length();
println!("p0: {p0}, p1: {p1}, l: {l}");
let cat = table.get_catenary::<f64>(&p0, &p1, &l).unwrap();
assert_rel_cat_re(&cat, &cat_ref, 1e-5);
}
#[test]
fn get_catenary_south_west() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(11.0)
.c(22.0)
.h(16.0)
.s_0_from_x(5.0)
.s_1_from_x(32.0);
let (p0, p1) = cat_ref.end_points();
let l = cat_ref.length();
let cat = table.get_catenary::<f64>(&p0, &p1, &l).unwrap();
assert_rel_cat_re(&cat, &cat_ref, 1e-5);
}
#[test]
fn get_catenary_none() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let p0 = Point2::new(0.0, 0.0);
let p1 = Point2::new(1.0, 0.0);
let l = 0.5;
let cat = table.get_catenary::<f64>(&p0, &p1, &l);
assert!(cat.is_none());
}
#[test]
fn get_catenary_south_west_dual() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(11.0)
.c(22.0)
.h(16.0)
.s_0_from_x(5.0)
.s_1_from_x(32.0);
let (p0, p1) = cat_ref.end_points();
let p0_dual = p0.map(Dual::from_re);
let p1_dual = p1.map(Dual::from_re);
let l = cat_ref.length();
let l_dual = Dual::from_re(l);
let cat = table
.get_catenary::<Dual<f64, f64>>(&p0_dual, &p1_dual, &l_dual)
.unwrap();
assert_rel_cat_re(&cat, &cat_ref, 1e-5);
}
#[test]
fn get_catenary_line_dual() {
let table = Table::load("output/f64_catenary_table_t1000_d1000.bin").unwrap();
let cat_ref = CatMaker::a(0.0).c(0.0).h(0.0).s_0(-1.0).s_1(1.0);
let (p0, p1) = cat_ref.end_points();
let p0_dual = p0.map(Dual::from_re);
let p1_dual = p1.map(Dual::from_re);
let l = cat_ref.length();
let l_dual = Dual::new(l, 1.0);
let cat_dual = table
.get_catenary::<Dual<f64, f64>>(&p0_dual, &p1_dual, &l_dual)
.unwrap();
assert_rel_cat_re(&cat_dual, &cat_ref, 1e-5);
assert_relative_eq!(cat_dual.h.eps, -0.5, max_relative = 1e-5);
}
#[tokio::test]
async fn test_get_dataset_path() {
assert!(Table::<f32>::download(10, 10).await.is_ok());
}
#[test]
fn test_save_load_cache() {
let dataset = Table::<f32>::generate(7, 9)
.unwrap()
.save_to_cache()
.unwrap();
let loaded_dataset = Table::<f32>::load_from_cache(7, 9).unwrap();
let p0 = Point2::new(0.0, 0.0);
let p1 = Point2::new(1.0, 1.0);
let l = 3.0;
assert_eq!(
dataset.get_catenary(&p0, &p1, &l),
loaded_dataset.get_catenary(&p0, &p1, &l)
);
}
#[tokio::test]
async fn test_load_from_cache_or_download_or_generate_then_cache() {
assert!(
Table::<f32>::load_from_cache_or_download_or_generate_then_cache(11, 12)
.await
.is_ok()
);
}
}