use crate::error::{ClusteringError, Result};
use crate::subspace_advanced::lrr::LowRankRepresentation;
pub struct OrderedRobustSC {
pub n_clusters: usize,
pub rank: usize,
pub max_iter: usize,
}
impl OrderedRobustSC {
pub fn new(n_clusters: usize, rank: usize) -> Self {
Self {
n_clusters,
rank,
max_iter: 100,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn fit(&self, data: &[Vec<f64>]) -> Result<Vec<usize>> {
let n = data.len();
if n == 0 {
return Err(ClusteringError::InvalidInput(
"input data must not be empty".to_string(),
));
}
if self.n_clusters > n {
return Err(ClusteringError::InvalidInput(format!(
"n_clusters ({}) exceeds number of data points ({})",
self.n_clusters, n
)));
}
if self.rank == 0 {
return Err(ClusteringError::InvalidInput(
"rank must be at least 1".to_string(),
));
}
let dim = data[0].len();
let lambda = 1.0 / ((n.max(dim) as f64).sqrt() * (self.rank as f64).sqrt());
LowRankRepresentation::new(self.n_clusters, lambda)
.with_max_iter(self.max_iter)
.fit(data)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn two_subspace_data() -> Vec<Vec<f64>> {
vec![
vec![1.0, 0.0, 0.0],
vec![2.0, 0.0, 0.0],
vec![3.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 2.0, 0.0],
vec![0.0, 3.0, 0.0],
]
}
#[test]
fn test_orsc_basic() {
let data = two_subspace_data();
let labels = OrderedRobustSC::new(2, 1)
.fit(&data)
.expect("ORSC fit should succeed");
assert_eq!(labels.len(), 6);
for &l in &labels {
assert!(l < 2);
}
}
#[test]
fn test_orsc_empty_input() {
let data: Vec<Vec<f64>> = vec![];
let err = OrderedRobustSC::new(2, 1).fit(&data);
assert!(err.is_err());
}
#[test]
fn test_orsc_zero_rank() {
let data = two_subspace_data();
let err = OrderedRobustSC::new(2, 0).fit(&data);
assert!(err.is_err());
}
#[test]
fn test_orsc_n_clusters_exceeds_n() {
let data = vec![vec![1.0, 0.0]];
let err = OrderedRobustSC::new(5, 1).fit(&data);
assert!(err.is_err());
}
#[test]
fn test_orsc_various_ranks() {
let data = two_subspace_data();
for rank in 1..=3 {
let labels = OrderedRobustSC::new(2, rank)
.fit(&data)
.expect("ORSC with various ranks should succeed");
assert_eq!(labels.len(), 6);
}
}
}