use crate::error::{Result, TextError};
use std::f64::consts::PI;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum UsePooling {
Mean,
Max,
Cls,
Attentive,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct UseConfig {
pub d_model: usize,
pub n_heads: usize,
pub n_layers: usize,
pub ffn_dim: usize,
pub max_seq_len: usize,
pub vocab_size: usize,
pub pooling: UsePooling,
}
impl Default for UseConfig {
fn default() -> Self {
Self {
d_model: 128,
n_heads: 4,
n_layers: 2,
ffn_dim: 256,
max_seq_len: 512,
vocab_size: 30_000,
pooling: UsePooling::Mean,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct CrossLingualConfig {
pub shared_vocab_size: usize,
pub n_languages: usize,
pub lang_embedding_dim: usize,
}
impl Default for CrossLingualConfig {
fn default() -> Self {
Self {
shared_vocab_size: 50_000,
n_languages: 10,
lang_embedding_dim: 16,
}
}
}
fn lcg_weight(seed: u64, scale: f64) -> f64 {
let v = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
let frac = (v >> 11) as f64 / (1u64 << 53) as f64; (frac * 2.0 - 1.0) * scale
}
fn sinusoidal_pe(seq_len: usize, d_model: usize) -> Vec<Vec<f64>> {
let mut pe = vec![vec![0.0_f64; d_model]; seq_len];
for pos in 0..seq_len {
for i in 0..d_model / 2 {
let angle = pos as f64 / f64::powf(10_000.0, (2 * i) as f64 / d_model as f64);
pe[pos][2 * i] = angle.sin();
if 2 * i + 1 < d_model {
pe[pos][2 * i + 1] = angle.cos();
}
}
}
pe
}
fn layer_norm(x: &[f64], eps: f64) -> Vec<f64> {
let n = x.len() as f64;
let mean = x.iter().sum::<f64>() / n;
let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
x.iter().map(|v| (v - mean) / (var + eps).sqrt()).collect()
}
fn layer_norm_rows(x: &[Vec<f64>]) -> Vec<Vec<f64>> {
x.iter().map(|row| layer_norm(row, 1e-5)).collect()
}
fn matmul_2d(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let seq = a.len();
let d_in = b.len();
let d_out = if d_in == 0 { 0 } else { b[0].len() };
let mut out = vec![vec![0.0_f64; d_out]; seq];
for i in 0..seq {
for k in 0..d_in {
let a_ik = a[i][k];
for j in 0..d_out {
out[i][j] += a_ik * b[k][j];
}
}
}
out
}
fn matmul_rect(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
matmul_2d(a, b)
}
fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
if m.is_empty() {
return vec![];
}
let rows = m.len();
let cols = m[0].len();
let mut out = vec![vec![0.0_f64; rows]; cols];
for i in 0..rows {
for j in 0..cols {
out[j][i] = m[i][j];
}
}
out
}
fn add_bias(x: &[Vec<f64>], bias: &[f64]) -> Vec<Vec<f64>> {
x.iter()
.map(|row| row.iter().zip(bias).map(|(v, b)| v + b).collect())
.collect()
}
fn mat_add(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
a.iter()
.zip(b)
.map(|(ra, rb)| ra.iter().zip(rb).map(|(x, y)| x + y).collect())
.collect()
}
pub struct TransformerEncoderLayer {
d_model: usize,
n_heads: usize,
ffn_dim: usize,
wq: Vec<Vec<f64>>, wk: Vec<Vec<f64>>,
wv: Vec<Vec<f64>>,
wo: Vec<Vec<f64>>,
w1: Vec<Vec<f64>>, b1: Vec<f64>,
w2: Vec<Vec<f64>>, b2: Vec<f64>,
pub attn_query: Vec<f64>,
}
impl TransformerEncoderLayer {
pub fn new(d_model: usize, n_heads: usize, ffn_dim: usize) -> Self {
let scale_attn = 1.0 / (d_model as f64).sqrt();
let scale_ffn = 1.0 / (ffn_dim as f64).sqrt();
let init_matrix = |rows: usize, cols: usize, offset: u64, scale: f64| -> Vec<Vec<f64>> {
(0..rows)
.map(|r| {
(0..cols)
.map(|c| lcg_weight(offset + (r * cols + c) as u64, scale))
.collect()
})
.collect()
};
let init_bias = |len: usize, offset: u64, scale: f64| -> Vec<f64> {
(0..len)
.map(|i| lcg_weight(offset + i as u64, scale))
.collect()
};
let wq = init_matrix(d_model, d_model, 1000, scale_attn);
let wk = init_matrix(d_model, d_model, 2000, scale_attn);
let wv = init_matrix(d_model, d_model, 3000, scale_attn);
let wo = init_matrix(d_model, d_model, 4000, scale_attn);
let w1 = init_matrix(d_model, ffn_dim, 5000, scale_ffn);
let b1 = init_bias(ffn_dim, 6000, 0.01);
let w2 = init_matrix(ffn_dim, d_model, 7000, scale_ffn);
let b2 = init_bias(d_model, 8000, 0.01);
let attn_query = init_bias(d_model, 9000, scale_attn);
Self {
d_model,
n_heads,
ffn_dim,
wq,
wk,
wv,
wo,
w1,
b1,
w2,
b2,
attn_query,
}
}
pub fn self_attention(
&self,
x: &[Vec<f64>],
mask: Option<&[Vec<bool>]>,
) -> Result<Vec<Vec<f64>>> {
let seq_len = x.len();
if seq_len == 0 {
return Err(TextError::InvalidInput(
"self_attention: empty sequence".into(),
));
}
let d_head = self.d_model / self.n_heads;
if d_head == 0 {
return Err(TextError::InvalidInput("d_model must be >= n_heads".into()));
}
let q = matmul_2d(x, &self.wq); let k = matmul_2d(x, &self.wk);
let v = matmul_2d(x, &self.wv);
let scale = 1.0 / (d_head as f64).sqrt();
let mut concat_heads = vec![vec![0.0_f64; self.d_model]; seq_len];
for h in 0..self.n_heads {
let h_start = h * d_head;
let h_end = h_start + d_head;
let q_h: Vec<Vec<f64>> = q.iter().map(|row| row[h_start..h_end].to_vec()).collect();
let k_h: Vec<Vec<f64>> = k.iter().map(|row| row[h_start..h_end].to_vec()).collect();
let v_h: Vec<Vec<f64>> = v.iter().map(|row| row[h_start..h_end].to_vec()).collect();
let kt = transpose(&k_h);
let scores_raw = matmul_rect(&q_h, &kt);
let mut attn_weights = vec![vec![0.0_f64; seq_len]; seq_len];
for i in 0..seq_len {
let mut row = vec![0.0_f64; seq_len];
for j in 0..seq_len {
let masked = mask.is_some_and(|m| m[i][j]);
row[j] = if masked {
f64::NEG_INFINITY
} else {
scores_raw[i][j] * scale
};
}
let max_v = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = row.iter().map(|v| (v - max_v).exp()).collect();
let sum_exp: f64 = exps.iter().sum();
let sum_exp = if sum_exp < 1e-12 { 1e-12 } else { sum_exp };
for j in 0..seq_len {
attn_weights[i][j] = exps[j] / sum_exp;
}
}
let ctx = matmul_rect(&attn_weights, &v_h);
for i in 0..seq_len {
for j in 0..d_head {
concat_heads[i][h_start + j] = ctx[i][j];
}
}
}
let out = matmul_2d(&concat_heads, &self.wo);
Ok(out)
}
pub fn ffn(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if x.is_empty() {
return Err(TextError::InvalidInput("ffn: empty input".into()));
}
let h = add_bias(&matmul_2d(x, &self.w1), &self.b1);
let h_relu: Vec<Vec<f64>> = h
.iter()
.map(|row| row.iter().map(|v| v.max(0.0)).collect())
.collect();
let out = add_bias(&matmul_2d(&h_relu, &self.w2), &self.b2);
Ok(out)
}
pub fn forward(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let sa_out = self.self_attention(x, None)?;
let x1 = layer_norm_rows(&mat_add(x, &sa_out));
let ffn_out = self.ffn(&x1)?;
let x2 = layer_norm_rows(&mat_add(&x1, &ffn_out));
Ok(x2)
}
}
pub struct UniversalSentenceEncoder {
pub config: UseConfig,
layers: Vec<TransformerEncoderLayer>,
token_embeddings: Vec<Vec<f64>>,
}
impl UniversalSentenceEncoder {
pub fn new(config: UseConfig) -> Self {
let scale = 1.0 / (config.d_model as f64).sqrt();
let token_embeddings: Vec<Vec<f64>> = (0..config.vocab_size)
.map(|tok| {
(0..config.d_model)
.map(|dim| lcg_weight((tok * config.d_model + dim) as u64 + 100_000, scale))
.collect()
})
.collect();
let layers = (0..config.n_layers)
.map(|l| {
let _offset = l as u64 * 1_000_000;
TransformerEncoderLayer::new(config.d_model, config.n_heads, config.ffn_dim)
})
.collect();
Self {
config,
layers,
token_embeddings,
}
}
fn embed(&self, token_ids: &[usize]) -> Result<Vec<Vec<f64>>> {
let seq_len = token_ids.len().min(self.config.max_seq_len);
if seq_len == 0 {
return Err(TextError::InvalidInput(
"encode: token_ids must not be empty".into(),
));
}
let pe = sinusoidal_pe(seq_len, self.config.d_model);
let embedded: Result<Vec<Vec<f64>>> = token_ids[..seq_len]
.iter()
.enumerate()
.map(|(pos, &tok_id)| {
if tok_id >= self.config.vocab_size {
return Err(TextError::InvalidInput(format!(
"token_id {} out of range (vocab_size={})",
tok_id, self.config.vocab_size
)));
}
let emb = &self.token_embeddings[tok_id];
Ok(emb.iter().zip(&pe[pos]).map(|(e, p)| e + p).collect())
})
.collect();
embedded
}
fn pool(&self, hidden: &[Vec<f64>]) -> Vec<f64> {
match self.config.pooling {
UsePooling::Mean => {
let n = hidden.len() as f64;
let d = hidden[0].len();
let mut out = vec![0.0_f64; d];
for row in hidden {
for (i, v) in row.iter().enumerate() {
out[i] += v;
}
}
out.iter_mut().for_each(|v| *v /= n);
out
}
UsePooling::Max => {
let d = hidden[0].len();
let mut out = vec![f64::NEG_INFINITY; d];
for row in hidden {
for (i, v) in row.iter().enumerate() {
if *v > out[i] {
out[i] = *v;
}
}
}
out
}
UsePooling::Cls => hidden[0].clone(),
UsePooling::Attentive => {
let query = if self.layers.is_empty() {
vec![1.0_f64; hidden[0].len()]
} else {
self.layers[0].attn_query.clone()
};
let d = hidden[0].len();
let scores: Vec<f64> = hidden
.iter()
.map(|row| row.iter().zip(&query).map(|(v, q)| v * q).sum::<f64>())
.collect();
let max_s = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = scores.iter().map(|s| (s - max_s).exp()).collect();
let sum_exp: f64 = exps.iter().sum::<f64>().max(1e-12);
let weights: Vec<f64> = exps.iter().map(|e| e / sum_exp).collect();
let mut out = vec![0.0_f64; d];
for (row, w) in hidden.iter().zip(&weights) {
for (i, v) in row.iter().enumerate() {
out[i] += v * w;
}
}
out
}
}
}
pub fn encode(&self, token_ids: &[usize]) -> Result<Vec<f64>> {
let mut x = self.embed(token_ids)?;
for layer in &self.layers {
x = layer.forward(&x)?;
}
Ok(self.pool(&x))
}
pub fn encode_batch(&self, batch: &[Vec<usize>]) -> Result<Vec<Vec<f64>>> {
batch.iter().map(|ids| self.encode(ids)).collect()
}
pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
if na < 1e-12 || nb < 1e-12 {
0.0
} else {
(dot / (na * nb)).clamp(-1.0, 1.0)
}
}
pub fn cross_lingual_encode(
&self,
token_ids: &[usize],
lang_id: usize,
xl_config: &CrossLingualConfig,
) -> Result<Vec<f64>> {
if lang_id >= xl_config.n_languages {
return Err(TextError::InvalidInput(format!(
"lang_id {} >= n_languages {}",
lang_id, xl_config.n_languages
)));
}
let d = self.config.d_model;
let ld = xl_config.lang_embedding_dim;
let lang_emb_raw: Vec<f64> = (0..ld)
.map(|i| {
let angle = lang_id as f64 / f64::powf(100.0, (2 * i) as f64 / ld as f64);
if i % 2 == 0 {
angle.sin()
} else {
angle.cos()
}
})
.collect();
let lang_emb: Vec<f64> = (0..d).map(|i| lang_emb_raw[i % ld]).collect();
let mut x = self.embed(token_ids)?;
for row in x.iter_mut() {
for (j, v) in row.iter_mut().enumerate() {
*v += lang_emb[j];
}
}
for layer in &self.layers {
x = layer.forward(&x)?;
}
Ok(self.pool(&x))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_use() -> UniversalSentenceEncoder {
UniversalSentenceEncoder::new(UseConfig::default())
}
#[test]
fn test_default_config() {
let cfg = UseConfig::default();
assert_eq!(cfg.d_model, 128);
assert_eq!(cfg.n_heads, 4);
assert_eq!(cfg.n_layers, 2);
assert_eq!(cfg.ffn_dim, 256);
assert_eq!(cfg.pooling, UsePooling::Mean);
}
#[test]
fn test_encode_output_size() {
let use_model = make_use();
let ids = vec![1, 2, 3, 4, 5];
let emb = use_model.encode(&ids).expect("encode failed");
assert_eq!(emb.len(), 128, "embedding must have d_model dimensions");
}
#[test]
fn test_cosine_similarity_identical() {
let v = vec![1.0_f64, 2.0, 3.0, 4.0];
let sim = UniversalSentenceEncoder::cosine_similarity(&v, &v);
assert!((sim - 1.0).abs() < 1e-9, "identical vectors → sim = 1.0");
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0_f64, 0.0];
let b = vec![0.0_f64, 1.0];
let sim = UniversalSentenceEncoder::cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-9, "orthogonal vectors → sim ≈ 0.0");
}
#[test]
fn test_batch_consistent_with_single() {
let use_model = make_use();
let ids1 = vec![1_usize, 2, 3];
let ids2 = vec![4_usize, 5];
let batch = use_model
.encode_batch(&[ids1.clone(), ids2.clone()])
.expect("batch failed");
let single1 = use_model.encode(&ids1).expect("single encode 1 failed");
let single2 = use_model.encode(&ids2).expect("single encode 2 failed");
for (a, b) in batch[0].iter().zip(&single1) {
assert!((a - b).abs() < 1e-12, "batch[0] must equal single encode");
}
for (a, b) in batch[1].iter().zip(&single2) {
assert!((a - b).abs() < 1e-12, "batch[1] must equal single encode");
}
}
#[test]
fn test_cross_lingual_config_defaults() {
let cfg = CrossLingualConfig::default();
assert_eq!(cfg.shared_vocab_size, 50_000);
assert_eq!(cfg.n_languages, 10);
assert_eq!(cfg.lang_embedding_dim, 16);
}
#[test]
fn test_cross_lingual_encode_output_size() {
let use_model = make_use();
let xl = CrossLingualConfig::default();
let emb = use_model
.cross_lingual_encode(&[1, 2, 3], 0, &xl)
.expect("cross-lingual encode failed");
assert_eq!(emb.len(), 128);
}
#[test]
fn test_encode_different_inputs_differ() {
let cfg = UseConfig {
n_layers: 0,
..UseConfig::default()
};
let use_model = UniversalSentenceEncoder::new(cfg);
let emb1 = use_model.encode(&[1, 2, 3]).unwrap();
let emb2 = use_model.encode(&[100, 200, 300]).unwrap();
let all_eq = emb1.iter().zip(&emb2).all(|(a, b)| (a - b).abs() < 1e-12);
assert!(
!all_eq,
"different token inputs should produce numerically distinct embeddings"
);
}
#[test]
fn test_sinusoidal_pe_shape() {
let pe = sinusoidal_pe(10, 128);
assert_eq!(pe.len(), 10);
assert_eq!(pe[0].len(), 128);
}
#[test]
fn test_max_pooling() {
let cfg = UseConfig {
pooling: UsePooling::Max,
n_layers: 1,
..UseConfig::default()
};
let m = UniversalSentenceEncoder::new(cfg);
let emb = m.encode(&[1, 2, 3]).unwrap();
assert_eq!(emb.len(), 128);
}
#[test]
fn test_cls_pooling() {
let cfg = UseConfig {
pooling: UsePooling::Cls,
n_layers: 1,
..UseConfig::default()
};
let m = UniversalSentenceEncoder::new(cfg);
let emb = m.encode(&[0, 1, 2]).unwrap();
assert_eq!(emb.len(), 128);
}
#[test]
fn test_attentive_pooling() {
let cfg = UseConfig {
pooling: UsePooling::Attentive,
n_layers: 1,
..UseConfig::default()
};
let m = UniversalSentenceEncoder::new(cfg);
let emb = m.encode(&[5, 6, 7]).unwrap();
assert_eq!(emb.len(), 128);
}
}