use crate::error::{ClusteringError, Result};
use crate::subspace_advanced::ssc::spectral_cluster_normalized;
pub struct LowRankRepresentation {
pub n_clusters: usize,
pub lambda: f64,
pub max_iter: usize,
pub tol: f64,
}
impl LowRankRepresentation {
pub fn new(n_clusters: usize, lambda: f64) -> Self {
Self {
n_clusters,
lambda,
max_iter: 100,
tol: 1e-6,
}
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn fit(&self, data: &[Vec<f64>]) -> Result<Vec<usize>> {
let n = data.len();
if n == 0 {
return Err(ClusteringError::InvalidInput(
"input data must not be empty".to_string(),
));
}
if self.n_clusters > n {
return Err(ClusteringError::InvalidInput(format!(
"n_clusters ({}) exceeds number of data points ({})",
self.n_clusters, n
)));
}
let z = self.compute_low_rank_representation(data)?;
let mut affinity = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in 0..n {
affinity[i][j] = (z[i][j].abs() + z[j][i].abs()) / 2.0;
}
}
spectral_cluster_normalized(&affinity, n, self.n_clusters)
}
fn compute_low_rank_representation(&self, data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = data.len();
let g = gram_matrix(data);
let rho = 1.0;
let mut z = vec![vec![0.0f64; n]; n];
let mut a = vec![vec![0.0f64; n]; n];
let mut y = vec![vec![0.0f64; n]; n];
for _iter in 0..self.max_iter {
let z_prev = z.clone();
let thresh = 1.0 / rho;
for i in 0..n {
for j in 0..n {
let arg = g[i][j] + rho * a[i][j] - y[i][j];
z[i][j] = soft_threshold_scalar(arg, thresh);
}
}
let row_thresh = self.lambda / rho;
for i in 0..n {
let mut tmp_row: Vec<f64> = (0..n).map(|j| z[i][j] + y[i][j] / rho).collect();
let row_norm: f64 = tmp_row.iter().map(|x| x * x).sum::<f64>().sqrt();
if row_norm > row_thresh {
let scale = 1.0 - row_thresh / row_norm;
for j in 0..n {
a[i][j] = scale * tmp_row[j];
}
} else {
for j in 0..n {
a[i][j] = 0.0;
}
}
}
for i in 0..n {
for j in 0..n {
y[i][j] += rho * (z[i][j] - a[i][j]);
}
}
let primal_res: f64 = z
.iter()
.zip(z_prev.iter())
.flat_map(|(r, rp)| r.iter().zip(rp.iter()).map(|(a, b)| (a - b).powi(2)))
.sum::<f64>()
.sqrt();
if primal_res < self.tol {
break;
}
}
Ok(z)
}
}
fn gram_matrix(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = data.len();
let mut g = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in i..n {
let dot: f64 = data[i]
.iter()
.zip(data[j].iter())
.map(|(a, b)| a * b)
.sum();
g[i][j] = dot;
g[j][i] = dot;
}
}
g
}
#[inline]
fn soft_threshold_scalar(z: f64, thresh: f64) -> f64 {
if z > thresh {
z - thresh
} else if z < -thresh {
z + thresh
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
fn two_subspace_data() -> Vec<Vec<f64>> {
vec![
vec![1.0, 0.0, 0.0],
vec![2.0, 0.0, 0.0],
vec![3.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 2.0, 0.0],
vec![0.0, 3.0, 0.0],
]
}
#[test]
fn test_lrr_basic() {
let data = two_subspace_data();
let labels = LowRankRepresentation::new(2, 0.1)
.fit(&data)
.expect("LRR fit should succeed");
assert_eq!(labels.len(), 6);
for &l in &labels {
assert!(l < 2);
}
}
#[test]
fn test_lrr_empty_input() {
let data: Vec<Vec<f64>> = vec![];
let err = LowRankRepresentation::new(2, 0.1).fit(&data);
assert!(err.is_err());
}
#[test]
fn test_lrr_n_clusters_exceeds_n() {
let data = vec![vec![1.0, 0.0]];
let err = LowRankRepresentation::new(5, 0.1).fit(&data);
assert!(err.is_err());
}
#[test]
fn test_gram_matrix() {
let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let g = gram_matrix(&data);
assert!((g[0][0] - 1.0).abs() < 1e-10);
assert!((g[0][1] - 0.0).abs() < 1e-10);
assert!((g[0][2] - 1.0).abs() < 1e-10);
assert!((g[1][2] - 1.0).abs() < 1e-10);
assert!((g[2][2] - 2.0).abs() < 1e-10);
}
#[test]
fn test_soft_threshold_scalar() {
assert!((soft_threshold_scalar(0.5, 0.3) - 0.2).abs() < 1e-10);
assert!((soft_threshold_scalar(-0.5, 0.3) + 0.2).abs() < 1e-10);
assert_eq!(soft_threshold_scalar(0.1, 0.3), 0.0);
}
}