pub fn kernel_thin(gram: &[f64], n: usize, k: usize) -> Vec<usize> {
assert!(k <= n, "k ({k}) must be <= n ({n})");
assert_eq!(gram.len(), n * n, "gram must be n*n");
if k == 0 {
return Vec::new();
}
let mut col_mean = vec![0.0; n];
for j in 0..n {
let mut s = 0.0;
for i in 0..n {
s += gram[i * n + j];
}
col_mean[j] = s / n as f64;
}
let mut selected = Vec::with_capacity(k);
let mut in_set = vec![false; n];
let mut sum_cross = vec![0.0; n];
let mut sum_within = 0.0;
let mut cross_mean_sum = 0.0;
for step in 0..k {
let s = step; let s_new = (s + 1) as f64;
let s_new_sq = s_new * s_new;
let mut best_idx = usize::MAX;
let mut best_obj = f64::INFINITY;
for c in 0..n {
if in_set[c] {
continue;
}
let new_within = sum_within + 2.0 * sum_cross[c] + gram[c * n + c];
let new_cross_mean = cross_mean_sum + col_mean[c];
let obj = new_within / s_new_sq - 2.0 * new_cross_mean / s_new;
if obj < best_obj {
best_obj = obj;
best_idx = c;
}
}
selected.push(best_idx);
in_set[best_idx] = true;
for c in 0..n {
sum_cross[c] += gram[best_idx * n + c];
}
sum_within += 2.0 * (sum_cross[best_idx] - gram[best_idx * n + best_idx])
+ gram[best_idx * n + best_idx];
cross_mean_sum += col_mean[best_idx];
}
selected.clear();
in_set.fill(false);
let mut sum_cross = vec![0.0; n];
let mut sum_within = 0.0;
let mut cross_mean_sum = 0.0;
for step in 0..k {
let s_new = (step + 1) as f64;
let s_new_sq = s_new * s_new;
let mut best_idx = usize::MAX;
let mut best_obj = f64::INFINITY;
for c in 0..n {
if in_set[c] {
continue;
}
let new_within = sum_within + 2.0 * sum_cross[c] + gram[c * n + c];
let new_cross_mean = cross_mean_sum + col_mean[c];
let obj = new_within / s_new_sq - 2.0 * new_cross_mean / s_new;
if obj < best_obj {
best_obj = obj;
best_idx = c;
}
}
selected.push(best_idx);
in_set[best_idx] = true;
sum_within += 2.0 * sum_cross[best_idx] + gram[best_idx * n + best_idx];
cross_mean_sum += col_mean[best_idx];
for c in 0..n {
sum_cross[c] += gram[best_idx * n + c];
}
}
selected
}
pub fn kernel_herd(gram: &[f64], n: usize, k: usize) -> Vec<usize> {
assert!(n > 0, "n must be > 0");
assert_eq!(gram.len(), n * n, "gram must be n*n");
if k == 0 {
return Vec::new();
}
let mut mu = vec![0.0; n];
for j in 0..n {
let mut s = 0.0;
for i in 0..n {
s += gram[i * n + j];
}
mu[j] = s / n as f64;
}
let mut selected = Vec::with_capacity(k);
let mut sum_kernel = vec![0.0; n];
for step in 0..k {
let t = (step + 1) as f64;
let mut best_idx = 0;
let mut best_val = f64::NEG_INFINITY;
for j in 0..n {
let val = mu[j] - sum_kernel[j] / t;
if val > best_val {
best_val = val;
best_idx = j;
}
}
selected.push(best_idx);
for j in 0..n {
sum_kernel[j] += gram[best_idx * n + j];
}
}
selected
}
pub fn mmd_sq_from_gram(gram: &[f64], n: usize, subset: &[usize]) -> f64 {
let m = subset.len();
if m == 0 {
return 0.0;
}
let mf = m as f64;
let nf = n as f64;
let mut kss = 0.0;
for &i in subset {
for &j in subset {
kss += gram[i * n + j];
}
}
kss /= mf * mf;
let mut ksx = 0.0;
for &i in subset {
for j in 0..n {
ksx += gram[i * n + j];
}
}
ksx = 2.0 * ksx / (mf * nf);
let mut kxx = 0.0;
for i in 0..n {
for j in 0..n {
kxx += gram[i * n + j];
}
}
kxx /= nf * nf;
kss - ksx + kxx
}
#[cfg(test)]
mod tests {
use super::*;
fn simple_gram(n: usize) -> Vec<f64> {
let sigma = (n as f64) / 2.0;
let mut g = vec![0.0; n * n];
for i in 0..n {
for j in 0..n {
let d = (i as f64 - j as f64).powi(2);
g[i * n + j] = (-d / (2.0 * sigma * sigma)).exp();
}
}
g
}
#[test]
fn thin_indices_unique_and_bounded() {
let n = 20;
let k = 5;
let gram = simple_gram(n);
let sel = kernel_thin(&gram, n, k);
assert_eq!(sel.len(), k);
for &idx in &sel {
assert!(idx < n);
}
let mut sorted = sel.clone();
sorted.sort();
sorted.dedup();
assert_eq!(sorted.len(), k);
}
#[test]
fn thin_k_equals_n() {
let n = 8;
let gram = simple_gram(n);
let sel = kernel_thin(&gram, n, n);
assert_eq!(sel.len(), n);
let mut sorted = sel.clone();
sorted.sort();
assert_eq!(sorted, (0..n).collect::<Vec<_>>());
}
#[test]
fn thin_k_zero() {
let gram = simple_gram(5);
let sel = kernel_thin(&gram, 5, 0);
assert!(sel.is_empty());
}
#[test]
fn thin_beats_endpoints() {
let n = 30;
let k = 5;
let gram = simple_gram(n);
let thinned = kernel_thin(&gram, n, k);
let first_k: Vec<usize> = (0..k).collect();
let mmd_thin = mmd_sq_from_gram(&gram, n, &thinned);
let mmd_first = mmd_sq_from_gram(&gram, n, &first_k);
assert!(
mmd_thin <= mmd_first + 1e-12,
"thinned MMD^2 ({mmd_thin}) should be <= first-k MMD^2 ({mmd_first})"
);
}
#[test]
fn thin_k1_picks_closest_to_mean() {
let n = 11; let gram = simple_gram(n);
let sel = kernel_thin(&gram, n, 1);
assert_eq!(sel.len(), 1);
assert_eq!(sel[0], 5, "k=1 should select the center point (index 5)");
}
#[test]
fn herd_correct_length() {
let n = 10;
let k = 7;
let gram = simple_gram(n);
let sel = kernel_herd(&gram, n, k);
assert_eq!(sel.len(), k);
for &idx in &sel {
assert!(idx < n);
}
}
#[test]
fn herd_allows_duplicates_when_needed() {
let n = 3;
let k = 6;
let gram = simple_gram(n);
let sel = kernel_herd(&gram, n, k);
assert_eq!(sel.len(), k);
}
#[test]
fn herd_beats_single_point() {
let n = 20;
let k = 5;
let gram = simple_gram(n);
let herded = kernel_herd(&gram, n, k);
let mut unique_herded: Vec<usize> = herded.clone();
unique_herded.sort();
unique_herded.dedup();
let single_point = vec![herded[0]];
if unique_herded.len() > 1 {
let mmd_herd = mmd_sq_from_gram(&gram, n, &unique_herded);
let mmd_single = mmd_sq_from_gram(&gram, n, &single_point);
assert!(
mmd_herd <= mmd_single + 1e-12,
"herded MMD^2 ({mmd_herd}) should be <= single-point MMD^2 ({mmd_single})"
);
}
}
#[test]
fn mmd_sq_full_set_is_zero() {
let n = 10;
let gram = simple_gram(n);
let all: Vec<usize> = (0..n).collect();
let mmd = mmd_sq_from_gram(&gram, n, &all);
assert!(mmd.abs() < 1e-12, "MMD^2(X, X) should be 0, got {mmd}");
}
}