1use crate::{
31 error::{VisionError, VisionResult},
32 handle::LcgRng,
33 vit::vit_block::{gelu_exact, layer_norm, linear},
34};
35
36#[derive(Debug, Clone, PartialEq)]
40pub struct ClipTextConfig {
41 pub vocab_size: usize,
43 pub n_ctx: usize,
45 pub width: usize,
47 pub depth: usize,
49 pub n_heads: usize,
51 pub mlp_ratio: usize,
53 pub embed_dim: usize,
55 pub eot_token: usize,
60}
61
62impl ClipTextConfig {
63 #[allow(clippy::too_many_arguments)]
71 pub fn new(
72 vocab_size: usize,
73 n_ctx: usize,
74 width: usize,
75 depth: usize,
76 n_heads: usize,
77 mlp_ratio: usize,
78 embed_dim: usize,
79 eot_token: usize,
80 ) -> VisionResult<Self> {
81 if width == 0 {
82 return Err(VisionError::InvalidEmbedDim(width));
83 }
84 if embed_dim == 0 {
85 return Err(VisionError::InvalidEmbedDim(embed_dim));
86 }
87 if n_heads == 0 {
88 return Err(VisionError::InvalidNumHeads(n_heads));
89 }
90 if width % n_heads != 0 {
91 return Err(VisionError::HeadDimMismatch {
92 n_heads,
93 embed_dim: width,
94 });
95 }
96 if vocab_size == 0 {
97 return Err(VisionError::Internal("vocab_size must be > 0".into()));
98 }
99 if n_ctx == 0 {
100 return Err(VisionError::Internal("n_ctx must be > 0".into()));
101 }
102 if depth == 0 {
103 return Err(VisionError::Internal("depth must be > 0".into()));
104 }
105 if eot_token >= vocab_size {
106 return Err(VisionError::Internal(
107 "eot_token must be < vocab_size".into(),
108 ));
109 }
110 Ok(Self {
111 vocab_size,
112 n_ctx,
113 width,
114 depth,
115 n_heads,
116 mlp_ratio,
117 embed_dim,
118 eot_token,
119 })
120 }
121
122 #[must_use]
127 pub fn tiny() -> Self {
128 Self {
129 vocab_size: 64,
130 n_ctx: 16,
131 width: 32,
132 depth: 2,
133 n_heads: 4,
134 mlp_ratio: 4,
135 embed_dim: 24,
136 eot_token: 63,
137 }
138 }
139
140 #[must_use]
142 #[inline]
143 pub fn head_dim(&self) -> usize {
144 self.width / self.n_heads
145 }
146
147 #[must_use]
149 #[inline]
150 pub fn mlp_dim(&self) -> usize {
151 self.mlp_ratio * self.width
152 }
153}
154
155struct TextBlockWeights {
159 qkv_weight: Vec<f32>, qkv_bias: Vec<f32>, out_weight: Vec<f32>, out_bias: Vec<f32>, mlp1_weight: Vec<f32>, mlp1_bias: Vec<f32>, mlp2_weight: Vec<f32>, mlp2_bias: Vec<f32>, ln1_weight: Vec<f32>, ln1_bias: Vec<f32>,
169 ln2_weight: Vec<f32>,
170 ln2_bias: Vec<f32>,
171}
172
173impl TextBlockWeights {
174 fn default_init(cfg: &ClipTextConfig, rng: &mut LcgRng) -> Self {
175 let w = cfg.width;
176 let mlp = cfg.mlp_dim();
177 let scale = 1.0 / (w as f32).sqrt();
178 let fill = |rng: &mut LcgRng, n: usize, sc: f32| -> Vec<f32> {
179 let mut v = vec![0.0f32; n];
180 rng.fill_normal(&mut v);
181 for x in &mut v {
182 *x *= sc;
183 }
184 v
185 };
186 Self {
187 qkv_weight: fill(rng, 3 * w * w, scale),
188 qkv_bias: vec![0.0f32; 3 * w],
189 out_weight: fill(rng, w * w, scale),
190 out_bias: vec![0.0f32; w],
191 mlp1_weight: fill(rng, mlp * w, scale),
192 mlp1_bias: vec![0.0f32; mlp],
193 mlp2_weight: fill(rng, w * mlp, scale),
194 mlp2_bias: vec![0.0f32; w],
195 ln1_weight: vec![1.0f32; w],
196 ln1_bias: vec![0.0f32; w],
197 ln2_weight: vec![1.0f32; w],
198 ln2_bias: vec![0.0f32; w],
199 }
200 }
201}
202
203#[allow(clippy::too_many_arguments)]
214fn causal_mhsa(
215 tokens: &[f32],
216 n: usize,
217 e: usize,
218 n_heads: usize,
219 head_dim: usize,
220 qkv_weight: &[f32],
221 qkv_bias: &[f32],
222 out_weight: &[f32],
223 out_bias: &[f32],
224) -> VisionResult<Vec<f32>> {
225 let qkv = linear(tokens, qkv_weight, qkv_bias, e, 3 * e);
227
228 let mut q = vec![0.0f32; n * e];
229 let mut k = vec![0.0f32; n * e];
230 let mut v = vec![0.0f32; n * e];
231 for t in 0..n {
232 let src = &qkv[t * 3 * e..(t + 1) * 3 * e];
233 q[t * e..(t + 1) * e].copy_from_slice(&src[..e]);
234 k[t * e..(t + 1) * e].copy_from_slice(&src[e..2 * e]);
235 v[t * e..(t + 1) * e].copy_from_slice(&src[2 * e..]);
236 }
237
238 let scale = 1.0 / (head_dim as f32).sqrt();
239 let mut concat = vec![0.0f32; n * e];
240
241 for h in 0..n_heads {
242 let off = h * head_dim;
243 for i in 0..n {
244 let mut max_score = f32::NEG_INFINITY;
247 let mut row_scores = vec![0.0f32; i + 1];
248 for (j, slot) in row_scores.iter_mut().enumerate() {
249 let mut dot = 0.0f32;
250 for d in 0..head_dim {
251 dot += q[i * e + off + d] * k[j * e + off + d];
252 }
253 let s = dot * scale;
254 *slot = s;
255 if s > max_score {
256 max_score = s;
257 }
258 }
259 let mut sum = 0.0f32;
260 for s in &mut row_scores {
261 *s = (*s - max_score).exp();
262 sum += *s;
263 }
264 let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
265 for d in 0..head_dim {
266 let mut acc = 0.0f32;
267 for (j, &sw) in row_scores.iter().enumerate() {
268 acc += sw * inv * v[j * e + off + d];
269 }
270 concat[i * e + off + d] = acc;
271 }
272 }
273 }
274
275 let out = linear(&concat, out_weight, out_bias, e, e);
276 if out.iter().any(|x| !x.is_finite()) {
277 return Err(VisionError::NonFinite("clip text attention output"));
278 }
279 Ok(out)
280}
281
282pub struct ClipTextEncoder {
286 pub config: ClipTextConfig,
288 pub token_embedding: Vec<f32>,
290 pub positional_embedding: Vec<f32>,
292 blocks: Vec<TextBlockWeights>,
294 final_ln_weight: Vec<f32>,
296 final_ln_bias: Vec<f32>,
298 text_projection: Vec<f32>,
300}
301
302impl ClipTextEncoder {
303 pub fn new(cfg: ClipTextConfig, rng: &mut LcgRng) -> VisionResult<Self> {
308 let w = cfg.width;
309
310 let mut token_embedding = vec![0.0f32; cfg.vocab_size * w];
313 rng.fill_normal(&mut token_embedding);
314 for v in &mut token_embedding {
315 *v *= 0.02;
316 }
317 let mut positional_embedding = vec![0.0f32; cfg.n_ctx * w];
318 rng.fill_normal(&mut positional_embedding);
319 for v in &mut positional_embedding {
320 *v *= 0.01;
321 }
322
323 let mut blocks = Vec::with_capacity(cfg.depth);
324 for _ in 0..cfg.depth {
325 blocks.push(TextBlockWeights::default_init(&cfg, rng));
326 }
327
328 let final_ln_weight = vec![1.0f32; w];
329 let final_ln_bias = vec![0.0f32; w];
330
331 let scale = 1.0 / (w as f32).sqrt();
333 let mut text_projection = vec![0.0f32; cfg.embed_dim * w];
334 rng.fill_normal(&mut text_projection);
335 for v in &mut text_projection {
336 *v *= scale;
337 }
338
339 Ok(Self {
340 config: cfg,
341 token_embedding,
342 positional_embedding,
343 blocks,
344 final_ln_weight,
345 final_ln_bias,
346 text_projection,
347 })
348 }
349
350 #[must_use]
358 pub fn eot_position(&self, tokens: &[usize]) -> usize {
359 if tokens.is_empty() {
360 return 0;
361 }
362 for (idx, &tok) in tokens.iter().enumerate().rev() {
364 if tok == self.config.eot_token {
365 return idx;
366 }
367 }
368 let mut best_idx = tokens.len() - 1;
370 let mut best_val = tokens[best_idx];
371 for (idx, &tok) in tokens.iter().enumerate() {
372 if tok > best_val {
373 best_val = tok;
374 best_idx = idx;
375 }
376 }
377 best_idx
378 }
379
380 pub fn hidden_states(&self, tokens: &[usize]) -> VisionResult<Vec<f32>> {
390 let cfg = &self.config;
391 let w = cfg.width;
392 let n = tokens.len();
393 if n == 0 {
394 return Err(VisionError::EmptyInput("token sequence"));
395 }
396 if n > cfg.n_ctx {
397 return Err(VisionError::Internal(
398 "sequence length exceeds n_ctx".into(),
399 ));
400 }
401 for &tok in tokens {
402 if tok >= cfg.vocab_size {
403 return Err(VisionError::Internal(
404 "token id out of vocabulary range".into(),
405 ));
406 }
407 }
408
409 let mut h = vec![0.0f32; n * w];
411 for (pos, &tok) in tokens.iter().enumerate() {
412 let te = &self.token_embedding[tok * w..(tok + 1) * w];
413 let pe = &self.positional_embedding[pos * w..(pos + 1) * w];
414 let dst = &mut h[pos * w..(pos + 1) * w];
415 for d in 0..w {
416 dst[d] = te[d] + pe[d];
417 }
418 }
419
420 for blk in &self.blocks {
422 let normed = layer_norm(&h, &blk.ln1_weight, &blk.ln1_bias, n, w, 1e-5);
424 let attn = causal_mhsa(
425 &normed,
426 n,
427 w,
428 cfg.n_heads,
429 cfg.head_dim(),
430 &blk.qkv_weight,
431 &blk.qkv_bias,
432 &blk.out_weight,
433 &blk.out_bias,
434 )?;
435 for (hv, av) in h.iter_mut().zip(attn.iter()) {
436 *hv += av;
437 }
438
439 let normed2 = layer_norm(&h, &blk.ln2_weight, &blk.ln2_bias, n, w, 1e-5);
441 let mlp_dim = cfg.mlp_dim();
442 let mid = linear(&normed2, &blk.mlp1_weight, &blk.mlp1_bias, w, mlp_dim);
443 let mid: Vec<f32> = mid.into_iter().map(gelu_exact).collect();
444 let mlp_out = linear(&mid, &blk.mlp2_weight, &blk.mlp2_bias, mlp_dim, w);
445 for (hv, mv) in h.iter_mut().zip(mlp_out.iter()) {
446 *hv += mv;
447 }
448 }
449
450 let out = layer_norm(&h, &self.final_ln_weight, &self.final_ln_bias, n, w, 1e-5);
452 Ok(out)
453 }
454
455 pub fn encode(&self, tokens: &[usize]) -> VisionResult<Vec<f32>> {
463 let cfg = &self.config;
464 let w = cfg.width;
465 let hs = self.hidden_states(tokens)?;
466
467 let pool = self.eot_position(tokens);
469 let pooled = &hs[pool * w..(pool + 1) * w];
470
471 let mut z = vec![0.0f32; cfg.embed_dim];
473 for (p, zp) in z.iter_mut().enumerate() {
474 let row = &self.text_projection[p * w..(p + 1) * w];
475 *zp = row
476 .iter()
477 .zip(pooled.iter())
478 .map(|(&a, &b)| a * b)
479 .sum::<f32>();
480 }
481
482 let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
484 let inv = 1.0 / norm.max(1e-12);
485 for v in &mut z {
486 *v *= inv;
487 }
488
489 if z.iter().any(|v| !v.is_finite()) {
490 return Err(VisionError::NonFinite("clip text embedding"));
491 }
492 Ok(z)
493 }
494
495 pub fn encode_batch(&self, sequences: &[Vec<usize>]) -> VisionResult<Vec<Vec<f32>>> {
503 let mut out = Vec::with_capacity(sequences.len());
504 for seq in sequences {
505 out.push(self.encode(seq)?);
506 }
507 Ok(out)
508 }
509}
510
511#[cfg(test)]
514mod tests {
515 use super::*;
516
517 fn make_encoder(seed: u64) -> ClipTextEncoder {
518 let mut rng = LcgRng::new(seed);
519 ClipTextEncoder::new(ClipTextConfig::tiny(), &mut rng).expect("encoder ok")
520 }
521
522 fn cosine(a: &[f32], b: &[f32]) -> f32 {
523 let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
524 let na: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
525 let nb: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
526 dot / (na * nb + 1e-12)
527 }
528
529 #[test]
532 fn config_tiny_valid() {
533 let cfg = ClipTextConfig::tiny();
534 assert_eq!(cfg.head_dim(), 8);
535 assert_eq!(cfg.mlp_dim(), 128);
536 }
537
538 #[test]
539 fn config_head_mismatch_errors() {
540 let r = ClipTextConfig::new(64, 16, 30, 2, 4, 4, 24, 63);
541 assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
542 }
543
544 #[test]
545 fn config_zero_width_errors() {
546 let r = ClipTextConfig::new(64, 16, 0, 2, 4, 4, 24, 63);
547 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
548 }
549
550 #[test]
551 fn config_eot_out_of_range_errors() {
552 let r = ClipTextConfig::new(64, 16, 32, 2, 4, 4, 24, 64);
553 assert!(matches!(r, Err(VisionError::Internal(_))));
554 }
555
556 #[test]
559 fn encode_output_is_unit_norm() {
560 let enc = make_encoder(1);
561 let tokens = vec![3usize, 7, 12, 5, 63];
562 let z = enc.encode(&tokens).expect("encode ok");
563 let norm: f32 = z.iter().map(|&v| v * v).sum::<f32>().sqrt();
564 assert!(
565 (norm - 1.0).abs() < 1e-5,
566 "text embedding must be L2-unit-norm; got {norm}"
567 );
568 }
569
570 #[test]
573 fn causality_future_token_does_not_affect_earlier_hidden_state() {
574 let enc = make_encoder(2);
575 let seq_a = vec![5usize, 9, 14, 2, 63];
577 let seq_b = vec![5usize, 9, 14, 31, 63]; let hs_a = enc.hidden_states(&seq_a).expect("ok");
579 let hs_b = enc.hidden_states(&seq_b).expect("ok");
580 let w = enc.config.width;
581 for pos in 0..3 {
584 for d in 0..w {
585 let a = hs_a[pos * w + d];
586 let b = hs_b[pos * w + d];
587 assert!(
588 (a - b).abs() < 1e-6,
589 "causality violated at pos {pos}, dim {d}: {a} vs {b}"
590 );
591 }
592 }
593 let diff_pos3: f32 = (0..w)
595 .map(|d| (hs_a[3 * w + d] - hs_b[3 * w + d]).abs())
596 .sum();
597 assert!(
598 diff_pos3 > 1e-6,
599 "position 3 should change when its own token changes (diff={diff_pos3})"
600 );
601 }
602
603 #[test]
606 fn different_sequences_give_different_embeddings() {
607 let enc = make_encoder(3);
608 let za = enc.encode(&[1usize, 2, 3, 63]).expect("ok");
609 let zb = enc.encode(&[10usize, 20, 30, 63]).expect("ok");
610 let diff: f32 = za.iter().zip(zb.iter()).map(|(a, b)| (a - b).abs()).sum();
611 assert!(
612 diff > 1e-4,
613 "distinct token sequences must produce distinct embeddings (diff={diff})"
614 );
615 }
616
617 #[test]
620 fn deterministic_same_input_same_output() {
621 let enc = make_encoder(4);
622 let tokens = vec![4usize, 8, 15, 16, 23, 42, 63];
623 let z1 = enc.encode(&tokens).expect("ok");
624 let z2 = enc.encode(&tokens).expect("ok");
625 assert_eq!(z1, z2, "encoder must be deterministic");
626 }
627
628 #[test]
631 fn cosine_of_identical_inputs_is_one() {
632 let enc = make_encoder(5);
633 let tokens = vec![2usize, 4, 6, 8, 63];
634 let z = enc.encode(&tokens).expect("ok");
635 let sim = cosine(&z, &z);
636 assert!(
637 (sim - 1.0).abs() < 1e-5,
638 "cosine(z, z) must be 1.0; got {sim}"
639 );
640 }
641
642 #[test]
645 fn projection_output_dim_matches_config() {
646 let enc = make_encoder(6);
647 let z = enc.encode(&[1usize, 2, 63]).expect("ok");
648 assert_eq!(
649 z.len(),
650 enc.config.embed_dim,
651 "projected embedding dim must equal config.embed_dim"
652 );
653 }
654
655 #[test]
658 fn eot_position_selects_last_eot_occurrence() {
659 let enc = make_encoder(7);
660 let tokens = vec![5usize, 9, 14, 2, 63, 0, 0];
662 assert_eq!(
663 enc.eot_position(&tokens),
664 4,
665 "must pool at the last EOT (id=63) position"
666 );
667 }
668
669 #[test]
670 fn eot_position_argmax_fallback_when_no_explicit_eot() {
671 let enc = make_encoder(8);
672 let tokens = vec![5usize, 9, 40, 2, 7];
674 assert_eq!(
675 enc.eot_position(&tokens),
676 2,
677 "argmax fallback should pick the highest-id position"
678 );
679 }
680
681 #[test]
682 fn pooling_uses_eot_hidden_state() {
683 let enc = make_encoder(9);
688 let base = vec![3usize, 7, 12, 63, 1, 2]; let changed = vec![3usize, 7, 12, 63, 30, 40]; let z_base = enc.encode(&base).expect("ok");
691 let z_changed = enc.encode(&changed).expect("ok");
692 let diff: f32 = z_base
693 .iter()
694 .zip(z_changed.iter())
695 .map(|(a, b)| (a - b).abs())
696 .sum();
697 assert!(
698 diff < 1e-6,
699 "tokens after the EOT must not affect the pooled embedding (diff={diff})"
700 );
701 }
702
703 #[test]
706 fn empty_sequence_errors() {
707 let enc = make_encoder(10);
708 let r = enc.encode(&[]);
709 assert!(matches!(r, Err(VisionError::EmptyInput(_))));
710 }
711
712 #[test]
713 fn sequence_too_long_errors() {
714 let enc = make_encoder(11);
715 let too_long: Vec<usize> = (0..enc.config.n_ctx + 1).map(|i| i % 60).collect();
716 let r = enc.encode(&too_long);
717 assert!(matches!(r, Err(VisionError::Internal(_))));
718 }
719
720 #[test]
721 fn out_of_vocab_token_errors() {
722 let enc = make_encoder(12);
723 let r = enc.encode(&[1usize, 9999, 63]);
724 assert!(matches!(r, Err(VisionError::Internal(_))));
725 }
726
727 #[test]
730 fn encode_batch_matches_individual() {
731 let enc = make_encoder(13);
732 let seqs = vec![vec![1usize, 2, 63], vec![5usize, 9, 14, 63]];
733 let batch = enc.encode_batch(&seqs).expect("ok");
734 assert_eq!(batch.len(), 2);
735 for (i, seq) in seqs.iter().enumerate() {
736 let single = enc.encode(seq).expect("ok");
737 for (a, b) in batch[i].iter().zip(single.iter()) {
738 assert!((a - b).abs() < 1e-6, "batch vs single mismatch");
739 }
740 }
741 }
742
743 #[test]
746 fn early_token_change_propagates_to_later_positions() {
747 let enc = make_encoder(14);
748 let seq_a = vec![5usize, 9, 14, 2, 63];
749 let seq_b = vec![31usize, 9, 14, 2, 63]; let hs_a = enc.hidden_states(&seq_a).expect("ok");
751 let hs_b = enc.hidden_states(&seq_b).expect("ok");
752 let w = enc.config.width;
753 let diff_pos4: f32 = (0..w)
755 .map(|d| (hs_a[4 * w + d] - hs_b[4 * w + d]).abs())
756 .sum();
757 assert!(
758 diff_pos4 > 1e-6,
759 "changing position 0 must affect later positions (diff={diff_pos4})"
760 );
761 }
762}