use crate::DType;
use crate::spatial::{validate_matching_dims, validate_points_2d, validate_points_dtype};
use numr::error::Result;
use numr::ops::DistanceOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::spatial::traits::distance::DistanceMetric;
pub fn cdist_impl<R, C>(
client: &C,
x: &Tensor<R>,
y: &Tensor<R>,
metric: DistanceMetric,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R> + RuntimeClient<R>,
{
validate_points_dtype(x.dtype(), "cdist")?;
validate_points_dtype(y.dtype(), "cdist")?;
validate_points_2d(x.shape(), "cdist")?;
validate_points_2d(y.shape(), "cdist")?;
validate_matching_dims(x.shape(), y.shape(), "cdist")?;
client.cdist(x, y, metric)
}
pub fn pdist_impl<R, C>(client: &C, x: &Tensor<R>, metric: DistanceMetric) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R> + RuntimeClient<R>,
{
validate_points_dtype(x.dtype(), "pdist")?;
validate_points_2d(x.shape(), "pdist")?;
client.pdist(x, metric)
}
pub fn squareform_impl<R, C>(client: &C, condensed: &Tensor<R>, n: usize) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R> + RuntimeClient<R>,
{
client.squareform(condensed, n)
}
pub fn squareform_inverse_impl<R, C>(client: &C, square: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: DistanceOps<R> + RuntimeClient<R>,
{
client.squareform_inverse(square)
}