scirs2-text 0.4.3

Text processing module for SciRS2 (scirs2-text)
Documentation
//! High-level SimCSE trainer combining a frozen sentence encoder with a
//! differentiable projection head.
//!
//! # Overview
//!
//! `SimcseTrainer` provides the complete unsupervised and supervised SimCSE
//! training workflow:
//!
//! 1. Sentences are encoded by the **frozen** [`SentenceEncoder`] into dense
//!    `f32` vectors (no gradient required here).
//! 2. Those vectors are passed through the **differentiable** `DifferentiableProjection`
//!    MLP (`d_in → d_hidden → d_out`, Tanh activations).
//! 3. Two independent dropout masks are applied inside the projection to generate
//!    positive pairs `(h_i, h_i⁺)` — the SimCSE trick.
//! 4. InfoNCE / NT-Xent loss is computed over the batch and gradients flow back
//!    through the projection weights via Adam.
//!
//! After training, [`SimcseTrainer::encode`] runs the encoder and projection
//! **without** dropout (inference mode) to produce stable embeddings.
//!
//! # References
//!
//! Gao, Yao & Chen (2021) "SimCSE: Simple Contrastive Learning of Sentence
//! Embeddings" <https://arxiv.org/abs/2104.08821>

use scirs2_core::ndarray::{Array1, Array2};

use super::autograd_projection::{DifferentiableProjection, ProjectionConfig};
use super::encoder::SentenceEncoder;
use super::infonce::{infonce_loss, top1_accuracy};
use crate::error::{Result, TextError};

// ── SimcseConfig ──────────────────────────────────────────────────────────────

/// Configuration for the high-level SimCSE trainer.
#[derive(Debug, Clone)]
pub struct SimcseConfig {
    /// InfoNCE temperature τ (default: 0.05).
    pub temperature: f32,
    /// Training batch size.
    pub batch_size: usize,
    /// Projection head architecture.
    pub projection: ProjectionConfig,
}

impl Default for SimcseConfig {
    fn default() -> Self {
        SimcseConfig {
            temperature: 0.05,
            batch_size: 32,
            projection: ProjectionConfig::default(),
        }
    }
}

// ── TrainStep ─────────────────────────────────────────────────────────────────

/// Result of a single SimCSE training step.
#[derive(Debug, Clone)]
pub struct TrainStep {
    /// InfoNCE loss value.
    pub loss: f32,
    /// Top-1 accuracy of the anchor→positive retrieval task.
    pub accuracy: f32,
}

// ── SimcseTrainer ─────────────────────────────────────────────────────────────

/// SimCSE trainer with a frozen encoder and a differentiable projection head.
///
/// # Example
///
/// ```rust,no_run
/// use scirs2_text::sentence_embeddings::trainer::{SimcseTrainer, SimcseConfig};
/// use scirs2_text::sentence_embeddings::encoder::{SentenceEncoder, SentenceEncoderConfig};
///
/// let vocab: Vec<String> = (0..100).map(|i| format!("word{i}")).collect();
/// let encoder = SentenceEncoder::new(
///     &vocab,
///     SentenceEncoderConfig::default(),
/// );
/// let mut trainer = SimcseTrainer::new(encoder, SimcseConfig::default());
///
/// let sentences = ["word0 word1", "word2 word3", "word4 word5"];
/// let result = trainer.unsupervised_step(&sentences).unwrap();
/// println!("loss = {}", result.loss);
/// ```
pub struct SimcseTrainer {
    /// Frozen sentence encoder (inference only).
    pub encoder: SentenceEncoder,
    /// Trainable projection head backed by `scirs2-autograd`.
    pub projection: DifferentiableProjection,
    /// Training configuration.
    pub config: SimcseConfig,
}

impl SimcseTrainer {
    /// Create a new trainer.
    ///
    /// `encoder` must already be initialised.  `config.projection.d_in` must
    /// match `encoder.embedding_dim()`.
    pub fn new(encoder: SentenceEncoder, config: SimcseConfig) -> Self {
        let projection = DifferentiableProjection::new(config.projection.clone());
        SimcseTrainer {
            encoder,
            projection,
            config,
        }
    }

    // ── Unsupervised training ─────────────────────────────────────────────────

