Skip to main content

ferrotorch_diffusion/
clip_text_encoder.rs

1//! Stable-Diffusion 1.5 CLIP text encoder
2//! (`openai/clip-vit-large-patch14` — the text tower of CLIP-ViT-L/14).
3//!
4//! Phase B.3c of real-artifact-driven development — third and final SD
5//! sub-model. Together with the VAE decoder (Phase B.3a, #1150) and the
6//! UNet (Phase B.3b, #1151) this completes the bit-perfect SD-1.5
7//! inference pipeline:
8//!
9//! ```text
10//! CLIP text encoder  →  UNet (denoise)  →  VAE decoder
11//! [B, S=77, 768]        [B, 4, 64, 64]    [B, 3, 512, 512]
12//! ```
13//!
14//! Mirrors `transformers.CLIPTextModel` for the SD-1.5 config exactly:
15//!
16//! ```text
17//! hidden_size        = 768
18//! intermediate_size  = 3072
19//! num_attention_heads = 12
20//! num_hidden_layers  = 12
21//! max_position_embeddings = 77
22//! vocab_size         = 49408
23//! hidden_act         = "quick_gelu"     # x * sigmoid(1.702 * x)
24//! layer_norm_eps     = 1e-5
25//! ```
26//!
27//! Architecture (state-dict prefix in parens):
28//!
29//! ```text
30//! CLIPTextModel
31//! └── text_model
32//!     ├── embeddings
33//!     │   ├── token_embedding.weight    [49408, 768]
34//!     │   └── position_embedding.weight [77, 768]
35//!     ├── encoder
36//!     │   └── layers.{0..11}.
37//!     │       ├── layer_norm1.{weight,bias}    [768], [768]
38//!     │       ├── self_attn.
39//!     │       │   ├── q_proj.{weight,bias}    [768,768], [768]
40//!     │       │   ├── k_proj.{weight,bias}    [768,768], [768]
41//!     │       │   ├── v_proj.{weight,bias}    [768,768], [768]
42//!     │       │   └── out_proj.{weight,bias}  [768,768], [768]
43//!     │       ├── layer_norm2.{weight,bias}    [768], [768]
44//!     │       └── mlp.
45//!     │           ├── fc1.{weight,bias}        [3072,768], [3072]
46//!     │           └── fc2.{weight,bias}        [768,3072], [768]
47//!     └── final_layer_norm.{weight,bias}        [768], [768]
48//! ```
49//!
50//! Forward pass (per layer is pre-LayerNorm + residual):
51//!
52//! ```text
53//! h = token_embedding(input_ids) + position_embedding(arange(S))
54//! for layer in encoder.layers:
55//!     residual = h
56//!     h = layer_norm1(h)
57//!     h = causal_self_attn(h, h, h)            # ← causal mask is critical
58//!     h = residual + h
59//!     residual = h
60//!     h = layer_norm2(h)
61//!     h = fc2(quick_gelu(fc1(h)))
62//!     h = residual + h
63//! h = final_layer_norm(h)
64//! return h                                      # last_hidden_state [B, S, 768]
65//! ```
66//!
67//! ## Critical correctness gotchas
68//!
69//! 1. **Causal mask**. Despite the "encoder" name, CLIP text-side
70//!    self-attention is causal — position `i` attends only to
71//!    `0..=i`. Omitting this would still pass shape checks but break
72//!    parity vs `transformers`.
73//! 2. **QuickGELU**, not standard GELU. CLIP-ViT-L/14 uses the fast
74//!    sigmoid approximation `x * sigmoid(1.702 * x)`, not the erf-based
75//!    or tanh-based GELU. We pin this via
76//!    `GELU::with_approximate(GeluApproximate::Sigmoid)`.
77//! 3. **Position embedding is *learned* and absolute** — full table of
78//!    77 entries. Position ids are `[0, 1, ..., S-1]` for every forward.
79//! 4. **All four self-attention projections have bias** (unlike SD's
80//!    UNet `Attention` which has `bias=False` on q/k/v).
81//! 5. **SD uses `last_hidden_state` directly**, not the EOS-pooled
82//!    output. We return `[B, S, hidden_size]`.
83
84use std::collections::HashMap;
85
86use ferrotorch_core::grad_fns::arithmetic::{add, mul};
87use ferrotorch_core::{
88    FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
89};
90use ferrotorch_nn::module::{Module, StateDict};
91use ferrotorch_nn::parameter::Parameter;
92use ferrotorch_nn::{
93    Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
94    transpose_heads_to_2d,
95};
96
97// ---------------------------------------------------------------------------
98// Config
99// ---------------------------------------------------------------------------
100
101/// Configuration for the SD-1.5 CLIP text encoder
102/// (`runwayml/stable-diffusion-v1-5/text_encoder/config.json`).
103#[derive(Debug, Clone)]
104pub struct ClipTextConfig {
105    /// Hidden width. SD-1.5: 768.
106    pub hidden_size: usize,
107    /// FFN expansion width. SD-1.5: 3072.
108    pub intermediate_size: usize,
109    /// Number of attention heads per layer. SD-1.5: 12. Must divide
110    /// `hidden_size` evenly.
111    pub num_attention_heads: usize,
112    /// Number of transformer layers. SD-1.5: 12.
113    pub num_hidden_layers: usize,
114    /// Maximum sequence length. SD-1.5: 77.
115    pub max_position_embeddings: usize,
116    /// Token vocabulary size. SD-1.5: 49408.
117    pub vocab_size: usize,
118    /// LayerNorm epsilon. SD-1.5: 1e-5.
119    pub layer_norm_eps: f64,
120}
121
122impl Default for ClipTextConfig {
123    fn default() -> Self {
124        Self::sd_v1_5()
125    }
126}
127
128impl ClipTextConfig {
129    /// SD-1.5 CLIP text encoder defaults (CLIP-ViT-L/14 text tower).
130    pub fn sd_v1_5() -> Self {
131        Self {
132            hidden_size: 768,
133            intermediate_size: 3072,
134            num_attention_heads: 12,
135            num_hidden_layers: 12,
136            max_position_embeddings: 77,
137            vocab_size: 49408,
138            layer_norm_eps: 1e-5,
139        }
140    }
141
142    /// Per-head dimension.
143    #[inline]
144    #[must_use]
145    pub fn head_dim(&self) -> usize {
146        self.hidden_size / self.num_attention_heads
147    }
148
149    /// Validate field bounds.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`FerrotorchError::InvalidArgument`] for any out-of-bounds
154    /// or arithmetic-incompatible field.
155    pub fn validate(&self) -> FerrotorchResult<()> {
156        if self.hidden_size == 0
157            || self.intermediate_size == 0
158            || self.num_attention_heads == 0
159            || self.num_hidden_layers == 0
160            || self.max_position_embeddings == 0
161            || self.vocab_size == 0
162        {
163            return Err(FerrotorchError::InvalidArgument {
164                message: "ClipTextConfig: all size fields must be > 0".into(),
165            });
166        }
167        if self.hidden_size % self.num_attention_heads != 0 {
168            return Err(FerrotorchError::InvalidArgument {
169                message: format!(
170                    "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
171                    self.hidden_size, self.num_attention_heads,
172                ),
173            });
174        }
175        if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
176            return Err(FerrotorchError::InvalidArgument {
177                message: format!(
178                    "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
179                    self.layer_norm_eps,
180                ),
181            });
182        }
183        Ok(())
184    }
185
186    /// Parse a `text_encoder/config.json` document into a [`ClipTextConfig`].
187    ///
188    /// Recognised keys (all optional — anything missing falls back to the
189    /// SD-1.5 defaults): `hidden_size`, `intermediate_size`,
190    /// `num_attention_heads`, `num_hidden_layers`,
191    /// `max_position_embeddings`, `vocab_size`, `layer_norm_eps`.
192    ///
193    /// # Errors
194    ///
195    /// Returns [`FerrotorchError::InvalidArgument`] on malformed JSON or
196    /// invalid field values.
197    pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
198        let v: serde_json::Value =
199            serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
200                message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
201            })?;
202        let mut cfg = Self::default();
203        if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
204            cfg.hidden_size = x as usize;
205        }
206        if let Some(x) = v
207            .get("intermediate_size")
208            .and_then(serde_json::Value::as_u64)
209        {
210            cfg.intermediate_size = x as usize;
211        }
212        if let Some(x) = v
213            .get("num_attention_heads")
214            .and_then(serde_json::Value::as_u64)
215        {
216            cfg.num_attention_heads = x as usize;
217        }
218        if let Some(x) = v
219            .get("num_hidden_layers")
220            .and_then(serde_json::Value::as_u64)
221        {
222            cfg.num_hidden_layers = x as usize;
223        }
224        if let Some(x) = v
225            .get("max_position_embeddings")
226            .and_then(serde_json::Value::as_u64)
227        {
228            cfg.max_position_embeddings = x as usize;
229        }
230        if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
231            cfg.vocab_size = x as usize;
232        }
233        if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
234            cfg.layer_norm_eps = x;
235        }
236        cfg.validate()?;
237        Ok(cfg)
238    }
239
240    /// Parse a `text_encoder/config.json` file from disk.
241    ///
242    /// # Errors
243    ///
244    /// Returns [`FerrotorchError::InvalidArgument`] for I/O or parse
245    /// failures.
246    pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
247        let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
248            message: format!(
249                "ClipTextConfig::from_file: failed to read {}: {e}",
250                path.display(),
251            ),
252        })?;
253        Self::from_json_str(&s)
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Helper: reshape utilities (own-the-data so the per-row buffers are
259// always contiguous and ready for the BMM-style attention call).
260// ---------------------------------------------------------------------------
261
262fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
263    let prod: usize = shape.iter().product();
264    if prod != t.numel() {
265        return Err(FerrotorchError::ShapeMismatch {
266            message: format!(
267                "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
268                 match source numel {}",
269                t.numel()
270            ),
271        });
272    }
273    let data = t.data_vec()?;
274    Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
275}
276
277/// Build a 1-D float-encoded index tensor from u32 ids (matches the
278/// trick `BertEmbeddings::float_index_tensor` uses).
279fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
280    let data: Vec<T> = ids
281        .iter()
282        .map(|&i| numeric_cast::cast::<u32, T>(i))
283        .collect::<FerrotorchResult<Vec<T>>>()?;
284    let n = data.len();
285    Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
286}
287
288// ---------------------------------------------------------------------------
289// CLIPTextEmbeddings
290// ---------------------------------------------------------------------------
291
292/// Token embedding + learned absolute position embedding. The two
293/// lookups are summed. Mirrors `CLIPTextEmbeddings` in transformers.
294///
295/// Note: there is NO LayerNorm at the embedding level (unlike BERT).
296/// The first per-layer `layer_norm1` handles normalisation downstream.
297#[derive(Debug)]
298pub struct ClipTextEmbeddings<T: Float> {
299    /// Token lookup — `[vocab_size, hidden_size]`.
300    pub token_embedding: Embedding<T>,
301    /// Learned position lookup — `[max_position_embeddings, hidden_size]`.
302    pub position_embedding: Embedding<T>,
303    hidden_size: usize,
304    max_position_embeddings: usize,
305    training: bool,
306}
307
308impl<T: Float> ClipTextEmbeddings<T> {
309    /// Build randomly-initialized embeddings for the given config.
310    ///
311    /// # Errors
312    ///
313    /// Returns [`FerrotorchError`] from the underlying [`Embedding`]
314    /// constructors.
315    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
316        cfg.validate()?;
317        Ok(Self {
318            token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
319            position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
320            hidden_size: cfg.hidden_size,
321            max_position_embeddings: cfg.max_position_embeddings,
322            training: false,
323        })
324    }
325
326    /// Run the embedding sum on a sequence of token ids.
327    ///
328    /// `input_ids` is the verbatim token-id vector (length `S`). The
329    /// output has shape `[1, S, hidden]`.
330    ///
331    /// # Errors
332    ///
333    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty or
334    ///   exceeds `max_position_embeddings`.
335    /// * Propagates downstream embedding-lookup errors.
336    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
337        if input_ids.is_empty() {
338            return Err(FerrotorchError::InvalidArgument {
339                message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
340            });
341        }
342        let seq_len = input_ids.len();
343        if seq_len > self.max_position_embeddings {
344            return Err(FerrotorchError::InvalidArgument {
345                message: format!(
346                    "ClipTextEmbeddings: sequence length {seq_len} exceeds \
347                     max_position_embeddings {}",
348                    self.max_position_embeddings,
349                ),
350            });
351        }
352
353        let word_idx = float_index_tensor::<T>(input_ids)?;
354        let word_2d = self.token_embedding.forward(&word_idx)?; // [S, hidden]
355
356        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
357        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
358        let pos_2d = self.position_embedding.forward(&pos_idx)?; // [S, hidden]
359
360        let summed = add(&word_2d, &pos_2d)?;
361        // Promote to [1, S, hidden] so downstream uses 3-D ranks.
362        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
363    }
364}
365
366impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
367    /// When called via `Module::forward` we treat `input` as a
368    /// 1-D float-index tensor (same convention as the inner
369    /// `Embedding` modules). Real callers should use
370    /// [`Self::forward_from_ids`].
371    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
372        let word_2d = self.token_embedding.forward(input)?;
373        let seq_len = input.numel();
374        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
375        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
376        let pos_2d = self.position_embedding.forward(&pos_idx)?;
377        let summed = add(&word_2d, &pos_2d)?;
378        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
379    }
380
381    fn parameters(&self) -> Vec<&Parameter<T>> {
382        let mut out = Vec::new();
383        out.extend(self.token_embedding.parameters());
384        out.extend(self.position_embedding.parameters());
385        out
386    }
387
388    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
389        let mut out = Vec::new();
390        out.extend(self.token_embedding.parameters_mut());
391        out.extend(self.position_embedding.parameters_mut());
392        out
393    }
394
395    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
396        let mut out = Vec::new();
397        for (n, p) in self.token_embedding.named_parameters() {
398            out.push((format!("token_embedding.{n}"), p));
399        }
400        for (n, p) in self.position_embedding.named_parameters() {
401            out.push((format!("position_embedding.{n}"), p));
402        }
403        out
404    }
405
406    fn train(&mut self) {
407        self.training = true;
408    }
409
410    fn eval(&mut self) {
411        self.training = false;
412    }
413
414    fn is_training(&self) -> bool {
415        self.training
416    }
417
418    fn state_dict(&self) -> StateDict<T> {
419        self.named_parameters()
420            .into_iter()
421            .map(|(n, p)| (n, p.tensor().clone()))
422            .collect()
423    }
424
425    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
426        let extract = |prefix: &str| -> StateDict<T> {
427            let p = format!("{prefix}.");
428            state
429                .iter()
430                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
431                .collect()
432        };
433        if strict {
434            let prefixes = ["token_embedding", "position_embedding"];
435            for k in state.keys() {
436                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
437                    return Err(FerrotorchError::InvalidArgument {
438                        message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
439                    });
440                }
441            }
442        }
443        self.token_embedding
444            .load_state_dict(&extract("token_embedding"), strict)?;
445        self.position_embedding
446            .load_state_dict(&extract("position_embedding"), strict)?;
447        Ok(())
448    }
449}
450
451// ---------------------------------------------------------------------------
452// CLIP self-attention
453// ---------------------------------------------------------------------------
454
455/// Multi-head self-attention with a *causal* mask. Mirrors
456/// `CLIPAttention` (the text-side variant) in transformers — all four
457/// projections (q/k/v/out) carry bias, the scale is `1/sqrt(head_dim)`,
458/// and the attention is causal (each position attends to itself and
459/// earlier positions only).
460///
461/// State-dict layout:
462///
463/// ```text
464/// q_proj.{weight,bias}    [hidden, hidden], [hidden]
465/// k_proj.{weight,bias}    [hidden, hidden], [hidden]
466/// v_proj.{weight,bias}    [hidden, hidden], [hidden]
467/// out_proj.{weight,bias}  [hidden, hidden], [hidden]
468/// ```
469#[derive(Debug)]
470pub struct ClipSelfAttention<T: Float> {
471    /// Query projection — `[hidden, hidden]`, with bias.
472    pub q_proj: Linear<T>,
473    /// Key projection — `[hidden, hidden]`, with bias.
474    pub k_proj: Linear<T>,
475    /// Value projection — `[hidden, hidden]`, with bias.
476    pub v_proj: Linear<T>,
477    /// Output projection — `[hidden, hidden]`, with bias.
478    pub out_proj: Linear<T>,
479    num_heads: usize,
480    head_dim: usize,
481    hidden: usize,
482    training: bool,
483}
484
485impl<T: Float> ClipSelfAttention<T> {
486    /// Build randomly-initialized self-attention projections.
487    ///
488    /// # Errors
489    ///
490    /// Returns the underlying [`FerrotorchError`] on bad config dims.
491    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
492        cfg.validate()?;
493        Ok(Self {
494            q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
495            k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
496            v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
497            out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
498            num_heads: cfg.num_attention_heads,
499            head_dim: cfg.head_dim(),
500            hidden: cfg.hidden_size,
501            training: false,
502        })
503    }
504}
505
506impl<T: Float> Module<T> for ClipSelfAttention<T> {
507    /// Forward — input `[1, S, hidden]`, output `[1, S, hidden]`.
508    ///
509    /// The attention is causal: position `i` cannot attend to position
510    /// `j > i`. This matches `transformers.CLIPAttention`'s use of
511    /// `_create_4d_causal_attention_mask`.
512    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
513        let shape = input.shape();
514        if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
515            return Err(FerrotorchError::ShapeMismatch {
516                message: format!(
517                    "ClipSelfAttention expects [1, S, {}], got {:?}",
518                    self.hidden, shape,
519                ),
520            });
521        }
522        let seq_len = shape[1];
523
524        // Projections — Linear handles any rank; output is [1, S, hidden].
525        let q = self.q_proj.forward(input)?;
526        let k = self.k_proj.forward(input)?;
527        let v = self.v_proj.forward(input)?;
528
529        // Drop the batch=1 leading dim so reshape_to_heads / the
530        // attention helper can treat the rows as [S, H*d].
531        let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
532        let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
533        let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
534
535        // [S, H*d] → [H, S, d] (batch-first heads).
536        let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
537        let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
538        let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
539
540        // Scaled dot-product attention with causal mask. `standard_attention`
541        // applies `1/sqrt(head_dim)` scaling and `-inf` upper-triangular
542        // mask, then softmax + value mix.
543        let ctx = standard_attention(&q_h, &k_h, &v_h, /* causal = */ true)?;
544
545        // [H, S, d] → [S, H*d] → [1, S, hidden].
546        let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
547        let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
548
549        // Output projection (with bias).
550        self.out_proj.forward(&ctx3)
551    }
552
553    fn parameters(&self) -> Vec<&Parameter<T>> {
554        let mut out = Vec::new();
555        out.extend(self.q_proj.parameters());
556        out.extend(self.k_proj.parameters());
557        out.extend(self.v_proj.parameters());
558        out.extend(self.out_proj.parameters());
559        out
560    }
561
562    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
563        let mut out = Vec::new();
564        out.extend(self.q_proj.parameters_mut());
565        out.extend(self.k_proj.parameters_mut());
566        out.extend(self.v_proj.parameters_mut());
567        out.extend(self.out_proj.parameters_mut());
568        out
569    }
570
571    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
572        let mut out = Vec::new();
573        for (n, p) in self.q_proj.named_parameters() {
574            out.push((format!("q_proj.{n}"), p));
575        }
576        for (n, p) in self.k_proj.named_parameters() {
577            out.push((format!("k_proj.{n}"), p));
578        }
579        for (n, p) in self.v_proj.named_parameters() {
580            out.push((format!("v_proj.{n}"), p));
581        }
582        for (n, p) in self.out_proj.named_parameters() {
583            out.push((format!("out_proj.{n}"), p));
584        }
585        out
586    }
587
588    fn train(&mut self) {
589        self.training = true;
590    }
591
592    fn eval(&mut self) {
593        self.training = false;
594    }
595
596    fn is_training(&self) -> bool {
597        self.training
598    }
599
600    fn state_dict(&self) -> StateDict<T> {
601        self.named_parameters()
602            .into_iter()
603            .map(|(n, p)| (n, p.tensor().clone()))
604            .collect()
605    }
606
607    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
608        let extract = |prefix: &str| -> StateDict<T> {
609            let p = format!("{prefix}.");
610            state
611                .iter()
612                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
613                .collect()
614        };
615        if strict {
616            let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
617            for k in state.keys() {
618                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
619                    return Err(FerrotorchError::InvalidArgument {
620                        message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
621                    });
622                }
623            }
624        }
625        self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
626        self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
627        self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
628        self.out_proj
629            .load_state_dict(&extract("out_proj"), strict)?;
630        Ok(())
631    }
632}
633
634// ---------------------------------------------------------------------------
635// CLIPMLP
636// ---------------------------------------------------------------------------
637
638/// CLIP MLP: `fc2(quick_gelu(fc1(x)))`.
639///
640/// QuickGELU is `x * sigmoid(1.702 * x)` — the fast sigmoid
641/// approximation. NOT the standard `0.5 * x * (1 + erf(x/sqrt(2)))`
642/// kernel. This is the published `hidden_act = "quick_gelu"` in
643/// CLIP-ViT-L/14's config.
644#[derive(Debug)]
645pub struct ClipMlp<T: Float> {
646    /// Expansion projection — `[hidden, intermediate]`, with bias.
647    pub fc1: Linear<T>,
648    /// Reduction projection — `[intermediate, hidden]`, with bias.
649    pub fc2: Linear<T>,
650    activation: GELU,
651    training: bool,
652}
653
654impl<T: Float> ClipMlp<T> {
655    /// Build randomly-initialized MLP for the given config.
656    ///
657    /// # Errors
658    ///
659    /// Returns the underlying [`FerrotorchError`] on bad config dims.
660    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
661        cfg.validate()?;
662        Ok(Self {
663            fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
664            fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
665            // QuickGELU: `x * sigmoid(1.702 * x)`. See module doc.
666            activation: GELU::with_approximate(GeluApproximate::Sigmoid),
667            training: false,
668        })
669    }
670}
671
672impl<T: Float> Module<T> for ClipMlp<T> {
673    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
674        let h = self.fc1.forward(input)?;
675        let h = self.activation.forward(&h)?;
676        self.fc2.forward(&h)
677    }
678
679    fn parameters(&self) -> Vec<&Parameter<T>> {
680        let mut out = Vec::new();
681        out.extend(self.fc1.parameters());
682        out.extend(self.fc2.parameters());
683        out
684    }
685
686    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
687        let mut out = Vec::new();
688        out.extend(self.fc1.parameters_mut());
689        out.extend(self.fc2.parameters_mut());
690        out
691    }
692
693    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
694        let mut out = Vec::new();
695        for (n, p) in self.fc1.named_parameters() {
696            out.push((format!("fc1.{n}"), p));
697        }
698        for (n, p) in self.fc2.named_parameters() {
699            out.push((format!("fc2.{n}"), p));
700        }
701        out
702    }
703
704    fn train(&mut self) {
705        self.training = true;
706    }
707
708    fn eval(&mut self) {
709        self.training = false;
710    }
711
712    fn is_training(&self) -> bool {
713        self.training
714    }
715
716    fn state_dict(&self) -> StateDict<T> {
717        self.named_parameters()
718            .into_iter()
719            .map(|(n, p)| (n, p.tensor().clone()))
720            .collect()
721    }
722
723    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
724        let extract = |prefix: &str| -> StateDict<T> {
725            let p = format!("{prefix}.");
726            state
727                .iter()
728                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
729                .collect()
730        };
731        if strict {
732            for k in state.keys() {
733                if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
734                    return Err(FerrotorchError::InvalidArgument {
735                        message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
736                    });
737                }
738            }
739        }
740        self.fc1.load_state_dict(&extract("fc1"), strict)?;
741        self.fc2.load_state_dict(&extract("fc2"), strict)?;
742        Ok(())
743    }
744}
745
746// ---------------------------------------------------------------------------
747// CLIPEncoderLayer
748// ---------------------------------------------------------------------------
749
750/// One CLIP text encoder layer.
751///
752/// Pre-LayerNorm + residual stack:
753///
754/// ```text
755/// h = x + self_attn(layer_norm1(x))
756/// h = h + mlp(layer_norm2(h))
757/// ```
758#[derive(Debug)]
759pub struct ClipEncoderLayer<T: Float> {
760    /// Pre-attention LayerNorm.
761    pub layer_norm1: LayerNorm<T>,
762    /// Causal self-attention (q/k/v/out, all biased).
763    pub self_attn: ClipSelfAttention<T>,
764    /// Pre-FFN LayerNorm.
765    pub layer_norm2: LayerNorm<T>,
766    /// Two-layer MLP with QuickGELU activation.
767    pub mlp: ClipMlp<T>,
768    training: bool,
769}
770
771impl<T: Float> ClipEncoderLayer<T> {
772    /// Build a randomly-initialized encoder layer.
773    ///
774    /// # Errors
775    ///
776    /// Returns the underlying [`FerrotorchError`] on bad config dims.
777    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
778        Ok(Self {
779            layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
780            self_attn: ClipSelfAttention::new(cfg)?,
781            layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
782            mlp: ClipMlp::new(cfg)?,
783            training: false,
784        })
785    }
786}
787
788impl<T: Float> Module<T> for ClipEncoderLayer<T> {
789    /// Forward — `[1, S, hidden]` → `[1, S, hidden]`.
790    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
791        // Self-attention sub-block (pre-norm).
792        let normed = self.layer_norm1.forward(input)?;
793        let attn_out = self.self_attn.forward(&normed)?;
794        let after_attn = add(input, &attn_out)?;
795
796        // MLP sub-block (pre-norm).
797        let normed_ffn = self.layer_norm2.forward(&after_attn)?;
798        let mlp_out = self.mlp.forward(&normed_ffn)?;
799        add(&after_attn, &mlp_out)
800    }
801
802    fn parameters(&self) -> Vec<&Parameter<T>> {
803        let mut out = Vec::new();
804        out.extend(self.layer_norm1.parameters());
805        out.extend(self.self_attn.parameters());
806        out.extend(self.layer_norm2.parameters());
807        out.extend(self.mlp.parameters());
808        out
809    }
810
811    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
812        let mut out = Vec::new();
813        out.extend(self.layer_norm1.parameters_mut());
814        out.extend(self.self_attn.parameters_mut());
815        out.extend(self.layer_norm2.parameters_mut());
816        out.extend(self.mlp.parameters_mut());
817        out
818    }
819
820    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
821        let mut out = Vec::new();
822        for (n, p) in self.layer_norm1.named_parameters() {
823            out.push((format!("layer_norm1.{n}"), p));
824        }
825        for (n, p) in self.self_attn.named_parameters() {
826            out.push((format!("self_attn.{n}"), p));
827        }
828        for (n, p) in self.layer_norm2.named_parameters() {
829            out.push((format!("layer_norm2.{n}"), p));
830        }
831        for (n, p) in self.mlp.named_parameters() {
832            out.push((format!("mlp.{n}"), p));
833        }
834        out
835    }
836
837    fn train(&mut self) {
838        self.training = true;
839    }
840
841    fn eval(&mut self) {
842        self.training = false;
843    }
844
845    fn is_training(&self) -> bool {
846        self.training
847    }
848
849    fn state_dict(&self) -> StateDict<T> {
850        self.named_parameters()
851            .into_iter()
852            .map(|(n, p)| (n, p.tensor().clone()))
853            .collect()
854    }
855
856    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
857        let extract = |prefix: &str| -> StateDict<T> {
858            let p = format!("{prefix}.");
859            state
860                .iter()
861                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
862                .collect()
863        };
864        if strict {
865            let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
866            for k in state.keys() {
867                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
868                    return Err(FerrotorchError::InvalidArgument {
869                        message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
870                    });
871                }
872            }
873        }
874        self.layer_norm1
875            .load_state_dict(&extract("layer_norm1"), strict)?;
876        self.self_attn
877            .load_state_dict(&extract("self_attn"), strict)?;
878        self.layer_norm2
879            .load_state_dict(&extract("layer_norm2"), strict)?;
880        self.mlp.load_state_dict(&extract("mlp"), strict)?;
881        Ok(())
882    }
883}
884
885// ---------------------------------------------------------------------------
886// CLIPEncoder
887// ---------------------------------------------------------------------------
888
889/// Stack of `num_hidden_layers` [`ClipEncoderLayer`]s applied in order.
890#[derive(Debug)]
891pub struct ClipEncoder<T: Float> {
892    /// One layer per `num_hidden_layers`.
893    pub layers: Vec<ClipEncoderLayer<T>>,
894    training: bool,
895}
896
897impl<T: Float> ClipEncoder<T> {
898    /// Build a randomly-initialized encoder stack.
899    ///
900    /// # Errors
901    ///
902    /// Returns the underlying [`FerrotorchError`] on bad config dims.
903    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
904        cfg.validate()?;
905        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
906        for _ in 0..cfg.num_hidden_layers {
907            layers.push(ClipEncoderLayer::new(cfg)?);
908        }
909        Ok(Self {
910            layers,
911            training: false,
912        })
913    }
914}
915
916impl<T: Float> Module<T> for ClipEncoder<T> {
917    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
918        let mut h = input.clone();
919        for l in &self.layers {
920            h = l.forward(&h)?;
921        }
922        Ok(h)
923    }
924
925    fn parameters(&self) -> Vec<&Parameter<T>> {
926        let mut out = Vec::new();
927        for l in &self.layers {
928            out.extend(l.parameters());
929        }
930        out
931    }
932
933    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
934        let mut out = Vec::new();
935        for l in &mut self.layers {
936            out.extend(l.parameters_mut());
937        }
938        out
939    }
940
941    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
942        let mut out = Vec::new();
943        for (i, l) in self.layers.iter().enumerate() {
944            for (n, p) in l.named_parameters() {
945                out.push((format!("layers.{i}.{n}"), p));
946            }
947        }
948        out
949    }
950
951    fn train(&mut self) {
952        self.training = true;
953        for l in &mut self.layers {
954            l.train();
955        }
956    }
957
958    fn eval(&mut self) {
959        self.training = false;
960        for l in &mut self.layers {
961            l.eval();
962        }
963    }
964
965    fn is_training(&self) -> bool {
966        self.training
967    }
968
969    fn state_dict(&self) -> StateDict<T> {
970        self.named_parameters()
971            .into_iter()
972            .map(|(n, p)| (n, p.tensor().clone()))
973            .collect()
974    }
975
976    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
977        let extract = |prefix: &str| -> StateDict<T> {
978            let p = format!("{prefix}.");
979            state
980                .iter()
981                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
982                .collect()
983        };
984        if strict {
985            for k in state.keys() {
986                if !k.starts_with("layers.") {
987                    return Err(FerrotorchError::InvalidArgument {
988                        message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
989                    });
990                }
991            }
992        }
993        for (i, l) in self.layers.iter_mut().enumerate() {
994            l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
995        }
996        Ok(())
997    }
998}
999
1000// ---------------------------------------------------------------------------
1001// CLIPTextTransformer / ClipTextEncoder
1002// ---------------------------------------------------------------------------
1003
1004/// The full SD-1.5 CLIP text encoder. Wraps [`ClipTextEmbeddings`] +
1005/// [`ClipEncoder`] + a final [`LayerNorm`] (`final_layer_norm`).
1006///
1007/// Mirrors `CLIPTextTransformer` in transformers. The HF
1008/// `CLIPTextModel` wrapper sits one prefix above
1009/// (`text_model.embeddings.*`, `text_model.encoder.*`,
1010/// `text_model.final_layer_norm.*`) — see
1011/// [`crate::safetensors_loader::load_clip_text_encoder`] for the
1012/// `text_model.` strip.
1013///
1014/// Output is the per-token `last_hidden_state` `[B, S, hidden_size]`.
1015/// SD-1.5 consumes this directly as `encoder_hidden_states` for the
1016/// UNet's cross-attention (no pooling).
1017#[derive(Debug)]
1018pub struct ClipTextEncoder<T: Float> {
1019    /// Token + position embedding sum.
1020    pub embeddings: ClipTextEmbeddings<T>,
1021    /// 12 × [`ClipEncoderLayer`] for SD-1.5.
1022    pub encoder: ClipEncoder<T>,
1023    /// Final LayerNorm over the last hidden state.
1024    pub final_layer_norm: LayerNorm<T>,
1025    /// Frozen copy of the configuration used to build the module.
1026    pub config: ClipTextConfig,
1027    training: bool,
1028}
1029
1030impl<T: Float> ClipTextEncoder<T> {
1031    /// Build a randomly-initialized text encoder.
1032    ///
1033    /// # Errors
1034    ///
1035    /// Returns the underlying [`FerrotorchError`] from any sub-module
1036    /// constructor.
1037    pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1038        cfg.validate()?;
1039        let embeddings = ClipTextEmbeddings::new(&cfg)?;
1040        let encoder = ClipEncoder::new(&cfg)?;
1041        let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1042        Ok(Self {
1043            embeddings,
1044            encoder,
1045            final_layer_norm,
1046            config: cfg,
1047            training: false,
1048        })
1049    }
1050
1051    /// Run the encoder on a token-id sequence and return the per-token
1052    /// `last_hidden_state` `[1, S, hidden_size]`.
1053    ///
1054    /// `input_ids` is the verbatim CLIP-BPE token-id vector (length
1055    /// `S`). For SD-1.5 the canonical inference call is `S = 77`
1056    /// (already-padded with EOS to the max length).
1057    ///
1058    /// # Errors
1059    ///
1060    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty
1061    ///   or longer than `max_position_embeddings`.
1062    /// * Propagates downstream Embedding / LayerNorm errors.
1063    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1064        let h = self.embeddings.forward_from_ids(input_ids)?;
1065        let h = self.encoder.forward(&h)?;
1066        self.final_layer_norm.forward(&h)
1067    }
1068
1069    /// Run the encoder on a pre-built float-encoded token-id tensor of
1070    /// shape `[S]`. Returns `[1, S, hidden_size]`.
1071    ///
1072    /// `ids` carries u32 token ids losslessly cast to `T`
1073    /// (`numeric_cast::cast::<u32, T>`). Mirrors what the dump example
1074    /// reads off disk.
1075    ///
1076    /// # Errors
1077    ///
1078    /// Propagates downstream lookup / LayerNorm errors and converts
1079    /// invalid (negative / NaN / overflow) ids to
1080    /// [`FerrotorchError::InvalidArgument`].
1081    pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1082        // Convert to u32 ids — same defensive cast the underlying
1083        // `Embedding::forward` does.
1084        if ids.ndim() != 1 {
1085            return Err(FerrotorchError::ShapeMismatch {
1086                message: format!(
1087                    "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1088                    ids.shape()
1089                ),
1090            });
1091        }
1092        let data = ids.data_vec()?;
1093        let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1094        for (i, v) in data.iter().enumerate() {
1095            let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1096                FerrotorchError::InvalidArgument {
1097                    message: format!(
1098                        "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1099                         not representable as f64"
1100                    ),
1101                }
1102            })?;
1103            if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1104                return Err(FerrotorchError::InvalidArgument {
1105                    message: format!(
1106                        "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1107                         is not a non-negative integer"
1108                    ),
1109                });
1110            }
1111            u32_ids.push(f as u32);
1112        }
1113        self.forward_from_ids(&u32_ids)
1114    }
1115
1116    /// Load a HuggingFace `CLIPTextModel` state dict into this module.
1117    ///
1118    /// Accepts both:
1119    ///   - bare-`text_model` layout (no prefix; what the pin script
1120    ///     normalises to).
1121    ///   - full `text_model.<rest>` prefix (what the upstream HF
1122    ///     checkpoint ships).
1123    ///
1124    /// The HF safetensors also ships a non-parameter buffer we
1125    /// explicitly drop:
1126    ///
1127    /// * `text_model.embeddings.position_ids` — a `[1, max_pos]`
1128    ///   `arange(max_pos)` buffer regenerated each forward pass. Recorded
1129    ///   in the [`crate::safetensors_loader::DropReport`].
1130    ///
1131    /// # Errors
1132    ///
1133    /// Forwards whatever each sub-module's `load_state_dict` returns
1134    /// (shape mismatch / strict-mode missing key). Strict mode will
1135    /// surface `text_model.embeddings.position_ids` and any unknown
1136    /// key as errors; callers with a full HF checkpoint must pass
1137    /// `strict=false`.
1138    pub fn load_hf_state_dict(
1139        &mut self,
1140        hf_state: &StateDict<T>,
1141        strict: bool,
1142    ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1143        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1144        let mut dropped: Vec<String> = Vec::new();
1145        for (k, v) in hf_state {
1146            // Strip the optional `text_model.` prefix.
1147            let after = k
1148                .strip_prefix("text_model.")
1149                .map_or_else(|| k.clone(), str::to_owned);
1150
1151            // `embeddings.position_ids` is a buffer — not a parameter on our
1152            // side. Drop in both modes; record so the pin script can audit.
1153            if after == "embeddings.position_ids" {
1154                dropped.push(k.clone());
1155                continue;
1156            }
1157
1158            let is_known = after.starts_with("embeddings.token_embedding.")
1159                || after.starts_with("embeddings.position_embedding.")
1160                || after.starts_with("encoder.")
1161                || after.starts_with("final_layer_norm.");
1162            if is_known {
1163                remapped.insert(after, v.clone());
1164                continue;
1165            }
1166
1167            if strict {
1168                return Err(FerrotorchError::InvalidArgument {
1169                    message: format!(
1170                        "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1171                         known CLIP text-tower parameter and strict mode is on. \
1172                         Pass strict=false to drop unknown keys."
1173                    ),
1174                });
1175            }
1176            dropped.push(k.clone());
1177        }
1178        dropped.sort();
1179        self.load_state_dict(&remapped, strict)?;
1180        Ok(crate::safetensors_loader::DropReport { dropped })
1181    }
1182}
1183
1184impl<T: Float> Module<T> for ClipTextEncoder<T> {
1185    /// `Module::forward` treats `input` as already-summed embeddings
1186    /// `[1, S, hidden]` and only runs the encoder + final LayerNorm.
1187    /// Real callers should use [`Self::forward_from_ids`] /
1188    /// [`Self::forward_from_id_tensor`].
1189    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1190        let h = self.encoder.forward(input)?;
1191        self.final_layer_norm.forward(&h)
1192    }
1193
1194    fn parameters(&self) -> Vec<&Parameter<T>> {
1195        let mut out = Vec::new();
1196        out.extend(self.embeddings.parameters());
1197        out.extend(self.encoder.parameters());
1198        out.extend(self.final_layer_norm.parameters());
1199        out
1200    }
1201
1202    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1203        let mut out = Vec::new();
1204        out.extend(self.embeddings.parameters_mut());
1205        out.extend(self.encoder.parameters_mut());
1206        out.extend(self.final_layer_norm.parameters_mut());
1207        out
1208    }
1209
1210    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1211        let mut out = Vec::new();
1212        for (n, p) in self.embeddings.named_parameters() {
1213            out.push((format!("embeddings.{n}"), p));
1214        }
1215        for (n, p) in self.encoder.named_parameters() {
1216            out.push((format!("encoder.{n}"), p));
1217        }
1218        for (n, p) in self.final_layer_norm.named_parameters() {
1219            out.push((format!("final_layer_norm.{n}"), p));
1220        }
1221        out
1222    }
1223
1224    fn train(&mut self) {
1225        self.training = true;
1226        self.embeddings.train();
1227        self.encoder.train();
1228        self.final_layer_norm.train();
1229    }
1230
1231    fn eval(&mut self) {
1232        self.training = false;
1233        self.embeddings.eval();
1234        self.encoder.eval();
1235        self.final_layer_norm.eval();
1236    }
1237
1238    fn is_training(&self) -> bool {
1239        self.training
1240    }
1241
1242    fn state_dict(&self) -> StateDict<T> {
1243        self.named_parameters()
1244            .into_iter()
1245            .map(|(n, p)| (n, p.tensor().clone()))
1246            .collect()
1247    }
1248
1249    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1250        let extract = |prefix: &str| -> StateDict<T> {
1251            let p = format!("{prefix}.");
1252            state
1253                .iter()
1254                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1255                .collect()
1256        };
1257        if strict {
1258            for k in state.keys() {
1259                if !(k.starts_with("embeddings.")
1260                    || k.starts_with("encoder.")
1261                    || k.starts_with("final_layer_norm."))
1262                {
1263                    return Err(FerrotorchError::InvalidArgument {
1264                        message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1265                    });
1266                }
1267            }
1268        }
1269        self.embeddings
1270            .load_state_dict(&extract("embeddings"), strict)?;
1271        self.encoder.load_state_dict(&extract("encoder"), strict)?;
1272        self.final_layer_norm
1273            .load_state_dict(&extract("final_layer_norm"), strict)?;
1274        Ok(())
1275    }
1276}
1277
1278// `mul` is re-exported above so downstream features (e.g. embedding
1279// scaling, never used for CLIP-ViT-L) can reach for it without a fresh
1280// import. Suppress the unused-import warning on a vanilla build.
1281#[allow(dead_code)]
1282fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1283    mul(a, b)
1284}
1285
1286// ---------------------------------------------------------------------------
1287// Tests
1288// ---------------------------------------------------------------------------
1289
1290#[cfg(test)]
1291mod tests {
1292    use super::*;
1293
1294    fn tiny_cfg() -> ClipTextConfig {
1295        // 2 heads × 4 dim/head = 8 hidden_size; 16 intermediate; 1 layer;
1296        // 6 positions; tiny vocab.
1297        ClipTextConfig {
1298            hidden_size: 8,
1299            intermediate_size: 16,
1300            num_attention_heads: 2,
1301            num_hidden_layers: 1,
1302            max_position_embeddings: 6,
1303            vocab_size: 32,
1304            layer_norm_eps: 1e-5,
1305        }
1306    }
1307
1308    #[test]
1309    fn sd_v1_5_config_is_canonical() {
1310        let c = ClipTextConfig::sd_v1_5();
1311        assert_eq!(c.hidden_size, 768);
1312        assert_eq!(c.intermediate_size, 3072);
1313        assert_eq!(c.num_attention_heads, 12);
1314        assert_eq!(c.num_hidden_layers, 12);
1315        assert_eq!(c.max_position_embeddings, 77);
1316        assert_eq!(c.vocab_size, 49408);
1317        assert_eq!(c.head_dim(), 64);
1318        c.validate().unwrap();
1319    }
1320
1321    #[test]
1322    fn validate_catches_bad_head_count() {
1323        let mut c = tiny_cfg();
1324        c.num_attention_heads = 3; // 8 % 3 != 0
1325        assert!(c.validate().is_err());
1326    }
1327
1328    #[test]
1329    fn from_json_str_round_trip() {
1330        let json = r#"{
1331            "hidden_size": 768,
1332            "intermediate_size": 3072,
1333            "num_attention_heads": 12,
1334            "num_hidden_layers": 12,
1335            "max_position_embeddings": 77,
1336            "vocab_size": 49408,
1337            "layer_norm_eps": 1e-5,
1338            "hidden_act": "quick_gelu"
1339        }"#;
1340        let c = ClipTextConfig::from_json_str(json).unwrap();
1341        assert_eq!(c.hidden_size, 768);
1342        assert_eq!(c.intermediate_size, 3072);
1343        assert_eq!(c.num_attention_heads, 12);
1344        assert_eq!(c.num_hidden_layers, 12);
1345        assert_eq!(c.max_position_embeddings, 77);
1346    }
1347
1348    #[test]
1349    fn embeddings_forward_shape() {
1350        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1351        let ids = [1u32, 5, 7, 9];
1352        let out = emb.forward_from_ids(&ids).unwrap();
1353        assert_eq!(out.shape(), &[1, 4, 8]);
1354        for &v in out.data().unwrap() {
1355            assert!(v.is_finite(), "embedding non-finite: {v}");
1356        }
1357    }
1358
1359    #[test]
1360    fn embeddings_reject_too_long_sequence() {
1361        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1362        let ids: Vec<u32> = (0..7).collect(); // > max_position 6
1363        assert!(emb.forward_from_ids(&ids).is_err());
1364    }
1365
1366    #[test]
1367    fn self_attention_forward_shape() {
1368        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1369        let x = Tensor::from_storage(
1370            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1371            vec![1, 5, 8],
1372            false,
1373        )
1374        .unwrap();
1375        let out = attn.forward(&x).unwrap();
1376        assert_eq!(out.shape(), &[1, 5, 8]);
1377        for &v in out.data().unwrap() {
1378            assert!(v.is_finite());
1379        }
1380    }
1381
1382    #[test]
1383    fn self_attention_is_actually_causal() {
1384        // Changing later tokens MUST NOT change earlier rows.
1385        // Build a tensor [1, 4, 8] with the first 2 rows fixed and the
1386        // last 2 rows perturbed across two runs. The first 2 output
1387        // rows must be bit-identical between the runs (within f32
1388        // round-off).
1389        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1390        let mut a = vec![0.1f32; 4 * 8];
1391        for i in 0..2 * 8 {
1392            a[i] = ((i + 1) as f32).sin();
1393        }
1394        let mut b = a.clone();
1395        // Perturb only rows 2 and 3.
1396        for i in (2 * 8)..(4 * 8) {
1397            b[i] = ((i + 11) as f32).sin();
1398        }
1399        let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1400        let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1401        let oa = attn.forward(&xa).unwrap();
1402        let ob = attn.forward(&xb).unwrap();
1403        let da = oa.data().unwrap();
1404        let db = ob.data().unwrap();
1405        for i in 0..2 * 8 {
1406            assert!(
1407                (da[i] - db[i]).abs() < 1e-5,
1408                "row {} ({}) differs between runs: {} vs {}",
1409                i / 8,
1410                i % 8,
1411                da[i],
1412                db[i]
1413            );
1414        }
1415    }
1416
1417    #[test]
1418    fn mlp_uses_quick_gelu() {
1419        // QuickGELU(x) = x * sigmoid(1.702 * x). Verify the FC1 + GELU
1420        // branch produces this for a known scalar input — we do so
1421        // indirectly by checking that the forward output remains finite
1422        // and the intermediate activation at zero input gives bias.
1423        let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1424        let x = Tensor::from_storage(
1425            TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1426            vec![1, 3, 8],
1427            false,
1428        )
1429        .unwrap();
1430        let out = mlp.forward(&x).unwrap();
1431        assert_eq!(out.shape(), &[1, 3, 8]);
1432        for &v in out.data().unwrap() {
1433            assert!(v.is_finite());
1434        }
1435    }
1436
1437    #[test]
1438    fn encoder_layer_forward_shape() {
1439        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1440        let x = Tensor::from_storage(
1441            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1442            vec![1, 5, 8],
1443            false,
1444        )
1445        .unwrap();
1446        let out = layer.forward(&x).unwrap();
1447        assert_eq!(out.shape(), &[1, 5, 8]);
1448        for &v in out.data().unwrap() {
1449            assert!(v.is_finite());
1450        }
1451    }
1452
1453    #[test]
1454    fn encoder_layer_named_parameters_use_hf_layout() {
1455        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1456        let names: Vec<String> = layer
1457            .named_parameters()
1458            .into_iter()
1459            .map(|(n, _)| n)
1460            .collect();
1461        for k in [
1462            "layer_norm1.weight",
1463            "layer_norm1.bias",
1464            "self_attn.q_proj.weight",
1465            "self_attn.q_proj.bias",
1466            "self_attn.k_proj.weight",
1467            "self_attn.v_proj.weight",
1468            "self_attn.out_proj.weight",
1469            "self_attn.out_proj.bias",
1470            "layer_norm2.weight",
1471            "mlp.fc1.weight",
1472            "mlp.fc1.bias",
1473            "mlp.fc2.weight",
1474            "mlp.fc2.bias",
1475        ] {
1476            assert!(
1477                names.iter().any(|n| n == k),
1478                "missing parameter key {k:?} in {names:?}"
1479            );
1480        }
1481    }
1482
1483    #[test]
1484    fn tiny_encoder_forward_from_ids_shape() {
1485        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1486        let ids = vec![1u32, 5, 7];
1487        let out = enc.forward_from_ids(&ids).unwrap();
1488        assert_eq!(out.shape(), &[1, 3, 8]);
1489        for &v in out.data().unwrap() {
1490            assert!(v.is_finite());
1491        }
1492    }
1493
1494    #[test]
1495    fn tiny_named_parameters_use_hf_layout() {
1496        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1497        let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1498        for k in [
1499            "embeddings.token_embedding.weight",
1500            "embeddings.position_embedding.weight",
1501            "encoder.layers.0.layer_norm1.weight",
1502            "encoder.layers.0.self_attn.q_proj.weight",
1503            "encoder.layers.0.self_attn.out_proj.bias",
1504            "encoder.layers.0.layer_norm2.bias",
1505            "encoder.layers.0.mlp.fc1.weight",
1506            "encoder.layers.0.mlp.fc2.bias",
1507            "final_layer_norm.weight",
1508            "final_layer_norm.bias",
1509        ] {
1510            assert!(
1511                names.iter().any(|n| n == k),
1512                "missing parameter key {k:?} in {names:?}"
1513            );
1514        }
1515    }
1516
1517    #[test]
1518    fn round_trip_state_dict() {
1519        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1520        let sd = src.state_dict();
1521        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1522        dst.load_state_dict(&sd, true).unwrap();
1523        let ids = vec![2u32, 4, 6];
1524        let a = src.forward_from_ids(&ids).unwrap();
1525        let b = dst.forward_from_ids(&ids).unwrap();
1526        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1527            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1528        }
1529    }
1530
1531    #[test]
1532    fn load_hf_state_dict_strips_text_model_prefix() {
1533        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1534        let bare = src.state_dict();
1535        let mut prefixed: StateDict<f32> = HashMap::new();
1536        for (k, v) in bare {
1537            prefixed.insert(format!("text_model.{k}"), v);
1538        }
1539        // Add the position_ids buffer — it should be dropped.
1540        prefixed.insert(
1541            "text_model.embeddings.position_ids".into(),
1542            ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1543        );
1544        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1545        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1546        assert_eq!(
1547            rep.dropped,
1548            vec!["text_model.embeddings.position_ids".to_string()]
1549        );
1550        let ids = vec![1u32, 2, 3];
1551        let a = src.forward_from_ids(&ids).unwrap();
1552        let b = dst.forward_from_ids(&ids).unwrap();
1553        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1554            assert!((x - y).abs() < 1e-5);
1555        }
1556    }
1557
1558    #[test]
1559    fn load_hf_state_dict_strict_rejects_unknown_key() {
1560        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1561        let mut sd: StateDict<f32> = HashMap::new();
1562        sd.insert(
1563            "mystery.key".into(),
1564            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1565        );
1566        assert!(dst.load_hf_state_dict(&sd, true).is_err());
1567    }
1568
1569    #[test]
1570    fn forward_from_id_tensor_matches_forward_from_ids() {
1571        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1572        let ids = vec![1u32, 5, 7];
1573        let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1574        let a = enc.forward_from_ids(&ids).unwrap();
1575        let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1576        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1577            assert!((x - y).abs() < 1e-5);
1578        }
1579    }
1580}