Skip to main content

scirs2_text/embeddings/
universal.rs

1//! Universal Sentence Encoder architecture — simplified transformer-based
2//! sentence encoder producing fixed-size embeddings.
3//!
4//! # References
5//! Cer et al. (2018) "Universal Sentence Encoder"
6
7use crate::error::{Result, TextError};
8use std::f64::consts::PI;
9
10// ---------------------------------------------------------------------------
11// Configuration
12// ---------------------------------------------------------------------------
13
14/// Pooling strategy for aggregating token representations into a sentence vector.
15#[non_exhaustive]
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum UsePooling {
18    /// Average of all token representations.
19    Mean,
20    /// Element-wise maximum over token representations.
21    Max,
22    /// Use the first (CLS) token representation.
23    Cls,
24    /// Learned attention-weighted mean.
25    Attentive,
26}
27
28/// Configuration for the Universal Sentence Encoder-style model.
29#[non_exhaustive]
30#[derive(Debug, Clone)]
31pub struct UseConfig {
32    /// Embedding / hidden dimension.
33    pub d_model: usize,
34    /// Number of attention heads.
35    pub n_heads: usize,
36    /// Number of transformer encoder layers.
37    pub n_layers: usize,
38    /// Inner dimension of the position-wise FFN.
39    pub ffn_dim: usize,
40    /// Maximum sequence length.
41    pub max_seq_len: usize,
42    /// Vocabulary size.
43    pub vocab_size: usize,
44    /// Sentence-level pooling strategy.
45    pub pooling: UsePooling,
46}
47
48impl Default for UseConfig {
49    fn default() -> Self {
50        Self {
51            d_model: 128,
52            n_heads: 4,
53            n_layers: 2,
54            ffn_dim: 256,
55            max_seq_len: 512,
56            vocab_size: 30_000,
57            pooling: UsePooling::Mean,
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// Cross-lingual configuration
64// ---------------------------------------------------------------------------
65
66/// Configuration for cross-lingual sentence encoding.
67#[non_exhaustive]
68#[derive(Debug, Clone)]
69pub struct CrossLingualConfig {
70    /// Shared vocabulary size across all languages.
71    pub shared_vocab_size: usize,
72    /// Number of supported languages.
73    pub n_languages: usize,
74    /// Dimension of per-language embedding appended to token embeddings.
75    pub lang_embedding_dim: usize,
76}
77
78impl Default for CrossLingualConfig {
79    fn default() -> Self {
80        Self {
81            shared_vocab_size: 50_000,
82            n_languages: 10,
83            lang_embedding_dim: 16,
84        }
85    }
86}
87
88// ---------------------------------------------------------------------------
89// Deterministic pseudo-random weight generation (LCG)
90// ---------------------------------------------------------------------------
91
92/// Simple linear-congruential "random" number in [-scale, +scale].
93/// Used to produce deterministic dummy weights without any external crates.
94fn lcg_weight(seed: u64, scale: f64) -> f64 {
95    // LCG parameters from Numerical Recipes
96    let v = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
97    let frac = (v >> 11) as f64 / (1u64 << 53) as f64; // [0, 1)
98    (frac * 2.0 - 1.0) * scale
99}
100
101// ---------------------------------------------------------------------------
102// Sinusoidal positional encoding
103// ---------------------------------------------------------------------------
104
105/// Compute sinusoidal positional encoding matrix of shape [seq_len × d_model].
106fn sinusoidal_pe(seq_len: usize, d_model: usize) -> Vec<Vec<f64>> {
107    let mut pe = vec![vec![0.0_f64; d_model]; seq_len];
108    for pos in 0..seq_len {
109        for i in 0..d_model / 2 {
110            let angle = pos as f64 / f64::powf(10_000.0, (2 * i) as f64 / d_model as f64);
111            pe[pos][2 * i] = angle.sin();
112            if 2 * i + 1 < d_model {
113                pe[pos][2 * i + 1] = angle.cos();
114            }
115        }
116    }
117    pe
118}
119
120// ---------------------------------------------------------------------------
121// LayerNorm (simplified, no learnable parameters)
122// ---------------------------------------------------------------------------
123
124fn layer_norm(x: &[f64], eps: f64) -> Vec<f64> {
125    let n = x.len() as f64;
126    let mean = x.iter().sum::<f64>() / n;
127    let var = x.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / n;
128    x.iter().map(|v| (v - mean) / (var + eps).sqrt()).collect()
129}
130
131fn layer_norm_rows(x: &[Vec<f64>]) -> Vec<Vec<f64>> {
132    x.iter().map(|row| layer_norm(row, 1e-5)).collect()
133}
134
135// ---------------------------------------------------------------------------
136// Matrix multiply helpers
137// ---------------------------------------------------------------------------
138
139/// (seq_len × d_in) × (d_in × d_out) → (seq_len × d_out)
140fn matmul_2d(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
141    let seq = a.len();
142    let d_in = b.len();
143    let d_out = if d_in == 0 { 0 } else { b[0].len() };
144    let mut out = vec![vec![0.0_f64; d_out]; seq];
145    for i in 0..seq {
146        for k in 0..d_in {
147            let a_ik = a[i][k];
148            for j in 0..d_out {
149                out[i][j] += a_ik * b[k][j];
150            }
151        }
152    }
153    out
154}
155
156/// (n × m) × (m × p) → (n × p)
157fn matmul_rect(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
158    matmul_2d(a, b)
159}
160
161/// Transpose a 2-D matrix.
162fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
163    if m.is_empty() {
164        return vec![];
165    }
166    let rows = m.len();
167    let cols = m[0].len();
168    let mut out = vec![vec![0.0_f64; rows]; cols];
169    for i in 0..rows {
170        for j in 0..cols {
171            out[j][i] = m[i][j];
172        }
173    }
174    out
175}
176
177/// Add bias (broadcast over seq dim).
178fn add_bias(x: &[Vec<f64>], bias: &[f64]) -> Vec<Vec<f64>> {
179    x.iter()
180        .map(|row| row.iter().zip(bias).map(|(v, b)| v + b).collect())
181        .collect()
182}
183
184/// Element-wise add of two same-shape matrices.
185fn mat_add(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
186    a.iter()
187        .zip(b)
188        .map(|(ra, rb)| ra.iter().zip(rb).map(|(x, y)| x + y).collect())
189        .collect()
190}
191
192// ---------------------------------------------------------------------------
193// Transformer encoder layer
194// ---------------------------------------------------------------------------
195
196/// A single transformer encoder layer (multi-head self-attention + FFN).
197pub struct TransformerEncoderLayer {
198    d_model: usize,
199    n_heads: usize,
200    ffn_dim: usize,
201    // Projection weights for Q, K, V, O — generated deterministically
202    wq: Vec<Vec<f64>>, // d_model × d_model
203    wk: Vec<Vec<f64>>,
204    wv: Vec<Vec<f64>>,
205    wo: Vec<Vec<f64>>,
206    // FFN
207    w1: Vec<Vec<f64>>, // d_model × ffn_dim
208    b1: Vec<f64>,
209    w2: Vec<Vec<f64>>, // ffn_dim × d_model
210    b2: Vec<f64>,
211    /// Query vector for attentive pooling (used only with `UsePooling::Attentive`).
212    pub attn_query: Vec<f64>,
213}
214
215impl TransformerEncoderLayer {
216    /// Create a new encoder layer with deterministic random weights.
217    pub fn new(d_model: usize, n_heads: usize, ffn_dim: usize) -> Self {
218        let scale_attn = 1.0 / (d_model as f64).sqrt();
219        let scale_ffn = 1.0 / (ffn_dim as f64).sqrt();
220
221        let init_matrix = |rows: usize, cols: usize, offset: u64, scale: f64| -> Vec<Vec<f64>> {
222            (0..rows)
223                .map(|r| {
224                    (0..cols)
225                        .map(|c| lcg_weight(offset + (r * cols + c) as u64, scale))
226                        .collect()
227                })
228                .collect()
229        };
230        let init_bias = |len: usize, offset: u64, scale: f64| -> Vec<f64> {
231            (0..len)
232                .map(|i| lcg_weight(offset + i as u64, scale))
233                .collect()
234        };
235
236        let wq = init_matrix(d_model, d_model, 1000, scale_attn);
237        let wk = init_matrix(d_model, d_model, 2000, scale_attn);
238        let wv = init_matrix(d_model, d_model, 3000, scale_attn);
239        let wo = init_matrix(d_model, d_model, 4000, scale_attn);
240        let w1 = init_matrix(d_model, ffn_dim, 5000, scale_ffn);
241        let b1 = init_bias(ffn_dim, 6000, 0.01);
242        let w2 = init_matrix(ffn_dim, d_model, 7000, scale_ffn);
243        let b2 = init_bias(d_model, 8000, 0.01);
244        let attn_query = init_bias(d_model, 9000, scale_attn);
245
246        Self {
247            d_model,
248            n_heads,
249            ffn_dim,
250            wq,
251            wk,
252            wv,
253            wo,
254            w1,
255            b1,
256            w2,
257            b2,
258            attn_query,
259        }
260    }
261
262    /// Multi-head self-attention with scaled dot-product.
263    ///
264    /// `x`: \[seq_len × d_model\]
265    /// Returns \[seq_len × d_model\]
266    pub fn self_attention(
267        &self,
268        x: &[Vec<f64>],
269        mask: Option<&[Vec<bool>]>,
270    ) -> Result<Vec<Vec<f64>>> {
271        let seq_len = x.len();
272        if seq_len == 0 {
273            return Err(TextError::InvalidInput(
274                "self_attention: empty sequence".into(),
275            ));
276        }
277        let d_head = self.d_model / self.n_heads;
278        if d_head == 0 {
279            return Err(TextError::InvalidInput("d_model must be >= n_heads".into()));
280        }
281
282        let q = matmul_2d(x, &self.wq); // seq × d_model
283        let k = matmul_2d(x, &self.wk);
284        let v = matmul_2d(x, &self.wv);
285
286        let scale = 1.0 / (d_head as f64).sqrt();
287
288        let mut concat_heads = vec![vec![0.0_f64; self.d_model]; seq_len];
289
290        for h in 0..self.n_heads {
291            let h_start = h * d_head;
292            let h_end = h_start + d_head;
293
294            // Extract head slices
295            let q_h: Vec<Vec<f64>> = q.iter().map(|row| row[h_start..h_end].to_vec()).collect();
296            let k_h: Vec<Vec<f64>> = k.iter().map(|row| row[h_start..h_end].to_vec()).collect();
297            let v_h: Vec<Vec<f64>> = v.iter().map(|row| row[h_start..h_end].to_vec()).collect();
298
299            // scores = Q × K^T  [seq × seq]
300            let kt = transpose(&k_h);
301            let scores_raw = matmul_rect(&q_h, &kt);
302
303            // Apply mask & scale, then softmax per row
304            let mut attn_weights = vec![vec![0.0_f64; seq_len]; seq_len];
305            for i in 0..seq_len {
306                let mut row = vec![0.0_f64; seq_len];
307                for j in 0..seq_len {
308                    let masked = mask.is_some_and(|m| m[i][j]);
309                    row[j] = if masked {
310                        f64::NEG_INFINITY
311                    } else {
312                        scores_raw[i][j] * scale
313                    };
314                }
315                // softmax
316                let max_v = row.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
317                let exps: Vec<f64> = row.iter().map(|v| (v - max_v).exp()).collect();
318                let sum_exp: f64 = exps.iter().sum();
319                let sum_exp = if sum_exp < 1e-12 { 1e-12 } else { sum_exp };
320                for j in 0..seq_len {
321                    attn_weights[i][j] = exps[j] / sum_exp;
322                }
323            }
324
325            // context = attn_weights × V_h  [seq × d_head]
326            let ctx = matmul_rect(&attn_weights, &v_h);
327
328            for i in 0..seq_len {
329                for j in 0..d_head {
330                    concat_heads[i][h_start + j] = ctx[i][j];
331                }
332            }
333        }
334
335        // output projection
336        let out = matmul_2d(&concat_heads, &self.wo);
337        Ok(out)
338    }
339
340    /// Position-wise feed-forward: Linear(ReLU(Linear(x))).
341    pub fn ffn(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
342        if x.is_empty() {
343            return Err(TextError::InvalidInput("ffn: empty input".into()));
344        }
345        // hidden = x × W1 + b1, ReLU
346        let h = add_bias(&matmul_2d(x, &self.w1), &self.b1);
347        let h_relu: Vec<Vec<f64>> = h
348            .iter()
349            .map(|row| row.iter().map(|v| v.max(0.0)).collect())
350            .collect();
351        // out = h_relu × W2 + b2
352        let out = add_bias(&matmul_2d(&h_relu, &self.w2), &self.b2);
353        Ok(out)
354    }
355
356    /// Full encoder layer forward pass with residual connections and layer norm.
357    pub fn forward(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
358        // 1. Self-attention sublayer: LN(x + SA(x))
359        let sa_out = self.self_attention(x, None)?;
360        let x1 = layer_norm_rows(&mat_add(x, &sa_out));
361
362        // 2. FFN sublayer: LN(x + FFN(x))
363        let ffn_out = self.ffn(&x1)?;
364        let x2 = layer_norm_rows(&mat_add(&x1, &ffn_out));
365
366        Ok(x2)
367    }
368}
369
370// ---------------------------------------------------------------------------
371// Universal Sentence Encoder
372// ---------------------------------------------------------------------------
373
374/// Universal Sentence Encoder — transformer-based architecture.
375pub struct UniversalSentenceEncoder {
376    /// Model configuration.
377    pub config: UseConfig,
378    layers: Vec<TransformerEncoderLayer>,
379    /// Token embedding table [vocab_size × d_model].
380    token_embeddings: Vec<Vec<f64>>,
381}
382
383impl UniversalSentenceEncoder {
384    /// Construct a new USE with the given configuration.
385    pub fn new(config: UseConfig) -> Self {
386        let scale = 1.0 / (config.d_model as f64).sqrt();
387        let token_embeddings: Vec<Vec<f64>> = (0..config.vocab_size)
388            .map(|tok| {
389                (0..config.d_model)
390                    .map(|dim| lcg_weight((tok * config.d_model + dim) as u64 + 100_000, scale))
391                    .collect()
392            })
393            .collect();
394
395        let layers = (0..config.n_layers)
396            .map(|l| {
397                // Offset seed per layer so each layer gets distinct weights
398                let _offset = l as u64 * 1_000_000;
399                TransformerEncoderLayer::new(config.d_model, config.n_heads, config.ffn_dim)
400            })
401            .collect();
402
403        Self {
404            config,
405            layers,
406            token_embeddings,
407        }
408    }
409
410    /// Lookup + add sinusoidal positional encoding for token IDs.
411    fn embed(&self, token_ids: &[usize]) -> Result<Vec<Vec<f64>>> {
412        let seq_len = token_ids.len().min(self.config.max_seq_len);
413        if seq_len == 0 {
414            return Err(TextError::InvalidInput(
415                "encode: token_ids must not be empty".into(),
416            ));
417        }
418        let pe = sinusoidal_pe(seq_len, self.config.d_model);
419        let embedded: Result<Vec<Vec<f64>>> = token_ids[..seq_len]
420            .iter()
421            .enumerate()
422            .map(|(pos, &tok_id)| {
423                if tok_id >= self.config.vocab_size {
424                    return Err(TextError::InvalidInput(format!(
425                        "token_id {} out of range (vocab_size={})",
426                        tok_id, self.config.vocab_size
427                    )));
428                }
429                let emb = &self.token_embeddings[tok_id];
430                Ok(emb.iter().zip(&pe[pos]).map(|(e, p)| e + p).collect())
431            })
432            .collect();
433        embedded
434    }
435
436    /// Pool the final hidden states into a single sentence vector.
437    fn pool(&self, hidden: &[Vec<f64>]) -> Vec<f64> {
438        match self.config.pooling {
439            UsePooling::Mean => {
440                let n = hidden.len() as f64;
441                let d = hidden[0].len();
442                let mut out = vec![0.0_f64; d];
443                for row in hidden {
444                    for (i, v) in row.iter().enumerate() {
445                        out[i] += v;
446                    }
447                }
448                out.iter_mut().for_each(|v| *v /= n);
449                out
450            }
451            UsePooling::Max => {
452                let d = hidden[0].len();
453                let mut out = vec![f64::NEG_INFINITY; d];
454                for row in hidden {
455                    for (i, v) in row.iter().enumerate() {
456                        if *v > out[i] {
457                            out[i] = *v;
458                        }
459                    }
460                }
461                out
462            }
463            UsePooling::Cls => hidden[0].clone(),
464            UsePooling::Attentive => {
465                // Soft attention scores using first layer's query vector
466                let query = if self.layers.is_empty() {
467                    vec![1.0_f64; hidden[0].len()]
468                } else {
469                    self.layers[0].attn_query.clone()
470                };
471                let d = hidden[0].len();
472                let scores: Vec<f64> = hidden
473                    .iter()
474                    .map(|row| row.iter().zip(&query).map(|(v, q)| v * q).sum::<f64>())
475                    .collect();
476                let max_s = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
477                let exps: Vec<f64> = scores.iter().map(|s| (s - max_s).exp()).collect();
478                let sum_exp: f64 = exps.iter().sum::<f64>().max(1e-12);
479                let weights: Vec<f64> = exps.iter().map(|e| e / sum_exp).collect();
480
481                let mut out = vec![0.0_f64; d];
482                for (row, w) in hidden.iter().zip(&weights) {
483                    for (i, v) in row.iter().enumerate() {
484                        out[i] += v * w;
485                    }
486                }
487                out
488            }
489        }
490    }
491
492    /// Encode token IDs to a fixed-size sentence embedding of length `d_model`.
493    pub fn encode(&self, token_ids: &[usize]) -> Result<Vec<f64>> {
494        let mut x = self.embed(token_ids)?;
495        for layer in &self.layers {
496            x = layer.forward(&x)?;
497        }
498        Ok(self.pool(&x))
499    }
500
501    /// Encode a batch of token ID sequences.
502    pub fn encode_batch(&self, batch: &[Vec<usize>]) -> Result<Vec<Vec<f64>>> {
503        batch.iter().map(|ids| self.encode(ids)).collect()
504    }
505
506    /// Cosine similarity between two sentence embeddings.
507    pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
508        let dot: f64 = a.iter().zip(b).map(|(x, y)| x * y).sum();
509        let na: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
510        let nb: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
511        if na < 1e-12 || nb < 1e-12 {
512            0.0
513        } else {
514            (dot / (na * nb)).clamp(-1.0, 1.0)
515        }
516    }
517
518    /// Cross-lingual encoding — adds a language embedding before the stack.
519    ///
520    /// `token_ids`: vocabulary indices
521    /// `lang_id`: 0-indexed language identifier
522    /// `xl_config`: cross-lingual configuration
523    pub fn cross_lingual_encode(
524        &self,
525        token_ids: &[usize],
526        lang_id: usize,
527        xl_config: &CrossLingualConfig,
528    ) -> Result<Vec<f64>> {
529        if lang_id >= xl_config.n_languages {
530            return Err(TextError::InvalidInput(format!(
531                "lang_id {} >= n_languages {}",
532                lang_id, xl_config.n_languages
533            )));
534        }
535        // Build a language embedding vector of length lang_embedding_dim, then
536        // tile / truncate to d_model and add to token embeddings.
537        let d = self.config.d_model;
538        let ld = xl_config.lang_embedding_dim;
539        let lang_emb_raw: Vec<f64> = (0..ld)
540            .map(|i| {
541                // sinusoidal language embedding
542                let angle = lang_id as f64 / f64::powf(100.0, (2 * i) as f64 / ld as f64);
543                if i % 2 == 0 {
544                    angle.sin()
545                } else {
546                    angle.cos()
547                }
548            })
549            .collect();
550        // Tile to d_model
551        let lang_emb: Vec<f64> = (0..d).map(|i| lang_emb_raw[i % ld]).collect();
552
553        let mut x = self.embed(token_ids)?;
554        // Add language embedding to every token
555        for row in x.iter_mut() {
556            for (j, v) in row.iter_mut().enumerate() {
557                *v += lang_emb[j];
558            }
559        }
560        for layer in &self.layers {
561            x = layer.forward(&x)?;
562        }
563        Ok(self.pool(&x))
564    }
565}
566
567// ---------------------------------------------------------------------------
568// Tests
569// ---------------------------------------------------------------------------
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    fn make_use() -> UniversalSentenceEncoder {
576        UniversalSentenceEncoder::new(UseConfig::default())
577    }
578
579    #[test]
580    fn test_default_config() {
581        let cfg = UseConfig::default();
582        assert_eq!(cfg.d_model, 128);
583        assert_eq!(cfg.n_heads, 4);
584        assert_eq!(cfg.n_layers, 2);
585        assert_eq!(cfg.ffn_dim, 256);
586        assert_eq!(cfg.pooling, UsePooling::Mean);
587    }
588
589    #[test]
590    fn test_encode_output_size() {
591        let use_model = make_use();
592        let ids = vec![1, 2, 3, 4, 5];
593        let emb = use_model.encode(&ids).expect("encode failed");
594        assert_eq!(emb.len(), 128, "embedding must have d_model dimensions");
595    }
596
597    #[test]
598    fn test_cosine_similarity_identical() {
599        let v = vec![1.0_f64, 2.0, 3.0, 4.0];
600        let sim = UniversalSentenceEncoder::cosine_similarity(&v, &v);
601        assert!((sim - 1.0).abs() < 1e-9, "identical vectors → sim = 1.0");
602    }
603
604    #[test]
605    fn test_cosine_similarity_orthogonal() {
606        let a = vec![1.0_f64, 0.0];
607        let b = vec![0.0_f64, 1.0];
608        let sim = UniversalSentenceEncoder::cosine_similarity(&a, &b);
609        assert!(sim.abs() < 1e-9, "orthogonal vectors → sim ≈ 0.0");
610    }
611
612    #[test]
613    fn test_batch_consistent_with_single() {
614        let use_model = make_use();
615        let ids1 = vec![1_usize, 2, 3];
616        let ids2 = vec![4_usize, 5];
617        let batch = use_model
618            .encode_batch(&[ids1.clone(), ids2.clone()])
619            .expect("batch failed");
620        let single1 = use_model.encode(&ids1).expect("single encode 1 failed");
621        let single2 = use_model.encode(&ids2).expect("single encode 2 failed");
622        for (a, b) in batch[0].iter().zip(&single1) {
623            assert!((a - b).abs() < 1e-12, "batch[0] must equal single encode");
624        }
625        for (a, b) in batch[1].iter().zip(&single2) {
626            assert!((a - b).abs() < 1e-12, "batch[1] must equal single encode");
627        }
628    }
629
630    #[test]
631    fn test_cross_lingual_config_defaults() {
632        let cfg = CrossLingualConfig::default();
633        assert_eq!(cfg.shared_vocab_size, 50_000);
634        assert_eq!(cfg.n_languages, 10);
635        assert_eq!(cfg.lang_embedding_dim, 16);
636    }
637
638    #[test]
639    fn test_cross_lingual_encode_output_size() {
640        let use_model = make_use();
641        let xl = CrossLingualConfig::default();
642        let emb = use_model
643            .cross_lingual_encode(&[1, 2, 3], 0, &xl)
644            .expect("cross-lingual encode failed");
645        assert_eq!(emb.len(), 128);
646    }
647
648    #[test]
649    fn test_encode_different_inputs_differ() {
650        // Use a model with n_layers=0 so we get only the embedding + positional
651        // encoding, without layernorm collapse from the transformer stack.
652        let cfg = UseConfig {
653            n_layers: 0,
654            ..UseConfig::default()
655        };
656        let use_model = UniversalSentenceEncoder::new(cfg);
657        let emb1 = use_model.encode(&[1, 2, 3]).unwrap();
658        let emb2 = use_model.encode(&[100, 200, 300]).unwrap();
659        // The two embeddings should not be element-wise identical
660        let all_eq = emb1.iter().zip(&emb2).all(|(a, b)| (a - b).abs() < 1e-12);
661        assert!(
662            !all_eq,
663            "different token inputs should produce numerically distinct embeddings"
664        );
665    }
666
667    #[test]
668    fn test_sinusoidal_pe_shape() {
669        let pe = sinusoidal_pe(10, 128);
670        assert_eq!(pe.len(), 10);
671        assert_eq!(pe[0].len(), 128);
672    }
673
674    #[test]
675    fn test_max_pooling() {
676        let cfg = UseConfig {
677            pooling: UsePooling::Max,
678            n_layers: 1,
679            ..UseConfig::default()
680        };
681        let m = UniversalSentenceEncoder::new(cfg);
682        let emb = m.encode(&[1, 2, 3]).unwrap();
683        assert_eq!(emb.len(), 128);
684    }
685
686    #[test]
687    fn test_cls_pooling() {
688        let cfg = UseConfig {
689            pooling: UsePooling::Cls,
690            n_layers: 1,
691            ..UseConfig::default()
692        };
693        let m = UniversalSentenceEncoder::new(cfg);
694        let emb = m.encode(&[0, 1, 2]).unwrap();
695        assert_eq!(emb.len(), 128);
696    }
697
698    #[test]
699    fn test_attentive_pooling() {
700        let cfg = UseConfig {
701            pooling: UsePooling::Attentive,
702            n_layers: 1,
703            ..UseConfig::default()
704        };
705        let m = UniversalSentenceEncoder::new(cfg);
706        let emb = m.encode(&[5, 6, 7]).unwrap();
707        assert_eq!(emb.len(), 128);
708    }
709}