use crate::sentence_embeddings::encoder::{cosine_sim, l2_norm_f32, SentenceEncoder};
#[derive(Debug, Clone)]
pub struct SimCSELoss {
pub loss: f32,
pub accuracy: f32,
}
pub struct SimCSETrainer {
pub encoder: SentenceEncoder,
pub temperature: f32,
noise_std: f32,
lcg_state: u64,
}
impl SimCSETrainer {
pub fn new(encoder: SentenceEncoder, temperature: f32) -> Self {
SimCSETrainer {
encoder,
temperature,
noise_std: 0.1,
lcg_state: 12345,
}
}
pub fn unsupervised_step(&mut self, sentences: &[&str]) -> SimCSELoss {
if sentences.is_empty() {
return SimCSELoss {
loss: 0.0,
accuracy: 0.0,
};
}
let n = sentences.len();
let dim = self.encoder.embedding_dim();
let mut embs_a: Vec<Vec<f32>> = Vec::with_capacity(n);
let mut embs_b: Vec<Vec<f32>> = Vec::with_capacity(n);
for &s in sentences {
let base = self.encoder.encode(s);
embs_a.push(self.add_noise(&base));
embs_b.push(self.add_noise(&base));
}
let _ = dim;
let loss_val = self.nt_xent_loss_inner(&embs_a, &embs_b);
let accuracy = self.top1_accuracy(&embs_a, &embs_b);
SimCSELoss {
loss: loss_val,
accuracy,
}
}
pub fn nt_xent_loss(&self, embeddings_a: &[Vec<f32>], embeddings_b: &[Vec<f32>]) -> f32 {
self.nt_xent_loss_inner(embeddings_a, embeddings_b)
}
pub fn update(&mut self, loss: &SimCSELoss, lr: f32) {
if !loss.loss.is_finite() || loss.loss <= 0.0 {
return;
}
let scale = if loss.accuracy > 0.5 {
1.0 - lr * 0.01 } else {
1.0 + lr * 0.01 };
let keys: Vec<String> = self.encoder.embeddings_mut().keys().cloned().collect();
for key in &keys {
if let Some(emb) = self.encoder.embeddings_mut().get_mut(key) {
for v in emb.iter_mut() {
*v *= scale;
}
}
}
}
pub fn train(
&mut self,
sentences: &[&str],
n_steps: usize,
batch_size: usize,
lr: f32,
) -> Vec<f32> {
if sentences.is_empty() || n_steps == 0 || batch_size == 0 {
return vec![];
}
let bs = batch_size.min(sentences.len());
let mut loss_history = Vec::with_capacity(n_steps);
for step in 0..n_steps {
let start = (step * bs) % sentences.len();
let end = (start + bs).min(sentences.len());
let batch = &sentences[start..end];
let step_loss = self.unsupervised_step(batch);
loss_history.push(step_loss.loss);
self.update(&step_loss, lr);
}
loss_history
}
fn nt_xent_loss_inner(&self, embs_a: &[Vec<f32>], embs_b: &[Vec<f32>]) -> f32 {
let n = embs_a.len().min(embs_b.len());
if n == 0 {
return 0.0;
}
let tau = self.temperature;
let mut total_loss = 0.0f32;
for i in 0..n {
let pos_sim = cosine_sim(&embs_a[i], &embs_b[i]) / tau;
let exp_pos = pos_sim.exp();
let mut denom = 0.0f32;
for j in 0..n {
let sim = cosine_sim(&embs_a[i], &embs_b[j]) / tau;
denom += sim.exp();
}
if denom > 0.0 && denom.is_finite() {
total_loss += -(exp_pos / denom).ln();
}
}
total_loss / n as f32
}
fn top1_accuracy(&self, embs_a: &[Vec<f32>], embs_b: &[Vec<f32>]) -> f32 {
let n = embs_a.len().min(embs_b.len());
if n == 0 {
return 0.0;
}
let mut correct = 0usize;
for i in 0..n {
let mut best_j = 0;
let mut best_sim = f32::NEG_INFINITY;
for j in 0..n {
let s = cosine_sim(&embs_a[i], &embs_b[j]);
if s > best_sim {
best_sim = s;
best_j = j;
}
}
if best_j == i {
correct += 1;
}
}
correct as f32 / n as f32
}
fn add_noise(&mut self, v: &[f32]) -> Vec<f32> {
let std = self.noise_std;
let mut noisy: Vec<f32> = v
.iter()
.enumerate()
.map(|(i, &x)| {
let noise = lcg_normal(&mut self.lcg_state, i as u64) * std;
x + noise
})
.collect();
l2_norm_f32(noisy.clone())
.into_iter()
.enumerate()
.for_each(|(i, v)| noisy[i] = v);
noisy
}
}
impl std::fmt::Debug for SimCSETrainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimCSETrainer")
.field("temperature", &self.temperature)
.field("noise_std", &self.noise_std)
.finish()
}
}
fn lcg_normal(state: &mut u64, extra: u64) -> f32 {
const A: u64 = 6_364_136_223_846_793_005;
const C: u64 = 1_442_695_040_888_963_407;
*state = A.wrapping_mul(state.wrapping_add(extra)).wrapping_add(C);
let u1 = ((*state >> 12) as f64) / ((1u64 << 52) as f64);
*state = A.wrapping_mul(*state).wrapping_add(C);
let u2 = ((*state >> 12) as f64) / ((1u64 << 52) as f64);
let n = (-2.0 * (u1 + 1e-15).ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
n.clamp(-3.0, 3.0) as f32
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sentence_embeddings::encoder::{PoolingStrategy, SentenceEncoderConfig};
fn build_trainer() -> SimCSETrainer {
let vocab: Vec<String> = (0..200).map(|i| format!("word{i}")).collect();
let enc = SentenceEncoder::new(
&vocab,
SentenceEncoderConfig {
embedding_dim: 32,
max_seq_len: 64,
pooling: PoolingStrategy::Mean,
normalize: true,
},
);
SimCSETrainer::new(enc, 0.05)
}
#[test]
fn test_simcse_loss_valid_range() {
let trainer = build_trainer();
let sentences = [
"word0 word1 word2",
"word3 word4 word5",
"word6 word7 word8",
"word9 word10 word11",
];
let embs_a: Vec<Vec<f32>> = sentences
.iter()
.map(|s| trainer.encoder.encode(s))
.collect();
let embs_b: Vec<Vec<f32>> = sentences
.iter()
.map(|s| trainer.encoder.encode(s))
.collect();
let loss = trainer.nt_xent_loss(&embs_a, &embs_b);
assert!(loss.is_finite(), "NT-Xent loss must be finite, got {loss}");
assert!(loss >= 0.0, "NT-Xent loss must be >= 0, got {loss}");
}
#[test]
fn test_simcse_train_runs() {
let mut trainer = build_trainer();
let sentences: Vec<&str> = (0..10)
.map(|i| {
Box::leak(format!("word{} word{}", i, i + 1).into_boxed_str()) as &str
})
.collect();
let history = trainer.train(&sentences, 5, 4, 0.01);
assert_eq!(history.len(), 5, "should return one loss per step");
for &l in &history {
assert!(l.is_finite(), "each step loss must be finite");
}
}
#[test]
fn test_unsupervised_step_valid() {
let mut trainer = build_trainer();
let sentences = vec!["word0 word1", "word2 word3", "word4 word5"];
let result = trainer.unsupervised_step(&sentences);
assert!(result.loss.is_finite());
assert!(result.accuracy >= 0.0 && result.accuracy <= 1.0);
}
#[test]
fn test_nt_xent_loss_single_pair() {
let trainer = build_trainer();
let a = vec![trainer.encoder.encode("word0 word1")];
let b = vec![trainer.encoder.encode("word0 word1")];
let loss = trainer.nt_xent_loss(&a, &b);
assert!(loss.is_finite());
}
}