Skip to main content

oxicuda_vision/text/
clip_text.rs

1//! CLIP text encoder — a faithful CPU reference of the Transformer text
2//! tower from Radford et al. 2021, *"Learning Transferable Visual Models From
3//! Natural Language Supervision"* (CLIP).
4//!
5//! The encoder turns a sequence of integer token ids into a single unit-norm
6//! embedding in the joint image-text space:
7//!
8//! ```text
9//! tokens [n_ctx]
10//!   → token_embedding[token]            → [n_ctx, width]
11//!   → + positional_embedding            → [n_ctx, width]
12//!   → N × pre-LN transformer block
13//!        (CAUSAL multi-head self-attention + MLP with GELU)
14//!   → final LayerNorm                   → [n_ctx, width]
15//!   → take row at the EOS / last token  → [width]
16//!   → linear projection                 → [embed_dim]
17//!   → L2-normalise                      → [embed_dim]
18//! ```
19//!
20//! The two architectural details that distinguish the CLIP text tower from a
21//! plain ViT encoder are (1) the **causal attention mask** — a token may only
22//! attend to itself and to earlier positions, exactly as in an autoregressive
23//! language model — and (2) the **EOS pooling** — the joint embedding is read
24//! from the hidden state at the position of the end-of-text token (here, the
25//! highest token id in the sequence, matching CLIP's use of the largest BPE id
26//! `<|endoftext|>`), not from a prepended CLS token.
27//!
28//! All weights are flat row-major `Vec<f32>`; no `unsafe`, no external RNG.
29
30use crate::{
31    error::{VisionError, VisionResult},
32    handle::LcgRng,
33    vit::vit_block::{gelu_exact, layer_norm, linear},
34};
35
36// ─── Config ────────────────────────────────────────────────────────────────────
37
38/// Hyper-parameters for the CLIP text Transformer.
39#[derive(Debug, Clone, PartialEq)]
40pub struct ClipTextConfig {
41    /// Vocabulary size (number of distinct token ids).
42    pub vocab_size: usize,
43    /// Maximum context length (number of positional embeddings).
44    pub n_ctx: usize,
45    /// Transformer width (token embedding / residual-stream dimension).
46    pub width: usize,
47    /// Number of transformer blocks.
48    pub depth: usize,
49    /// Number of attention heads. Must divide `width`.
50    pub n_heads: usize,
51    /// MLP hidden-dim multiplier (`mlp_dim = mlp_ratio * width`).
52    pub mlp_ratio: usize,
53    /// Output joint-embedding dimension (after the text projection).
54    pub embed_dim: usize,
55    /// Id used as the end-of-text marker. When pooling, the position of the
56    /// *last* occurrence of this id (or, if absent, the highest id present) is
57    /// used. CLIP itself uses `argmax` over ids because `<|endoftext|>` is the
58    /// largest BPE id; we follow that convention via [`Self::eot_token`].
59    pub eot_token: usize,
60}
61
62impl ClipTextConfig {
63    /// Validate and construct a config.
64    ///
65    /// # Errors
66    /// - [`VisionError::InvalidEmbedDim`] if `width == 0` or `embed_dim == 0`.
67    /// - [`VisionError::InvalidNumHeads`] if `n_heads == 0`.
68    /// - [`VisionError::HeadDimMismatch`] if `n_heads` does not divide `width`.
69    /// - [`VisionError::Internal`] if `vocab_size`, `n_ctx`, or `depth` is 0.
70    #[allow(clippy::too_many_arguments)]
71    pub fn new(
72        vocab_size: usize,
73        n_ctx: usize,
74        width: usize,
75        depth: usize,
76        n_heads: usize,
77        mlp_ratio: usize,
78        embed_dim: usize,
79        eot_token: usize,
80    ) -> VisionResult<Self> {
81        if width == 0 {
82            return Err(VisionError::InvalidEmbedDim(width));
83        }
84        if embed_dim == 0 {
85            return Err(VisionError::InvalidEmbedDim(embed_dim));
86        }
87        if n_heads == 0 {
88            return Err(VisionError::InvalidNumHeads(n_heads));
89        }
90        if width % n_heads != 0 {
91            return Err(VisionError::HeadDimMismatch {
92                n_heads,
93                embed_dim: width,
94            });
95        }
96        if vocab_size == 0 {
97            return Err(VisionError::Internal("vocab_size must be > 0".into()));
98        }
99        if n_ctx == 0 {
100            return Err(VisionError::Internal("n_ctx must be > 0".into()));
101        }
102        if depth == 0 {
103            return Err(VisionError::Internal("depth must be > 0".into()));
104        }
105        if eot_token >= vocab_size {
106            return Err(VisionError::Internal(
107                "eot_token must be < vocab_size".into(),
108            ));
109        }
110        Ok(Self {
111            vocab_size,
112            n_ctx,
113            width,
114            depth,
115            n_heads,
116            mlp_ratio,
117            embed_dim,
118            eot_token,
119        })
120    }
121
122    /// A tiny config suitable for unit tests.
123    ///
124    /// `vocab_size = 64`, `n_ctx = 16`, `width = 32`, `depth = 2`,
125    /// `n_heads = 4`, `mlp_ratio = 4`, `embed_dim = 24`, `eot_token = 63`.
126    #[must_use]
127    pub fn tiny() -> Self {
128        Self {
129            vocab_size: 64,
130            n_ctx: 16,
131            width: 32,
132            depth: 2,
133            n_heads: 4,
134            mlp_ratio: 4,
135            embed_dim: 24,
136            eot_token: 63,
137        }
138    }
139
140    /// Per-head dimension.
141    #[must_use]
142    #[inline]
143    pub fn head_dim(&self) -> usize {
144        self.width / self.n_heads
145    }
146
147    /// MLP hidden dimension.
148    #[must_use]
149    #[inline]
150    pub fn mlp_dim(&self) -> usize {
151        self.mlp_ratio * self.width
152    }
153}
154
155// ─── Per-block weights ──────────────────────────────────────────────────────────
156
157/// Learnable weights for one pre-LN causal transformer block.
158struct TextBlockWeights {
159    qkv_weight: Vec<f32>,  // [3*width, width]
160    qkv_bias: Vec<f32>,    // [3*width]
161    out_weight: Vec<f32>,  // [width, width]
162    out_bias: Vec<f32>,    // [width]
163    mlp1_weight: Vec<f32>, // [mlp_dim, width]
164    mlp1_bias: Vec<f32>,   // [mlp_dim]
165    mlp2_weight: Vec<f32>, // [width, mlp_dim]
166    mlp2_bias: Vec<f32>,   // [width]
167    ln1_weight: Vec<f32>,  // [width]
168    ln1_bias: Vec<f32>,
169    ln2_weight: Vec<f32>,
170    ln2_bias: Vec<f32>,
171}
172
173impl TextBlockWeights {
174    fn default_init(cfg: &ClipTextConfig, rng: &mut LcgRng) -> Self {
175        let w = cfg.width;
176        let mlp = cfg.mlp_dim();
177        let scale = 1.0 / (w as f32).sqrt();
178        let fill = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
179            let mut v = vec![0.0f32; n];
180            rng.fill_normal(&mut v);
181            for x in &mut v {
182                *x *= sc;
183            }
184            v
185        };
186        Self {
187            qkv_weight: fill(rng, 3 * w * w, scale),
188            qkv_bias: vec![0.0f32; 3 * w],
189            out_weight: fill(rng, w * w, scale),
190            out_bias: vec![0.0f32; w],
191            mlp1_weight: fill(rng, mlp * w, scale),
192            mlp1_bias: vec![0.0f32; mlp],
193            mlp2_weight: fill(rng, w * mlp, scale),
194            mlp2_bias: vec![0.0f32; w],
195            ln1_weight: vec![1.0f32; w],
196            ln1_bias: vec![0.0f32; w],
197            ln2_weight: vec![1.0f32; w],
198            ln2_bias: vec![0.0f32; w],
199        }
200    }
201}
202
203// ─── Causal multi-head self-attention ───────────────────────────────────────────
204
205/// Causal (autoregressive) multi-head self-attention.
206///
207/// Identical to standard scaled-dot-product attention except that, before the
208/// row-wise softmax, every score `S[i, j]` with `j > i` is set to `-∞`, so the
209/// softmax weight of any *future* key is exactly zero. Position `i` therefore
210/// mixes only positions `0..=i`.
211///
212/// `tokens` is `[n · e]` row-major; returns `[n · e]`.
213#[allow(clippy::too_many_arguments)]
214fn causal_mhsa(
215    tokens: &[f32],
216    n: usize,
217    e: usize,
218    n_heads: usize,
219    head_dim: usize,
220    qkv_weight: &[f32],
221    qkv_bias: &[f32],
222    out_weight: &[f32],
223    out_bias: &[f32],
224) -> VisionResult<Vec<f32>> {
225    // Fused QKV projection → [n, 3e].
226    let qkv = linear(tokens, qkv_weight, qkv_bias, e, 3 * e);
227
228    let mut q = vec![0.0f32; n * e];
229    let mut k = vec![0.0f32; n * e];
230    let mut v = vec![0.0f32; n * e];
231    for t in 0..n {
232        let src = &qkv[t * 3 * e..(t + 1) * 3 * e];
233        q[t * e..(t + 1) * e].copy_from_slice(&src[..e]);
234        k[t * e..(t + 1) * e].copy_from_slice(&src[e..2 * e]);
235        v[t * e..(t + 1) * e].copy_from_slice(&src[2 * e..]);
236    }
237
238    let scale = 1.0 / (head_dim as f32).sqrt();
239    let mut concat = vec![0.0f32; n * e];
240
241    for h in 0..n_heads {
242        let off = h * head_dim;
243        for i in 0..n {
244            // Causal window: keys 0..=i only.
245            // Stable softmax over the masked row.
246            let mut max_score = f32::NEG_INFINITY;
247            let mut row_scores = vec![0.0f32; i + 1];
248            for (j, slot) in row_scores.iter_mut().enumerate() {
249                let mut dot = 0.0f32;
250                for d in 0..head_dim {
251                    dot += q[i * e + off + d] * k[j * e + off + d];
252                }
253                let s = dot * scale;
254                *slot = s;
255                if s > max_score {
256                    max_score = s;
257                }
258            }
259            let mut sum = 0.0f32;
260            for s in &mut row_scores {
261                *s = (*s - max_score).exp();
262                sum += *s;
263            }
264            let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
265            for d in 0..head_dim {
266                let mut acc = 0.0f32;
267                for (j, &sw) in row_scores.iter().enumerate() {
268                    acc += sw * inv * v[j * e + off + d];
269                }
270                concat[i * e + off + d] = acc;
271            }
272        }
273    }
274
275    let out = linear(&concat, out_weight, out_bias, e, e);
276    if out.iter().any(|x| !x.is_finite()) {
277        return Err(VisionError::NonFinite("clip text attention output"));
278    }
279    Ok(out)
280}
281
282// ─── CLIP text encoder ──────────────────────────────────────────────────────────
283
284/// The CLIP Transformer text tower.
285pub struct ClipTextEncoder {
286    /// Configuration.
287    pub config: ClipTextConfig,
288    /// Token embedding table: `[vocab_size · width]` row-major.
289    pub token_embedding: Vec<f32>,
290    /// Learned positional embedding: `[n_ctx · width]` row-major.
291    pub positional_embedding: Vec<f32>,
292    /// Per-block weights.
293    blocks: Vec<TextBlockWeights>,
294    /// Final LayerNorm scale `[width]`.
295    final_ln_weight: Vec<f32>,
296    /// Final LayerNorm bias `[width]`.
297    final_ln_bias: Vec<f32>,
298    /// Text projection: `[embed_dim · width]` row-major (maps width → embed_dim).
299    text_projection: Vec<f32>,
300}
301
302impl ClipTextEncoder {
303    /// Construct a CLIP text encoder with Gaussian-initialised weights.
304    ///
305    /// # Errors
306    /// Propagates configuration / sub-component validation errors.
307    pub fn new(cfg: ClipTextConfig, rng: &mut LcgRng) -> VisionResult<Self> {
308        let w = cfg.width;
309
310        // Token & positional embeddings use the canonical CLIP init std of 0.02
311        // (small so the residual stream starts near the identity).
312        let mut token_embedding = vec![0.0f32; cfg.vocab_size * w];
313        rng.fill_normal(&mut token_embedding);
314        for v in &mut token_embedding {
315            *v *= 0.02;
316        }
317        let mut positional_embedding = vec![0.0f32; cfg.n_ctx * w];
318        rng.fill_normal(&mut positional_embedding);
319        for v in &mut positional_embedding {
320            *v *= 0.01;
321        }
322
323        let mut blocks = Vec::with_capacity(cfg.depth);
324        for _ in 0..cfg.depth {
325            blocks.push(TextBlockWeights::default_init(&cfg, rng));
326        }
327
328        let final_ln_weight = vec![1.0f32; w];
329        let final_ln_bias = vec![0.0f32; w];
330
331        // Text projection (width → embed_dim), scaled by 1/√width.
332        let scale = 1.0 / (w as f32).sqrt();
333        let mut text_projection = vec![0.0f32; cfg.embed_dim * w];
334        rng.fill_normal(&mut text_projection);
335        for v in &mut text_projection {
336            *v *= scale;
337        }
338
339        Ok(Self {
340            config: cfg,
341            token_embedding,
342            positional_embedding,
343            blocks,
344            final_ln_weight,
345            final_ln_bias,
346            text_projection,
347        })
348    }
349
350    /// Locate the pooling position for a token sequence.
351    ///
352    /// Following CLIP, the joint embedding is read at the position of the
353    /// end-of-text token. We select the position of the **last** occurrence of
354    /// `eot_token`; if it never appears, we fall back to CLIP's `argmax`
355    /// convention (the position of the highest token id), and as a final
356    /// fallback the last index.
357    #[must_use]
358    pub fn eot_position(&self, tokens: &[usize]) -> usize {
359        if tokens.is_empty() {
360            return 0;
361        }
362        // Last occurrence of the explicit EOT id.
363        for (idx, &tok) in tokens.iter().enumerate().rev() {
364            if tok == self.config.eot_token {
365                return idx;
366            }
367        }
368        // Fallback: position of the maximum id (CLIP argmax).
369        let mut best_idx = tokens.len() - 1;
370        let mut best_val = tokens[best_idx];
371        for (idx, &tok) in tokens.iter().enumerate() {
372            if tok > best_val {
373                best_val = tok;
374                best_idx = idx;
375            }
376        }
377        best_idx
378    }
379
380    /// Run the full encoder and return the contextual hidden states *before*
381    /// pooling and projection: `[n · width]`, after the final LayerNorm.
382    ///
383    /// Exposed so causality tests can probe individual token hidden states.
384    ///
385    /// # Errors
386    /// - [`VisionError::EmptyInput`] if `tokens` is empty.
387    /// - [`VisionError::Internal`] if the sequence is longer than `n_ctx` or a
388    ///   token id is out of the vocabulary range.
389    pub fn hidden_states(&self, tokens: &[usize]) -> VisionResult<Vec<f32>> {
390        let cfg = &self.config;
391        let w = cfg.width;
392        let n = tokens.len();
393        if n == 0 {
394            return Err(VisionError::EmptyInput("token sequence"));
395        }
396        if n > cfg.n_ctx {
397            return Err(VisionError::Internal(
398                "sequence length exceeds n_ctx".into(),
399            ));
400        }
401        for &tok in tokens {
402            if tok >= cfg.vocab_size {
403                return Err(VisionError::Internal(
404                    "token id out of vocabulary range".into(),
405                ));
406            }
407        }
408
409        // Embed: token_embedding[token] + positional_embedding[position].
410        let mut h = vec![0.0f32; n * w];
411        for (pos, &tok) in tokens.iter().enumerate() {
412            let te = &self.token_embedding[tok * w..(tok + 1) * w];
413            let pe = &self.positional_embedding[pos * w..(pos + 1) * w];
414            let dst = &mut h[pos * w..(pos + 1) * w];
415            for d in 0..w {
416                dst[d] = te[d] + pe[d];
417            }
418        }
419
420        // Pre-LN causal transformer blocks.
421        for blk in &self.blocks {
422            // Block 1: x = x + Attn(LN1(x))   (causal)
423            let normed = layer_norm(&h, &blk.ln1_weight, &blk.ln1_bias, n, w, 1e-5);
424            let attn = causal_mhsa(
425                &normed,
426                n,
427                w,
428                cfg.n_heads,
429                cfg.head_dim(),
430                &blk.qkv_weight,
431                &blk.qkv_bias,
432                &blk.out_weight,
433                &blk.out_bias,
434            )?;
435            for (hv, av) in h.iter_mut().zip(attn.iter()) {
436                *hv += av;
437            }
438
439            // Block 2: x = x + MLP(LN2(x))
440            let normed2 = layer_norm(&h, &blk.ln2_weight, &blk.ln2_bias, n, w, 1e-5);
441            let mlp_dim = cfg.mlp_dim();
442            let mid = linear(&normed2, &blk.mlp1_weight, &blk.mlp1_bias, w, mlp_dim);
443            let mid: Vec<f32> = mid.into_iter().map(gelu_exact).collect();
444            let mlp_out = linear(&mid, &blk.mlp2_weight, &blk.mlp2_bias, mlp_dim, w);
445            for (hv, mv) in h.iter_mut().zip(mlp_out.iter()) {
446                *hv += mv;
447            }
448        }
449
450        // Final LayerNorm.
451        let out = layer_norm(&h, &self.final_ln_weight, &self.final_ln_bias, n, w, 1e-5);
452        Ok(out)
453    }
454
455    /// Encode a token sequence to a unit-norm joint-space embedding.
456    ///
457    /// # Returns
458    /// `[embed_dim]` L2-normalised text embedding.
459    ///
460    /// # Errors
461    /// Propagates errors from [`Self::hidden_states`].
462    pub fn encode(&self, tokens: &[usize]) -> VisionResult<Vec<f32>> {
463        let cfg = &self.config;
464        let w = cfg.width;
465        let hs = self.hidden_states(tokens)?;
466
467        // Pool the hidden state at the EOS / argmax position.
468        let pool = self.eot_position(tokens);
469        let pooled = &hs[pool * w..(pool + 1) * w];
470
471        // Linear projection width → embed_dim (no bias, like CLIP's text_projection).
472        let mut z = vec![0.0f32; cfg.embed_dim];
473        for (p, zp) in z.iter_mut().enumerate() {
474            let row = &self.text_projection[p * w..(p + 1) * w];
475            *zp = row
476                .iter()
477                .zip(pooled.iter())
478                .map(|(&a, &b)| a * b)
479                .sum::<f32>();
480        }
481
482        // L2-normalise.
483        let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
484        let inv = 1.0 / norm.max(1e-12);
485        for v in &mut z {
486            *v *= inv;
487        }
488
489        if z.iter().any(|v| !v.is_finite()) {
490            return Err(VisionError::NonFinite("clip text embedding"));
491        }
492        Ok(z)
493    }
494
495    /// Encode a batch of (independent) token sequences.
496    ///
497    /// # Returns
498    /// One `[embed_dim]` embedding per input sequence.
499    ///
500    /// # Errors
501    /// Propagates the first encoding error.
502    pub fn encode_batch(&self, sequences: &[Vec<usize>]) -> VisionResult<Vec<Vec<f32>>> {
503        let mut out = Vec::with_capacity(sequences.len());
504        for seq in sequences {
505            out.push(self.encode(seq)?);
506        }
507        Ok(out)
508    }
509}
510
511// ─── Tests ──────────────────────────────────────────────────────────────────────
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    fn make_encoder(seed: u64) -> ClipTextEncoder {
518        let mut rng = LcgRng::new(seed);
519        ClipTextEncoder::new(ClipTextConfig::tiny(), &mut rng).expect("encoder ok")
520    }
521
522    fn cosine(a: &[f32], b: &[f32]) -> f32 {
523        let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
524        let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
525        let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
526        dot / (na * nb + 1e-12)
527    }
528
529    // ── Config ──────────────────────────────────────────────────────────────────
530
531    #[test]
532    fn config_tiny_valid() {
533        let cfg = ClipTextConfig::tiny();
534        assert_eq!(cfg.head_dim(), 8);
535        assert_eq!(cfg.mlp_dim(), 128);
536    }
537
538    #[test]
539    fn config_head_mismatch_errors() {
540        let r = ClipTextConfig::new(64, 16, 30, 2, 4, 4, 24, 63);
541        assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
542    }
543
544    #[test]
545    fn config_zero_width_errors() {
546        let r = ClipTextConfig::new(64, 16, 0, 2, 4, 4, 24, 63);
547        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
548    }
549
550    #[test]
551    fn config_eot_out_of_range_errors() {
552        let r = ClipTextConfig::new(64, 16, 32, 2, 4, 4, 24, 64);
553        assert!(matches!(r, Err(VisionError::Internal(_))));
554    }
555
556    // ── (a) Output embedding is unit-norm ─────────────────────────────────────────
557
558    #[test]
559    fn encode_output_is_unit_norm() {
560        let enc = make_encoder(1);
561        let tokens = vec![3usize, 7, 12, 5, 63];
562        let z = enc.encode(&tokens).expect("encode ok");
563        let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
564        assert!(
565            (norm - 1.0).abs() < 1e-5,
566            "text embedding must be L2-unit-norm; got {norm}"
567        );
568    }
569
570    // ── (b) CAUSALITY: a future token cannot change an earlier hidden state ────────
571
572    #[test]
573    fn causality_future_token_does_not_affect_earlier_hidden_state() {
574        let enc = make_encoder(2);
575        // Two sequences identical in positions 0..=2, differing at position 3.
576        let seq_a = vec![5usize, 9, 14, 2, 63];
577        let seq_b = vec![5usize, 9, 14, 31, 63]; // position 3 changed
578        let hs_a = enc.hidden_states(&seq_a).expect("ok");
579        let hs_b = enc.hidden_states(&seq_b).expect("ok");
580        let w = enc.config.width;
581        // Hidden states at positions 0,1,2 must be identical (they only attend
582        // to positions ≤ their index, none of which changed).
583        for pos in 0..3 {
584            for d in 0..w {
585                let a = hs_a[pos * w + d];
586                let b = hs_b[pos * w + d];
587                assert!(
588                    (a - b).abs() < 1e-6,
589                    "causality violated at pos {pos}, dim {d}: {a} vs {b}"
590                );
591            }
592        }
593        // Sanity: position 3 (which saw the changed token) *should* differ.
594        let diff_pos3: f32 = (0..w)
595            .map(|d| (hs_a[3 * w + d] - hs_b[3 * w + d]).abs())
596            .sum();
597        assert!(
598            diff_pos3 > 1e-6,
599            "position 3 should change when its own token changes (diff={diff_pos3})"
600        );
601    }
602
603    // ── (c) Different sequences → different embeddings ─────────────────────────────
604
605    #[test]
606    fn different_sequences_give_different_embeddings() {
607        let enc = make_encoder(3);
608        let za = enc.encode(&[1usize, 2, 3, 63]).expect("ok");
609        let zb = enc.encode(&[10usize, 20, 30, 63]).expect("ok");
610        let diff: f32 = za.iter().zip(zb.iter()).map(|(a, b)| (a - b).abs()).sum();
611        assert!(
612            diff > 1e-4,
613            "distinct token sequences must produce distinct embeddings (diff={diff})"
614        );
615    }
616
617    // ── (d) Determinism ───────────────────────────────────────────────────────────
618
619    #[test]
620    fn deterministic_same_input_same_output() {
621        let enc = make_encoder(4);
622        let tokens = vec![4usize, 8, 15, 16, 23, 42, 63];
623        let z1 = enc.encode(&tokens).expect("ok");
624        let z2 = enc.encode(&tokens).expect("ok");
625        assert_eq!(z1, z2, "encoder must be deterministic");
626    }
627
628    // ── (e) Cosine similarity of identical inputs == 1 ────────────────────────────
629
630    #[test]
631    fn cosine_of_identical_inputs_is_one() {
632        let enc = make_encoder(5);
633        let tokens = vec![2usize, 4, 6, 8, 63];
634        let z = enc.encode(&tokens).expect("ok");
635        let sim = cosine(&z, &z);
636        assert!(
637            (sim - 1.0).abs() < 1e-5,
638            "cosine(z, z) must be 1.0; got {sim}"
639        );
640    }
641
642    // ── (f) Projection output dim == configured joint dim ─────────────────────────
643
644    #[test]
645    fn projection_output_dim_matches_config() {
646        let enc = make_encoder(6);
647        let z = enc.encode(&[1usize, 2, 63]).expect("ok");
648        assert_eq!(
649            z.len(),
650            enc.config.embed_dim,
651            "projected embedding dim must equal config.embed_dim"
652        );
653    }
654
655    // ── (g) EOS / pooling position selection ──────────────────────────────────────
656
657    #[test]
658    fn eot_position_selects_last_eot_occurrence() {
659        let enc = make_encoder(7);
660        // EOT id is 63. It appears at index 4 (last real position before padding).
661        let tokens = vec![5usize, 9, 14, 2, 63, 0, 0];
662        assert_eq!(
663            enc.eot_position(&tokens),
664            4,
665            "must pool at the last EOT (id=63) position"
666        );
667    }
668
669    #[test]
670    fn eot_position_argmax_fallback_when_no_explicit_eot() {
671        let enc = make_encoder(8);
672        // No id == 63; highest id is 40 at index 2 → argmax fallback.
673        let tokens = vec![5usize, 9, 40, 2, 7];
674        assert_eq!(
675            enc.eot_position(&tokens),
676            2,
677            "argmax fallback should pick the highest-id position"
678        );
679    }
680
681    #[test]
682    fn pooling_uses_eot_hidden_state() {
683        // The embedding must be derived from the hidden state at the EOT
684        // position: changing a token *after* the EOT must not change the
685        // embedding (because pooling happens at EOT and causal attention means
686        // EOT never sees later tokens).
687        let enc = make_encoder(9);
688        let base = vec![3usize, 7, 12, 63, 1, 2]; // EOT at index 3
689        let changed = vec![3usize, 7, 12, 63, 30, 40]; // tokens after EOT differ
690        let z_base = enc.encode(&base).expect("ok");
691        let z_changed = enc.encode(&changed).expect("ok");
692        let diff: f32 = z_base
693            .iter()
694            .zip(z_changed.iter())
695            .map(|(a, b)| (a - b).abs())
696            .sum();
697        assert!(
698            diff < 1e-6,
699            "tokens after the EOT must not affect the pooled embedding (diff={diff})"
700        );
701    }
702
703    // ── Error paths ───────────────────────────────────────────────────────────────
704
705    #[test]
706    fn empty_sequence_errors() {
707        let enc = make_encoder(10);
708        let r = enc.encode(&[]);
709        assert!(matches!(r, Err(VisionError::EmptyInput(_))));
710    }
711
712    #[test]
713    fn sequence_too_long_errors() {
714        let enc = make_encoder(11);
715        let too_long: Vec<usize> = (0..enc.config.n_ctx + 1).map(|i| i % 60).collect();
716        let r = enc.encode(&too_long);
717        assert!(matches!(r, Err(VisionError::Internal(_))));
718    }
719
720    #[test]
721    fn out_of_vocab_token_errors() {
722        let enc = make_encoder(12);
723        let r = enc.encode(&[1usize, 9999, 63]);
724        assert!(matches!(r, Err(VisionError::Internal(_))));
725    }
726
727    // ── Batch ─────────────────────────────────────────────────────────────────────
728
729    #[test]
730    fn encode_batch_matches_individual() {
731        let enc = make_encoder(13);
732        let seqs = vec![vec![1usize, 2, 63], vec![5usize, 9, 14, 63]];
733        let batch = enc.encode_batch(&seqs).expect("ok");
734        assert_eq!(batch.len(), 2);
735        for (i, seq) in seqs.iter().enumerate() {
736            let single = enc.encode(seq).expect("ok");
737            for (a, b) in batch[i].iter().zip(single.iter()) {
738                assert!((a - b).abs() < 1e-6, "batch vs single mismatch");
739            }
740        }
741    }
742
743    // ── Extra causality guard: an EARLY token change DOES propagate forward ────────
744
745    #[test]
746    fn early_token_change_propagates_to_later_positions() {
747        let enc = make_encoder(14);
748        let seq_a = vec![5usize, 9, 14, 2, 63];
749        let seq_b = vec![31usize, 9, 14, 2, 63]; // position 0 changed
750        let hs_a = enc.hidden_states(&seq_a).expect("ok");
751        let hs_b = enc.hidden_states(&seq_b).expect("ok");
752        let w = enc.config.width;
753        // A later position (e.g. 4) must change because it attends back to pos 0.
754        let diff_pos4: f32 = (0..w)
755            .map(|d| (hs_a[4 * w + d] - hs_b[4 * w + d]).abs())
756            .sum();
757        assert!(
758            diff_pos4 > 1e-6,
759            "changing position 0 must affect later positions (diff={diff_pos4})"
760        );
761    }
762}