use scirs2_core::ndarray::{Array2};
use scirs2_linalg::inv;
use crate::error::{Result, TransformError};
#[derive(Debug, Clone)]
pub struct MahalanobisDistance {
pub m: Vec<Vec<f64>>,
}
impl MahalanobisDistance {
pub fn from_matrix(m: Vec<Vec<f64>>) -> Result<Self> {
let d = m.len();
for (i, row) in m.iter().enumerate() {
if row.len() != d {
return Err(TransformError::InvalidInput(format!(
"Row {i} has {} cols but expected {d}",
row.len()
)));
}
}
Ok(MahalanobisDistance { m })
}
pub fn identity(d: usize) -> Self {
let m: Vec<Vec<f64>> = (0..d)
.map(|i| (0..d).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
.collect();
MahalanobisDistance { m }
}
pub fn dim(&self) -> usize {
self.m.len()
}
pub fn dist_sq(&self, a: &[f64], b: &[f64]) -> Result<f64> {
let d = self.m.len();
if a.len() != d || b.len() != d {
return Err(TransformError::InvalidInput(format!(
"Vectors must have length {d}, got {} and {}",
a.len(),
b.len()
)));
}
let diff: Vec<f64> = a.iter().zip(b.iter()).map(|(ai, bi)| ai - bi).collect();
let mut md = vec![0.0f64; d];
for i in 0..d {
for j in 0..d {
md[i] += self.m[i][j] * diff[j];
}
}
let sq: f64 = diff.iter().zip(md.iter()).map(|(di, mdi)| di * mdi).sum();
Ok(sq.max(0.0))
}
pub fn dist(&self, a: &[f64], b: &[f64]) -> Result<f64> {
Ok(self.dist_sq(a, b)?.sqrt())
}
pub fn transform_point(&self, x: &[f64]) -> Result<Vec<f64>> {
let d = self.m.len();
if x.len() != d {
return Err(TransformError::InvalidInput(format!(
"Expected length {d}, got {}",
x.len()
)));
}
let mut out = vec![0.0f64; d];
for i in 0..d {
for j in 0..d {
out[i] += self.m[i][j] * x[j];
}
}
Ok(out)
}
pub fn pairwise_distances(&self, data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = data.len();
let mut dists = vec![vec![0.0f64; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let d = self.dist(&data[i], &data[j])?;
dists[i][j] = d;
dists[j][i] = d;
}
}
Ok(dists)
}
}
#[derive(Debug, Clone)]
pub struct ITMLConstraint {
pub i: usize,
pub j: usize,
pub bound: f64,
pub similar: bool,
}
impl ITMLConstraint {
pub fn similar(i: usize, j: usize, upper: f64) -> Self {
ITMLConstraint { i, j, bound: upper, similar: true }
}
pub fn dissimilar(i: usize, j: usize, lower: f64) -> Self {
ITMLConstraint { i, j, bound: lower, similar: false }
}
}
#[derive(Debug, Clone)]
pub struct ITML {
pub max_iter: usize,
pub gamma: f64,
pub tol: f64,
}
#[derive(Debug, Clone)]
pub struct ITMLModel {
pub metric: Array2<f64>,
pub dual_vars: Vec<f64>,
pub n_iter: usize,
}
impl Default for ITML {
fn default() -> Self {
ITML { max_iter: 100, gamma: 1.0, tol: 1e-3 }
}
}
impl ITML {
pub fn new(max_iter: usize, gamma: f64, tol: f64) -> Self {
ITML { max_iter, gamma, tol }
}
pub fn fit(&self, x: &[Vec<f64>], constraints: &[ITMLConstraint]) -> Result<ITMLModel> {
let n = x.len();
if n == 0 {
return Err(TransformError::InvalidInput("Empty dataset".to_string()));
}
let d = x[0].len();
if d == 0 {
return Err(TransformError::InvalidInput(
"Feature dimension must be > 0".to_string(),
));
}
for (k, c) in constraints.iter().enumerate() {
if c.i >= n || c.j >= n {
return Err(TransformError::InvalidInput(format!(
"Constraint {k}: indices ({}, {}) out of range for n={n}",
c.i, c.j
)));
}
}
let nc = constraints.len();
if nc == 0 {
let prior_val = 1.0 / self.gamma.max(1e-12);
let mut m = Array2::<f64>::zeros((d, d));
for i in 0..d {
m[[i, i]] = prior_val;
}
return Ok(ITMLModel { metric: m, dual_vars: vec![], n_iter: 0 });
}
let rho: Vec<f64> = constraints.iter().map(|c| c.bound * c.bound).collect();
let prior_val = 1.0 / self.gamma.max(1e-12);
let mut m = Array2::<f64>::zeros((d, d));
for i in 0..d {
m[[i, i]] = prior_val;
}
let mut lambda = vec![0.0f64; nc];
let mut n_iter = 0usize;
for _outer in 0..self.max_iter {
let mut max_violation = 0.0f64;
n_iter += 1;
for (ci, con) in constraints.iter().enumerate() {
let xi = &x[con.i];
let xj = &x[con.j];
let diff: Vec<f64> = xi.iter().zip(xj.iter()).map(|(a, b)| a - b).collect();
let mut mdiff = vec![0.0f64; d];
for a in 0..d {
for b in 0..d {
mdiff[a] += m[[a, b]] * diff[b];
}
}
let alpha_c: f64 = diff.iter().zip(mdiff.iter()).map(|(di, mdi)| di * mdi).sum();
let alpha_c = alpha_c.max(0.0);
let inv_gamma = 1.0 / self.gamma.max(1e-12);
let denominator = rho[ci] * (alpha_c + inv_gamma);
if denominator.abs() < 1e-15 {
continue;
}
let delta = (rho[ci] - alpha_c) / denominator;
let lambda_new = if con.similar {
(lambda[ci] + delta).max(0.0)
} else {
(lambda[ci] + delta).min(0.0)
};
let actual_delta = lambda_new - lambda[ci];
lambda[ci] = lambda_new;
let denom = 1.0 + actual_delta * alpha_c;
if denom.abs() < 1e-15 {
continue;
}
let scale = actual_delta / denom;
for a in 0..d {
for b in 0..d {
m[[a, b]] += scale * mdiff[a] * mdiff[b];
}
}
let violation = (alpha_c - rho[ci]).abs() / rho[ci].max(1e-10);
if violation > max_violation {
max_violation = violation;
}
}
if max_violation < self.tol {
break;
}
}
Ok(ITMLModel {
metric: m,
dual_vars: lambda,
n_iter,
})
}
}
impl ITMLModel {
pub fn mahalanobis(&self) -> Result<MahalanobisDistance> {
let d = self.metric.nrows();
let m: Vec<Vec<f64>> = (0..d)
.map(|i| (0..d).map(|j| self.metric[[i, j]]).collect())
.collect();
MahalanobisDistance::from_matrix(m)
}
pub fn transform(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let d = self.metric.nrows();
let n = x.len();
let mut out = vec![vec![0.0f64; d]; n];
for (i, row) in x.iter().enumerate() {
if row.len() != d {
return Err(TransformError::InvalidInput(format!(
"Row {i}: expected {d} features, got {}",
row.len()
)));
}
for a in 0..d {
let s: f64 = (0..d).map(|b| self.metric[[a, b]] * row[b]).sum();
out[i][a] = s;
}
}
Ok(out)
}
pub fn pairwise_distances(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let mah = self.mahalanobis()?;
mah.pairwise_distances(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mahalanobis_identity() {
let mah = MahalanobisDistance::identity(3);
let a = [1.0, 0.0, 0.0];
let b = [0.0, 0.0, 0.0];
let d = mah.dist(&a, &b).expect("dist");
assert!((d - 1.0).abs() < 1e-10);
}
#[test]
fn test_mahalanobis_scaled() {
let m = vec![
vec![4.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let mah = MahalanobisDistance::from_matrix(m).expect("from_matrix");
let a = [1.0, 0.0, 0.0];
let b = [0.0, 0.0, 0.0];
let d = mah.dist(&a, &b).expect("dist");
assert!((d - 2.0).abs() < 1e-10);
}
#[test]
fn test_mahalanobis_pairwise() {
let mah = MahalanobisDistance::identity(2);
let data = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let dists = mah.pairwise_distances(&data).expect("pairwise");
assert_eq!(dists.len(), 3);
assert!((dists[0][1] - 1.0).abs() < 1e-10);
assert!((dists[0][2] - 1.0).abs() < 1e-10);
assert!((dists[1][2] - 2.0_f64.sqrt()).abs() < 1e-9);
}
#[test]
fn test_itml_basic() {
let x: Vec<Vec<f64>> = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![5.0, 5.0],
vec![5.1, 5.2],
];
let constraints = vec![
ITMLConstraint::similar(0, 1, 1.0),
ITMLConstraint::dissimilar(0, 2, 2.0),
ITMLConstraint::similar(2, 3, 1.0),
];
let itml = ITML::new(50, 1.0, 1e-3);
let model = itml.fit(&x, &constraints).expect("ITML fit");
assert!(model.n_iter > 0);
assert_eq!(model.metric.nrows(), 2);
let mah = model.mahalanobis().expect("mahalanobis");
let d_sim = mah.dist(&x[0], &x[1]).expect("d_sim");
let d_dis = mah.dist(&x[0], &x[2]).expect("d_dis");
assert!(d_sim < d_dis, "similar pair {d_sim:.4} should be < dissimilar {d_dis:.4}");
}
#[test]
fn test_itml_no_constraints() {
let x = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let itml = ITML::default();
let model = itml.fit(&x, &[]).expect("ITML no constraints");
assert!(model.metric[[0, 0]] > 0.0);
assert!(model.metric[[1, 1]] > 0.0);
}
#[test]
fn test_mahalanobis_dimension_error() {
let mah = MahalanobisDistance::identity(3);
let a = [1.0, 2.0]; let b = [0.0, 0.0, 0.0];
assert!(mah.dist(&a, &b).is_err());
}
}