1use crate::error::{Result, TextError};
8use std::f64::consts::PI;
9
10#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum UsePooling {
18 Mean,
20 Max,
22 Cls,
24 Attentive,
26}
27
28#[non_exhaustive]
30#[derive(Debug, Clone)]
31pub struct UseConfig {
32 pub d_model: usize,
34 pub n_heads: usize,
36 pub n_layers: usize,
38 pub ffn_dim: usize,
40 pub max_seq_len: usize,
42 pub vocab_size: usize,
44 pub pooling: UsePooling,
46}
47
48impl Default for UseConfig {
49 fn default() -> Self {
50 Self {
51 d_model: 128,
52 n_heads: 4,
53 n_layers: 2,
54 ffn_dim: 256,
55 max_seq_len: 512,
56 vocab_size: 30_000,
57 pooling: UsePooling::Mean,
58 }
59 }
60}
61
62#[non_exhaustive]
68#[derive(Debug, Clone)]
69pub struct CrossLingualConfig {
70 pub shared_vocab_size: usize,
72 pub n_languages: usize,
74 pub lang_embedding_dim: usize,
76}
77
78impl Default for CrossLingualConfig {
79 fn default() -> Self {
80 Self {
81 shared_vocab_size: 50_000,
82 n_languages: 10,
83 lang_embedding_dim: 16,
84 }
85 }
86}
87
88fn lcg_weight(seed: u64, scale: f64) -> f64 {
95 let v = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
97 let frac = (v >> 11) as f64 / (1u64 << 53) as f64; (frac * 2.0 - 1.0) * scale
99}
100
101fn sinusoidal_pe(seq_len: usize, d_model: usize) -> Vec<Vec<f64>> {
107 let mut pe = vec![vec![0.0_f64; d_model]; seq_len];
108 for pos in 0..seq_len {
109 for i in 0..d_model / 2 {
110 let angle = pos as f64 / f64::powf(10_000.0, (2 * i) as f64 / d_model as f64);
111 pe[pos][2 * i] = angle.sin();
112 if 2 * i + 1 < d_model {
113 pe[pos][2 * i + 1] = angle.cos();
114 }
115 }
116 }
117 pe
118}
119
120fn layer_norm(x: &[f64], eps: f64) -> Vec<f64> {
125 let n = x.len() as f64;
126 let mean = x.iter().sum::<f64>() / n;
127 let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
128 x.iter().map(|v| (v - mean) / (var + eps).sqrt()).collect()
129}
130
131fn layer_norm_rows(x: &[Vec<f64>]) -> Vec<Vec<f64>> {
132 x.iter().map(|row| layer_norm(row, 1e-5)).collect()
133}
134
135fn matmul_2d(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
141 let seq = a.len();
142 let d_in = b.len();
143 let d_out = if d_in == 0 { 0 } else { b[0].len() };
144 let mut out = vec![vec![0.0_f64; d_out]; seq];
145 for i in 0..seq {
146 for k in 0..d_in {
147 let a_ik = a[i][k];
148 for j in 0..d_out {
149 out[i][j] += a_ik * b[k][j];
150 }
151 }
152 }
153 out
154}
155
156fn matmul_rect(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
158 matmul_2d(a, b)
159}
160
161fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
163 if m.is_empty() {
164 return vec![];
165 }
166 let rows = m.len();
167 let cols = m[0].len();
168 let mut out = vec![vec![0.0_f64; rows]; cols];
169 for i in 0..rows {
170 for j in 0..cols {
171 out[j][i] = m[i][j];
172 }
173 }
174 out
175}
176
177fn add_bias(x: &[Vec<f64>], bias: &[f64]) -> Vec<Vec<f64>> {
179 x.iter()
180 .map(|row| row.iter().zip(bias).map(|(v, b)| v + b).collect())
181 .collect()
182}
183
184fn mat_add(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
186 a.iter()
187 .zip(b)
188 .map(|(ra, rb)| ra.iter().zip(rb).map(|(x, y)| x + y).collect())
189 .collect()
190}
191
192pub struct TransformerEncoderLayer {
198 d_model: usize,
199 n_heads: usize,
200 ffn_dim: usize,
201 wq: Vec<Vec<f64>>, wk: Vec<Vec<f64>>,
204 wv: Vec<Vec<f64>>,
205 wo: Vec<Vec<f64>>,
206 w1: Vec<Vec<f64>>, b1: Vec<f64>,
209 w2: Vec<Vec<f64>>, b2: Vec<f64>,
211 pub attn_query: Vec<f64>,
213}
214
215impl TransformerEncoderLayer {
216 pub fn new(d_model: usize, n_heads: usize, ffn_dim: usize) -> Self {
218 let scale_attn = 1.0 / (d_model as f64).sqrt();
219 let scale_ffn = 1.0 / (ffn_dim as f64).sqrt();
220
221 let init_matrix = |rows: usize, cols: usize, offset: u64, scale: f64| -> Vec<Vec<f64>> {
222 (0..rows)
223 .map(|r| {
224 (0..cols)
225 .map(|c| lcg_weight(offset + (r * cols + c) as u64, scale))
226 .collect()
227 })
228 .collect()
229 };
230 let init_bias = |len: usize, offset: u64, scale: f64| -> Vec<f64> {
231 (0..len)
232 .map(|i| lcg_weight(offset + i as u64, scale))
233 .collect()
234 };
235
236 let wq = init_matrix(d_model, d_model, 1000, scale_attn);
237 let wk = init_matrix(d_model, d_model, 2000, scale_attn);
238 let wv = init_matrix(d_model, d_model, 3000, scale_attn);
239 let wo = init_matrix(d_model, d_model, 4000, scale_attn);
240 let w1 = init_matrix(d_model, ffn_dim, 5000, scale_ffn);
241 let b1 = init_bias(ffn_dim, 6000, 0.01);
242 let w2 = init_matrix(ffn_dim, d_model, 7000, scale_ffn);
243 let b2 = init_bias(d_model, 8000, 0.01);
244 let attn_query = init_bias(d_model, 9000, scale_attn);
245
246 Self {
247 d_model,
248 n_heads,
249 ffn_dim,
250 wq,
251 wk,
252 wv,
253 wo,
254 w1,
255 b1,
256 w2,
257 b2,
258 attn_query,
259 }
260 }
261
262 pub fn self_attention(
267 &self,
268 x: &[Vec<f64>],
269 mask: Option<&[Vec<bool>]>,
270 ) -> Result<Vec<Vec<f64>>> {
271 let seq_len = x.len();
272 if seq_len == 0 {
273 return Err(TextError::InvalidInput(
274 "self_attention: empty sequence".into(),
275 ));
276 }
277 let d_head = self.d_model / self.n_heads;
278 if d_head == 0 {
279 return Err(TextError::InvalidInput("d_model must be >= n_heads".into()));
280 }
281
282 let q = matmul_2d(x, &self.wq); let k = matmul_2d(x, &self.wk);
284 let v = matmul_2d(x, &self.wv);
285
286 let scale = 1.0 / (d_head as f64).sqrt();
287
288 let mut concat_heads = vec![vec![0.0_f64; self.d_model]; seq_len];
289
290 for h in 0..self.n_heads {
291 let h_start = h * d_head;
292 let h_end = h_start + d_head;
293
294 let q_h: Vec<Vec<f64>> = q.iter().map(|row| row[h_start..h_end].to_vec()).collect();
296 let k_h: Vec<Vec<f64>> = k.iter().map(|row| row[h_start..h_end].to_vec()).collect();
297 let v_h: Vec<Vec<f64>> = v.iter().map(|row| row[h_start..h_end].to_vec()).collect();
298
299 let kt = transpose(&k_h);
301 let scores_raw = matmul_rect(&q_h, &kt);
302
303 let mut attn_weights = vec![vec![0.0_f64; seq_len]; seq_len];
305 for i in 0..seq_len {
306 let mut row = vec![0.0_f64; seq_len];
307 for j in 0..seq_len {
308 let masked = mask.is_some_and(|m| m[i][j]);
309 row[j] = if masked {
310 f64::NEG_INFINITY
311 } else {
312 scores_raw[i][j] * scale
313 };
314 }
315 let max_v = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
317 let exps: Vec<f64> = row.iter().map(|v| (v - max_v).exp()).collect();
318 let sum_exp: f64 = exps.iter().sum();
319 let sum_exp = if sum_exp < 1e-12 { 1e-12 } else { sum_exp };
320 for j in 0..seq_len {
321 attn_weights[i][j] = exps[j] / sum_exp;
322 }
323 }
324
325 let ctx = matmul_rect(&attn_weights, &v_h);
327
328 for i in 0..seq_len {
329 for j in 0..d_head {
330 concat_heads[i][h_start + j] = ctx[i][j];
331 }
332 }
333 }
334
335 let out = matmul_2d(&concat_heads, &self.wo);
337 Ok(out)
338 }
339
340 pub fn ffn(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
342 if x.is_empty() {
343 return Err(TextError::InvalidInput("ffn: empty input".into()));
344 }
345 let h = add_bias(&matmul_2d(x, &self.w1), &self.b1);
347 let h_relu: Vec<Vec<f64>> = h
348 .iter()
349 .map(|row| row.iter().map(|v| v.max(0.0)).collect())
350 .collect();
351 let out = add_bias(&matmul_2d(&h_relu, &self.w2), &self.b2);
353 Ok(out)
354 }
355
356 pub fn forward(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
358 let sa_out = self.self_attention(x, None)?;
360 let x1 = layer_norm_rows(&mat_add(x, &sa_out));
361
362 let ffn_out = self.ffn(&x1)?;
364 let x2 = layer_norm_rows(&mat_add(&x1, &ffn_out));
365
366 Ok(x2)
367 }
368}
369
370pub struct UniversalSentenceEncoder {
376 pub config: UseConfig,
378 layers: Vec<TransformerEncoderLayer>,
379 token_embeddings: Vec<Vec<f64>>,
381}
382
383impl UniversalSentenceEncoder {
384 pub fn new(config: UseConfig) -> Self {
386 let scale = 1.0 / (config.d_model as f64).sqrt();
387 let token_embeddings: Vec<Vec<f64>> = (0..config.vocab_size)
388 .map(|tok| {
389 (0..config.d_model)
390 .map(|dim| lcg_weight((tok * config.d_model + dim) as u64 + 100_000, scale))
391 .collect()
392 })
393 .collect();
394
395 let layers = (0..config.n_layers)
396 .map(|l| {
397 let _offset = l as u64 * 1_000_000;
399 TransformerEncoderLayer::new(config.d_model, config.n_heads, config.ffn_dim)
400 })
401 .collect();
402
403 Self {
404 config,
405 layers,
406 token_embeddings,
407 }
408 }
409
410 fn embed(&self, token_ids: &[usize]) -> Result<Vec<Vec<f64>>> {
412 let seq_len = token_ids.len().min(self.config.max_seq_len);
413 if seq_len == 0 {
414 return Err(TextError::InvalidInput(
415 "encode: token_ids must not be empty".into(),
416 ));
417 }
418 let pe = sinusoidal_pe(seq_len, self.config.d_model);
419 let embedded: Result<Vec<Vec<f64>>> = token_ids[..seq_len]
420 .iter()
421 .enumerate()
422 .map(|(pos, &tok_id)| {
423 if tok_id >= self.config.vocab_size {
424 return Err(TextError::InvalidInput(format!(
425 "token_id {} out of range (vocab_size={})",
426 tok_id, self.config.vocab_size
427 )));
428 }
429 let emb = &self.token_embeddings[tok_id];
430 Ok(emb.iter().zip(&pe[pos]).map(|(e, p)| e + p).collect())
431 })
432 .collect();
433 embedded
434 }
435
436 fn pool(&self, hidden: &[Vec<f64>]) -> Vec<f64> {
438 match self.config.pooling {
439 UsePooling::Mean => {
440 let n = hidden.len() as f64;
441 let d = hidden[0].len();
442 let mut out = vec![0.0_f64; d];
443 for row in hidden {
444 for (i, v) in row.iter().enumerate() {
445 out[i] += v;
446 }
447 }
448 out.iter_mut().for_each(|v| *v /= n);
449 out
450 }
451 UsePooling::Max => {
452 let d = hidden[0].len();
453 let mut out = vec![f64::NEG_INFINITY; d];
454 for row in hidden {
455 for (i, v) in row.iter().enumerate() {
456 if *v > out[i] {
457 out[i] = *v;
458 }
459 }
460 }
461 out
462 }
463 UsePooling::Cls => hidden[0].clone(),
464 UsePooling::Attentive => {
465 let query = if self.layers.is_empty() {
467 vec![1.0_f64; hidden[0].len()]
468 } else {
469 self.layers[0].attn_query.clone()
470 };
471 let d = hidden[0].len();
472 let scores: Vec<f64> = hidden
473 .iter()
474 .map(|row| row.iter().zip(&query).map(|(v, q)| v * q).sum::<f64>())
475 .collect();
476 let max_s = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
477 let exps: Vec<f64> = scores.iter().map(|s| (s - max_s).exp()).collect();
478 let sum_exp: f64 = exps.iter().sum::<f64>().max(1e-12);
479 let weights: Vec<f64> = exps.iter().map(|e| e / sum_exp).collect();
480
481 let mut out = vec![0.0_f64; d];
482 for (row, w) in hidden.iter().zip(&weights) {
483 for (i, v) in row.iter().enumerate() {
484 out[i] += v * w;
485 }
486 }
487 out
488 }
489 }
490 }
491
492 pub fn encode(&self, token_ids: &[usize]) -> Result<Vec<f64>> {
494 let mut x = self.embed(token_ids)?;
495 for layer in &self.layers {
496 x = layer.forward(&x)?;
497 }
498 Ok(self.pool(&x))
499 }
500
501 pub fn encode_batch(&self, batch: &[Vec<usize>]) -> Result<Vec<Vec<f64>>> {
503 batch.iter().map(|ids| self.encode(ids)).collect()
504 }
505
506 pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
508 let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
509 let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
510 let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
511 if na < 1e-12 || nb < 1e-12 {
512 0.0
513 } else {
514 (dot / (na * nb)).clamp(-1.0, 1.0)
515 }
516 }
517
518 pub fn cross_lingual_encode(
524 &self,
525 token_ids: &[usize],
526 lang_id: usize,
527 xl_config: &CrossLingualConfig,
528 ) -> Result<Vec<f64>> {
529 if lang_id >= xl_config.n_languages {
530 return Err(TextError::InvalidInput(format!(
531 "lang_id {} >= n_languages {}",
532 lang_id, xl_config.n_languages
533 )));
534 }
535 let d = self.config.d_model;
538 let ld = xl_config.lang_embedding_dim;
539 let lang_emb_raw: Vec<f64> = (0..ld)
540 .map(|i| {
541 let angle = lang_id as f64 / f64::powf(100.0, (2 * i) as f64 / ld as f64);
543 if i % 2 == 0 {
544 angle.sin()
545 } else {
546 angle.cos()
547 }
548 })
549 .collect();
550 let lang_emb: Vec<f64> = (0..d).map(|i| lang_emb_raw[i % ld]).collect();
552
553 let mut x = self.embed(token_ids)?;
554 for row in x.iter_mut() {
556 for (j, v) in row.iter_mut().enumerate() {
557 *v += lang_emb[j];
558 }
559 }
560 for layer in &self.layers {
561 x = layer.forward(&x)?;
562 }
563 Ok(self.pool(&x))
564 }
565}
566
567#[cfg(test)]
572mod tests {
573 use super::*;
574
575 fn make_use() -> UniversalSentenceEncoder {
576 UniversalSentenceEncoder::new(UseConfig::default())
577 }
578
579 #[test]
580 fn test_default_config() {
581 let cfg = UseConfig::default();
582 assert_eq!(cfg.d_model, 128);
583 assert_eq!(cfg.n_heads, 4);
584 assert_eq!(cfg.n_layers, 2);
585 assert_eq!(cfg.ffn_dim, 256);
586 assert_eq!(cfg.pooling, UsePooling::Mean);
587 }
588
589 #[test]
590 fn test_encode_output_size() {
591 let use_model = make_use();
592 let ids = vec![1, 2, 3, 4, 5];
593 let emb = use_model.encode(&ids).expect("encode failed");
594 assert_eq!(emb.len(), 128, "embedding must have d_model dimensions");
595 }
596
597 #[test]
598 fn test_cosine_similarity_identical() {
599 let v = vec![1.0_f64, 2.0, 3.0, 4.0];
600 let sim = UniversalSentenceEncoder::cosine_similarity(&v, &v);
601 assert!((sim - 1.0).abs() < 1e-9, "identical vectors → sim = 1.0");
602 }
603
604 #[test]
605 fn test_cosine_similarity_orthogonal() {
606 let a = vec![1.0_f64, 0.0];
607 let b = vec![0.0_f64, 1.0];
608 let sim = UniversalSentenceEncoder::cosine_similarity(&a, &b);
609 assert!(sim.abs() < 1e-9, "orthogonal vectors → sim ≈ 0.0");
610 }
611
612 #[test]
613 fn test_batch_consistent_with_single() {
614 let use_model = make_use();
615 let ids1 = vec![1_usize, 2, 3];
616 let ids2 = vec![4_usize, 5];
617 let batch = use_model
618 .encode_batch(&[ids1.clone(), ids2.clone()])
619 .expect("batch failed");
620 let single1 = use_model.encode(&ids1).expect("single encode 1 failed");
621 let single2 = use_model.encode(&ids2).expect("single encode 2 failed");
622 for (a, b) in batch[0].iter().zip(&single1) {
623 assert!((a - b).abs() < 1e-12, "batch[0] must equal single encode");
624 }
625 for (a, b) in batch[1].iter().zip(&single2) {
626 assert!((a - b).abs() < 1e-12, "batch[1] must equal single encode");
627 }
628 }
629
630 #[test]
631 fn test_cross_lingual_config_defaults() {
632 let cfg = CrossLingualConfig::default();
633 assert_eq!(cfg.shared_vocab_size, 50_000);
634 assert_eq!(cfg.n_languages, 10);
635 assert_eq!(cfg.lang_embedding_dim, 16);
636 }
637
638 #[test]
639 fn test_cross_lingual_encode_output_size() {
640 let use_model = make_use();
641 let xl = CrossLingualConfig::default();
642 let emb = use_model
643 .cross_lingual_encode(&[1, 2, 3], 0, &xl)
644 .expect("cross-lingual encode failed");
645 assert_eq!(emb.len(), 128);
646 }
647
648 #[test]
649 fn test_encode_different_inputs_differ() {
650 let cfg = UseConfig {
653 n_layers: 0,
654 ..UseConfig::default()
655 };
656 let use_model = UniversalSentenceEncoder::new(cfg);
657 let emb1 = use_model.encode(&[1, 2, 3]).unwrap();
658 let emb2 = use_model.encode(&[100, 200, 300]).unwrap();
659 let all_eq = emb1.iter().zip(&emb2).all(|(a, b)| (a - b).abs() < 1e-12);
661 assert!(
662 !all_eq,
663 "different token inputs should produce numerically distinct embeddings"
664 );
665 }
666
667 #[test]
668 fn test_sinusoidal_pe_shape() {
669 let pe = sinusoidal_pe(10, 128);
670 assert_eq!(pe.len(), 10);
671 assert_eq!(pe[0].len(), 128);
672 }
673
674 #[test]
675 fn test_max_pooling() {
676 let cfg = UseConfig {
677 pooling: UsePooling::Max,
678 n_layers: 1,
679 ..UseConfig::default()
680 };
681 let m = UniversalSentenceEncoder::new(cfg);
682 let emb = m.encode(&[1, 2, 3]).unwrap();
683 assert_eq!(emb.len(), 128);
684 }
685
686 #[test]
687 fn test_cls_pooling() {
688 let cfg = UseConfig {
689 pooling: UsePooling::Cls,
690 n_layers: 1,
691 ..UseConfig::default()
692 };
693 let m = UniversalSentenceEncoder::new(cfg);
694 let emb = m.encode(&[0, 1, 2]).unwrap();
695 assert_eq!(emb.len(), 128);
696 }
697
698 #[test]
699 fn test_attentive_pooling() {
700 let cfg = UseConfig {
701 pooling: UsePooling::Attentive,
702 n_layers: 1,
703 ..UseConfig::default()
704 };
705 let m = UniversalSentenceEncoder::new(cfg);
706 let emb = m.encode(&[5, 6, 7]).unwrap();
707 assert_eq!(emb.len(), 128);
708 }
709}