use crate::binding::sequence::types::PyTrajectoryType;
use crate::distance::batch::{DistanceAlgorithm, Metric};
use crate::distance::distance_type::DistanceType;
use numpy::{PyArray1, PyArray2};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyfunction,
};
use pyo3_stub_gen_derive::gen_stub_pymethods;
use std::str::FromStr;
#[cfg(feature = "python-binding")]
#[gen_stub_pyclass]
#[pyclass(name = "Metric")]
pub struct PyMetric {
inner: Metric,
}
#[cfg(feature = "python-binding")]
#[gen_stub_pymethods]
#[pymethods]
impl PyMetric {
#[new]
#[pyo3(signature = ())]
fn new() -> PyResult<Self> {
Err(PyValueError::new_err(
"Metric objects cannot be created directly. Use a factory method like Metric.sspd() or Metric.lcss().",
))
}
#[staticmethod]
#[pyo3(signature = (type_d = "euclidean"))]
fn sspd(type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::SSPD, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (type_d = "euclidean"))]
fn dtw(type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::DTW, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (type_d = "euclidean"))]
fn hausdorff(type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::Hausdorff, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (type_d = "euclidean"))]
fn discret_frechet(type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::DiscretFrechet, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (eps, type_d = "euclidean"))]
fn lcss(eps: f64, type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::LCSS { eps }, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (eps, type_d = "euclidean"))]
fn edr(eps: f64, type_d: &str) -> PyResult<Self> {
let calculator = build_calculator(DistanceAlgorithm::EDR { eps }, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = (g, type_d = "euclidean"))]
fn erp(g: Vec<f64>, type_d: &str) -> PyResult<Self> {
if g.len() < 2 {
return Err(PyValueError::new_err(
"ERP 'g' parameter must have at least 2 elements.",
));
}
let algorithm = DistanceAlgorithm::ERP { g: [g[0], g[1]] };
let calculator = build_calculator(algorithm, type_d)?;
Ok(Self { inner: calculator })
}
#[staticmethod]
#[pyo3(signature = ())]
fn edwp() -> PyResult<Self> {
let algorithm = DistanceAlgorithm::EDwP;
let distance_type = DistanceType::Euclidean;
Ok(Self {
inner: Metric::new(algorithm, distance_type),
})
}
#[staticmethod]
#[pyo3(signature = ())]
fn frechet() -> PyResult<Self> {
let algorithm = DistanceAlgorithm::Frechet;
let distance_type = DistanceType::Euclidean;
Ok(Self {
inner: Metric::new(algorithm, distance_type),
})
}
}
fn build_calculator(algorithm: DistanceAlgorithm, type_d: &str) -> PyResult<Metric> {
let distance_type = DistanceType::from_str(type_d).map_err(|_| {
PyValueError::new_err(format!(
"Invalid distance type '{}'. Expected 'euclidean' or 'spherical'.",
type_d
))
})?;
Ok(Metric::new(algorithm, distance_type))
}
#[cfg(feature = "python-binding")]
#[gen_stub_pyfunction]
#[pyfunction(signature = (trajectories, metric, parallel=true, show_progress=false))]
pub fn pdist<'py>(
py: Python<'py>,
#[gen_stub(override_type(type_repr="typing.Sequence[typing.List[typing.List[float]] | numpy.ndarray]", imports=("typing", "numpy")))]
trajectories: &Bound<'py, PyList>,
metric: &PyMetric,
parallel: bool,
show_progress: bool,
) -> PyResult<Py<PyArray1<f64>>> {
if trajectories.len() < 2 {
return Err(PyValueError::new_err(
"pdist requires at least 2 trajectories",
));
}
let trajectories: Vec<PyTrajectoryType> = trajectories
.iter()
.map(|t| {
PyTrajectoryType::try_from(&t)
.map_err(|e| PyValueError::new_err(format!("Failed to convert trajectory: {}", e)))
})
.collect::<Result<Vec<_>, _>>()?;
let metric_inner = metric.inner;
let distances = py
.detach(|| {
crate::distance::batch::pdist(&trajectories, &metric_inner, parallel, show_progress)
})
.map_err(|e| PyValueError::new_err(format!("Failed to compute distances: {}", e)))?;
let array = PyArray1::from_vec(py, distances);
Ok(array.unbind())
}
#[cfg(feature = "python-binding")]
#[gen_stub_pyfunction]
#[pyfunction(signature = (trajectories_a, trajectories_b, metric, parallel=true, show_progress=false))]
pub fn cdist<'py>(
py: Python<'py>,
#[gen_stub(override_type(type_repr="typing.Sequence[typing.List[typing.List[float]] | numpy.ndarray]", imports=("typing", "numpy")))]
trajectories_a: &Bound<'py, PyList>,
#[gen_stub(override_type(type_repr="typing.Sequence[typing.List[typing.List[float]] | numpy.ndarray]", imports=("typing", "numpy")))]
trajectories_b: &Bound<'py, PyList>,
metric: &PyMetric,
parallel: bool,
show_progress: bool,
) -> PyResult<Py<PyArray2<f64>>> {
if trajectories_a.is_empty() {
return Err(PyValueError::new_err(
"cdist requires at least 1 trajectory in the first collection",
));
}
if trajectories_b.is_empty() {
return Err(PyValueError::new_err(
"cdist requires at least 1 trajectory in the second collection",
));
}
let trajectories_a: Vec<PyTrajectoryType> = trajectories_a
.iter()
.map(|t| {
PyTrajectoryType::try_from(&t).map_err(|e| {
PyValueError::new_err(format!(
"Failed to convert trajectory in first collection: {}",
e
))
})
})
.collect::<Result<Vec<_>, _>>()?;
let trajectories_b: Vec<PyTrajectoryType> = trajectories_b
.iter()
.map(|t| {
PyTrajectoryType::try_from(&t).map_err(|e| {
PyValueError::new_err(format!(
"Failed to convert trajectory in second collection: {}",
e
))
})
})
.collect::<Result<Vec<_>, _>>()?;
let metric_inner = metric.inner;
let n_b = trajectories_b.len();
let distances = py
.detach(|| {
crate::distance::batch::cdist(
&trajectories_a,
&trajectories_b,
&metric_inner,
parallel,
show_progress,
)
})
.map_err(|e| PyValueError::new_err(format!("Failed to compute distances: {}", e)))?;
let distances_2d: Vec<Vec<f64>> = distances.chunks(n_b).map(|row| row.to_vec()).collect();
let array = PyArray2::from_vec2(py, &distances_2d)
.map_err(|e| PyValueError::new_err(format!("Failed to create 2D array: {}", e)))?;
Ok(array.unbind())
}
define_stub_info_gatherer!(stub_info);