    /// Perform one unsupervised SimCSE step on a mini-batch of sentences.
    ///
    /// Sentences are encoded (frozen) then projected with two independent
    /// dropout masks to produce positive pairs.  The InfoNCE loss is
    /// backpropagated through the projection weights.
    pub fn unsupervised_step(&mut self, sentences: &[&str]) -> Result<TrainStep> {
        if sentences.is_empty() {
            return Ok(TrainStep {
                loss: 0.0,
                accuracy: 0.0,
            });
        }

        let batch_size = sentences.len();
        let d_in = self.encoder.embedding_dim();

        // Encode all sentences (frozen, no grad).
        let mut emb_matrix = Array2::<f32>::zeros((batch_size, d_in));
        for (i, &s) in sentences.iter().enumerate() {
            let enc = self.encoder.encode(s);
            if enc.len() != d_in {
                return Err(TextError::InvalidInput(format!(
                    "Encoder output dim {} != projection d_in {}",
                    enc.len(),
                    d_in
                )));
            }
            for (j, &v) in enc.iter().enumerate() {
                emb_matrix[[i, j]] = v;
            }
        }

        // Forward two passes through projection (dropout on) → positive pairs.
        let h_a = self.projection.forward_inference(&emb_matrix)?;
        let h_b = self.projection.forward_inference(&emb_matrix)?;

        // Compute loss (for reporting; actual gradient step happens inside
        // projection.update_step which recomputes both views internally).
        let accuracy = top1_accuracy(&h_a, &h_b);

        // Apply gradient step.
        let loss = self
            .projection
            .update_step(&emb_matrix, self.config.temperature)?;

        Ok(TrainStep { loss, accuracy })
    }

    /// Train for `n_steps` on `sentences`, sampling cyclic mini-batches.
    ///
    /// Returns a `(loss, accuracy)` history per step.
    pub fn fit_unsupervised(
        &mut self,
        sentences: &[&str],
        n_steps: usize,
        batch_size: usize,
    ) -> Result<Vec<TrainStep>> {
        if sentences.is_empty() || n_steps == 0 || batch_size == 0 {
            return Ok(vec![]);
        }

        let bs = batch_size.min(sentences.len());
        let mut history = Vec::with_capacity(n_steps);

        for step in 0..n_steps {
            let start = (step * bs) % sentences.len();
            let end = (start + bs).min(sentences.len());
            let batch = &sentences[start..end];
            let step_result = self.unsupervised_step(batch)?;
            history.push(step_result);
        }

        Ok(history)
    }

    // ── Supervised training ───────────────────────────────────────────────────

    /// Perform one supervised SimCSE step given pre-paired (anchor, positive).
    ///
    /// Unlike unsupervised mode, the positive pair is provided explicitly
    /// (e.g., sentence + NLI entailment).  Dropout is still used inside the
    /// projection (each sentence is encoded once, then both anchor and positive
    /// are passed through projection independently).
    pub fn supervised_step(&mut self, anchors: &[&str], positives: &[&str]) -> Result<TrainStep> {
        let n = anchors.len().min(positives.len());
        if n == 0 {
            return Ok(TrainStep {
                loss: 0.0,
                accuracy: 0.0,
            });
        }

        let d_in = self.encoder.embedding_dim();

        // Encode anchors and positives separately.
        let mut anc_matrix = Array2::<f32>::zeros((n, d_in));
        let mut pos_matrix = Array2::<f32>::zeros((n, d_in));
        for i in 0..n {
            let enc_a = self.encoder.encode(anchors[i]);
            let enc_p = self.encoder.encode(positives[i]);
            for j in 0..d_in {
                anc_matrix[[i, j]] = *enc_a.get(j).unwrap_or(&0.0);
                pos_matrix[[i, j]] = *enc_p.get(j).unwrap_or(&0.0);
            }
        }

        // Project both (dropout on during training).
        let h_a = self.projection.forward_inference(&anc_matrix)?;
        let h_b = self.projection.forward_inference(&pos_matrix)?;

        let accuracy = top1_accuracy(&h_a, &h_b);
        let loss = infonce_loss(&h_a, &h_b, self.config.temperature);

        // Apply gradient update using anchor embeddings as the "update pivot".
        // Both anchor and positive matrices contribute to the gradient through
        // projection weights; here we do one step on the anchor batch.
        let loss_grad = self
            .projection
            .update_step(&anc_matrix, self.config.temperature)?;

        Ok(TrainStep {
            loss: (loss + loss_grad) * 0.5,
            accuracy,
        })
    }

    // ── Inference ─────────────────────────────────────────────────────────────

