use ndarray::{Array1, Array2, ArrayView2};
use rand::Rng;
pub mod clam;
pub use clam::{am_assign, am_contract, am_soft_assign, clam_loss};
#[inline]
fn debug_assert_valid_bandwidth(sigma: f64) {
debug_assert!(
sigma.is_finite() && sigma > 0.0,
"sigma must be finite and > 0"
);
}
pub fn rbf(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let sq_dist: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| (xi - yi).powi(2))
.sum();
(-sq_dist / (2.0 * sigma * sigma)).exp()
}
pub fn polynomial(x: &[f64], y: &[f64], degree: u32, gamma: f64, coef0: f64) -> f64 {
let dot: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
(gamma * dot + coef0).powi(degree as i32)
}
pub fn linear(x: &[f64], y: &[f64]) -> f64 {
x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum()
}
pub fn laplacian(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let l1_dist: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| (xi - yi).abs()).sum();
(-l1_dist / sigma).exp()
}
pub fn epanechnikov(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let sq_dist: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| (xi - yi).powi(2))
.sum();
let u_sq = sq_dist / (sigma * sigma);
(1.0 - u_sq).max(0.0)
}
pub fn triangle(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let sq_dist: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| (xi - yi).powi(2))
.sum();
let dist = sq_dist.sqrt();
let u = dist / sigma;
(1.0 - u).max(0.0)
}
pub fn cosine(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let sq_dist: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| (xi - yi).powi(2))
.sum();
let dist = sq_dist.sqrt();
let u = (dist / sigma).min(1.0);
(std::f64::consts::FRAC_PI_2 * u).cos()
}
pub fn quartic(x: &[f64], y: &[f64], sigma: f64) -> f64 {
let k = epanechnikov(x, y, sigma);
k * k
}
pub fn triweight(x: &[f64], y: &[f64], sigma: f64) -> f64 {
let k = epanechnikov(x, y, sigma);
k * k * k
}
pub fn tricube(x: &[f64], y: &[f64], sigma: f64) -> f64 {
debug_assert_valid_bandwidth(sigma);
let sq_dist: f64 = x
.iter()
.zip(y.iter())
.map(|(xi, yi)| (xi - yi).powi(2))
.sum();
let dist = sq_dist.sqrt();
let u = dist / sigma;
let term = (1.0 - u.powi(3)).max(0.0);
term * term * term
}
pub fn kernel_matrix<F>(data: &[Vec<f64>], kernel: F) -> Array2<f64>
where
F: Fn(&[f64], &[f64]) -> f64,
{
let n = data.len();
let mut k = Array2::zeros((n, n));
for i in 0..n {
for j in i..n {
let kij = kernel(&data[i], &data[j]);
k[[i, j]] = kij;
k[[j, i]] = kij; }
}
k
}
pub fn rbf_kernel_matrix_ndarray(points: ArrayView2<'_, f64>, sigma: f64) -> Array2<f64> {
debug_assert!(
sigma.is_finite() && sigma > 0.0,
"sigma must be finite and > 0"
);
let n = points.nrows();
let sigma_sq_2 = 2.0 * sigma * sigma;
let mut sq_norms = Array1::<f64>::zeros(n);
for i in 0..n {
let row = points.row(i);
sq_norms[i] = row.dot(&row);
}
let mut k = points.dot(&points.t());
for i in 0..n {
for j in 0..n {
let dist_sq = (sq_norms[i] + sq_norms[j] - 2.0 * k[[i, j]]).max(0.0);
k[[i, j]] = (-dist_sq / sigma_sq_2).exp();
}
}
k
}
pub fn mmd_biased<F>(x: &[Vec<f64>], y: &[Vec<f64>], kernel: F) -> f64
where
F: Fn(&[f64], &[f64]) -> f64,
{
let nx = x.len() as f64;
let ny = y.len() as f64;
if nx == 0.0 || ny == 0.0 {
return 0.0;
}
let mut kxx = 0.0;
for xi in x {
for xj in x {
kxx += kernel(xi, xj);
}
}
kxx /= nx * nx;
let mut kyy = 0.0;
for yi in y {
for yj in y {
kyy += kernel(yi, yj);
}
}
kyy /= ny * ny;
let mut kxy = 0.0;
for xi in x {
for yj in y {
kxy += kernel(xi, yj);
}
}
kxy /= nx * ny;
(kxx + kyy - 2.0 * kxy).max(0.0)
}
pub fn mmd_unbiased<F>(x: &[Vec<f64>], y: &[Vec<f64>], kernel: F) -> f64
where
F: Fn(&[f64], &[f64]) -> f64,
{
let m = x.len();
let n = y.len();
if m < 2 || n < 2 {
return 0.0;
}
let mut kxx = 0.0;
for (i, xi) in x.iter().enumerate() {
for (j, xj) in x.iter().enumerate() {
if i != j {
kxx += kernel(xi, xj);
}
}
}
kxx /= (m * (m - 1)) as f64;
let mut kyy = 0.0;
for (i, yi) in y.iter().enumerate() {
for (j, yj) in y.iter().enumerate() {
if i != j {
kyy += kernel(yi, yj);
}
}
}
kyy /= (n * (n - 1)) as f64;
let mut kxy = 0.0;
for xi in x {
for yj in y {
kxy += kernel(xi, yj);
}
}
kxy /= (m * n) as f64;
kxx + kyy - 2.0 * kxy
}
pub fn mmd_permutation_test<F>(
x: &[Vec<f64>],
y: &[Vec<f64>],
kernel: F,
num_permutations: usize,
) -> (f64, f64)
where
F: Fn(&[f64], &[f64]) -> f64 + Copy,
{
let observed_mmd = mmd_unbiased(x, y, kernel);
let mut pooled: Vec<&Vec<f64>> = x.iter().chain(y.iter()).collect();
let nx = x.len();
let mut rng = rand::rng();
let mut count_greater = 0usize;
for _ in 0..num_permutations {
for i in (1..pooled.len()).rev() {
let j = rng.random_range(0..=i);
pooled.swap(i, j);
}
let x_perm: Vec<Vec<f64>> = pooled[..nx].iter().map(|v| (*v).clone()).collect();
let y_perm: Vec<Vec<f64>> = pooled[nx..].iter().map(|v| (*v).clone()).collect();
let perm_mmd = mmd_unbiased(&x_perm, &y_perm, kernel);
if perm_mmd >= observed_mmd {
count_greater += 1;
}
}
let p_value = (count_greater as f64 + 1.0) / (num_permutations as f64 + 1.0);
(observed_mmd, p_value)
}
pub fn median_bandwidth(data: &[Vec<f64>]) -> f64 {
let n = data.len();
if n < 2 {
return 1.0;
}
let mut distances = Vec::new();
for i in 0..n {
for j in (i + 1)..n {
let sq_dist: f64 = data[i]
.iter()
.zip(data[j].iter())
.map(|(xi, xj)| (xi - xj).powi(2))
.sum();
distances.push(sq_dist.sqrt());
}
}
distances.sort_by(|a, b| a.total_cmp(b));
let median = distances[distances.len() / 2];
median / (2.0_f64).sqrt()
}
pub fn kernel_sum<F>(v: &[f64], memories: &[Vec<f64>], kernel: F) -> f64
where
F: Fn(&[f64], &[f64]) -> f64,
{
memories.iter().map(|xi| kernel(v, xi)).sum()
}
pub fn energy_lse(v: &[f64], memories: &[Vec<f64>], beta: f64) -> f64 {
if memories.is_empty() {
return 0.0;
}
let neg_half_beta = -0.5 * beta;
let log_terms: Vec<f64> = memories
.iter()
.map(|xi| {
let sq_dist: f64 = v
.iter()
.zip(xi.iter())
.map(|(vi, xii)| (vi - xii).powi(2))
.sum();
neg_half_beta * sq_dist
})
.collect();
let max_term = log_terms.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
if max_term.is_infinite() {
return f64::INFINITY;
}
let sum_exp: f64 = log_terms.iter().map(|&t| (t - max_term).exp()).sum();
-(max_term + sum_exp.ln())
}
pub fn energy_lse_grad(v: &[f64], memories: &[Vec<f64>], beta: f64) -> Vec<f64> {
if memories.is_empty() || v.is_empty() {
return vec![0.0; v.len()];
}
let d = v.len();
let neg_half_beta = -0.5 * beta;
let sq_dists: Vec<f64> = memories
.iter()
.map(|xi| {
v.iter()
.zip(xi.iter())
.map(|(vi, xii)| (vi - xii).powi(2))
.sum()
})
.collect();
let log_weights: Vec<f64> = sq_dists.iter().map(|&d| neg_half_beta * d).collect();
let max_log = log_weights
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
let exp_weights: Vec<f64> = log_weights.iter().map(|&w| (w - max_log).exp()).collect();
let sum_exp: f64 = exp_weights.iter().sum();
let softmax_weights: Vec<f64> = exp_weights.iter().map(|&w| w / sum_exp).collect();
let mut grad = vec![0.0; d];
for (mu, xi) in memories.iter().enumerate() {
let w = softmax_weights[mu];
for (i, (vi, xii)) in v.iter().zip(xi.iter()).enumerate() {
grad[i] += w * (vi - xii);
}
}
for g in &mut grad {
*g *= beta;
}
grad
}
pub fn energy_lsr(v: &[f64], memories: &[Vec<f64>], beta: f64) -> f64 {
if memories.is_empty() {
return 0.0;
}
let half_beta = 0.5 * beta;
let sum: f64 = memories
.iter()
.map(|xi| {
let sq_dist: f64 = v
.iter()
.zip(xi.iter())
.map(|(vi, xii)| (vi - xii).powi(2))
.sum();
(1.0 - half_beta * sq_dist).max(0.0) })
.sum();
if sum <= 0.0 {
f64::INFINITY } else {
-sum.ln()
}
}
pub fn energy_lsr_grad(v: &[f64], memories: &[Vec<f64>], beta: f64) -> Vec<f64> {
if memories.is_empty() || v.is_empty() {
return vec![0.0; v.len()];
}
let d = v.len();
let half_beta = 0.5 * beta;
let kernel_vals: Vec<f64> = memories
.iter()
.map(|xi| {
let sq_dist: f64 = v
.iter()
.zip(xi.iter())
.map(|(vi, xii)| (vi - xii).powi(2))
.sum();
(1.0 - half_beta * sq_dist).max(0.0)
})
.collect();
let sum: f64 = kernel_vals.iter().sum();
if sum <= 0.0 {
return vec![0.0; d]; }
let mut grad = vec![0.0; d];
for (mu, xi) in memories.iter().enumerate() {
if kernel_vals[mu] > 0.0 {
for (i, (vi, xii)) in v.iter().zip(xi.iter()).enumerate() {
grad[i] += vi - xii;
}
}
}
let scale = beta / sum;
for g in &mut grad {
*g *= scale;
}
grad
}
fn energy_descent_step(v: &mut [f64], grad: &[f64], learning_rate: f64) {
for (vi, gi) in v.iter_mut().zip(grad.iter()) {
*vi -= learning_rate * gi;
}
}
pub fn retrieve_memory<F>(
query: Vec<f64>,
memories: &[Vec<f64>],
energy_grad: F,
learning_rate: f64,
max_iters: usize,
tolerance: f64,
) -> (Vec<f64>, usize)
where
F: Fn(&[f64], &[Vec<f64>]) -> Vec<f64>,
{
let mut v = query;
for iter in 0..max_iters {
let grad = energy_grad(&v, memories);
let grad_norm: f64 = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
if grad_norm < tolerance {
return (v, iter);
}
energy_descent_step(&mut v, &grad, learning_rate);
}
(v, max_iters)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rbf_self() {
let x = vec![1.0, 2.0, 3.0];
let k = rbf(&x, &x, 1.0);
assert!((k - 1.0).abs() < 1e-10, "k(x, x) should be 1 for RBF");
}
#[test]
fn test_rbf_distant() {
let x = vec![0.0, 0.0];
let y = vec![100.0, 100.0];
let k = rbf(&x, &y, 1.0);
assert!(k < 1e-10, "distant points should have ~0 similarity");
}
#[test]
fn test_polynomial() {
let x = vec![1.0, 2.0];
let y = vec![3.0, 4.0];
let k = polynomial(&x, &y, 2, 1.0, 1.0);
assert!((k - 144.0).abs() < 1e-10);
}
#[test]
fn test_mmd_same_distribution() {
let x = vec![vec![0.0], vec![0.1], vec![0.2]];
let y = vec![vec![0.05], vec![0.15], vec![0.25]];
let mmd = mmd_unbiased(&x, &y, |a, b| rbf(a, b, 1.0));
assert!(mmd < 0.1, "same distribution should have small MMD");
}
#[test]
fn test_mmd_different_distributions() {
let x = vec![vec![0.0], vec![0.1], vec![0.2]];
let y = vec![vec![10.0], vec![10.1], vec![10.2]];
let mmd = mmd_unbiased(&x, &y, |a, b| rbf(a, b, 1.0));
assert!(mmd > 0.5, "different distributions should have large MMD");
}
#[test]
fn test_mmd_non_negative() {
let x = vec![vec![0.0], vec![0.1], vec![0.2], vec![0.3]];
let y = vec![vec![10.0], vec![10.1], vec![10.2], vec![10.3]];
let mmd = mmd_unbiased(&x, &y, |a, b| rbf(a, b, 1.0));
assert!(
mmd >= 0.0,
"MMD should be non-negative for different distributions"
);
}
#[test]
fn test_kernel_matrix_symmetric() {
let data = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let k = kernel_matrix(&data, |x, y| rbf(x, y, 1.0));
for i in 0..3 {
for j in 0..3 {
assert!(
(k[[i, j]] - k[[j, i]]).abs() < 1e-10,
"kernel matrix should be symmetric"
);
}
}
}
#[test]
fn test_median_bandwidth_positive() {
let data = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let sigma = median_bandwidth(&data);
assert!(sigma > 0.0, "bandwidth should be positive");
}
#[test]
fn test_kernel_sum() {
let v = vec![0.0, 0.0];
let memories = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let sum = kernel_sum(&v, &memories, |a, b| rbf(a, b, 1.0));
assert!((sum - 1.0).abs() < 0.01);
}
#[test]
fn test_energy_lse_at_memory() {
let memories = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let e1 = energy_lse(&[0.0, 0.0], &memories, 1.0);
let e2 = energy_lse(&[5.0, 5.0], &memories, 1.0);
assert!(e1 < e2, "energy should be lower at stored memory");
}
#[test]
fn test_energy_lse_grad_points_toward_memory() {
let memories = vec![vec![0.0, 0.0]];
let v = vec![1.0, 0.0];
let grad = energy_lse_grad(&v, &memories, 2.0);
assert!(grad[0] > 0.0, "gradient should point away from memory");
}
#[test]
fn test_energy_lsr_finite_at_memory() {
let memories = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let e = energy_lsr(&[0.0, 0.0], &memories, 1.0);
assert!(e.is_finite(), "energy should be finite at memory");
}
#[test]
fn test_energy_lsr_infinite_outside_support() {
let memories = vec![vec![0.0, 0.0]];
let e = energy_lsr(&[100.0, 100.0], &memories, 1.0);
assert!(e.is_infinite(), "energy should be infinite outside support");
}
#[test]
fn test_retrieve_memory_lse() {
let memories = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let query = vec![1.0, 1.0];
let (retrieved, _iters) = retrieve_memory(
query,
&memories,
|v, m| energy_lse_grad(v, m, 2.0),
0.1,
100,
1e-6,
);
let dist_to_first: f64 = retrieved.iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(
dist_to_first < 2.0,
"should retrieve near first memory, got dist {}",
dist_to_first
);
}
#[test]
fn test_energy_lsr_single_step_retrieval() {
let memories = vec![vec![0.0, 0.0], vec![10.0, 10.0]];
let beta = 2.0;
let query = vec![0.1, 0.0];
let grad = energy_lsr_grad(&query, &memories, beta);
assert!(grad[0] > 0.0);
assert!((grad[1] - 0.0).abs() < 1e-10);
}
#[test]
fn test_epanechnikov_compact_support() {
let x = vec![0.0, 0.0];
let y_close = vec![0.5, 0.0];
let y_far = vec![2.0, 0.0];
let k_close = epanechnikov(&x, &y_close, 1.0);
assert!(k_close > 0.0, "should be positive inside support");
let k_far = epanechnikov(&x, &y_far, 1.0);
assert!(
(k_far - 0.0).abs() < 1e-10,
"should be zero outside support"
);
}
#[test]
fn test_triangle_linear_decay() {
let x = vec![0.0];
let sigma = 1.0;
assert!((triangle(&x, &[0.0], sigma) - 1.0).abs() < 1e-10);
assert!((triangle(&x, &[0.5], sigma) - 0.5).abs() < 1e-10);
assert!((triangle(&x, &[1.0], sigma) - 0.0).abs() < 1e-10);
}
}