use crate::binding::PyDpResult;
use crate::binding::sequence::types::PyTrajectoryType;
use crate::distance::base::TrajectoryCalculator;
use crate::distance::distance_type::DistanceType;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3_stub_gen::{define_stub_info_gatherer, derive::gen_stub_pyfunction};
use std::str::FromStr;
#[gen_stub_pyfunction]
#[pyfunction(signature = (t1, t2, dist_type, use_full_matrix=false))]
pub fn discret_frechet(
#[gen_stub(override_type(type_repr="typing.List[typing.List[float]] | numpy.ndarray", imports=("typing", "numpy")))]
t1: &Bound<'_, PyAny>,
#[gen_stub(override_type(type_repr="typing.List[typing.List[float]] | numpy.ndarray", imports=("typing", "numpy")))]
t2: &Bound<'_, PyAny>,
dist_type: String,
use_full_matrix: bool,
) -> PyResult<PyDpResult> {
let traj1 = PyTrajectoryType::try_from(t1)?;
let traj2 = PyTrajectoryType::try_from(t2)?;
let distance_type = DistanceType::from_str(&dist_type).map_err(|_| {
PyValueError::new_err(format!(
"Invalid distance type '{}'. Only 'euclidean' is supported for discret_frechet",
dist_type
))
})?;
if distance_type != DistanceType::Euclidean {
return Err(PyValueError::new_err(format!(
"Invalid distance type '{:?}'. Only 'euclidean' is supported for discret_frechet",
distance_type
)));
}
let calculator = TrajectoryCalculator::new(&traj1, &traj2, DistanceType::Euclidean);
let result = crate::distance::discret_frechet::discret_frechet(&calculator, use_full_matrix);
Ok(PyDpResult { inner: result })
}
#[gen_stub_pyfunction]
#[pyfunction(signature = (distance_matrix, use_full_matrix=false))]
pub fn discret_frechet_with_matrix<'py>(
#[gen_stub(override_type(type_repr="numpy.ndarray", imports=("numpy",)))]
distance_matrix: &Bound<'py, PyAny>,
use_full_matrix: bool,
) -> PyResult<PyDpResult> {
use numpy::{PyArrayMethods, PyUntypedArrayMethods};
let array = distance_matrix
.downcast::<numpy::PyArray2<f64>>()
.map_err(|_| {
PyValueError::new_err("distance_matrix must be a 2D numpy array of float64 values")
})?;
let readonly_array = array.readonly();
let shape = readonly_array.shape();
if shape.len() != 2 {
return Err(PyValueError::new_err(format!(
"distance_matrix must be a 2D array, got shape {:?}",
shape
)));
}
let n0 = shape[0];
let n1 = shape[1];
let view = readonly_array.as_array();
if !view.is_standard_layout() {
return Err(PyValueError::new_err(
"distance_matrix must be contiguous (C-order)",
));
}
let data_ptr = view.as_ptr();
let distance_matrix_slice = unsafe { std::slice::from_raw_parts(data_ptr, n0 * n1) };
let calculator =
crate::distance::base::PrecomputedDistanceCalculator::new(distance_matrix_slice, n0, n1);
let result = crate::distance::discret_frechet::discret_frechet(&calculator, use_full_matrix);
Ok(PyDpResult { inner: result })
}
define_stub_info_gatherer!(stub_info);