use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MeanShiftOptions {
pub bandwidth: Option<f64>,
pub max_iter: usize,
pub tol: f64,
pub bin_seeding: bool,
}
impl Default for MeanShiftOptions {
fn default() -> Self {
Self {
bandwidth: None,
max_iter: 300,
tol: 1e-3,
bin_seeding: false,
}
}
}
#[derive(Debug, Clone)]
pub struct MeanShiftResult<R: Runtime<DType = DType>> {
pub labels: Tensor<R>,
pub cluster_centers: Tensor<R>,
pub n_iter: usize,
}
pub trait MeanShiftAlgorithms<R: Runtime<DType = DType>> {
fn mean_shift(
&self,
data: &Tensor<R>,
options: &MeanShiftOptions,
) -> Result<MeanShiftResult<R>>;
}