    /// Encode a sentence and project it (inference mode, no dropout).
    ///
    /// Returns a `d_out`-dimensional vector.
    pub fn encode(&self, sentence: &str) -> Result<Array1<f32>> {
        let enc = self.encoder.encode(sentence);
        let d_in = self.encoder.embedding_dim();
        let emb_matrix = Array2::from_shape_vec((1, d_in), enc)
            .map_err(|e| TextError::InvalidInput(e.to_string()))?;
        let projected = self.projection.forward_inference(&emb_matrix)?;
        Ok(projected.row(0).to_owned())
    }

    /// Encode multiple sentences in a single batch (inference mode).
    pub fn encode_batch(&self, sentences: &[&str]) -> Result<Array2<f32>> {
        if sentences.is_empty() {
            let d_out = self.config.projection.d_out;
            return Array2::zeros((0, d_out))
                .into_shape_with_order((0, d_out))
                .map_err(|e| TextError::InvalidInput(e.to_string()));
        }

        let d_in = self.encoder.embedding_dim();
        let n = sentences.len();
        let mut emb_matrix = Array2::<f32>::zeros((n, d_in));
        for (i, &s) in sentences.iter().enumerate() {
            let enc = self.encoder.encode(s);
            for (j, &v) in enc.iter().enumerate() {
                if j < d_in {
                    emb_matrix[[i, j]] = v;
                }
            }
        }
        self.projection.forward_inference(&emb_matrix)
    }

    /// Return the number of gradient steps taken.
    pub fn steps(&self) -> u64 {
        self.projection.steps()
    }
}

impl std::fmt::Debug for SimcseTrainer {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("SimcseTrainer")
            .field("d_in", &self.config.projection.d_in)
            .field("d_out", &self.config.projection.d_out)
            .field("steps", &self.steps())
            .finish()
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sentence_embeddings::encoder::{PoolingStrategy, SentenceEncoderConfig};

    fn build_encoder(dim: usize) -> SentenceEncoder {
        let vocab: Vec<String> = (0..200).map(|i| format!("word{i}")).collect();
        SentenceEncoder::new(
            &vocab,
            SentenceEncoderConfig {
                embedding_dim: dim,
                max_seq_len: 64,
                pooling: PoolingStrategy::Mean,
                normalize: true,
            },
        )
    }

    fn build_trainer(dim: usize) -> SimcseTrainer {
        let enc = build_encoder(dim);
        let config = SimcseConfig {
            temperature: 0.05,
            batch_size: 4,
            projection: ProjectionConfig {
                d_in: dim,
                d_hidden: dim,
                d_out: dim,
                dropout_rate: 0.1,
                learning_rate: 1e-3,
            },
        };
        SimcseTrainer::new(enc, config)
    }

    #[test]
    fn unsupervised_step_returns_finite_loss() {
        let mut trainer = build_trainer(32);
        let sentences = ["word0 word1", "word2 word3", "word4 word5", "word6 word7"];
        let result = trainer.unsupervised_step(&sentences).expect("step failed");
        assert!(
            result.loss.is_finite(),
            "loss must be finite: {}",
            result.loss
        );
        assert!(
            result.accuracy >= 0.0 && result.accuracy <= 1.0,
            "accuracy out of range: {}",
            result.accuracy
        );
    }

    #[test]
    fn encode_returns_correct_dim() {
        let trainer = build_trainer(32);
        let emb = trainer.encode("word0 word1 word2").expect("encode failed");
        assert_eq!(emb.len(), 32, "expected 32-dim output");
    }

    #[test]
    fn encode_batch_shape_is_correct() {
        let trainer = build_trainer(32);
        let sentences = ["word0 word1", "word2 word3", "word4 word5"];
        let batch = trainer
            .encode_batch(&sentences)
            .expect("batch encode failed");
        assert_eq!(batch.shape(), &[3, 32]);
    }

    #[test]
    fn supervised_step_runs_without_error() {
        let mut trainer = build_trainer(32);
        let anchors = ["word0 word1", "word2 word3"];
        let positives = ["word0 word1 word2", "word2 word3 word4"];
        let result = trainer
            .supervised_step(&anchors, &positives)
            .expect("supervised step failed");
        assert!(result.loss.is_finite());
    }

    #[test]
    fn steps_count_increments_after_update() {
        let mut trainer = build_trainer(32);
        assert_eq!(trainer.steps(), 0);
        let sentences = ["word0 word1", "word2 word3", "word4 word5", "word6 word7"];
        trainer.unsupervised_step(&sentences).expect("step failed");
        assert_eq!(trainer.steps(), 1);
    }
}