use crate::distance::base::TrajectoryCalculator;
pub use crate::distance::distance_type::DistanceType;
use crate::err::TrajDistError;
use crate::traits::CoordSequence;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[cfg(feature = "progress")]
use indicatif::{ProgressBar, ProgressStyle};
#[cfg(feature = "progress")]
use std::io::IsTerminal;
#[cfg(feature = "progress")]
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
#[cfg(feature = "progress")]
use std::sync::Arc;
#[cfg(feature = "progress")]
enum ProgressTracker {
Terminal(ProgressBar),
Logging {
counter: Arc<AtomicU64>,
total: u64,
stop_flag: Arc<AtomicBool>,
monitor_handle: Option<std::thread::JoinHandle<()>>,
},
}
#[cfg(feature = "progress")]
impl ProgressTracker {
fn new(total: u64, label: &str) -> Self {
if std::io::stderr().is_terminal() {
let progress_bar = ProgressBar::new(total);
progress_bar.set_style(
ProgressStyle::with_template(
"{msg} {bar:40.cyan/blue} {pos}/{len} [{elapsed_precise}<{eta_precise}, {per_sec}]"
)
.expect("hardcoded progress bar template should be valid")
.progress_chars("█▉▊▋▌▍▎▏ ")
);
progress_bar.set_message(label.to_string());
Self::Terminal(progress_bar)
} else {
let counter = Arc::new(AtomicU64::new(0));
let stop_flag = Arc::new(AtomicBool::new(false));
let monitor_counter = Arc::clone(&counter);
let monitor_stop = Arc::clone(&stop_flag);
let monitor_label = label.to_string();
let monitor_handle = std::thread::spawn(move || {
while !monitor_stop.load(Ordering::Relaxed) {
std::thread::sleep(std::time::Duration::from_secs(10));
if monitor_stop.load(Ordering::Relaxed) {
break;
}
let current = monitor_counter.load(Ordering::Relaxed);
let percentage = if total > 0 {
(current as f64 / total as f64) * 100.0
} else {
0.0
};
eprintln!(
"[progress] {}: {}/{} ({:.1}%)",
monitor_label, current, total, percentage
);
}
});
Self::Logging {
counter,
total,
stop_flag,
monitor_handle: Some(monitor_handle),
}
}
}
fn inc(&self, delta: u64) {
match self {
Self::Terminal(pb) => pb.inc(delta),
Self::Logging { counter, .. } => {
counter.fetch_add(delta, Ordering::Relaxed);
}
}
}
fn finish(&mut self) {
match self {
Self::Terminal(pb) => pb.finish(),
Self::Logging {
counter,
total,
stop_flag,
monitor_handle,
} => {
stop_flag.store(true, Ordering::Relaxed);
if let Some(handle) = monitor_handle.take() {
handle.thread().unpark();
let _ = handle.join();
}
let final_count = counter.load(Ordering::Relaxed);
eprintln!("[progress] done: {}/{}", final_count, total);
}
}
}
}
#[cfg(all(feature = "progress", feature = "parallel"))]
enum SyncProgressRef<'a> {
Terminal(&'a ProgressBar),
Logging(&'a Arc<AtomicU64>),
None,
}
#[cfg(all(feature = "progress", feature = "parallel"))]
impl SyncProgressRef<'_> {
fn inc(&self, delta: u64) {
match self {
Self::Terminal(pb) => pb.inc(delta),
Self::Logging(counter) => {
counter.fetch_add(delta, Ordering::Relaxed);
}
Self::None => {}
}
}
}
#[cfg(all(feature = "progress", feature = "parallel"))]
unsafe impl Sync for SyncProgressRef<'_> {}
#[cfg(all(feature = "progress", feature = "parallel"))]
impl ProgressTracker {
fn as_sync_ref(&self) -> SyncProgressRef<'_> {
match self {
Self::Terminal(pb) => SyncProgressRef::Terminal(pb),
Self::Logging { counter, .. } => SyncProgressRef::Logging(counter),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DistanceAlgorithm {
SSPD,
DTW,
Hausdorff,
LCSS { eps: f64 },
EDR { eps: f64 },
ERP { g: [f64; 2] },
DiscretFrechet,
EDwP,
Frechet,
}
#[derive(Debug, Clone, Copy)]
pub struct Metric {
algorithm: DistanceAlgorithm,
distance_type: DistanceType,
}
impl Metric {
pub fn new(algorithm: DistanceAlgorithm, distance_type: DistanceType) -> Self {
Self {
algorithm,
distance_type,
}
}
pub fn algorithm(&self) -> DistanceAlgorithm {
self.algorithm
}
pub fn distance_type(&self) -> DistanceType {
self.distance_type
}
}
impl std::fmt::Display for Metric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let algo_name = match self.algorithm {
DistanceAlgorithm::SSPD => "SSPD",
DistanceAlgorithm::DTW => "DTW",
DistanceAlgorithm::Hausdorff => "Hausdorff",
DistanceAlgorithm::LCSS { .. } => "LCSS",
DistanceAlgorithm::EDR { .. } => "EDR",
DistanceAlgorithm::ERP { .. } => "ERP",
DistanceAlgorithm::DiscretFrechet => "Discret Frechet",
DistanceAlgorithm::EDwP => "EDwP",
DistanceAlgorithm::Frechet => "Frechet",
};
write!(f, "{}/{}", algo_name, self.distance_type)
}
}
impl Metric {
pub fn distance<T: CoordSequence>(&self, traj1: &T, traj2: &T) -> f64 {
match self.algorithm {
DistanceAlgorithm::SSPD => {
crate::distance::sspd::sspd(traj1, traj2, self.distance_type)
}
DistanceAlgorithm::DTW => {
let calculator = TrajectoryCalculator::new(traj1, traj2, self.distance_type);
crate::distance::dtw::dtw(&calculator, false).distance
}
DistanceAlgorithm::Hausdorff => {
crate::distance::hausdorff::hausdorff(traj1, traj2, self.distance_type)
}
DistanceAlgorithm::LCSS { eps } => {
let calculator = TrajectoryCalculator::new(traj1, traj2, self.distance_type);
crate::distance::lcss::lcss(&calculator, eps, false).distance
}
DistanceAlgorithm::EDR { eps } => {
let calculator = TrajectoryCalculator::new(traj1, traj2, self.distance_type);
crate::distance::edr::edr(&calculator, eps, false).distance
}
DistanceAlgorithm::ERP { g } => {
let calculator = TrajectoryCalculator::new(traj1, traj2, self.distance_type);
crate::distance::erp::erp_standard(&calculator, &g, false).distance
}
DistanceAlgorithm::DiscretFrechet => {
let calculator = TrajectoryCalculator::new(traj1, traj2, self.distance_type);
crate::distance::discret_frechet::discret_frechet(&calculator, false).distance
}
DistanceAlgorithm::EDwP => {
if self.distance_type != DistanceType::Euclidean {
panic!("EDwP only supports Euclidean distance");
}
crate::distance::edwp::edwp(traj1, traj2, false).distance
}
DistanceAlgorithm::Frechet => {
if self.distance_type != DistanceType::Euclidean {
panic!("Frechet only supports Euclidean distance");
}
crate::distance::frechet::frechet(traj1, traj2)
}
}
}
}
pub fn pdist<T>(
trajectories: &[T],
metric: &Metric,
parallel: bool,
show_progress: bool,
) -> Result<Vec<f64>, TrajDistError>
where
T: CoordSequence + Sync,
{
let n = trajectories.len();
if n < 2 {
return Err(TrajDistError::InvalidParams(
"pdist requires at least 2 trajectories".to_string(),
));
}
#[cfg(feature = "progress")]
let mut tracker = if show_progress {
let total = (n * (n - 1) / 2) as u64;
Some(ProgressTracker::new(total, &format!("pdist [{}]", metric)))
} else {
None
};
#[cfg(not(feature = "progress"))]
let _ = show_progress;
#[cfg(feature = "parallel")]
let distances = if parallel {
compute_pdist_parallel(
trajectories,
metric,
#[cfg(feature = "progress")]
&tracker,
)
} else {
compute_pdist_sequential(
trajectories,
metric,
#[cfg(feature = "progress")]
&tracker,
)
};
#[cfg(not(feature = "parallel"))]
{
let _ = parallel;
}
#[cfg(not(feature = "parallel"))]
let distances = compute_pdist_sequential(
trajectories,
metric,
#[cfg(feature = "progress")]
&tracker,
);
#[cfg(feature = "progress")]
if let Some(ref mut t) = tracker {
t.finish();
}
Ok(distances)
}
fn compute_pdist_sequential<T: CoordSequence>(
trajectories: &[T],
metric: &Metric,
#[cfg(feature = "progress")] tracker: &Option<ProgressTracker>,
) -> Vec<f64> {
let n = trajectories.len();
let mut distances = Vec::with_capacity(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let dist = metric.distance(&trajectories[i], &trajectories[j]);
distances.push(dist);
#[cfg(feature = "progress")]
if let Some(t) = tracker {
t.inc(1);
}
}
}
distances
}
#[cfg(feature = "parallel")]
fn compute_pdist_parallel<T: CoordSequence + Sync>(
trajectories: &[T],
metric: &Metric,
#[cfg(feature = "progress")] tracker: &Option<ProgressTracker>,
) -> Vec<f64> {
let n = trajectories.len();
let pairs: Vec<(usize, usize)> = (0..n)
.flat_map(|i| ((i + 1)..n).map(move |j| (i, j)))
.collect();
#[cfg(feature = "progress")]
let sync_ref = match tracker {
Some(t) => t.as_sync_ref(),
None => SyncProgressRef::None,
};
#[cfg(feature = "progress")]
{
let progress = &sync_ref;
pairs
.into_par_iter()
.map(|(i, j)| {
let dist = metric.distance(&trajectories[i], &trajectories[j]);
progress.inc(1);
dist
})
.collect()
}
#[cfg(not(feature = "progress"))]
{
pairs
.into_par_iter()
.map(|(i, j)| metric.distance(&trajectories[i], &trajectories[j]))
.collect()
}
}
pub fn cdist<T>(
trajectories_a: &[T],
trajectories_b: &[T],
metric: &Metric,
parallel: bool,
show_progress: bool,
) -> Result<Vec<f64>, TrajDistError>
where
T: CoordSequence + Sync,
{
let n_a = trajectories_a.len();
let n_b = trajectories_b.len();
if n_a == 0 {
return Err(TrajDistError::InvalidParams(
"cdist requires at least 1 trajectory in the first collection".to_string(),
));
}
if n_b == 0 {
return Err(TrajDistError::InvalidParams(
"cdist requires at least 1 trajectory in the second collection".to_string(),
));
}
#[cfg(feature = "progress")]
let mut tracker = if show_progress {
let total = (n_a * n_b) as u64;
Some(ProgressTracker::new(total, &format!("cdist [{}]", metric)))
} else {
None
};
#[cfg(not(feature = "progress"))]
let _ = show_progress;
let mut distances = vec![0.0; n_a * n_b];
#[cfg(feature = "parallel")]
{
if parallel {
compute_cdist_parallel(
trajectories_a,
trajectories_b,
&mut distances,
metric,
#[cfg(feature = "progress")]
&tracker,
);
} else {
compute_cdist_sequential(
trajectories_a,
trajectories_b,
&mut distances,
metric,
#[cfg(feature = "progress")]
&tracker,
);
}
}
#[cfg(not(feature = "parallel"))]
{
let _ = parallel;
compute_cdist_sequential(
trajectories_a,
trajectories_b,
&mut distances,
metric,
#[cfg(feature = "progress")]
&tracker,
);
}
#[cfg(feature = "progress")]
if let Some(ref mut t) = tracker {
t.finish();
}
Ok(distances)
}
fn compute_cdist_sequential<T: CoordSequence>(
trajectories_a: &[T],
trajectories_b: &[T],
distances: &mut [f64],
metric: &Metric,
#[cfg(feature = "progress")] tracker: &Option<ProgressTracker>,
) {
let n_b = trajectories_b.len();
for (i, traj_a) in trajectories_a.iter().enumerate() {
for (j, traj_b) in trajectories_b.iter().enumerate() {
let idx = i * n_b + j;
distances[idx] = metric.distance(traj_a, traj_b);
}
#[cfg(feature = "progress")]
if let Some(t) = tracker {
t.inc(n_b as u64);
}
}
}
#[cfg(feature = "parallel")]
fn compute_cdist_parallel<T: CoordSequence + Sync>(
trajectories_a: &[T],
trajectories_b: &[T],
distances: &mut [f64],
metric: &Metric,
#[cfg(feature = "progress")] tracker: &Option<ProgressTracker>,
) {
let n_b = trajectories_b.len();
#[cfg(feature = "progress")]
let sync_ref = match tracker {
Some(t) => t.as_sync_ref(),
None => SyncProgressRef::None,
};
#[cfg(feature = "progress")]
{
let progress = &sync_ref;
distances
.par_chunks_mut(n_b)
.enumerate()
.for_each(|(i, row)| {
for (j, dist) in row.iter_mut().enumerate() {
*dist = metric.distance(&trajectories_a[i], &trajectories_b[j]);
}
progress.inc(n_b as u64);
});
}
#[cfg(not(feature = "progress"))]
{
distances
.par_chunks_mut(n_b)
.enumerate()
.for_each(|(i, row)| {
for (j, dist) in row.iter_mut().enumerate() {
*dist = metric.distance(&trajectories_a[i], &trajectories_b[j]);
}
});
}
}