use super::distance::{DistanceMetric, SquaredEuclidean};
use super::flat::DataRef;
use super::util;
use crate::error::{Error, Result};
use rand::prelude::*;
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Constraint {
MustLink(usize, usize),
CannotLink(usize, usize),
}
#[derive(Debug, Clone)]
pub struct CopKmeans<D: DistanceMetric = SquaredEuclidean> {
k: usize,
max_iter: usize,
tol: f64,
seed: Option<u64>,
metric: D,
}
impl CopKmeans<SquaredEuclidean> {
pub fn new(k: usize) -> Self {
assert!(k > 0, "k must be at least 1");
Self {
k,
max_iter: 100,
tol: 1e-4,
seed: None,
metric: SquaredEuclidean,
}
}
}
impl<D: DistanceMetric> CopKmeans<D> {
pub fn with_metric(k: usize, metric: D) -> Self {
assert!(k > 0, "k must be at least 1");
Self {
k,
max_iter: 100,
tol: 1e-4,
seed: None,
metric,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
fn violates_cannot_link(
point_idx: usize,
cluster: usize,
labels: &[Option<usize>],
cannot_links: &[Vec<usize>],
) -> bool {
for &other in &cannot_links[point_idx] {
if labels[other] == Some(cluster) {
return true;
}
}
false
}
fn must_link_cluster(
point_idx: usize,
labels: &[Option<usize>],
must_links: &[Vec<usize>],
) -> Option<usize> {
for &other in &must_links[point_idx] {
if let Some(c) = labels[other] {
return Some(c);
}
}
None
}
fn transitive_closure(adjacency: &[Vec<usize>], n: usize) -> Vec<Vec<usize>> {
let mut visited = vec![false; n];
let mut groups = Vec::new();
for start in 0..n {
if visited[start] || adjacency[start].is_empty() {
continue;
}
let mut group = Vec::new();
let mut stack = vec![start];
while let Some(node) = stack.pop() {
if visited[node] {
continue;
}
visited[node] = true;
group.push(node);
for &neighbor in &adjacency[node] {
if !visited[neighbor] {
stack.push(neighbor);
}
}
}
groups.push(group);
}
groups
}
fn constrained_assign(
&self,
data: &(impl DataRef + ?Sized),
centroids: &[Vec<f32>],
must_links: &[Vec<usize>],
cannot_links: &[Vec<usize>],
order: &[usize],
) -> Result<Vec<Option<usize>>> {
let n = data.n();
let mut labels: Vec<Option<usize>> = vec![None; n];
let mut candidates: Vec<(usize, f32)> = Vec::with_capacity(self.k);
for &i in order {
if let Some(forced) = Self::must_link_cluster(i, &labels, must_links) {
if Self::violates_cannot_link(i, forced, &labels, cannot_links) {
return Err(Error::ConstraintViolation(format!(
"point {i}: must-link forces cluster {forced} but cannot-link forbids it"
)));
}
labels[i] = Some(forced);
continue;
}
candidates.clear();
candidates
.extend((0..self.k).map(|k| (k, self.metric.distance(data.row(i), ¢roids[k]))));
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut assigned = false;
for (k, _) in &candidates {
if !Self::violates_cannot_link(i, *k, &labels, cannot_links) {
labels[i] = Some(*k);
assigned = true;
break;
}
}
if !assigned {
return Err(Error::ConstraintViolation(format!(
"point {i}: no valid cluster assignment exists"
)));
}
}
Ok(labels)
}
}
impl<D: DistanceMetric> CopKmeans<D> {
pub fn fit_predict_constrained(
&self,
data: &(impl DataRef + ?Sized),
constraints: &[Constraint],
) -> Result<Vec<usize>> {
if data.n() == 0 {
return Err(Error::EmptyInput);
}
if self.k == 0 {
return Err(Error::InvalidParameter {
name: "k",
message: "must be at least 1",
});
}
let n = data.n();
let d = data.d();
if d == 0 {
return Err(Error::InvalidParameter {
name: "dimension",
message: "must be at least 1",
});
}
if self.k > n {
return Err(Error::InvalidClusterCount {
requested: self.k,
n_items: n,
});
}
for i in 0..n {
if data.row(i).len() != d {
return Err(Error::DimensionMismatch {
expected: d,
found: data.row(i).len(),
});
}
}
util::validate_finite(data)?;
for c in constraints {
let (a, b) = match c {
Constraint::MustLink(a, b) | Constraint::CannotLink(a, b) => (*a, *b),
};
if a >= n || b >= n {
return Err(Error::InvalidParameter {
name: "constraint index",
message: "exceeds dataset size",
});
}
}
let mut must_links: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut cannot_links: Vec<Vec<usize>> = vec![Vec::new(); n];
for c in constraints {
match c {
Constraint::MustLink(a, b) => {
must_links[*a].push(*b);
must_links[*b].push(*a);
}
Constraint::CannotLink(a, b) => {
cannot_links[*a].push(*b);
cannot_links[*b].push(*a);
}
}
}
let must_link_groups = Self::transitive_closure(&must_links, n);
let mut must_links_closed: Vec<Vec<usize>> = vec![Vec::new(); n];
for group in &must_link_groups {
for &member in group {
for &other in group {
if other != member {
must_links_closed[member].push(other);
}
}
}
}
let must_links = must_links_closed;
for c in constraints {
if let Constraint::CannotLink(a, b) = c {
if must_links[*a].contains(b) {
return Err(Error::ConstraintViolation(format!(
"points {} and {} are both must-linked (transitively) and cannot-linked",
a, b
)));
}
}
}
let mut rng = match self.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_os_rng(),
};
let mut centroids = util::kmeanspp_init(data, self.k, &self.metric, 2.0, &mut rng);
let mut new_centroids = vec![vec![0.0f32; d]; self.k];
let mut counts = vec![0usize; self.k];
let mut sums_f64 = vec![vec![0.0f64; d]; self.k];
let mut order: Vec<usize> = (0..n).collect();
let effective_tol = (self.tol * util::mean_variance(data) * self.k as f64) as f32;
for _ in 0..self.max_iter {
order.iter_mut().enumerate().for_each(|(i, v)| *v = i);
order.shuffle(&mut rng);
let labels =
self.constrained_assign(data, ¢roids, &must_links, &cannot_links, &order)?;
for c in &mut new_centroids {
c.fill(0.0);
}
counts.fill(0);
for s in &mut sums_f64 {
s.fill(0.0);
}
for (i, label) in labels.iter().enumerate() {
let k = label.expect("constrained_assign guarantees all labels are Some");
let row = data.row(i);
for j in 0..d {
sums_f64[k][j] += row[j] as f64;
}
counts[k] += 1;
}
for k in 0..self.k {
if counts[k] > 0 {
let divisor = counts[k] as f64;
for j in 0..d {
new_centroids[k][j] = (sums_f64[k][j] / divisor) as f32;
}
} else {
let idx = rng.random_range(0..n);
new_centroids[k] = data.row(idx).to_vec();
}
}
if self.metric.normalize_centroids() {
for c in &mut new_centroids {
let norm: f32 = c.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for val in c.iter_mut() {
*val /= norm;
}
}
}
}
let shift: f32 = centroids
.iter()
.zip(new_centroids.iter())
.flat_map(|(old, new)| old.iter().zip(new.iter()).map(|(a, b)| (a - b).powi(2)))
.sum();
std::mem::swap(&mut centroids, &mut new_centroids);
if shift < effective_tol {
break;
}
}
let mut order: Vec<usize> = (0..n).collect();
order.shuffle(&mut rng);
let labels =
self.constrained_assign(data, ¢roids, &must_links, &cannot_links, &order)?;
Ok(labels
.into_iter()
.map(|l| l.expect("final constrained_assign guarantees all labels are Some"))
.collect())
}
}
#[cfg(test)]
mod autotrait_tests {
use super::*;
fn assert_autotraits<T: Send + Sync + Sized + Unpin>() {}
#[test]
fn cop_kmeans_is_send_sync() {
assert_autotraits::<CopKmeans<SquaredEuclidean>>();
assert_autotraits::<CopKmeans<super::super::distance::Euclidean>>();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn must_link_forces_same_cluster() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let constraints = vec![Constraint::MustLink(0, 1)];
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
assert_eq!(
labels[0], labels[1],
"must-linked points should share a cluster"
);
}
#[test]
fn cannot_link_forces_different_clusters() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let constraints = vec![Constraint::CannotLink(0, 1)];
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
assert_ne!(
labels[0], labels[1],
"cannot-linked points should be in different clusters"
);
}
#[test]
fn must_link_transitive() {
let data = vec![
vec![0.0f32, 0.0],
vec![5.0, 5.0], vec![0.1, 0.1],
vec![10.0, 10.0],
];
let constraints = vec![Constraint::MustLink(0, 1), Constraint::MustLink(1, 2)];
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
}
#[test]
fn infeasible_constraints_return_error() {
let data = vec![vec![0.0f32, 0.0], vec![1.0, 1.0], vec![2.0, 2.0]];
let constraints = vec![
Constraint::CannotLink(0, 1),
Constraint::CannotLink(0, 2),
Constraint::CannotLink(1, 2),
];
let result = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints);
assert!(
result.is_err(),
"infeasible constraints should return error"
);
if let Err(Error::ConstraintViolation(msg)) = result {
assert!(
msg.contains("no valid cluster"),
"error message should mention no valid cluster: {msg}"
);
}
}
#[test]
fn conflicting_must_and_cannot_link_error() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let constraints = vec![Constraint::MustLink(0, 1), Constraint::CannotLink(0, 1)];
let result = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints);
assert!(result.is_err(), "contradictory constraints should fail");
}
#[test]
fn no_constraints_matches_kmeans_structure() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &[])
.unwrap();
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[2], labels[3]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn empty_input_error() {
let result = CopKmeans::new(2)
.with_seed(1)
.fit_predict_constrained(&[] as &[Vec<f32>], &[]);
assert!(result.is_err());
}
#[test]
fn invalid_constraint_index_error() {
let data = vec![vec![0.0f32, 0.0], vec![1.0, 1.0]];
let constraints = vec![Constraint::MustLink(0, 5)];
let result = CopKmeans::new(2)
.with_seed(1)
.fit_predict_constrained(&data, &constraints);
assert!(result.is_err());
}
#[test]
fn with_custom_metric() {
use crate::cluster::distance::Euclidean;
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let constraints = vec![Constraint::MustLink(0, 1), Constraint::CannotLink(0, 2)];
let labels = CopKmeans::with_metric(2, Euclidean)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
assert_eq!(labels[0], labels[1]);
assert_ne!(labels[0], labels[2]);
}
#[test]
fn transitive_contradiction_detected_early() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![0.2, 0.2],
vec![10.0, 10.0],
];
let constraints = vec![
Constraint::MustLink(0, 1),
Constraint::MustLink(1, 2),
Constraint::CannotLink(0, 2),
];
let result = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints);
assert!(result.is_err(), "transitive contradiction should be caught");
if let Err(Error::ConstraintViolation(msg)) = result {
assert!(
msg.contains("must-linked") && msg.contains("cannot-linked"),
"error should mention both constraint types: {msg}"
);
}
}
#[test]
fn monotonicity_same_cluster_must_link() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let labels_none = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &[])
.unwrap();
assert_eq!(labels_none[0], labels_none[1]);
let labels_with = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &[Constraint::MustLink(0, 1)])
.unwrap();
assert_eq!(
labels_none[0] == labels_none[2],
labels_with[0] == labels_with[2]
);
assert_eq!(
labels_none[0] == labels_none[3],
labels_with[0] == labels_with[3]
);
}
#[test]
fn deterministic_with_seed() {
let data = vec![
vec![0.0f32, 0.0],
vec![0.1, 0.1],
vec![5.0, 5.0],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let constraints = vec![Constraint::MustLink(0, 1), Constraint::CannotLink(0, 3)];
let labels1 = CopKmeans::new(2)
.with_seed(99)
.fit_predict_constrained(&data, &constraints)
.unwrap();
let labels2 = CopKmeans::new(2)
.with_seed(99)
.fit_predict_constrained(&data, &constraints)
.unwrap();
assert_eq!(
labels1, labels2,
"same seed should produce identical results"
);
}
#[test]
fn nan_input_rejected() {
let data = vec![vec![0.0, f32::NAN], vec![1.0, 1.0]];
let result = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &[]);
assert!(result.is_err());
}
#[test]
fn inf_input_rejected() {
let data = vec![vec![0.0, 0.0], vec![f32::INFINITY, 1.0]];
let result = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &[]);
assert!(result.is_err());
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn separated_data_and_constraints() -> impl Strategy<Value = (Vec<Vec<f32>>, Vec<Constraint>)> {
Just((
vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![100.0, 100.0],
vec![100.1, 100.1],
],
vec![
Constraint::MustLink(0, 1),
Constraint::MustLink(2, 3),
Constraint::CannotLink(0, 2),
],
))
}
proptest! {
#[test]
fn must_links_satisfied((data, constraints) in separated_data_and_constraints()) {
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
for c in &constraints {
if let Constraint::MustLink(a, b) = c {
prop_assert_eq!(
labels[*a], labels[*b],
"must-link ({}, {}) violated", a, b
);
}
}
}
#[test]
fn cannot_links_satisfied((data, constraints) in separated_data_and_constraints()) {
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
for c in &constraints {
if let Constraint::CannotLink(a, b) = c {
prop_assert_ne!(
labels[*a], labels[*b],
"cannot-link ({}, {}) violated", a, b
);
}
}
}
#[test]
fn labels_in_range((data, constraints) in separated_data_and_constraints()) {
let labels = CopKmeans::new(2)
.with_seed(42)
.fit_predict_constrained(&data, &constraints)
.unwrap();
for (i, &l) in labels.iter().enumerate() {
prop_assert!(l < 2, "point {i}: label {l} >= k=2");
}
}
}
}