use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use std::collections::VecDeque;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct NTXentLoss {
pub temperature: f64,
}
impl NTXentLoss {
pub fn new(temperature: f64) -> Self {
Self { temperature }
}
pub fn forward<F>(&self, z_i: &Array2<F>, z_j: &Array2<F>) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let n = z_i.nrows();
if z_j.nrows() != n {
return Err(NeuralError::ShapeMismatch(format!(
"NT-Xent: z_i has {} rows but z_j has {}",
n,
z_j.nrows()
)));
}
if n == 0 {
return Err(NeuralError::InvalidArgument(
"NT-Xent: batch size must be > 0".to_string(),
));
}
let tau = F::from_f64(self.temperature).ok_or_else(|| {
NeuralError::ComputationError("NT-Xent: cannot convert temperature".to_string())
})?;
let zi_norm = l2_normalise(z_i)?;
let zj_norm = l2_normalise(z_j)?;
let z_all = concatenate_rows(&zi_norm, &zj_norm)?;
let two_n = z_all.nrows();
let sim = cosine_sim_matrix(&z_all)?;
let mut total_loss = F::zero();
let neg_inf = F::from_f64(-1e38).ok_or_else(|| {
NeuralError::ComputationError("NT-Xent: cannot convert neg_inf".to_string())
})?;
for i in 0..two_n {
let pos_idx = if i < n { i + n } else { i - n };
let num_val = sim[[i, pos_idx]] / tau;
let mut log_denom = neg_inf;
for k in 0..two_n {
if k == i {
continue;
}
let logit = sim[[i, k]] / tau;
log_denom = log_sum_exp_pair(log_denom, logit);
}
total_loss += num_val - log_denom;
}
let two_n_f = F::from_usize(two_n).ok_or_else(|| {
NeuralError::ComputationError("NT-Xent: cannot convert 2N".to_string())
})?;
let loss = -(total_loss / two_n_f);
Ok(loss)
}
}
impl Default for NTXentLoss {
fn default() -> Self {
Self::new(0.1)
}
}
#[derive(Debug, Clone)]
pub struct SimCLRConfig {
pub representation_dim: usize,
pub projection_hidden_dim: usize,
pub projection_output_dim: usize,
pub temperature: f64,
pub weight_decay: f64,
}
impl Default for SimCLRConfig {
fn default() -> Self {
Self {
representation_dim: 512,
projection_hidden_dim: 2048,
projection_output_dim: 128,
temperature: 0.07,
weight_decay: 1e-6,
}
}
}
impl SimCLRConfig {
pub fn validate(&self) -> Result<()> {
if self.representation_dim == 0 {
return Err(NeuralError::ConfigError(
"SimCLRConfig: representation_dim must be > 0".to_string(),
));
}
if self.projection_hidden_dim == 0 {
return Err(NeuralError::ConfigError(
"SimCLRConfig: projection_hidden_dim must be > 0".to_string(),
));
}
if self.projection_output_dim == 0 {
return Err(NeuralError::ConfigError(
"SimCLRConfig: projection_output_dim must be > 0".to_string(),
));
}
if self.temperature <= 0.0 {
return Err(NeuralError::ConfigError(
"SimCLRConfig: temperature must be > 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ProjectionHead<F: Float + Debug + NumAssign> {
pub w1: Array2<F>,
pub b1: Array1<F>,
pub w2: Array2<F>,
pub b2: Array1<F>,
}
impl<F: Float + Debug + NumAssign + FromPrimitive> ProjectionHead<F> {
pub fn new(in_dim: usize, hidden_dim: usize, out_dim: usize) -> Result<Self> {
let scale1 = F::from_f64((2.0 / (in_dim + hidden_dim) as f64).sqrt()).ok_or_else(|| {
NeuralError::ComputationError("ProjectionHead: cannot compute scale1".to_string())
})?;
let scale2 =
F::from_f64((2.0 / (hidden_dim + out_dim) as f64).sqrt()).ok_or_else(|| {
NeuralError::ComputationError("ProjectionHead: cannot compute scale2".to_string())
})?;
let w1 = init_weight_matrix(hidden_dim, in_dim, scale1);
let b1 = Array1::zeros(hidden_dim);
let w2 = init_weight_matrix(out_dim, hidden_dim, scale2);
let b2 = Array1::zeros(out_dim);
Ok(Self { w1, b1, w2, b2 })
}
pub fn forward(&self, x: &Array2<F>) -> Result<Array2<F>> {
let h = linear_forward(x, &self.w1, &self.b1)?;
let h = relu2d(&h);
let out = linear_forward(&h, &self.w2, &self.b2)?;
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct SimCLRTrainer<F: Float + Debug + NumAssign + FromPrimitive> {
pub config: SimCLRConfig,
pub projection_head: ProjectionHead<F>,
pub loss_fn: NTXentLoss,
}
impl<F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive> SimCLRTrainer<F> {
pub fn new(config: SimCLRConfig) -> Result<Self> {
config.validate()?;
let projection_head = ProjectionHead::new(
config.representation_dim,
config.projection_hidden_dim,
config.projection_output_dim,
)?;
let loss_fn = NTXentLoss::new(config.temperature);
Ok(Self {
config,
projection_head,
loss_fn,
})
}
pub fn batch_loss(&self, rep_i: &Array2<F>, rep_j: &Array2<F>) -> Result<F> {
let z_i = self.projection_head.forward(rep_i)?;
let z_j = self.projection_head.forward(rep_j)?;
self.loss_fn.forward(&z_i, &z_j)
}
}
#[derive(Debug, Clone)]
pub struct MoCoQueue<F: Float + Debug + NumAssign> {
pub capacity: usize,
pub key_dim: usize,
pub momentum: f64,
pub temperature: f64,
queue: VecDeque<Array1<F>>,
}
impl<F: Float + Debug + NumAssign + FromPrimitive> MoCoQueue<F> {
pub fn new(capacity: usize, key_dim: usize, momentum: f64, temperature: f64) -> Self {
Self {
capacity,
key_dim,
momentum,
temperature,
queue: VecDeque::with_capacity(capacity),
}
}
pub fn enqueue_and_dequeue(&mut self, keys: &Array2<F>) -> Result<()> {
let n = keys.nrows();
if keys.ncols() != self.key_dim {
return Err(NeuralError::ShapeMismatch(format!(
"MoCoQueue::enqueue: keys have {} dims but queue expects {}",
keys.ncols(),
self.key_dim
)));
}
let keys_norm = l2_normalise(keys)?;
for i in 0..n {
let key = keys_norm.row(i).to_owned();
if self.queue.len() == self.capacity {
self.queue.pop_front();
}
self.queue.push_back(key);
}
Ok(())
}
pub fn len(&self) -> usize {
self.queue.len()
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn info_nce_loss(&self, queries: &Array2<F>, keys: &Array2<F>) -> Result<F> {
let n = queries.nrows();
if keys.nrows() != n {
return Err(NeuralError::ShapeMismatch(
"MoCoQueue::info_nce_loss: queries and keys must have same batch size".to_string(),
));
}
if self.queue.is_empty() {
return Err(NeuralError::InvalidState(
"MoCoQueue::info_nce_loss: queue is empty; enqueue keys first".to_string(),
));
}
let tau = F::from_f64(self.temperature).ok_or_else(|| {
NeuralError::ComputationError(
"MoCoQueue::info_nce_loss: cannot convert temperature".to_string(),
)
})?;
let q_norm = l2_normalise(queries)?;
let k_norm = l2_normalise(keys)?;
let mut total_loss = F::zero();
let neg_inf = F::from_f64(-1e38).ok_or_else(|| {
NeuralError::ComputationError("MoCoQueue: cannot convert neg_inf".to_string())
})?;
for i in 0..n {
let q = q_norm.row(i);
let k_pos = k_norm.row(i);
let pos_sim: F = q.iter().zip(k_pos.iter()).map(|(a, b)| *a * *b).fold(F::zero(), |acc, x| acc + x);
let pos_logit = pos_sim / tau;
let mut log_denom = pos_logit; for neg_key in &self.queue {
let neg_sim: F = q.iter().zip(neg_key.iter()).map(|(a, b)| *a * *b).fold(F::zero(), |acc, x| acc + x);
let neg_logit = neg_sim / tau;
log_denom = log_sum_exp_pair(log_denom, neg_logit);
}
total_loss += pos_logit - log_denom;
}
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError("MoCoQueue: cannot convert N".to_string())
})?;
Ok(-(total_loss / n_f))
}
pub fn ema_update<P>(&self, online_params: &[P], target_params: &mut [P]) -> Result<()>
where
P: Clone
+ std::ops::Mul<f64, Output = P>
+ std::ops::Add<Output = P>,
{
if online_params.len() != target_params.len() {
return Err(NeuralError::ShapeMismatch(format!(
"MoCoQueue::ema_update: online has {} params but target has {}",
online_params.len(),
target_params.len()
)));
}
let m = self.momentum;
let one_minus_m = 1.0 - m;
for (t, o) in target_params.iter_mut().zip(online_params.iter()) {
*t = t.clone() * m + o.clone() * one_minus_m;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct BYOLConfig {
pub representation_dim: usize,
pub hidden_dim: usize,
pub projection_dim: usize,
pub initial_tau: f64,
pub final_tau: f64,
pub total_steps: usize,
}
impl Default for BYOLConfig {
fn default() -> Self {
Self {
representation_dim: 512,
hidden_dim: 4096,
projection_dim: 256,
initial_tau: 0.996,
final_tau: 1.0,
total_steps: 100_000,
}
}
}
impl BYOLConfig {
pub fn validate(&self) -> Result<()> {
if self.representation_dim == 0 {
return Err(NeuralError::ConfigError(
"BYOLConfig: representation_dim must be > 0".to_string(),
));
}
if !(0.0..=1.0).contains(&self.initial_tau) || !(0.0..=1.0).contains(&self.final_tau) {
return Err(NeuralError::ConfigError(
"BYOLConfig: tau values must be in [0, 1]".to_string(),
));
}
if self.total_steps == 0 {
return Err(NeuralError::ConfigError(
"BYOLConfig: total_steps must be > 0".to_string(),
));
}
Ok(())
}
pub fn tau_at_step(&self, step: usize) -> f64 {
let t = (step as f64) / (self.total_steps as f64);
let t = t.clamp(0.0, 1.0);
let cos_val = (std::f64::consts::PI * t).cos();
self.initial_tau + (self.final_tau - self.initial_tau) * (1.0 - cos_val) / 2.0
}
}
#[derive(Debug, Clone)]
pub struct BYOLUpdate<F: Float + Debug> {
pub loss: F,
pub tau: f64,
pub step: usize,
}
impl<F: Float + Debug + NumAssign + FromPrimitive> BYOLUpdate<F> {
pub fn compute(
online_pred: &Array2<F>,
target_proj: &Array2<F>,
config: &BYOLConfig,
step: usize,
) -> Result<Self> {
if online_pred.shape() != target_proj.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"BYOLUpdate: online_pred shape {:?} != target_proj shape {:?}",
online_pred.shape(),
target_proj.shape()
)));
}
let p = l2_normalise(online_pred)?;
let z = l2_normalise(target_proj)?;
let n = p.nrows();
let mut dot_sum = F::zero();
for i in 0..n {
let pi = p.row(i);
let zi = z.row(i);
let dot: F = pi.iter().zip(zi.iter()).map(|(a, b)| *a * *b).fold(F::zero(), |acc, x| acc + x);
dot_sum += dot;
}
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError("BYOLUpdate: cannot convert N".to_string())
})?;
let two = F::from_f64(2.0).ok_or_else(|| {
NeuralError::ComputationError("BYOLUpdate: cannot convert 2.0".to_string())
})?;
let mean_dot = dot_sum / n_f;
let loss = two - two * mean_dot;
let tau = config.tau_at_step(step);
Ok(Self { loss, tau, step })
}
}
#[derive(Debug, Clone)]
pub struct SupConLoss {
pub temperature: f64,
pub contrast_mode: SupConMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SupConMode {
All,
One,
}
impl Default for SupConLoss {
fn default() -> Self {
Self {
temperature: 0.07,
contrast_mode: SupConMode::All,
}
}
}
impl SupConLoss {
pub fn new(temperature: f64, contrast_mode: SupConMode) -> Self {
Self {
temperature,
contrast_mode,
}
}
pub fn forward<F>(
&self,
features: &Array2<F>,
labels: &[usize],
n_views: usize,
) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let total = features.nrows();
let n = labels.len();
if n == 0 {
return Err(NeuralError::InvalidArgument(
"SupConLoss: labels must not be empty".to_string(),
));
}
if n_views == 0 {
return Err(NeuralError::InvalidArgument(
"SupConLoss: n_views must be > 0".to_string(),
));
}
if total != n * n_views {
return Err(NeuralError::ShapeMismatch(format!(
"SupConLoss: features has {} rows but N={} * n_views={} = {}",
total,
n,
n_views,
n * n_views
)));
}
let tau = F::from_f64(self.temperature).ok_or_else(|| {
NeuralError::ComputationError("SupConLoss: cannot convert temperature".to_string())
})?;
let feats = l2_normalise(features)?;
let anchor_indices: Vec<usize> = match self.contrast_mode {
SupConMode::One => (0..n).collect(),
SupConMode::All => (0..total).collect(),
};
let neg_inf = F::from_f64(-1e38).ok_or_else(|| {
NeuralError::ComputationError("SupConLoss: cannot convert neg_inf".to_string())
})?;
let mut total_loss = F::zero();
let mut num_valid = 0usize;
for &anchor_idx in &anchor_indices {
let sample_idx = anchor_idx % n;
let anchor_label = labels[sample_idx];
let anchor = feats.row(anchor_idx);
let mut pos_logits: Vec<F> = Vec::new();
let mut all_logits: Vec<F> = Vec::new();
for j in 0..total {
if j == anchor_idx {
continue;
}
let j_sample = j % n;
let sim: F = anchor.iter().zip(feats.row(j).iter())
.map(|(a, b)| *a * *b)
.fold(F::zero(), |acc, x| acc + x);
let logit = sim / tau;
all_logits.push(logit);
if labels[j_sample] == anchor_label {
pos_logits.push(logit);
}
}
if pos_logits.is_empty() {
continue;
}
let mut log_denom = neg_inf;
for &logit in &all_logits {
log_denom = log_sum_exp_pair(log_denom, logit);
}
let n_pos = F::from_usize(pos_logits.len()).ok_or_else(|| {
NeuralError::ComputationError("SupConLoss: cannot convert n_pos".to_string())
})?;
let mut anchor_loss = F::zero();
for &pos_logit in &pos_logits {
anchor_loss += pos_logit - log_denom;
}
total_loss += -(anchor_loss / n_pos);
num_valid += 1;
}
if num_valid == 0 {
return Ok(F::zero());
}
let num_valid_f = F::from_usize(num_valid).ok_or_else(|| {
NeuralError::ComputationError("SupConLoss: cannot convert num_valid".to_string())
})?;
Ok(total_loss / num_valid_f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TripletMiningStrategy {
HardNegative,
SemiHard,
Random,
}
#[derive(Debug, Clone)]
pub struct TripletMarginLoss {
pub margin: f64,
pub strategy: TripletMiningStrategy,
pub normalize: bool,
}
impl Default for TripletMarginLoss {
fn default() -> Self {
Self {
margin: 1.0,
strategy: TripletMiningStrategy::SemiHard,
normalize: true,
}
}
}
impl TripletMarginLoss {
pub fn new(margin: f64, strategy: TripletMiningStrategy, normalize: bool) -> Self {
Self {
margin,
strategy,
normalize,
}
}
pub fn forward<F>(&self, embeddings: &Array2<F>, labels: &[usize]) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let n = embeddings.nrows();
if labels.len() != n {
return Err(NeuralError::ShapeMismatch(format!(
"TripletMarginLoss: embeddings has {} rows but labels has {} elements",
n,
labels.len()
)));
}
let margin = F::from_f64(self.margin).ok_or_else(|| {
NeuralError::ComputationError("TripletMarginLoss: cannot convert margin".to_string())
})?;
let emb = if self.normalize {
l2_normalise(embeddings)?
} else {
embeddings.to_owned()
};
let dist_sq = pairwise_squared_dist(&emb)?;
let mut total_loss = F::zero();
let mut count = 0usize;
for a in 0..n {
let pos_indices: Vec<usize> = (0..n)
.filter(|&k| k != a && labels[k] == labels[a])
.collect();
let neg_indices: Vec<usize> = (0..n)
.filter(|&k| labels[k] != labels[a])
.collect();
if pos_indices.is_empty() || neg_indices.is_empty() {
continue;
}
for &p in &pos_indices {
let d_ap = dist_sq[[a, p]];
let neg_choice = match self.strategy {
TripletMiningStrategy::HardNegative => {
neg_indices.iter().copied().min_by(|&x, &y| {
dist_sq[[a, x]]
.partial_cmp(&dist_sq[[a, y]])
.unwrap_or(std::cmp::Ordering::Equal)
})
}
TripletMiningStrategy::SemiHard => {
let upper = d_ap + margin;
let semi_hard: Vec<usize> = neg_indices
.iter()
.copied()
.filter(|&k| {
let d_an = dist_sq[[a, k]];
d_an > d_ap && d_an < upper
})
.collect();
if semi_hard.is_empty() {
neg_indices.iter().copied().min_by(|&x, &y| {
dist_sq[[a, x]]
.partial_cmp(&dist_sq[[a, y]])
.unwrap_or(std::cmp::Ordering::Equal)
})
} else {
semi_hard.iter().copied().min_by(|&x, &y| {
dist_sq[[a, x]]
.partial_cmp(&dist_sq[[a, y]])
.unwrap_or(std::cmp::Ordering::Equal)
})
}
}
TripletMiningStrategy::Random => {
let idx = (a.wrapping_mul(131) ^ p.wrapping_mul(137)) % neg_indices.len();
Some(neg_indices[idx])
}
};
if let Some(neg) = neg_choice {
let d_an = dist_sq[[a, neg]];
let triplet_loss = (d_ap - d_an + margin).max(F::zero());
total_loss += triplet_loss;
count += 1;
}
}
}
if count == 0 {
return Ok(F::zero());
}
let count_f = F::from_usize(count).ok_or_else(|| {
NeuralError::ComputationError("TripletMarginLoss: cannot convert count".to_string())
})?;
Ok(total_loss / count_f)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ContrastivePairLoss {
pub margin: f64,
pub normalize: bool,
}
impl Default for ContrastivePairLoss {
fn default() -> Self {
Self {
margin: 1.0,
normalize: false,
}
}
}
impl ContrastivePairLoss {
pub fn new(margin: f64, normalize: bool) -> Self {
Self { margin, normalize }
}
pub fn forward<F>(&self, emb1: &Array2<F>, emb2: &Array2<F>, labels: &[F]) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
let n = emb1.nrows();
if emb2.nrows() != n || labels.len() != n {
return Err(NeuralError::ShapeMismatch(
"ContrastivePairLoss: emb1, emb2, labels must have same batch size".to_string(),
));
}
let margin = F::from_f64(self.margin).ok_or_else(|| {
NeuralError::ComputationError(
"ContrastivePairLoss: cannot convert margin".to_string(),
)
})?;
let e1 = if self.normalize {
l2_normalise(emb1)?
} else {
emb1.to_owned()
};
let e2 = if self.normalize {
l2_normalise(emb2)?
} else {
emb2.to_owned()
};
let mut total = F::zero();
for i in 0..n {
let diff = e1.row(i).to_owned() - e2.row(i).to_owned();
let dist_sq: F = diff.iter().map(|x| *x * *x).fold(F::zero(), |a, b| a + b);
let dist = dist_sq.sqrt();
let y = labels[i];
let pair_loss = if y > F::zero() {
dist_sq
} else {
let margin_term = (margin - dist).max(F::zero());
margin_term * margin_term
};
total += pair_loss;
}
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError(
"ContrastivePairLoss: cannot convert N".to_string(),
)
})?;
Ok(total / n_f)
}
}
fn l2_normalise<F>(x: &Array2<F>) -> Result<Array2<F>>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
let eps = F::from_f64(1e-12).ok_or_else(|| {
NeuralError::ComputationError("l2_normalise: cannot convert eps".to_string())
})?;
let norms = x
.map_axis(Axis(1), |row| {
let sq: F = row.iter().map(|v| *v * *v).fold(F::zero(), |a, b| a + b);
sq.sqrt().max(eps)
});
let mut out = x.to_owned();
for (mut row, &norm) in out.rows_mut().into_iter().zip(norms.iter()) {
row.mapv_inplace(|v| v / norm);
}
Ok(out)
}
fn concatenate_rows<F: Float + Debug>(a: &Array2<F>, b: &Array2<F>) -> Result<Array2<F>> {
if a.ncols() != b.ncols() {
return Err(NeuralError::ShapeMismatch(format!(
"concatenate_rows: a has {} cols but b has {}",
a.ncols(),
b.ncols()
)));
}
let mut out =
Array2::zeros((a.nrows() + b.nrows(), a.ncols()));
for (i, row) in a.rows().into_iter().enumerate() {
out.row_mut(i).assign(&row);
}
for (i, row) in b.rows().into_iter().enumerate() {
out.row_mut(a.nrows() + i).assign(&row);
}
Ok(out)
}
fn cosine_sim_matrix<F: Float + Debug + NumAssign>(x: &Array2<F>) -> Result<Array2<F>> {
let n = x.nrows();
let mut sim = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let dot: F = x
.row(i)
.iter()
.zip(x.row(j).iter())
.map(|(a, b)| *a * *b)
.fold(F::zero(), |acc, v| acc + v);
sim[[i, j]] = dot;
}
}
Ok(sim)
}
fn log_sum_exp_pair<F: Float>(a: F, b: F) -> F {
if a > b {
a + (b - a).exp().ln_1p()
} else {
b + (a - b).exp().ln_1p()
}
}
fn pairwise_squared_dist<F: Float + Debug + NumAssign>(x: &Array2<F>) -> Result<Array2<F>> {
let n = x.nrows();
let mut dist = Array2::zeros((n, n));
for i in 0..n {
for j in (i + 1)..n {
let d: F = x
.row(i)
.iter()
.zip(x.row(j).iter())
.map(|(a, b)| {
let diff = *a - *b;
diff * diff
})
.fold(F::zero(), |acc, v| acc + v);
dist[[i, j]] = d;
dist[[j, i]] = d;
}
}
Ok(dist)
}
fn linear_forward<F: Float + Debug + NumAssign>(
x: &Array2<F>,
w: &Array2<F>,
b: &Array1<F>,
) -> Result<Array2<F>> {
let n = x.nrows();
let in_dim = x.ncols();
let out_dim = w.nrows();
if w.ncols() != in_dim {
return Err(NeuralError::ShapeMismatch(format!(
"linear_forward: x has {} cols but W has {} cols",
in_dim,
w.ncols()
)));
}
if b.len() != out_dim {
return Err(NeuralError::ShapeMismatch(format!(
"linear_forward: W has {} rows but b has {} elements",
out_dim,
b.len()
)));
}
let mut out = Array2::zeros((n, out_dim));
for i in 0..n {
for j in 0..out_dim {
let dot: F = x
.row(i)
.iter()
.zip(w.row(j).iter())
.map(|(a, wij)| *a * *wij)
.fold(F::zero(), |acc, v| acc + v);
out[[i, j]] = dot + b[j];
}
}
Ok(out)
}
fn relu2d<F: Float>(x: &Array2<F>) -> Array2<F> {
x.mapv(|v| v.max(F::zero()))
}
fn init_weight_matrix<F: Float + FromPrimitive>(rows: usize, cols: usize, scale: F) -> Array2<F> {
let mut w = Array2::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let sign = if (i + j) % 2 == 0 { F::one() } else { -F::one() };
w[[i, j]] = sign * scale;
}
}
w
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_ntxent_basic() {
let loss_fn = NTXentLoss::new(0.1);
let n = 4;
let d = 8;
let z_i = Array2::<f64>::from_shape_fn((n, d), |(i, j)| {
((i * d + j) as f64 * 0.1).sin()
});
let z_j = Array2::<f64>::from_shape_fn((n, d), |(i, j)| {
((i * d + j) as f64 * 0.1 + 0.5).cos()
});
let loss = loss_fn.forward(&z_i, &z_j).expect("NT-Xent forward");
assert!(loss.is_finite(), "NT-Xent loss should be finite");
assert!(loss >= 0.0, "NT-Xent loss should be non-negative");
}
#[test]
fn test_byol_update() {
let config = BYOLConfig::default();
let n = 4;
let d = 8;
let online = Array2::<f64>::from_shape_fn((n, d), |(i, j)| (i + j) as f64 * 0.01);
let target = Array2::<f64>::from_shape_fn((n, d), |(i, j)| (i + j) as f64 * 0.02);
let update =
BYOLUpdate::compute(&online, &target, &config, 1000).expect("BYOL update");
assert!(update.loss.is_finite());
assert!(update.tau > 0.0 && update.tau <= 1.0);
}
#[test]
fn test_supcon_loss() {
let loss_fn = SupConLoss::new(0.07, SupConMode::All);
let features = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| {
((i * 16 + j) as f64 * 0.1).sin()
});
let labels = vec![0usize, 0, 1, 1];
let loss = loss_fn.forward(&features, &labels, 2).expect("SupCon forward");
assert!(loss.is_finite());
}
#[test]
fn test_triplet_margin_loss_hard() {
let loss_fn = TripletMarginLoss::new(1.0, TripletMiningStrategy::HardNegative, true);
let embeddings = Array2::<f64>::from_shape_fn((6, 4), |(i, j)| {
((i * 4 + j) as f64).sin()
});
let labels = vec![0usize, 0, 1, 1, 2, 2];
let loss = loss_fn.forward(&embeddings, &labels).expect("Triplet forward");
assert!(loss.is_finite());
assert!(loss >= 0.0);
}
#[test]
fn test_contrastive_pair_loss() {
let loss_fn = ContrastivePairLoss::new(1.0, true);
let emb1 = Array2::<f64>::from_shape_fn((4, 8), |(i, j)| (i + j) as f64 * 0.1);
let emb2 = Array2::<f64>::from_shape_fn((4, 8), |(i, j)| (i + j) as f64 * 0.2);
let labels = vec![1.0_f64, 0.0, 1.0, 0.0];
let loss = loss_fn
.forward(&emb1, &emb2, &labels)
.expect("ContrastivePairLoss forward");
assert!(loss.is_finite());
}
#[test]
fn test_moco_queue() {
let mut queue: MoCoQueue<f64> = MoCoQueue::new(16, 8, 0.999, 0.07);
let keys = Array2::<f64>::from_shape_fn((4, 8), |(i, j)| (i + j) as f64 * 0.1);
queue.enqueue_and_dequeue(&keys).expect("enqueue");
assert_eq!(queue.len(), 4);
let queries = Array2::<f64>::from_shape_fn((4, 8), |(i, j)| (i + j) as f64 * 0.05);
let loss = queue.info_nce_loss(&queries, &keys).expect("InfoNCE");
assert!(loss.is_finite());
}
#[test]
fn test_simclr_trainer() {
let config = SimCLRConfig {
representation_dim: 16,
projection_hidden_dim: 32,
projection_output_dim: 8,
temperature: 0.1,
weight_decay: 1e-4,
};
let trainer = SimCLRTrainer::<f64>::new(config).expect("SimCLRTrainer::new");
let rep_i = Array2::<f64>::from_shape_fn((4, 16), |(i, j)| (i + j) as f64 * 0.01);
let rep_j = Array2::<f64>::from_shape_fn((4, 16), |(i, j)| (i + j) as f64 * 0.02);
let loss = trainer.batch_loss(&rep_i, &rep_j).expect("batch_loss");
assert!(loss.is_finite());
}
}