use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_linalg::{eigh, solve};
use crate::error::{Result, TransformError};
pub fn mahalanobis_sq(x: &ArrayView1<f64>, y: &ArrayView1<f64>, m: &ArrayView2<f64>) -> f64 {
let diff: Array1<f64> = x.iter().zip(y.iter()).map(|(a, b)| a - b).collect();
let md: f64 = m.outer_iter()
.zip(diff.iter())
.map(|(row, &di)| {
let mv_i: f64 = row.iter().zip(diff.iter()).map(|(mij, &dj)| mij * dj).sum();
di * mv_i
})
.sum();
md.max(0.0)
}
pub fn mahalanobis(x: &ArrayView1<f64>, y: &ArrayView1<f64>, m: &ArrayView2<f64>) -> f64 {
mahalanobis_sq(x, y, m).sqrt()
}
pub fn pairwise_mahalanobis(x: &ArrayView2<f64>, m: &ArrayView2<f64>) -> Array2<f64> {
let n = x.nrows();
let mut dist = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in (i + 1)..n {
let d = mahalanobis(&x.row(i), &x.row(j), m);
dist[[i, j]] = d;
dist[[j, i]] = d;
}
}
dist
}
pub fn transform_with_factor(x: &ArrayView2<f64>, l: &ArrayView2<f64>) -> Result<Array2<f64>> {
let n = x.nrows();
let d = x.ncols();
let e = l.ncols();
if e != d {
return Err(TransformError::InvalidInput(format!(
"transform_with_factor: L has {} cols but X has {} features",
e, d
)));
}
let mut out = Array2::<f64>::zeros((n, l.nrows()));
for i in 0..n {
for k in 0..l.nrows() {
let mut s = 0.0f64;
for j in 0..d {
s += l[[k, j]] * x[[i, j]];
}
out[[i, k]] = s;
}
}
Ok(out)
}
pub fn knn_indices_mahalanobis(
x: &ArrayView2<f64>,
m: &ArrayView2<f64>,
k: usize,
) -> Vec<Vec<usize>> {
let n = x.nrows();
(0..n).map(|i| {
let mut dists: Vec<(usize, f64)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let d = mahalanobis_sq(&x.row(i), &x.row(j), m);
(j, d)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
dists.truncate(k);
dists.into_iter().map(|(idx, _)| idx).collect()
}).collect()
}
#[derive(Debug, Clone)]
pub struct MetricLearningResult {
pub metric: Array2<f64>,
pub factor: Array2<f64>,
pub loss_history: Vec<f64>,
pub n_iters: usize,
}
#[derive(Debug, Clone)]
pub struct LMNN {
pub k: usize,
pub max_iter: usize,
pub learning_rate: f64,
pub margin_weight: f64,
pub output_dim: Option<usize>,
pub tol: f64,
factor: Option<Array2<f64>>,
}
impl LMNN {
pub fn new(k: usize, max_iter: usize, learning_rate: f64) -> Result<Self> {
if k == 0 {
return Err(TransformError::InvalidInput("LMNN: k must be > 0".to_string()));
}
Ok(LMNN {
k,
max_iter,
learning_rate,
margin_weight: 1.0,
output_dim: None,
tol: 1e-6,
factor: None,
})
}
pub fn with_output_dim(mut self, dim: usize) -> Self {
self.output_dim = Some(dim);
self
}
pub fn with_margin_weight(mut self, w: f64) -> Self {
self.margin_weight = w;
self
}
fn compute_gradient(
x: &ArrayView2<f64>,
labels: &[i64],
m: &Array2<f64>,
k: usize,
margin_weight: f64,
) -> (Array2<f64>, f64) {
let n = x.nrows();
let d = x.ncols();
let target_neighbors: Vec<Vec<usize>> = (0..n).map(|i| {
let mut dists: Vec<(usize, f64)> = (0..n)
.filter(|&j| j != i && labels[j] == labels[i])
.map(|j| (j, mahalanobis_sq(&x.row(i), &x.row(j), &m.view())))
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
dists.truncate(k);
dists.into_iter().map(|(idx, _)| idx).collect()
}).collect();
let mut grad_m = Array2::<f64>::zeros((d, d));
let mut loss = 0.0f64;
for i in 0..n {
for &j in &target_neighbors[i] {
let diff: Array1<f64> = (0..d).map(|k| x[[i, k]] - x[[j, k]]).collect();
for a in 0..d {
for b in 0..d {
grad_m[[a, b]] += diff[a] * diff[b];
}
}
loss += mahalanobis_sq(&x.row(i), &x.row(j), &m.view());
}
}
let c = margin_weight;
for i in 0..n {
for &j in &target_neighbors[i] {
let d_ij = mahalanobis_sq(&x.row(i), &x.row(j), &m.view());
for l in 0..n {
if labels[l] == labels[i] {
continue;
}
let d_il = mahalanobis_sq(&x.row(i), &x.row(l), &m.view());
let margin_val = 1.0 + d_ij - d_il;
if margin_val > 0.0 {
loss += c * margin_val;
let diff_il: Array1<f64> = (0..d).map(|k| x[[i, k]] - x[[l, k]]).collect();
let diff_ij: Array1<f64> = (0..d).map(|k| x[[i, k]] - x[[j, k]]).collect();
for a in 0..d {
for b in 0..d {
grad_m[[a, b]] += c * (diff_ij[a] * diff_ij[b] - diff_il[a] * diff_il[b]);
}
}
}
}
}
}
(grad_m, loss)
}
pub fn fit(&mut self, x: &ArrayView2<f64>, labels: &ArrayView1<i64>) -> Result<MetricLearningResult> {
let n = x.nrows();
let d = x.ncols();
if labels.len() != n {
return Err(TransformError::InvalidInput(format!(
"LMNN: x has {} rows but labels has {} elements",
n, labels.len()
)));
}
if n < 2 {
return Err(TransformError::InvalidInput("LMNN requires at least 2 samples".to_string()));
}
let labels_vec: Vec<i64> = labels.iter().copied().collect();
let out_dim = self.output_dim.unwrap_or(d);
let mut l = Array2::<f64>::zeros((out_dim, d));
for i in 0..out_dim.min(d) {
l[[i, i]] = 1.0;
}
let mut m = l.t().dot(&l);
let mut loss_history = Vec::with_capacity(self.max_iter);
let mut prev_loss = f64::INFINITY;
for iter in 0..self.max_iter {
let (grad, loss) = Self::compute_gradient(x, &labels_vec, &m, self.k, self.margin_weight);
loss_history.push(loss);
m = m - self.learning_rate * &grad;
match eigh(&m.view(), None) {
Ok((eigenvalues, eigenvectors)) => {
let diag_plus: Array1<f64> = eigenvalues.mapv(|v| v.max(0.0));
let mut m_new = Array2::<f64>::zeros((d, d));
for i in 0..d {
if diag_plus[i] > 1e-12 {
let v = eigenvectors.column(i);
for a in 0..d {
for b in 0..d {
m_new[[a, b]] += diag_plus[i] * v[a] * v[b];
}
}
}
}
m = m_new;
}
Err(_) => {
for i in 0..d {
m[[i, i]] += 1e-8;
}
}
}
match eigh(&m.view(), None) {
Ok((ev, evec)) => {
let mut l_new = Array2::<f64>::zeros((out_dim.min(d), d));
let mut pairs: Vec<(f64, usize)> = ev.iter().enumerate()
.map(|(i, &e)| (e, i))
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (k_idx, &(e_val, e_idx)) in pairs.iter().enumerate().take(out_dim.min(d)) {
let sqrt_e = e_val.max(0.0).sqrt();
for j in 0..d {
l_new[[k_idx, j]] = sqrt_e * evec[[j, e_idx]];
}
}
l = l_new;
}
Err(_) => {}
}
if (prev_loss - loss).abs() / (prev_loss.abs() + 1e-10) < self.tol {
self.factor = Some(l.clone());
return Ok(MetricLearningResult {
metric: m,
factor: l,
loss_history,
n_iters: iter + 1,
});
}
prev_loss = loss;
}
self.factor = Some(l.clone());
Ok(MetricLearningResult {
metric: m,
factor: l,
loss_history,
n_iters: self.max_iter,
})
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
let l = self.factor.as_ref().ok_or_else(|| {
TransformError::NotFitted("LMNN must be fitted before transform".to_string())
})?;
transform_with_factor(x, &l.view())
}
}
#[derive(Debug, Clone)]
pub struct NCA {
pub output_dim: usize,
pub max_iter: usize,
pub learning_rate: f64,
pub regularization: f64,
pub tol: f64,
factor: Option<Array2<f64>>,
}
impl NCA {
pub fn new(output_dim: usize, max_iter: usize, learning_rate: f64) -> Result<Self> {
if output_dim == 0 {
return Err(TransformError::InvalidInput("NCA: output_dim must be > 0".to_string()));
}
Ok(NCA {
output_dim,
max_iter,
learning_rate,
regularization: 1e-5,
tol: 1e-6,
factor: None,
})
}
pub fn with_regularization(mut self, reg: f64) -> Self {
self.regularization = reg;
self
}
fn nca_objective_gradient(
x: &ArrayView2<f64>,
labels: &[i64],
a: &Array2<f64>,
reg: f64,
) -> (f64, Array2<f64>) {
let n = x.nrows();
let d = x.ncols();
let e = a.nrows();
let mut z = Array2::<f64>::zeros((n, e));
for i in 0..n {
for k in 0..e {
for j in 0..d {
z[[i, k]] += a[[k, j]] * x[[i, j]];
}
}
}
let mut p = Array2::<f64>::zeros((n, n));
for i in 0..n {
let mut sum_exp = 0.0f64;
for j in 0..n {
if j == i {
continue;
}
let dist_sq: f64 = (0..e).map(|k| (z[[i, k]] - z[[j, k]]).powi(2)).sum();
let exp_val = (-dist_sq).exp();
p[[i, j]] = exp_val;
sum_exp += exp_val;
}
if sum_exp > 1e-15 {
for j in 0..n {
p[[i, j]] /= sum_exp;
}
}
}
let p_i: Array1<f64> = (0..n).map(|i| {
(0..n).filter(|&j| j != i && labels[j] == labels[i])
.map(|j| p[[i, j]])
.sum::<f64>()
}).collect();
let objective: f64 = p_i.iter().sum::<f64>();
let mut grad_a = Array2::<f64>::zeros((e, d));
for i in 0..n {
let mut t1 = Array2::<f64>::zeros((e, d));
for k in 0..n {
if k == i {
continue;
}
let p_ik = p[[i, k]];
if p_ik < 1e-15 {
continue;
}
for a in 0..e {
for b in 0..d {
t1[[a, b]] += p_ik * (z[[i, a]] - z[[k, a]]) * (x[[i, b]] - x[[k, b]]);
}
}
}
let mut t2 = Array2::<f64>::zeros((e, d));
for j in 0..n {
if j == i || labels[j] != labels[i] {
continue;
}
let p_ij = p[[i, j]];
if p_ij < 1e-15 {
continue;
}
for a in 0..e {
for b in 0..d {
t2[[a, b]] += p_ij * (z[[i, a]] - z[[j, a]]) * (x[[i, b]] - x[[j, b]]);
}
}
}
let pi = p_i[i];
for a in 0..e {
for b in 0..d {
grad_a[[a, b]] += 2.0 * (pi * t1[[a, b]] - t2[[a, b]]);
}
}
}
let reg_obj: f64 = reg * a.iter().map(|v| v * v).sum::<f64>() / 2.0;
let objective_reg = objective - reg_obj;
let grad_reg = grad_a - a.mapv(|v| reg * v);
(objective_reg, grad_reg)
}
pub fn fit(&mut self, x: &ArrayView2<f64>, labels: &ArrayView1<i64>) -> Result<MetricLearningResult> {
let n = x.nrows();
let d = x.ncols();
if labels.len() != n {
return Err(TransformError::InvalidInput(format!(
"NCA: x has {} rows but labels has {} elements",
n, labels.len()
)));
}
if n < 2 {
return Err(TransformError::InvalidInput("NCA requires at least 2 samples".to_string()));
}
let labels_vec: Vec<i64> = labels.iter().copied().collect();
let e = self.output_dim.min(d);
let mut a = Array2::<f64>::zeros((e, d));
let mut state: u64 = 54321;
for i in 0..e {
for j in 0..d {
state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
let v = (state >> 11) as f64 / (1u64 << 53) as f64 - 0.5;
a[[i, j]] = v * 0.01;
}
if i < d {
a[[i, i]] = 1.0;
}
}
let mut loss_history = Vec::with_capacity(self.max_iter);
let mut prev_obj = f64::NEG_INFINITY;
for iter in 0..self.max_iter {
let (obj, grad) = Self::nca_objective_gradient(x, &labels_vec, &a, self.regularization);
loss_history.push(-obj);
a = a + self.learning_rate * &grad;
if (obj - prev_obj).abs() / (prev_obj.abs() + 1e-10) < self.tol {
let m = a.t().dot(&a);
self.factor = Some(a.clone());
return Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: iter + 1,
});
}
prev_obj = obj;
}
let m = a.t().dot(&a);
self.factor = Some(a.clone());
Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: self.max_iter,
})
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
let a = self.factor.as_ref().ok_or_else(|| {
TransformError::NotFitted("NCA must be fitted before transform".to_string())
})?;
transform_with_factor(x, &a.view())
}
}
#[derive(Debug, Clone)]
pub struct MLKR {
pub output_dim: usize,
pub max_iter: usize,
pub learning_rate: f64,
pub tol: f64,
factor: Option<Array2<f64>>,
}
impl MLKR {
pub fn new(output_dim: usize, max_iter: usize, learning_rate: f64) -> Result<Self> {
if output_dim == 0 {
return Err(TransformError::InvalidInput("MLKR: output_dim must be > 0".to_string()));
}
Ok(MLKR {
output_dim,
max_iter,
learning_rate,
tol: 1e-6,
factor: None,
})
}
fn mlkr_objective_gradient(
x: &ArrayView2<f64>,
y: &ArrayView1<f64>,
a: &Array2<f64>,
) -> (f64, Array2<f64>) {
let n = x.nrows();
let d = x.ncols();
let e = a.nrows();
let mut z = Array2::<f64>::zeros((n, e));
for i in 0..n {
for k in 0..e {
for j in 0..d {
z[[i, k]] += a[[k, j]] * x[[i, j]];
}
}
}
let mut k_mat = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let dist_sq: f64 = (0..e).map(|k| (z[[i, k]] - z[[j, k]]).powi(2)).sum();
k_mat[[i, j]] = (-dist_sq).exp();
}
}
let k_sums: Array1<f64> = (0..n)
.map(|i| (0..n).filter(|&j| j != i).map(|j| k_mat[[i, j]]).sum::<f64>())
.collect();
let y_hat: Array1<f64> = (0..n).map(|i| {
let denom = k_sums[i];
if denom < 1e-15 {
return y[i];
}
(0..n).filter(|&j| j != i).map(|j| k_mat[[i, j]] * y[j]).sum::<f64>() / denom
}).collect();
let residuals: Array1<f64> = y_hat.iter().zip(y.iter()).map(|(yh, yi)| yh - yi).collect();
let loss: f64 = residuals.iter().map(|r| r * r).sum();
let mut grad_a = Array2::<f64>::zeros((e, d));
for i in 0..n {
let r_i = residuals[i];
let s_i = k_sums[i];
if s_i < 1e-15 {
continue;
}
for j in 0..n {
if i == j {
continue;
}
let k_ij = k_mat[[i, j]];
if k_ij < 1e-15 {
continue;
}
let dy_hat_dk_ij = (y[j] - y_hat[i]) / s_i;
let scale = 2.0 * r_i * dy_hat_dk_ij * k_ij;
for a_idx in 0..e {
for b_idx in 0..d {
grad_a[[a_idx, b_idx]] -= scale * (z[[i, a_idx]] - z[[j, a_idx]]) * x[[i, b_idx]];
}
}
}
}
(loss, grad_a)
}
pub fn fit(&mut self, x: &ArrayView2<f64>, y: &ArrayView1<f64>) -> Result<MetricLearningResult> {
let n = x.nrows();
let d = x.ncols();
if y.len() != n {
return Err(TransformError::InvalidInput(format!(
"MLKR: x has {} rows but y has {} elements",
n, y.len()
)));
}
if n < 3 {
return Err(TransformError::InvalidInput("MLKR requires at least 3 samples".to_string()));
}
let e = self.output_dim.min(d);
let mut a = Array2::<f64>::zeros((e, d));
for i in 0..e.min(d) {
a[[i, i]] = 1.0;
}
let mut loss_history = Vec::with_capacity(self.max_iter);
let mut prev_loss = f64::INFINITY;
for iter in 0..self.max_iter {
let (loss, grad) = Self::mlkr_objective_gradient(x, y, &a);
loss_history.push(loss);
a = a - self.learning_rate * &grad;
if (prev_loss - loss).abs() / (prev_loss.abs() + 1e-10) < self.tol {
let m = a.t().dot(&a);
self.factor = Some(a.clone());
return Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: iter + 1,
});
}
prev_loss = loss;
}
let m = a.t().dot(&a);
self.factor = Some(a.clone());
Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: self.max_iter,
})
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
let a = self.factor.as_ref().ok_or_else(|| {
TransformError::NotFitted("MLKR must be fitted before transform".to_string())
})?;
transform_with_factor(x, &a.view())
}
}
#[derive(Debug, Clone)]
pub struct SiameseLoss {
pub margin: f64,
}
impl SiameseLoss {
pub fn new(margin: f64) -> Self {
SiameseLoss { margin }
}
pub fn compute(&self, distance: f64, label: u8) -> f64 {
match label {
0 => distance * distance, _ => {
let margin_dist = (self.margin - distance).max(0.0);
margin_dist * margin_dist }
}
}
pub fn batch_loss(&self, distances: &[f64], labels: &[u8]) -> Result<f64> {
if distances.len() != labels.len() {
return Err(TransformError::InvalidInput(
"SiameseLoss: distances and labels must have the same length".to_string(),
));
}
if distances.is_empty() {
return Err(TransformError::InvalidInput(
"SiameseLoss: batch must not be empty".to_string(),
));
}
let total: f64 = distances.iter().zip(labels.iter())
.map(|(&d, &l)| self.compute(d, l))
.sum();
Ok(total / distances.len() as f64)
}
}
#[derive(Debug, Clone)]
pub struct PairConstraint {
pub i: usize,
pub j: usize,
pub label: u8,
}
impl PairConstraint {
pub fn similar(i: usize, j: usize) -> Self {
PairConstraint { i, j, label: 0 }
}
pub fn dissimilar(i: usize, j: usize) -> Self {
PairConstraint { i, j, label: 1 }
}
}
#[derive(Debug, Clone)]
pub struct ContrastiveMetricLearner {
pub output_dim: usize,
pub margin: f64,
pub max_iter: usize,
pub learning_rate: f64,
pub tol: f64,
factor: Option<Array2<f64>>,
}
impl ContrastiveMetricLearner {
pub fn new(output_dim: usize, margin: f64, max_iter: usize, learning_rate: f64) -> Result<Self> {
if output_dim == 0 {
return Err(TransformError::InvalidInput(
"ContrastiveMetricLearner: output_dim must be > 0".to_string(),
));
}
if margin <= 0.0 {
return Err(TransformError::InvalidInput(
"ContrastiveMetricLearner: margin must be positive".to_string(),
));
}
Ok(ContrastiveMetricLearner {
output_dim,
margin,
max_iter,
learning_rate,
tol: 1e-6,
factor: None,
})
}
fn contrastive_loss_and_grad(
x: &ArrayView2<f64>,
pairs: &[PairConstraint],
a: &Array2<f64>,
margin: f64,
) -> (f64, Array2<f64>) {
let d = x.ncols();
let e = a.nrows();
let n = x.nrows();
let mut z = Array2::<f64>::zeros((n, e));
for i in 0..n {
for k in 0..e {
for j in 0..d {
z[[i, k]] += a[[k, j]] * x[[i, j]];
}
}
}
let mut total_loss = 0.0f64;
let mut grad_a = Array2::<f64>::zeros((e, d));
for pair in pairs {
let zi = z.row(pair.i);
let zj = z.row(pair.j);
let dist_sq: f64 = (0..e).map(|k| (zi[k] - zj[k]).powi(2)).sum();
let dist = dist_sq.sqrt();
match pair.label {
0 => {
total_loss += dist_sq;
let scale = 2.0;
for a_idx in 0..e {
for b_idx in 0..d {
grad_a[[a_idx, b_idx]] +=
scale * (zi[a_idx] - zj[a_idx]) * (x[[pair.i, b_idx]] - x[[pair.j, b_idx]]);
}
}
}
_ => {
let slack = margin - dist;
if slack > 0.0 {
total_loss += slack * slack;
if dist > 1e-10 {
let scale = -2.0 * slack / dist;
for a_idx in 0..e {
for b_idx in 0..d {
grad_a[[a_idx, b_idx]] +=
scale * (zi[a_idx] - zj[a_idx]) * (x[[pair.i, b_idx]] - x[[pair.j, b_idx]]);
}
}
}
}
}
}
}
let n_pairs = pairs.len().max(1) as f64;
(total_loss / n_pairs, grad_a / n_pairs)
}
pub fn fit(&mut self, x: &ArrayView2<f64>, pairs: &[PairConstraint]) -> Result<MetricLearningResult> {
let d = x.ncols();
let n = x.nrows();
if pairs.is_empty() {
return Err(TransformError::InvalidInput(
"ContrastiveMetricLearner: pairs list is empty".to_string(),
));
}
for p in pairs {
if p.i >= n || p.j >= n {
return Err(TransformError::InvalidInput(format!(
"ContrastiveMetricLearner: pair index out of bounds ({}, {}) for n={}",
p.i, p.j, n
)));
}
}
let e = self.output_dim.min(d);
let mut a = Array2::<f64>::zeros((e, d));
for i in 0..e.min(d) {
a[[i, i]] = 1.0;
}
let mut loss_history = Vec::with_capacity(self.max_iter);
let mut prev_loss = f64::INFINITY;
for iter in 0..self.max_iter {
let (loss, grad) = Self::contrastive_loss_and_grad(x, pairs, &a, self.margin);
loss_history.push(loss);
a = a - self.learning_rate * &grad;
if (prev_loss - loss).abs() / (prev_loss.abs() + 1e-10) < self.tol {
let m = a.t().dot(&a);
self.factor = Some(a.clone());
return Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: iter + 1,
});
}
prev_loss = loss;
}
let m = a.t().dot(&a);
self.factor = Some(a.clone());
Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: self.max_iter,
})
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
let a = self.factor.as_ref().ok_or_else(|| {
TransformError::NotFitted("ContrastiveMetricLearner must be fitted before transform".to_string())
})?;
transform_with_factor(x, &a.view())
}
}
#[derive(Debug, Clone)]
pub struct TripletConstraint {
pub anchor: usize,
pub positive: usize,
pub negative: usize,
}
impl TripletConstraint {
pub fn new(anchor: usize, positive: usize, negative: usize) -> Self {
TripletConstraint { anchor, positive, negative }
}
}
#[derive(Debug, Clone)]
pub struct TripletMetricLearner {
pub output_dim: usize,
pub margin: f64,
pub max_iter: usize,
pub learning_rate: f64,
pub tol: f64,
factor: Option<Array2<f64>>,
}
impl TripletMetricLearner {
pub fn new(output_dim: usize, margin: f64, max_iter: usize, learning_rate: f64) -> Result<Self> {
if output_dim == 0 {
return Err(TransformError::InvalidInput(
"TripletMetricLearner: output_dim must be > 0".to_string(),
));
}
if margin <= 0.0 {
return Err(TransformError::InvalidInput(
"TripletMetricLearner: margin must be positive".to_string(),
));
}
Ok(TripletMetricLearner {
output_dim,
margin,
max_iter,
learning_rate,
tol: 1e-6,
factor: None,
})
}
fn triplet_loss_and_grad(
x: &ArrayView2<f64>,
triplets: &[TripletConstraint],
a: &Array2<f64>,
margin: f64,
) -> (f64, Array2<f64>) {
let d = x.ncols();
let e = a.nrows();
let n = x.nrows();
let mut z = Array2::<f64>::zeros((n, e));
for i in 0..n {
for k in 0..e {
for j in 0..d {
z[[i, k]] += a[[k, j]] * x[[i, j]];
}
}
}
let mut total_loss = 0.0f64;
let mut grad_a = Array2::<f64>::zeros((e, d));
for t in triplets {
let za = z.row(t.anchor);
let zp = z.row(t.positive);
let zn = z.row(t.negative);
let d_ap_sq: f64 = (0..e).map(|k| (za[k] - zp[k]).powi(2)).sum();
let d_an_sq: f64 = (0..e).map(|k| (za[k] - zn[k]).powi(2)).sum();
let loss_t = (d_ap_sq - d_an_sq + margin).max(0.0);
if loss_t <= 0.0 {
continue;
}
total_loss += loss_t;
for a_idx in 0..e {
for b_idx in 0..d {
let grad_ap = 2.0 * (za[a_idx] - zp[a_idx]) * (x[[t.anchor, b_idx]] - x[[t.positive, b_idx]]);
let grad_an = 2.0 * (za[a_idx] - zn[a_idx]) * (x[[t.anchor, b_idx]] - x[[t.negative, b_idx]]);
grad_a[[a_idx, b_idx]] += grad_ap - grad_an;
}
}
}
let n_triplets = triplets.len().max(1) as f64;
(total_loss / n_triplets, grad_a / n_triplets)
}
pub fn fit(&mut self, x: &ArrayView2<f64>, triplets: &[TripletConstraint]) -> Result<MetricLearningResult> {
let d = x.ncols();
let n = x.nrows();
if triplets.is_empty() {
return Err(TransformError::InvalidInput(
"TripletMetricLearner: triplets list is empty".to_string(),
));
}
for t in triplets {
if t.anchor >= n || t.positive >= n || t.negative >= n {
return Err(TransformError::InvalidInput(format!(
"TripletMetricLearner: triplet index out of bounds for n={}",
n
)));
}
}
let e = self.output_dim.min(d);
let mut a = Array2::<f64>::zeros((e, d));
for i in 0..e.min(d) {
a[[i, i]] = 1.0;
}
let mut loss_history = Vec::with_capacity(self.max_iter);
let mut prev_loss = f64::INFINITY;
for iter in 0..self.max_iter {
let (loss, grad) = Self::triplet_loss_and_grad(x, triplets, &a, self.margin);
loss_history.push(loss);
a = a - self.learning_rate * &grad;
if (prev_loss - loss).abs() / (prev_loss.abs() + 1e-10) < self.tol {
let m = a.t().dot(&a);
self.factor = Some(a.clone());
return Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: iter + 1,
});
}
prev_loss = loss;
}
let m = a.t().dot(&a);
self.factor = Some(a.clone());
Ok(MetricLearningResult {
metric: m,
factor: a,
loss_history,
n_iters: self.max_iter,
})
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
let a = self.factor.as_ref().ok_or_else(|| {
TransformError::NotFitted("TripletMetricLearner must be fitted before transform".to_string())
})?;
transform_with_factor(x, &a.view())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, Array2};
#[test]
fn test_mahalanobis_identity() {
let m = Array2::<f64>::eye(3);
let x = array![1.0, 0.0, 0.0];
let y = array![0.0, 0.0, 0.0];
let d = mahalanobis(&x.view(), &y.view(), &m.view());
assert!((d - 1.0).abs() < 1e-10);
}
#[test]
fn test_siamese_loss_similar() {
let loss = SiameseLoss::new(1.0);
let l = loss.compute(0.5, 0);
assert!((l - 0.25).abs() < 1e-10);
}
#[test]
fn test_siamese_loss_dissimilar_within_margin() {
let loss = SiameseLoss::new(2.0);
let l = loss.compute(0.5, 1); assert!((l - 2.25).abs() < 1e-10);
}
#[test]
fn test_siamese_loss_dissimilar_outside_margin() {
let loss = SiameseLoss::new(1.0);
let l = loss.compute(2.0, 1); assert_eq!(l, 0.0);
}
#[test]
fn test_siamese_batch_loss() {
let loss = SiameseLoss::new(1.0);
let dists = vec![0.3, 1.5];
let labels = vec![0u8, 1u8];
let bl = loss.batch_loss(&dists, &labels).expect("batch_loss should succeed");
assert!(bl >= 0.0);
}
#[test]
fn test_lmnn_fit() {
let x = Array2::<f64>::eye(4);
let y = Array1::<i64>::from_vec(vec![0, 0, 0, 0]);
let mut lmnn = LMNN::new(1, 5, 1e-6).expect("LMNN::new should succeed");
let res = lmnn.fit(&x.view(), &y.view()).expect("LMNN fit should succeed");
assert_eq!(res.metric.shape(), &[4, 4]);
}
#[test]
fn test_nca_fit() {
let x = Array2::<f64>::zeros((8, 3));
let y = Array1::<i64>::from_vec(vec![0, 0, 1, 1, 0, 0, 1, 1]);
let mut nca = NCA::new(2, 5, 1e-5).expect("NCA::new should succeed");
let res = nca.fit(&x.view(), &y.view()).expect("NCA fit should succeed");
assert_eq!(res.factor.shape(), &[2, 3]);
}
#[test]
fn test_mlkr_fit() {
let x = Array2::<f64>::zeros((8, 3));
let y = Array1::<f64>::ones(8);
let mut mlkr = MLKR::new(2, 5, 1e-6).expect("MLKR::new should succeed");
let res = mlkr.fit(&x.view(), &y.view()).expect("MLKR fit should succeed");
assert_eq!(res.factor.shape(), &[2, 3]);
}
#[test]
fn test_contrastive_metric_learner() {
let x = Array2::<f64>::eye(4);
let pairs = vec![
PairConstraint::similar(0, 1),
PairConstraint::dissimilar(0, 2),
];
let mut cml = ContrastiveMetricLearner::new(2, 1.0, 5, 1e-5).expect("ContrastiveMetricLearner::new should succeed");
let res = cml.fit(&x.view(), &pairs).expect("ContrastiveMetricLearner fit should succeed");
assert_eq!(res.factor.shape()[1], 4);
}
#[test]
fn test_triplet_metric_learner() {
let x = Array2::<f64>::eye(4);
let triplets = vec![TripletConstraint::new(0, 1, 2)];
let mut tml = TripletMetricLearner::new(3, 1.0, 5, 1e-5).expect("TripletMetricLearner::new should succeed");
let res = tml.fit(&x.view(), &triplets).expect("TripletMetricLearner fit should succeed");
assert_eq!(res.factor.shape()[1], 4);
}
}