use scirs2_core::ndarray::{Array1, Array2};
use super::autograd_projection::{DifferentiableProjection, ProjectionConfig};
use super::encoder::SentenceEncoder;
use super::infonce::{infonce_loss, top1_accuracy};
use crate::error::{Result, TextError};
#[derive(Debug, Clone)]
pub struct SimcseConfig {
pub temperature: f32,
pub batch_size: usize,
pub projection: ProjectionConfig,
}
impl Default for SimcseConfig {
fn default() -> Self {
SimcseConfig {
temperature: 0.05,
batch_size: 32,
projection: ProjectionConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct TrainStep {
pub loss: f32,
pub accuracy: f32,
}
pub struct SimcseTrainer {
pub encoder: SentenceEncoder,
pub projection: DifferentiableProjection,
pub config: SimcseConfig,
}
impl SimcseTrainer {
pub fn new(encoder: SentenceEncoder, config: SimcseConfig) -> Self {
let projection = DifferentiableProjection::new(config.projection.clone());
SimcseTrainer {
encoder,
projection,
config,
}
}
pub fn unsupervised_step(&mut self, sentences: &[&str]) -> Result<TrainStep> {
if sentences.is_empty() {
return Ok(TrainStep {
loss: 0.0,
accuracy: 0.0,
});
}
let batch_size = sentences.len();
let d_in = self.encoder.embedding_dim();
let mut emb_matrix = Array2::<f32>::zeros((batch_size, d_in));
for (i, &s) in sentences.iter().enumerate() {
let enc = self.encoder.encode(s);
if enc.len() != d_in {
return Err(TextError::InvalidInput(format!(
"Encoder output dim {} != projection d_in {}",
enc.len(),
d_in
)));
}
for (j, &v) in enc.iter().enumerate() {
emb_matrix[[i, j]] = v;
}
}
let h_a = self.projection.forward_inference(&emb_matrix)?;
let h_b = self.projection.forward_inference(&emb_matrix)?;
let accuracy = top1_accuracy(&h_a, &h_b);
let loss = self
.projection
.update_step(&emb_matrix, self.config.temperature)?;
Ok(TrainStep { loss, accuracy })
}
pub fn fit_unsupervised(
&mut self,
sentences: &[&str],
n_steps: usize,
batch_size: usize,
) -> Result<Vec<TrainStep>> {
if sentences.is_empty() || n_steps == 0 || batch_size == 0 {
return Ok(vec![]);
}
let bs = batch_size.min(sentences.len());
let mut history = Vec::with_capacity(n_steps);
for step in 0..n_steps {
let start = (step * bs) % sentences.len();
let end = (start + bs).min(sentences.len());
let batch = &sentences[start..end];
let step_result = self.unsupervised_step(batch)?;
history.push(step_result);
}
Ok(history)
}
pub fn supervised_step(&mut self, anchors: &[&str], positives: &[&str]) -> Result<TrainStep> {
let n = anchors.len().min(positives.len());
if n == 0 {
return Ok(TrainStep {
loss: 0.0,
accuracy: 0.0,
});
}
let d_in = self.encoder.embedding_dim();
let mut anc_matrix = Array2::<f32>::zeros((n, d_in));
let mut pos_matrix = Array2::<f32>::zeros((n, d_in));
for i in 0..n {
let enc_a = self.encoder.encode(anchors[i]);
let enc_p = self.encoder.encode(positives[i]);
for j in 0..d_in {
anc_matrix[[i, j]] = *enc_a.get(j).unwrap_or(&0.0);
pos_matrix[[i, j]] = *enc_p.get(j).unwrap_or(&0.0);
}
}
let h_a = self.projection.forward_inference(&anc_matrix)?;
let h_b = self.projection.forward_inference(&pos_matrix)?;
let accuracy = top1_accuracy(&h_a, &h_b);
let loss = infonce_loss(&h_a, &h_b, self.config.temperature);
let loss_grad = self
.projection
.update_step(&anc_matrix, self.config.temperature)?;
Ok(TrainStep {
loss: (loss + loss_grad) * 0.5,
accuracy,
})
}
pub fn encode(&self, sentence: &str) -> Result<Array1<f32>> {
let enc = self.encoder.encode(sentence);
let d_in = self.encoder.embedding_dim();
let emb_matrix = Array2::from_shape_vec((1, d_in), enc)
.map_err(|e| TextError::InvalidInput(e.to_string()))?;
let projected = self.projection.forward_inference(&emb_matrix)?;
Ok(projected.row(0).to_owned())
}
pub fn encode_batch(&self, sentences: &[&str]) -> Result<Array2<f32>> {
if sentences.is_empty() {
let d_out = self.config.projection.d_out;
return Array2::zeros((0, d_out))
.into_shape_with_order((0, d_out))
.map_err(|e| TextError::InvalidInput(e.to_string()));
}
let d_in = self.encoder.embedding_dim();
let n = sentences.len();
let mut emb_matrix = Array2::<f32>::zeros((n, d_in));
for (i, &s) in sentences.iter().enumerate() {
let enc = self.encoder.encode(s);
for (j, &v) in enc.iter().enumerate() {
if j < d_in {
emb_matrix[[i, j]] = v;
}
}
}
self.projection.forward_inference(&emb_matrix)
}
pub fn steps(&self) -> u64 {
self.projection.steps()
}
}
impl std::fmt::Debug for SimcseTrainer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimcseTrainer")
.field("d_in", &self.config.projection.d_in)
.field("d_out", &self.config.projection.d_out)
.field("steps", &self.steps())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sentence_embeddings::encoder::{PoolingStrategy, SentenceEncoderConfig};
fn build_encoder(dim: usize) -> SentenceEncoder {
let vocab: Vec<String> = (0..200).map(|i| format!("word{i}")).collect();
SentenceEncoder::new(
&vocab,
SentenceEncoderConfig {
embedding_dim: dim,
max_seq_len: 64,
pooling: PoolingStrategy::Mean,
normalize: true,
},
)
}
fn build_trainer(dim: usize) -> SimcseTrainer {
let enc = build_encoder(dim);
let config = SimcseConfig {
temperature: 0.05,
batch_size: 4,
projection: ProjectionConfig {
d_in: dim,
d_hidden: dim,
d_out: dim,
dropout_rate: 0.1,
learning_rate: 1e-3,
},
};
SimcseTrainer::new(enc, config)
}
#[test]
fn unsupervised_step_returns_finite_loss() {
let mut trainer = build_trainer(32);
let sentences = ["word0 word1", "word2 word3", "word4 word5", "word6 word7"];
let result = trainer.unsupervised_step(&sentences).expect("step failed");
assert!(
result.loss.is_finite(),
"loss must be finite: {}",
result.loss
);
assert!(
result.accuracy >= 0.0 && result.accuracy <= 1.0,
"accuracy out of range: {}",
result.accuracy
);
}
#[test]
fn encode_returns_correct_dim() {
let trainer = build_trainer(32);
let emb = trainer.encode("word0 word1 word2").expect("encode failed");
assert_eq!(emb.len(), 32, "expected 32-dim output");
}
#[test]
fn encode_batch_shape_is_correct() {
let trainer = build_trainer(32);
let sentences = ["word0 word1", "word2 word3", "word4 word5"];
let batch = trainer
.encode_batch(&sentences)
.expect("batch encode failed");
assert_eq!(batch.shape(), &[3, 32]);
}
#[test]
fn supervised_step_runs_without_error() {
let mut trainer = build_trainer(32);
let anchors = ["word0 word1", "word2 word3"];
let positives = ["word0 word1 word2", "word2 word3 word4"];
let result = trainer
.supervised_step(&anchors, &positives)
.expect("supervised step failed");
assert!(result.loss.is_finite());
}
#[test]
fn steps_count_increments_after_update() {
let mut trainer = build_trainer(32);
assert_eq!(trainer.steps(), 0);
let sentences = ["word0 word1", "word2 word3", "word4 word5", "word6 word7"];
trainer.unsupervised_step(&sentences).expect("step failed");
assert_eq!(trainer.steps(), 1);
}
}