use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
#[derive(Debug, Clone, PartialEq)]
pub enum UniversalPoolingStrategy {
ClsToken,
Mean,
Max,
MeanSqrt,
AttentionPooling,
WeightedMean,
}
pub struct UniversalSentenceEncoder {
pub token_embeddings: Array2<f32>,
pub pooling: UniversalPoolingStrategy,
pub d_model: usize,
pub normalize_output: bool,
attention_query: Option<Array1<f32>>,
idf_weights: Option<Array1<f32>>,
}
impl UniversalSentenceEncoder {
pub fn new(
token_embeddings: Array2<f32>,
pooling: UniversalPoolingStrategy,
normalize_output: bool,
) -> Self {
let d_model = token_embeddings.ncols();
UniversalSentenceEncoder {
token_embeddings,
pooling,
d_model,
normalize_output,
attention_query: None,
idf_weights: None,
}
}
pub fn encode(&self, tokens: &[usize]) -> Array1<f32> {
if tokens.is_empty() || self.token_embeddings.nrows() == 0 {
return Array1::zeros(self.d_model);
}
let vocab_size = self.token_embeddings.nrows();
let safe_tokens: Vec<usize> = tokens
.iter()
.map(|&t| t.min(vocab_size.saturating_sub(1)))
.collect();
let result = match &self.pooling {
UniversalPoolingStrategy::ClsToken => {
self.token_embeddings.row(safe_tokens[0]).to_owned()
}
UniversalPoolingStrategy::Mean => self.mean_pool(&safe_tokens),
UniversalPoolingStrategy::Max => self.max_pool(&safe_tokens),
UniversalPoolingStrategy::MeanSqrt => {
let n = safe_tokens.len().max(1) as f32;
self.mean_pool(&safe_tokens).mapv(|v| v / n.sqrt())
}
UniversalPoolingStrategy::AttentionPooling => {
if let Some(q) = &self.attention_query {
self.attention_pool(&safe_tokens, q)
} else {
self.mean_pool(&safe_tokens)
}
}
UniversalPoolingStrategy::WeightedMean => {
if let Some(idf) = &self.idf_weights {
self.weighted_mean_pool(tokens, idf)
} else {
self.mean_pool(&safe_tokens)
}
}
};
if self.normalize_output {
l2_normalize(result)
} else {
result
}
}
pub fn fit_idf_weights(&mut self, corpus: &[Vec<usize>], vocab_size: usize) {
let n = corpus.len() as f32;
let mut df = vec![0u32; vocab_size];
for doc in corpus {
let mut seen = vec![false; vocab_size];
for &t in doc {
if t < vocab_size && !seen[t] {
df[t] += 1;
seen[t] = true;
}
}
}
let idf: Array1<f32> =
Array1::from_iter(df.iter().map(|&d| ((n + 1.0) / (d as f32 + 1.0)).ln()));
self.idf_weights = Some(idf);
}
pub fn fit_attention_pooling(&mut self, corpus: &[Vec<usize>], epochs: usize, lr: f32) {
let vocab_size = self.token_embeddings.nrows();
let mut q = Array1::<f32>::zeros(self.d_model);
for i in 0..vocab_size {
let row = self.token_embeddings.row(i);
for j in 0..self.d_model {
q[j] += row[j];
}
}
if vocab_size > 0 {
q.mapv_inplace(|v| v / vocab_size as f32);
}
let h = 1e-4_f32;
for _epoch in 0..epochs {
for doc in corpus {
if doc.is_empty() {
continue;
}
let safe: Vec<usize> = doc
.iter()
.map(|&t| t.min(vocab_size.saturating_sub(1)))
.collect();
let out0 = self.attention_pool_with_query(&safe, &q);
let mut grad = Array1::<f32>::zeros(self.d_model);
for j in 0..self.d_model {
let mut q_plus = q.clone();
q_plus[j] += h;
let out_plus = self.attention_pool_with_query(&safe, &q_plus);
let loss_plus: f32 = out_plus
.iter()
.zip(out0.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
grad[j] = loss_plus / h;
}
for j in 0..self.d_model {
q[j] -= lr * grad[j];
}
}
}
let norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 {
q.mapv_inplace(|v| v / norm);
}
self.attention_query = Some(q);
}
pub fn idf_weights(&self) -> Option<&Array1<f32>> {
self.idf_weights.as_ref()
}
pub fn attention_query(&self) -> Option<&Array1<f32>> {
self.attention_query.as_ref()
}
fn mean_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
let mut sum = Array1::<f32>::zeros(self.d_model);
for &t in safe_tokens {
let row = self.token_embeddings.row(t);
for j in 0..self.d_model {
sum[j] += row[j];
}
}
let n = safe_tokens.len().max(1) as f32;
sum.mapv(|v| v / n)
}
fn max_pool(&self, safe_tokens: &[usize]) -> Array1<f32> {
let mut result = self.token_embeddings.row(safe_tokens[0]).to_owned();
for &t in &safe_tokens[1..] {
let row = self.token_embeddings.row(t);
for j in 0..self.d_model {
if row[j] > result[j] {
result[j] = row[j];
}
}
}
result
}
fn attention_pool(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
self.attention_pool_with_query(safe_tokens, q)
}
fn attention_pool_with_query(&self, safe_tokens: &[usize], q: &Array1<f32>) -> Array1<f32> {
let n = safe_tokens.len();
let mut scores = vec![0.0f32; n];
for (i, &t) in safe_tokens.iter().enumerate() {
let row = self.token_embeddings.row(t);
scores[i] = row.iter().zip(q.iter()).map(|(a, b)| a * b).sum();
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut exp_scores: Vec<f32> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
if sum_exp > 1e-12 {
exp_scores.iter_mut().for_each(|s| *s /= sum_exp);
} else {
let uniform = 1.0 / n as f32;
exp_scores.iter_mut().for_each(|s| *s = uniform);
}
let mut result = Array1::<f32>::zeros(self.d_model);
for (i, &t) in safe_tokens.iter().enumerate() {
let row = self.token_embeddings.row(t);
let w = exp_scores[i];
for j in 0..self.d_model {
result[j] += w * row[j];
}
}
result
}
fn weighted_mean_pool(&self, tokens: &[usize], idf: &Array1<f32>) -> Array1<f32> {
let vocab_size = self.token_embeddings.nrows();
let idf_len = idf.len();
let mut result = Array1::<f32>::zeros(self.d_model);
let mut total_weight = 0.0f32;
for &t in tokens {
let row_idx = t.min(vocab_size.saturating_sub(1));
let weight = if t < idf_len { idf[t] } else { 1.0f32 };
let row = self.token_embeddings.row(row_idx);
for j in 0..self.d_model {
result[j] += weight * row[j];
}
total_weight += weight;
}
if total_weight > 1e-12 {
result.mapv_inplace(|v| v / total_weight);
}
result
}
pub fn embeddings_view(&self) -> ArrayView2<f32> {
self.token_embeddings.view()
}
}
impl std::fmt::Debug for UniversalSentenceEncoder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UniversalSentenceEncoder")
.field("vocab_size", &self.token_embeddings.nrows())
.field("d_model", &self.d_model)
.field("pooling", &self.pooling)
.field("normalize_output", &self.normalize_output)
.field("has_attention_query", &self.attention_query.is_some())
.field("has_idf_weights", &self.idf_weights.is_some())
.finish()
}
}
fn l2_normalize(mut v: Array1<f32>) -> Array1<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-12 && norm.is_finite() {
v.mapv_inplace(|x| x / norm);
}
v
}