use crate::error::{Result, TextError};
use scirs2_core::ndarray::{Array1, Array2, Axis};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum PoolingStrategy {
MeanPooling,
MaxPooling,
ClsToken,
MeanMax,
WeightedMean,
}
#[derive(Debug, Clone)]
pub struct SentenceEmbedderConfig {
pub pooling: PoolingStrategy,
pub normalize: bool,
pub dim: usize,
}
impl Default for SentenceEmbedderConfig {
fn default() -> Self {
SentenceEmbedderConfig {
pooling: PoolingStrategy::MeanPooling,
normalize: true,
dim: 768,
}
}
}
pub struct SentenceEmbedder {
config: SentenceEmbedderConfig,
token_embeddings: Array2<f64>,
}
impl SentenceEmbedder {
pub fn new(token_embeddings: Array2<f64>, mut config: SentenceEmbedderConfig) -> Self {
if config.dim == 0 {
config.dim = token_embeddings.ncols();
}
SentenceEmbedder {
config,
token_embeddings,
}
}
pub fn embed_token_ids(
&self,
token_ids: &[u32],
attention_mask: &[u32],
) -> Result<Array1<f64>> {
if token_ids.len() != attention_mask.len() {
return Err(TextError::InvalidInput(format!(
"token_ids length ({}) != attention_mask length ({})",
token_ids.len(),
attention_mask.len()
)));
}
if token_ids.is_empty() {
return Err(TextError::InvalidInput(
"Cannot embed an empty token sequence".to_string(),
));
}
let dim = self.token_embeddings.ncols();
let vocab_size = self.token_embeddings.nrows();
for &id in token_ids {
if id as usize >= vocab_size {
return Err(TextError::EmbeddingError(format!(
"Token ID {} out of vocabulary range [0, {})",
id, vocab_size
)));
}
}
let active: Vec<(usize, f64)> = token_ids
.iter()
.zip(attention_mask.iter())
.enumerate()
.filter_map(|(pos, (&id, &mask))| {
if mask != 0 {
Some((id as usize, pos as f64))
} else {
None
}
})
.collect();
if active.is_empty() {
return Ok(Array1::zeros(dim));
}
let result = match &self.config.pooling {
PoolingStrategy::MeanPooling => self.pool_mean(&active, dim),
PoolingStrategy::MaxPooling => self.pool_max(&active, dim),
PoolingStrategy::ClsToken => {
let cls_id = token_ids[0] as usize;
self.token_embeddings.row(cls_id).to_owned()
}
PoolingStrategy::MeanMax => {
let mean = self.pool_mean(&active, dim);
let max = self.pool_max(&active, dim);
let mut out = Array1::zeros(2 * dim);
out.slice_mut(scirs2_core::ndarray::s![..dim]).assign(&mean);
out.slice_mut(scirs2_core::ndarray::s![dim..]).assign(&max);
out
}
PoolingStrategy::WeightedMean => self.pool_weighted_mean(&active, dim),
};
if self.config.normalize && !matches!(self.config.pooling, PoolingStrategy::MeanMax) {
Ok(l2_normalize(result))
} else if self.config.normalize && matches!(self.config.pooling, PoolingStrategy::MeanMax) {
Ok(l2_normalize(result))
} else {
Ok(result)
}
}
pub fn embed_batch(
&self,
token_ids: &[Vec<u32>],
attention_masks: &[Vec<u32>],
) -> Result<Array2<f64>> {
if token_ids.len() != attention_masks.len() {
return Err(TextError::InvalidInput(format!(
"token_ids batch size ({}) != attention_masks batch size ({})",
token_ids.len(),
attention_masks.len()
)));
}
if token_ids.is_empty() {
return Err(TextError::InvalidInput(
"Batch must contain at least one sequence".to_string(),
));
}
let batch_size = token_ids.len();
let out_dim = match &self.config.pooling {
PoolingStrategy::MeanMax => self.config.dim * 2,
_ => self.config.dim,
};
let mut out = Array2::zeros((batch_size, out_dim));
for (i, (ids, mask)) in token_ids.iter().zip(attention_masks.iter()).enumerate() {
let emb = self.embed_token_ids(ids, mask)?;
out.row_mut(i).assign(&emb);
}
Ok(out)
}
pub fn semantic_similarity(&self, emb_a: &Array1<f64>, emb_b: &Array1<f64>) -> f64 {
cosine_similarity(emb_a, emb_b)
}
pub fn most_similar(
&self,
query: &Array1<f64>,
corpus: &Array2<f64>,
top_k: usize,
) -> Vec<(usize, f64)> {
let n = corpus.nrows();
let mut scores: Vec<(usize, f64)> = (0..n)
.map(|i| {
let row = corpus.row(i).to_owned();
(i, cosine_similarity(query, &row))
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(top_k);
scores
}
fn pool_mean(&self, active: &[(usize, f64)], dim: usize) -> Array1<f64> {
let mut sum = Array1::zeros(dim);
for &(id, _pos) in active {
sum = sum + self.token_embeddings.row(id).to_owned();
}
sum / active.len() as f64
}
fn pool_max(&self, active: &[(usize, f64)], dim: usize) -> Array1<f64> {
let mut result = Array1::from_elem(dim, f64::NEG_INFINITY);
for &(id, _pos) in active {
let row = self.token_embeddings.row(id);
for j in 0..dim {
if row[j] > result[j] {
result[j] = row[j];
}
}
}
result
}
fn pool_weighted_mean(&self, active: &[(usize, f64)], dim: usize) -> Array1<f64> {
let max_pos = active
.iter()
.map(|(_, p)| *p)
.fold(f64::NEG_INFINITY, f64::max)
.max(1.0);
let alpha = std::f64::consts::LN_2 / max_pos;
let mut weighted_sum = Array1::zeros(dim);
let mut total_weight = 0.0f64;
for &(id, pos) in active {
let weight = (-alpha * pos).exp();
weighted_sum = weighted_sum + self.token_embeddings.row(id).to_owned() * weight;
total_weight += weight;
}
if total_weight > 0.0 {
weighted_sum / total_weight
} else {
weighted_sum
}
}
}
pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm_a < f64::EPSILON || norm_b < f64::EPSILON {
0.0
} else {
(dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}
}
fn l2_normalize(mut v: Array1<f64>) -> Array1<f64> {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > f64::EPSILON {
v.mapv_inplace(|x| x / norm);
}
v
}
pub fn l2_norm(v: &Array1<f64>) -> f64 {
v.iter().map(|x| x * x).sum::<f64>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn identity_embeddings(n: usize) -> Array2<f64> {
let mut m = Array2::zeros((n, n));
for i in 0..n {
m[[i, i]] = 1.0;
}
m
}
fn make_config(pooling: PoolingStrategy) -> SentenceEmbedderConfig {
SentenceEmbedderConfig {
pooling,
normalize: false,
dim: 4,
}
}
fn all_ones_mask(n: usize) -> Vec<u32> {
vec![1u32; n]
}
#[test]
fn test_sentence_embedder_mean_pool() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
let ids = vec![0u32, 1];
let mask = all_ones_mask(2);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert!((out[0] - 0.5).abs() < 1e-9);
assert!((out[1] - 0.5).abs() < 1e-9);
assert!((out[2] - 0.0).abs() < 1e-9);
assert!((out[3] - 0.0).abs() < 1e-9);
}
#[test]
fn test_sentence_embedder_mean_pool_with_padding() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
let ids = vec![0u32, 1];
let mask = vec![1u32, 0];
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert!((out[0] - 1.0).abs() < 1e-9);
assert!((out[1] - 0.0).abs() < 1e-9);
}
#[test]
fn test_sentence_embedder_max_pool() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MaxPooling));
let ids = vec![0u32, 1, 2];
let mask = all_ones_mask(3);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert!((out[0] - 1.0).abs() < 1e-9);
assert!((out[1] - 1.0).abs() < 1e-9);
assert!((out[2] - 1.0).abs() < 1e-9);
assert!((out[3] - 0.0).abs() < 1e-9);
}
#[test]
fn test_sentence_embedder_cls() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::ClsToken));
let ids = vec![2u32, 3]; let mask = all_ones_mask(2);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert!((out[0] - 0.0).abs() < 1e-9);
assert!((out[1] - 0.0).abs() < 1e-9);
assert!((out[2] - 1.0).abs() < 1e-9);
assert!((out[3] - 0.0).abs() < 1e-9);
}
#[test]
fn test_sentence_embedder_normalize() {
let emb = identity_embeddings(4);
let mut cfg = make_config(PoolingStrategy::MeanPooling);
cfg.normalize = true;
let embedder = SentenceEmbedder::new(emb, cfg);
let ids = vec![0u32, 1]; let mask = all_ones_mask(2);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
let norm = l2_norm(&out);
assert!((norm - 1.0).abs() < 1e-9, "norm = {}", norm);
}
#[test]
fn test_similarity_identical() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::ClsToken));
let v = array![1.0_f64, 0.0, 0.0, 0.0];
let sim = embedder.semantic_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-9);
}
#[test]
fn test_similarity_orthogonal() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::ClsToken));
let a = array![1.0_f64, 0.0, 0.0, 0.0];
let b = array![0.0_f64, 1.0, 0.0, 0.0];
let sim = embedder.semantic_similarity(&a, &b);
assert!((sim - 0.0).abs() < 1e-9);
}
#[test]
fn test_similarity_opposite() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::ClsToken));
let a = array![1.0_f64, 0.0, 0.0, 0.0];
let b = array![-1.0_f64, 0.0, 0.0, 0.0];
let sim = embedder.semantic_similarity(&a, &b);
assert!((sim - (-1.0)).abs() < 1e-9);
}
#[test]
fn test_sentence_embedder_mean_max() {
let emb = identity_embeddings(4);
let cfg = SentenceEmbedderConfig {
pooling: PoolingStrategy::MeanMax,
normalize: false,
dim: 4,
};
let embedder = SentenceEmbedder::new(emb, cfg);
let ids = vec![0u32, 1];
let mask = all_ones_mask(2);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert_eq!(out.len(), 8);
}
#[test]
fn test_sentence_embedder_weighted_mean_single_token() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::WeightedMean));
let ids = vec![2u32]; let mask = all_ones_mask(1);
let out = embedder.embed_token_ids(&ids, &mask).expect("embed");
assert!((out[2] - 1.0).abs() < 1e-9);
}
#[test]
fn test_embed_batch_shape() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
let token_ids = vec![vec![0u32, 1], vec![2u32, 3]];
let masks = vec![all_ones_mask(2), all_ones_mask(2)];
let out = embedder
.embed_batch(&token_ids, &masks)
.expect("batch embed");
assert_eq!(out.nrows(), 2);
assert_eq!(out.ncols(), 4);
}
#[test]
fn test_most_similar_returns_top_k() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::ClsToken));
let query = array![1.0_f64, 0.0, 0.0, 0.0];
let corpus = Array2::from_shape_vec(
(3, 4),
vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
)
.expect("shape");
let results = embedder.most_similar(&query, &corpus, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); assert!((results[0].1 - 1.0).abs() < 1e-9);
}
#[test]
fn test_embed_empty_sequence_errors() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
assert!(embedder.embed_token_ids(&[], &[]).is_err());
}
#[test]
fn test_embed_mismatched_mask_errors() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
assert!(embedder.embed_token_ids(&[0u32, 1], &[1u32]).is_err());
}
#[test]
fn test_embed_out_of_vocab_errors() {
let emb = identity_embeddings(4);
let embedder = SentenceEmbedder::new(emb, make_config(PoolingStrategy::MeanPooling));
assert!(embedder.embed_token_ids(&[10u32], &[1u32]).is_err());
}
}