Skip to main content

ferrotorch_diffusion/
clip_text_encoder.rs

1//! Stable-Diffusion 1.5 CLIP text encoder
2//! (`openai/clip-vit-large-patch14` — the text tower of CLIP-ViT-L/14).
3//!
4//! Phase B.3c of real-artifact-driven development — third and final SD
5//! sub-model. Together with the VAE decoder (Phase B.3a, #1150) and the
6//! UNet (Phase B.3b, #1151) this completes the bit-perfect SD-1.5
7//! inference pipeline:
8//!
9//! ```text
10//! CLIP text encoder  →  UNet (denoise)  →  VAE decoder
11//! [B, S=77, 768]        [B, 4, 64, 64]    [B, 3, 512, 512]
12//! ```
13//!
14//! Mirrors `transformers.CLIPTextModel` for the SD-1.5 config exactly:
15//!
16//! ```text
17//! hidden_size        = 768
18//! intermediate_size  = 3072
19//! num_attention_heads = 12
20//! num_hidden_layers  = 12
21//! max_position_embeddings = 77
22//! vocab_size         = 49408
23//! hidden_act         = "quick_gelu"     # x * sigmoid(1.702 * x)
24//! layer_norm_eps     = 1e-5
25//! ```
26//!
27//! Architecture (state-dict prefix in parens):
28//!
29//! ```text
30//! CLIPTextModel
31//! └── text_model
32//!     ├── embeddings
33//!     │   ├── token_embedding.weight    [49408, 768]
34//!     │   └── position_embedding.weight [77, 768]
35//!     ├── encoder
36//!     │   └── layers.{0..11}.
37//!     │       ├── layer_norm1.{weight,bias}    [768], [768]
38//!     │       ├── self_attn.
39//!     │       │   ├── q_proj.{weight,bias}    [768,768], [768]
40//!     │       │   ├── k_proj.{weight,bias}    [768,768], [768]
41//!     │       │   ├── v_proj.{weight,bias}    [768,768], [768]
42//!     │       │   └── out_proj.{weight,bias}  [768,768], [768]
43//!     │       ├── layer_norm2.{weight,bias}    [768], [768]
44//!     │       └── mlp.
45//!     │           ├── fc1.{weight,bias}        [3072,768], [3072]
46//!     │           └── fc2.{weight,bias}        [768,3072], [768]
47//!     └── final_layer_norm.{weight,bias}        [768], [768]
48//! ```
49//!
50//! Forward pass (per layer is pre-LayerNorm + residual):
51//!
52//! ```text
53//! h = token_embedding(input_ids) + position_embedding(arange(S))
54//! for layer in encoder.layers:
55//!     residual = h
56//!     h = layer_norm1(h)
57//!     h = causal_self_attn(h, h, h)            # ← causal mask is critical
58//!     h = residual + h
59//!     residual = h
60//!     h = layer_norm2(h)
61//!     h = fc2(quick_gelu(fc1(h)))
62//!     h = residual + h
63//! h = final_layer_norm(h)
64//! return h                                      # last_hidden_state [B, S, 768]
65//! ```
66//!
67//! ## Critical correctness gotchas
68//!
69//! 1. **Causal mask**. Despite the "encoder" name, CLIP text-side
70//!    self-attention is causal — position `i` attends only to
71//!    `0..=i`. Omitting this would still pass shape checks but break
72//!    parity vs `transformers`.
73//! 2. **QuickGELU**, not standard GELU. CLIP-ViT-L/14 uses the fast
74//!    sigmoid approximation `x * sigmoid(1.702 * x)`, not the erf-based
75//!    or tanh-based GELU. We pin this via
76//!    `GELU::with_approximate(GeluApproximate::Sigmoid)`.
77//! 3. **Position embedding is *learned* and absolute** — full table of
78//!    77 entries. Position ids are `[0, 1, ..., S-1]` for every forward.
79//! 4. **All four self-attention projections have bias** (unlike SD's
80//!    UNet `Attention` which has `bias=False` on q/k/v).
81//! 5. **SD uses `last_hidden_state` directly**, not the EOS-pooled
82//!    output. We return `[B, S, hidden_size]`.
83//!
84//! ## REQ status (per `.design/ferrotorch-diffusion/clip_text_encoder.md`)
85//!
86//! | REQ | Status | Evidence |
87//! |---|---|---|
88//! | REQ-1 | SHIPPED | `ClipTextConfig` at `ferrotorch-diffusion/src/clip_text_encoder.rs:104..184`; consumer: `ferrotorch-diffusion/src/safetensors_loader.rs:17` imports it and `load_clip_text_encoder` at `safetensors_loader.rs:272` constructs the encoder |
89//! | REQ-2 | SHIPPED | `ClipTextEmbeddings` at `ferrotorch-diffusion/src/clip_text_encoder.rs:298..453`; consumer: `ClipTextEncoder::forward_from_ids` at `clip_text_encoder.rs:1064` calls `self.embeddings.forward_from_ids` |
90//! | REQ-3 | SHIPPED | `ClipSelfAttention` at `ferrotorch-diffusion/src/clip_text_encoder.rs:470..636` with causal `standard_attention(..., causal=true)`; consumer: `ClipEncoderLayer` at `clip_text_encoder.rs:759..887` and `pipeline.rs:101..103` |
91//! | REQ-4 | SHIPPED | `ClipMlp` at `ferrotorch-diffusion/src/clip_text_encoder.rs:645..748` using `GELU::with_approximate(GeluApproximate::Sigmoid)`; consumer: `ClipEncoderLayer` and the encode path from `pipeline.rs:101..103` |
92//! | REQ-5 | SHIPPED | `ClipEncoderLayer` at `ferrotorch-diffusion/src/clip_text_encoder.rs:759..887`; consumer: `ClipEncoder` at `clip_text_encoder.rs:891..998` chains them; reached from `pipeline.rs:101..103` |
93//! | REQ-6 | SHIPPED | `ClipEncoder::new` at `ferrotorch-diffusion/src/clip_text_encoder.rs:897..914` and forward at `clip_text_encoder.rs:916..998`; consumer: `ClipTextEncoder::forward_from_ids` at `clip_text_encoder.rs:1065` calls `self.encoder.forward(&h)` |
94//! | REQ-7 | SHIPPED | `ClipTextEncoder::forward_from_ids` at `ferrotorch-diffusion/src/clip_text_encoder.rs:1063..1067`; consumer: `ferrotorch-diffusion/src/pipeline.rs:102` calls it inside `encode_prompt` |
95//! | REQ-8 | SHIPPED | `load_hf_state_dict` at `ferrotorch-diffusion/src/clip_text_encoder.rs:1138..1181`; consumer: `ferrotorch-diffusion/src/safetensors_loader.rs:272..` `load_clip_text_encoder` invokes it on the HF checkpoint |
96
97use std::collections::HashMap;
98
99use ferrotorch_core::grad_fns::arithmetic::{add, mul};
100use ferrotorch_core::{
101    FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
102};
103use ferrotorch_nn::module::{Module, StateDict};
104use ferrotorch_nn::parameter::Parameter;
105use ferrotorch_nn::{
106    Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
107    transpose_heads_to_2d,
108};
109
110// ---------------------------------------------------------------------------
111// Config
112// ---------------------------------------------------------------------------
113
114/// Configuration for the SD-1.5 CLIP text encoder
115/// (`runwayml/stable-diffusion-v1-5/text_encoder/config.json`).
116#[derive(Debug, Clone)]
117pub struct ClipTextConfig {
118    /// Hidden width. SD-1.5: 768.
119    pub hidden_size: usize,
120    /// FFN expansion width. SD-1.5: 3072.
121    pub intermediate_size: usize,
122    /// Number of attention heads per layer. SD-1.5: 12. Must divide
123    /// `hidden_size` evenly.
124    pub num_attention_heads: usize,
125    /// Number of transformer layers. SD-1.5: 12.
126    pub num_hidden_layers: usize,
127    /// Maximum sequence length. SD-1.5: 77.
128    pub max_position_embeddings: usize,
129    /// Token vocabulary size. SD-1.5: 49408.
130    pub vocab_size: usize,
131    /// LayerNorm epsilon. SD-1.5: 1e-5.
132    pub layer_norm_eps: f64,
133}
134
135impl Default for ClipTextConfig {
136    fn default() -> Self {
137        Self::sd_v1_5()
138    }
139}
140
141impl ClipTextConfig {
142    /// SD-1.5 CLIP text encoder defaults (CLIP-ViT-L/14 text tower).
143    pub fn sd_v1_5() -> Self {
144        Self {
145            hidden_size: 768,
146            intermediate_size: 3072,
147            num_attention_heads: 12,
148            num_hidden_layers: 12,
149            max_position_embeddings: 77,
150            vocab_size: 49408,
151            layer_norm_eps: 1e-5,
152        }
153    }
154
155    /// Per-head dimension.
156    #[inline]
157    #[must_use]
158    pub fn head_dim(&self) -> usize {
159        self.hidden_size / self.num_attention_heads
160    }
161
162    /// Validate field bounds.
163    ///
164    /// # Errors
165    ///
166    /// Returns [`FerrotorchError::InvalidArgument`] for any out-of-bounds
167    /// or arithmetic-incompatible field.
168    pub fn validate(&self) -> FerrotorchResult<()> {
169        if self.hidden_size == 0
170            || self.intermediate_size == 0
171            || self.num_attention_heads == 0
172            || self.num_hidden_layers == 0
173            || self.max_position_embeddings == 0
174            || self.vocab_size == 0
175        {
176            return Err(FerrotorchError::InvalidArgument {
177                message: "ClipTextConfig: all size fields must be > 0".into(),
178            });
179        }
180        if self.hidden_size % self.num_attention_heads != 0 {
181            return Err(FerrotorchError::InvalidArgument {
182                message: format!(
183                    "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
184                    self.hidden_size, self.num_attention_heads,
185                ),
186            });
187        }
188        if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
189            return Err(FerrotorchError::InvalidArgument {
190                message: format!(
191                    "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
192                    self.layer_norm_eps,
193                ),
194            });
195        }
196        Ok(())
197    }
198
199    /// Parse a `text_encoder/config.json` document into a [`ClipTextConfig`].
200    ///
201    /// Recognised keys (all optional — anything missing falls back to the
202    /// SD-1.5 defaults): `hidden_size`, `intermediate_size`,
203    /// `num_attention_heads`, `num_hidden_layers`,
204    /// `max_position_embeddings`, `vocab_size`, `layer_norm_eps`.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`FerrotorchError::InvalidArgument`] on malformed JSON or
209    /// invalid field values.
210    pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
211        let v: serde_json::Value =
212            serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
213                message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
214            })?;
215        let mut cfg = Self::default();
216        if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
217            cfg.hidden_size = x as usize;
218        }
219        if let Some(x) = v
220            .get("intermediate_size")
221            .and_then(serde_json::Value::as_u64)
222        {
223            cfg.intermediate_size = x as usize;
224        }
225        if let Some(x) = v
226            .get("num_attention_heads")
227            .and_then(serde_json::Value::as_u64)
228        {
229            cfg.num_attention_heads = x as usize;
230        }
231        if let Some(x) = v
232            .get("num_hidden_layers")
233            .and_then(serde_json::Value::as_u64)
234        {
235            cfg.num_hidden_layers = x as usize;
236        }
237        if let Some(x) = v
238            .get("max_position_embeddings")
239            .and_then(serde_json::Value::as_u64)
240        {
241            cfg.max_position_embeddings = x as usize;
242        }
243        if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
244            cfg.vocab_size = x as usize;
245        }
246        if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
247            cfg.layer_norm_eps = x;
248        }
249        cfg.validate()?;
250        Ok(cfg)
251    }
252
253    /// Parse a `text_encoder/config.json` file from disk.
254    ///
255    /// # Errors
256    ///
257    /// Returns [`FerrotorchError::InvalidArgument`] for I/O or parse
258    /// failures.
259    pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
260        let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
261            message: format!(
262                "ClipTextConfig::from_file: failed to read {}: {e}",
263                path.display(),
264            ),
265        })?;
266        Self::from_json_str(&s)
267    }
268}
269
270// ---------------------------------------------------------------------------
271// Helper: reshape utilities (own-the-data so the per-row buffers are
272// always contiguous and ready for the BMM-style attention call).
273// ---------------------------------------------------------------------------
274
275fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
276    let prod: usize = shape.iter().product();
277    if prod != t.numel() {
278        return Err(FerrotorchError::ShapeMismatch {
279            message: format!(
280                "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
281                 match source numel {}",
282                t.numel()
283            ),
284        });
285    }
286    let data = t.data_vec()?;
287    Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
288}
289
290/// Build a 1-D float-encoded index tensor from u32 ids (matches the
291/// trick `BertEmbeddings::float_index_tensor` uses).
292fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
293    let data: Vec<T> = ids
294        .iter()
295        .map(|&i| numeric_cast::cast::<u32, T>(i))
296        .collect::<FerrotorchResult<Vec<T>>>()?;
297    let n = data.len();
298    Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
299}
300
301// ---------------------------------------------------------------------------
302// CLIPTextEmbeddings
303// ---------------------------------------------------------------------------
304
305/// Token embedding + learned absolute position embedding. The two
306/// lookups are summed. Mirrors `CLIPTextEmbeddings` in transformers.
307///
308/// Note: there is NO LayerNorm at the embedding level (unlike BERT).
309/// The first per-layer `layer_norm1` handles normalisation downstream.
310#[derive(Debug)]
311pub struct ClipTextEmbeddings<T: Float> {
312    /// Token lookup — `[vocab_size, hidden_size]`.
313    pub token_embedding: Embedding<T>,
314    /// Learned position lookup — `[max_position_embeddings, hidden_size]`.
315    pub position_embedding: Embedding<T>,
316    hidden_size: usize,
317    max_position_embeddings: usize,
318    training: bool,
319}
320
321impl<T: Float> ClipTextEmbeddings<T> {
322    /// Build randomly-initialized embeddings for the given config.
323    ///
324    /// # Errors
325    ///
326    /// Returns [`FerrotorchError`] from the underlying [`Embedding`]
327    /// constructors.
328    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
329        cfg.validate()?;
330        Ok(Self {
331            token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
332            position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
333            hidden_size: cfg.hidden_size,
334            max_position_embeddings: cfg.max_position_embeddings,
335            training: false,
336        })
337    }
338
339    /// Run the embedding sum on a sequence of token ids.
340    ///
341    /// `input_ids` is the verbatim token-id vector (length `S`). The
342    /// output has shape `[1, S, hidden]`.
343    ///
344    /// # Errors
345    ///
346    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty or
347    ///   exceeds `max_position_embeddings`.
348    /// * Propagates downstream embedding-lookup errors.
349    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
350        if input_ids.is_empty() {
351            return Err(FerrotorchError::InvalidArgument {
352                message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
353            });
354        }
355        let seq_len = input_ids.len();
356        if seq_len > self.max_position_embeddings {
357            return Err(FerrotorchError::InvalidArgument {
358                message: format!(
359                    "ClipTextEmbeddings: sequence length {seq_len} exceeds \
360                     max_position_embeddings {}",
361                    self.max_position_embeddings,
362                ),
363            });
364        }
365
366        let word_idx = float_index_tensor::<T>(input_ids)?;
367        let word_2d = self.token_embedding.forward(&word_idx)?; // [S, hidden]
368
369        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
370        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
371        let pos_2d = self.position_embedding.forward(&pos_idx)?; // [S, hidden]
372
373        let summed = add(&word_2d, &pos_2d)?;
374        // Promote to [1, S, hidden] so downstream uses 3-D ranks.
375        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
376    }
377}
378
379impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
380    /// When called via `Module::forward` we treat `input` as a
381    /// 1-D float-index tensor (same convention as the inner
382    /// `Embedding` modules). Real callers should use
383    /// [`Self::forward_from_ids`].
384    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
385        let word_2d = self.token_embedding.forward(input)?;
386        let seq_len = input.numel();
387        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
388        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
389        let pos_2d = self.position_embedding.forward(&pos_idx)?;
390        let summed = add(&word_2d, &pos_2d)?;
391        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
392    }
393
394    fn parameters(&self) -> Vec<&Parameter<T>> {
395        let mut out = Vec::new();
396        out.extend(self.token_embedding.parameters());
397        out.extend(self.position_embedding.parameters());
398        out
399    }
400
401    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
402        let mut out = Vec::new();
403        out.extend(self.token_embedding.parameters_mut());
404        out.extend(self.position_embedding.parameters_mut());
405        out
406    }
407
408    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
409        let mut out = Vec::new();
410        for (n, p) in self.token_embedding.named_parameters() {
411            out.push((format!("token_embedding.{n}"), p));
412        }
413        for (n, p) in self.position_embedding.named_parameters() {
414            out.push((format!("position_embedding.{n}"), p));
415        }
416        out
417    }
418
419    fn train(&mut self) {
420        self.training = true;
421    }
422
423    fn eval(&mut self) {
424        self.training = false;
425    }
426
427    fn is_training(&self) -> bool {
428        self.training
429    }
430
431    fn state_dict(&self) -> StateDict<T> {
432        self.named_parameters()
433            .into_iter()
434            .map(|(n, p)| (n, p.tensor().clone()))
435            .collect()
436    }
437
438    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
439        let extract = |prefix: &str| -> StateDict<T> {
440            let p = format!("{prefix}.");
441            state
442                .iter()
443                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
444                .collect()
445        };
446        if strict {
447            let prefixes = ["token_embedding", "position_embedding"];
448            for k in state.keys() {
449                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
450                    return Err(FerrotorchError::InvalidArgument {
451                        message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
452                    });
453                }
454            }
455        }
456        self.token_embedding
457            .load_state_dict(&extract("token_embedding"), strict)?;
458        self.position_embedding
459            .load_state_dict(&extract("position_embedding"), strict)?;
460        Ok(())
461    }
462}
463
464// ---------------------------------------------------------------------------
465// CLIP self-attention
466// ---------------------------------------------------------------------------
467
468/// Multi-head self-attention with a *causal* mask. Mirrors
469/// `CLIPAttention` (the text-side variant) in transformers — all four
470/// projections (q/k/v/out) carry bias, the scale is `1/sqrt(head_dim)`,
471/// and the attention is causal (each position attends to itself and
472/// earlier positions only).
473///
474/// State-dict layout:
475///
476/// ```text
477/// q_proj.{weight,bias}    [hidden, hidden], [hidden]
478/// k_proj.{weight,bias}    [hidden, hidden], [hidden]
479/// v_proj.{weight,bias}    [hidden, hidden], [hidden]
480/// out_proj.{weight,bias}  [hidden, hidden], [hidden]
481/// ```
482#[derive(Debug)]
483pub struct ClipSelfAttention<T: Float> {
484    /// Query projection — `[hidden, hidden]`, with bias.
485    pub q_proj: Linear<T>,
486    /// Key projection — `[hidden, hidden]`, with bias.
487    pub k_proj: Linear<T>,
488    /// Value projection — `[hidden, hidden]`, with bias.
489    pub v_proj: Linear<T>,
490    /// Output projection — `[hidden, hidden]`, with bias.
491    pub out_proj: Linear<T>,
492    num_heads: usize,
493    head_dim: usize,
494    hidden: usize,
495    training: bool,
496}
497
498impl<T: Float> ClipSelfAttention<T> {
499    /// Build randomly-initialized self-attention projections.
500    ///
501    /// # Errors
502    ///
503    /// Returns the underlying [`FerrotorchError`] on bad config dims.
504    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
505        cfg.validate()?;
506        Ok(Self {
507            q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
508            k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
509            v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
510            out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
511            num_heads: cfg.num_attention_heads,
512            head_dim: cfg.head_dim(),
513            hidden: cfg.hidden_size,
514            training: false,
515        })
516    }
517}
518
519impl<T: Float> Module<T> for ClipSelfAttention<T> {
520    /// Forward — input `[1, S, hidden]`, output `[1, S, hidden]`.
521    ///
522    /// The attention is causal: position `i` cannot attend to position
523    /// `j > i`. This matches `transformers.CLIPAttention`'s use of
524    /// `_create_4d_causal_attention_mask`.
525    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
526        let shape = input.shape();
527        if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
528            return Err(FerrotorchError::ShapeMismatch {
529                message: format!(
530                    "ClipSelfAttention expects [1, S, {}], got {:?}",
531                    self.hidden, shape,
532                ),
533            });
534        }
535        let seq_len = shape[1];
536
537        // Projections — Linear handles any rank; output is [1, S, hidden].
538        let q = self.q_proj.forward(input)?;
539        let k = self.k_proj.forward(input)?;
540        let v = self.v_proj.forward(input)?;
541
542        // Drop the batch=1 leading dim so reshape_to_heads / the
543        // attention helper can treat the rows as [S, H*d].
544        let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
545        let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
546        let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
547
548        // [S, H*d] → [H, S, d] (batch-first heads).
549        let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
550        let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
551        let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
552
553        // Scaled dot-product attention with causal mask. `standard_attention`
554        // applies `1/sqrt(head_dim)` scaling and `-inf` upper-triangular
555        // mask, then softmax + value mix.
556        let ctx = standard_attention(&q_h, &k_h, &v_h, /* causal = */ true)?;
557
558        // [H, S, d] → [S, H*d] → [1, S, hidden].
559        let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
560        let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
561
562        // Output projection (with bias).
563        self.out_proj.forward(&ctx3)
564    }
565
566    fn parameters(&self) -> Vec<&Parameter<T>> {
567        let mut out = Vec::new();
568        out.extend(self.q_proj.parameters());
569        out.extend(self.k_proj.parameters());
570        out.extend(self.v_proj.parameters());
571        out.extend(self.out_proj.parameters());
572        out
573    }
574
575    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
576        let mut out = Vec::new();
577        out.extend(self.q_proj.parameters_mut());
578        out.extend(self.k_proj.parameters_mut());
579        out.extend(self.v_proj.parameters_mut());
580        out.extend(self.out_proj.parameters_mut());
581        out
582    }
583
584    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
585        let mut out = Vec::new();
586        for (n, p) in self.q_proj.named_parameters() {
587            out.push((format!("q_proj.{n}"), p));
588        }
589        for (n, p) in self.k_proj.named_parameters() {
590            out.push((format!("k_proj.{n}"), p));
591        }
592        for (n, p) in self.v_proj.named_parameters() {
593            out.push((format!("v_proj.{n}"), p));
594        }
595        for (n, p) in self.out_proj.named_parameters() {
596            out.push((format!("out_proj.{n}"), p));
597        }
598        out
599    }
600
601    fn train(&mut self) {
602        self.training = true;
603    }
604
605    fn eval(&mut self) {
606        self.training = false;
607    }
608
609    fn is_training(&self) -> bool {
610        self.training
611    }
612
613    fn state_dict(&self) -> StateDict<T> {
614        self.named_parameters()
615            .into_iter()
616            .map(|(n, p)| (n, p.tensor().clone()))
617            .collect()
618    }
619
620    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
621        let extract = |prefix: &str| -> StateDict<T> {
622            let p = format!("{prefix}.");
623            state
624                .iter()
625                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
626                .collect()
627        };
628        if strict {
629            let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
630            for k in state.keys() {
631                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
632                    return Err(FerrotorchError::InvalidArgument {
633                        message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
634                    });
635                }
636            }
637        }
638        self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
639        self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
640        self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
641        self.out_proj
642            .load_state_dict(&extract("out_proj"), strict)?;
643        Ok(())
644    }
645}
646
647// ---------------------------------------------------------------------------
648// CLIPMLP
649// ---------------------------------------------------------------------------
650
651/// CLIP MLP: `fc2(quick_gelu(fc1(x)))`.
652///
653/// QuickGELU is `x * sigmoid(1.702 * x)` — the fast sigmoid
654/// approximation. NOT the standard `0.5 * x * (1 + erf(x/sqrt(2)))`
655/// kernel. This is the published `hidden_act = "quick_gelu"` in
656/// CLIP-ViT-L/14's config.
657#[derive(Debug)]
658pub struct ClipMlp<T: Float> {
659    /// Expansion projection — `[hidden, intermediate]`, with bias.
660    pub fc1: Linear<T>,
661    /// Reduction projection — `[intermediate, hidden]`, with bias.
662    pub fc2: Linear<T>,
663    activation: GELU,
664    training: bool,
665}
666
667impl<T: Float> ClipMlp<T> {
668    /// Build randomly-initialized MLP for the given config.
669    ///
670    /// # Errors
671    ///
672    /// Returns the underlying [`FerrotorchError`] on bad config dims.
673    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
674        cfg.validate()?;
675        Ok(Self {
676            fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
677            fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
678            // QuickGELU: `x * sigmoid(1.702 * x)`. See module doc.
679            activation: GELU::with_approximate(GeluApproximate::Sigmoid),
680            training: false,
681        })
682    }
683}
684
685impl<T: Float> Module<T> for ClipMlp<T> {
686    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
687        let h = self.fc1.forward(input)?;
688        let h = self.activation.forward(&h)?;
689        self.fc2.forward(&h)
690    }
691
692    fn parameters(&self) -> Vec<&Parameter<T>> {
693        let mut out = Vec::new();
694        out.extend(self.fc1.parameters());
695        out.extend(self.fc2.parameters());
696        out
697    }
698
699    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
700        let mut out = Vec::new();
701        out.extend(self.fc1.parameters_mut());
702        out.extend(self.fc2.parameters_mut());
703        out
704    }
705
706    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
707        let mut out = Vec::new();
708        for (n, p) in self.fc1.named_parameters() {
709            out.push((format!("fc1.{n}"), p));
710        }
711        for (n, p) in self.fc2.named_parameters() {
712            out.push((format!("fc2.{n}"), p));
713        }
714        out
715    }
716
717    fn train(&mut self) {
718        self.training = true;
719    }
720
721    fn eval(&mut self) {
722        self.training = false;
723    }
724
725    fn is_training(&self) -> bool {
726        self.training
727    }
728
729    fn state_dict(&self) -> StateDict<T> {
730        self.named_parameters()
731            .into_iter()
732            .map(|(n, p)| (n, p.tensor().clone()))
733            .collect()
734    }
735
736    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
737        let extract = |prefix: &str| -> StateDict<T> {
738            let p = format!("{prefix}.");
739            state
740                .iter()
741                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
742                .collect()
743        };
744        if strict {
745            for k in state.keys() {
746                if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
747                    return Err(FerrotorchError::InvalidArgument {
748                        message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
749                    });
750                }
751            }
752        }
753        self.fc1.load_state_dict(&extract("fc1"), strict)?;
754        self.fc2.load_state_dict(&extract("fc2"), strict)?;
755        Ok(())
756    }
757}
758
759// ---------------------------------------------------------------------------
760// CLIPEncoderLayer
761// ---------------------------------------------------------------------------
762
763/// One CLIP text encoder layer.
764///
765/// Pre-LayerNorm + residual stack:
766///
767/// ```text
768/// h = x + self_attn(layer_norm1(x))
769/// h = h + mlp(layer_norm2(h))
770/// ```
771#[derive(Debug)]
772pub struct ClipEncoderLayer<T: Float> {
773    /// Pre-attention LayerNorm.
774    pub layer_norm1: LayerNorm<T>,
775    /// Causal self-attention (q/k/v/out, all biased).
776    pub self_attn: ClipSelfAttention<T>,
777    /// Pre-FFN LayerNorm.
778    pub layer_norm2: LayerNorm<T>,
779    /// Two-layer MLP with QuickGELU activation.
780    pub mlp: ClipMlp<T>,
781    training: bool,
782}
783
784impl<T: Float> ClipEncoderLayer<T> {
785    /// Build a randomly-initialized encoder layer.
786    ///
787    /// # Errors
788    ///
789    /// Returns the underlying [`FerrotorchError`] on bad config dims.
790    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
791        Ok(Self {
792            layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
793            self_attn: ClipSelfAttention::new(cfg)?,
794            layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
795            mlp: ClipMlp::new(cfg)?,
796            training: false,
797        })
798    }
799}
800
801impl<T: Float> Module<T> for ClipEncoderLayer<T> {
802    /// Forward — `[1, S, hidden]` → `[1, S, hidden]`.
803    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
804        // Self-attention sub-block (pre-norm).
805        let normed = self.layer_norm1.forward(input)?;
806        let attn_out = self.self_attn.forward(&normed)?;
807        let after_attn = add(input, &attn_out)?;
808
809        // MLP sub-block (pre-norm).
810        let normed_ffn = self.layer_norm2.forward(&after_attn)?;
811        let mlp_out = self.mlp.forward(&normed_ffn)?;
812        add(&after_attn, &mlp_out)
813    }
814
815    fn parameters(&self) -> Vec<&Parameter<T>> {
816        let mut out = Vec::new();
817        out.extend(self.layer_norm1.parameters());
818        out.extend(self.self_attn.parameters());
819        out.extend(self.layer_norm2.parameters());
820        out.extend(self.mlp.parameters());
821        out
822    }
823
824    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
825        let mut out = Vec::new();
826        out.extend(self.layer_norm1.parameters_mut());
827        out.extend(self.self_attn.parameters_mut());
828        out.extend(self.layer_norm2.parameters_mut());
829        out.extend(self.mlp.parameters_mut());
830        out
831    }
832
833    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
834        let mut out = Vec::new();
835        for (n, p) in self.layer_norm1.named_parameters() {
836            out.push((format!("layer_norm1.{n}"), p));
837        }
838        for (n, p) in self.self_attn.named_parameters() {
839            out.push((format!("self_attn.{n}"), p));
840        }
841        for (n, p) in self.layer_norm2.named_parameters() {
842            out.push((format!("layer_norm2.{n}"), p));
843        }
844        for (n, p) in self.mlp.named_parameters() {
845            out.push((format!("mlp.{n}"), p));
846        }
847        out
848    }
849
850    fn train(&mut self) {
851        self.training = true;
852    }
853
854    fn eval(&mut self) {
855        self.training = false;
856    }
857
858    fn is_training(&self) -> bool {
859        self.training
860    }
861
862    fn state_dict(&self) -> StateDict<T> {
863        self.named_parameters()
864            .into_iter()
865            .map(|(n, p)| (n, p.tensor().clone()))
866            .collect()
867    }
868
869    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
870        let extract = |prefix: &str| -> StateDict<T> {
871            let p = format!("{prefix}.");
872            state
873                .iter()
874                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
875                .collect()
876        };
877        if strict {
878            let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
879            for k in state.keys() {
880                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
881                    return Err(FerrotorchError::InvalidArgument {
882                        message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
883                    });
884                }
885            }
886        }
887        self.layer_norm1
888            .load_state_dict(&extract("layer_norm1"), strict)?;
889        self.self_attn
890            .load_state_dict(&extract("self_attn"), strict)?;
891        self.layer_norm2
892            .load_state_dict(&extract("layer_norm2"), strict)?;
893        self.mlp.load_state_dict(&extract("mlp"), strict)?;
894        Ok(())
895    }
896}
897
898// ---------------------------------------------------------------------------
899// CLIPEncoder
900// ---------------------------------------------------------------------------
901
902/// Stack of `num_hidden_layers` [`ClipEncoderLayer`]s applied in order.
903#[derive(Debug)]
904pub struct ClipEncoder<T: Float> {
905    /// One layer per `num_hidden_layers`.
906    pub layers: Vec<ClipEncoderLayer<T>>,
907    training: bool,
908}
909
910impl<T: Float> ClipEncoder<T> {
911    /// Build a randomly-initialized encoder stack.
912    ///
913    /// # Errors
914    ///
915    /// Returns the underlying [`FerrotorchError`] on bad config dims.
916    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
917        cfg.validate()?;
918        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
919        for _ in 0..cfg.num_hidden_layers {
920            layers.push(ClipEncoderLayer::new(cfg)?);
921        }
922        Ok(Self {
923            layers,
924            training: false,
925        })
926    }
927}
928
929impl<T: Float> Module<T> for ClipEncoder<T> {
930    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
931        let mut h = input.clone();
932        for l in &self.layers {
933            h = l.forward(&h)?;
934        }
935        Ok(h)
936    }
937
938    fn parameters(&self) -> Vec<&Parameter<T>> {
939        let mut out = Vec::new();
940        for l in &self.layers {
941            out.extend(l.parameters());
942        }
943        out
944    }
945
946    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
947        let mut out = Vec::new();
948        for l in &mut self.layers {
949            out.extend(l.parameters_mut());
950        }
951        out
952    }
953
954    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
955        let mut out = Vec::new();
956        for (i, l) in self.layers.iter().enumerate() {
957            for (n, p) in l.named_parameters() {
958                out.push((format!("layers.{i}.{n}"), p));
959            }
960        }
961        out
962    }
963
964    fn train(&mut self) {
965        self.training = true;
966        for l in &mut self.layers {
967            l.train();
968        }
969    }
970
971    fn eval(&mut self) {
972        self.training = false;
973        for l in &mut self.layers {
974            l.eval();
975        }
976    }
977
978    fn is_training(&self) -> bool {
979        self.training
980    }
981
982    fn state_dict(&self) -> StateDict<T> {
983        self.named_parameters()
984            .into_iter()
985            .map(|(n, p)| (n, p.tensor().clone()))
986            .collect()
987    }
988
989    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
990        let extract = |prefix: &str| -> StateDict<T> {
991            let p = format!("{prefix}.");
992            state
993                .iter()
994                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
995                .collect()
996        };
997        if strict {
998            for k in state.keys() {
999                if !k.starts_with("layers.") {
1000                    return Err(FerrotorchError::InvalidArgument {
1001                        message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
1002                    });
1003                }
1004            }
1005        }
1006        for (i, l) in self.layers.iter_mut().enumerate() {
1007            l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
1008        }
1009        Ok(())
1010    }
1011}
1012
1013// ---------------------------------------------------------------------------
1014// CLIPTextTransformer / ClipTextEncoder
1015// ---------------------------------------------------------------------------
1016
1017/// The full SD-1.5 CLIP text encoder. Wraps [`ClipTextEmbeddings`] +
1018/// [`ClipEncoder`] + a final [`LayerNorm`] (`final_layer_norm`).
1019///
1020/// Mirrors `CLIPTextTransformer` in transformers. The HF
1021/// `CLIPTextModel` wrapper sits one prefix above
1022/// (`text_model.embeddings.*`, `text_model.encoder.*`,
1023/// `text_model.final_layer_norm.*`) — see
1024/// [`crate::safetensors_loader::load_clip_text_encoder`] for the
1025/// `text_model.` strip.
1026///
1027/// Output is the per-token `last_hidden_state` `[B, S, hidden_size]`.
1028/// SD-1.5 consumes this directly as `encoder_hidden_states` for the
1029/// UNet's cross-attention (no pooling).
1030#[derive(Debug)]
1031pub struct ClipTextEncoder<T: Float> {
1032    /// Token + position embedding sum.
1033    pub embeddings: ClipTextEmbeddings<T>,
1034    /// 12 × [`ClipEncoderLayer`] for SD-1.5.
1035    pub encoder: ClipEncoder<T>,
1036    /// Final LayerNorm over the last hidden state.
1037    pub final_layer_norm: LayerNorm<T>,
1038    /// Frozen copy of the configuration used to build the module.
1039    pub config: ClipTextConfig,
1040    training: bool,
1041}
1042
1043impl<T: Float> ClipTextEncoder<T> {
1044    /// Build a randomly-initialized text encoder.
1045    ///
1046    /// # Errors
1047    ///
1048    /// Returns the underlying [`FerrotorchError`] from any sub-module
1049    /// constructor.
1050    pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1051        cfg.validate()?;
1052        let embeddings = ClipTextEmbeddings::new(&cfg)?;
1053        let encoder = ClipEncoder::new(&cfg)?;
1054        let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1055        Ok(Self {
1056            embeddings,
1057            encoder,
1058            final_layer_norm,
1059            config: cfg,
1060            training: false,
1061        })
1062    }
1063
1064    /// Run the encoder on a token-id sequence and return the per-token
1065    /// `last_hidden_state` `[1, S, hidden_size]`.
1066    ///
1067    /// `input_ids` is the verbatim CLIP-BPE token-id vector (length
1068    /// `S`). For SD-1.5 the canonical inference call is `S = 77`
1069    /// (already-padded with EOS to the max length).
1070    ///
1071    /// # Errors
1072    ///
1073    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty
1074    ///   or longer than `max_position_embeddings`.
1075    /// * Propagates downstream Embedding / LayerNorm errors.
1076    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1077        let h = self.embeddings.forward_from_ids(input_ids)?;
1078        let h = self.encoder.forward(&h)?;
1079        self.final_layer_norm.forward(&h)
1080    }
1081
1082    /// Run the encoder on a pre-built float-encoded token-id tensor of
1083    /// shape `[S]`. Returns `[1, S, hidden_size]`.
1084    ///
1085    /// `ids` carries u32 token ids losslessly cast to `T`
1086    /// (`numeric_cast::cast::<u32, T>`). Mirrors what the dump example
1087    /// reads off disk.
1088    ///
1089    /// # Errors
1090    ///
1091    /// Propagates downstream lookup / LayerNorm errors and converts
1092    /// invalid (negative / NaN / overflow) ids to
1093    /// [`FerrotorchError::InvalidArgument`].
1094    pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1095        // Convert to u32 ids — same defensive cast the underlying
1096        // `Embedding::forward` does.
1097        if ids.ndim() != 1 {
1098            return Err(FerrotorchError::ShapeMismatch {
1099                message: format!(
1100                    "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1101                    ids.shape()
1102                ),
1103            });
1104        }
1105        let data = ids.data_vec()?;
1106        let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1107        for (i, v) in data.iter().enumerate() {
1108            let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1109                FerrotorchError::InvalidArgument {
1110                    message: format!(
1111                        "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1112                         not representable as f64"
1113                    ),
1114                }
1115            })?;
1116            if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1117                return Err(FerrotorchError::InvalidArgument {
1118                    message: format!(
1119                        "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1120                         is not a non-negative integer"
1121                    ),
1122                });
1123            }
1124            u32_ids.push(f as u32);
1125        }
1126        self.forward_from_ids(&u32_ids)
1127    }
1128
1129    /// Load a HuggingFace `CLIPTextModel` state dict into this module.
1130    ///
1131    /// Accepts both:
1132    ///   - bare-`text_model` layout (no prefix; what the pin script
1133    ///     normalises to).
1134    ///   - full `text_model.<rest>` prefix (what the upstream HF
1135    ///     checkpoint ships).
1136    ///
1137    /// The HF safetensors also ships a non-parameter buffer we
1138    /// explicitly drop:
1139    ///
1140    /// * `text_model.embeddings.position_ids` — a `[1, max_pos]`
1141    ///   `arange(max_pos)` buffer regenerated each forward pass. Recorded
1142    ///   in the [`crate::safetensors_loader::DropReport`].
1143    ///
1144    /// # Errors
1145    ///
1146    /// Forwards whatever each sub-module's `load_state_dict` returns
1147    /// (shape mismatch / strict-mode missing key). Strict mode will
1148    /// surface `text_model.embeddings.position_ids` and any unknown
1149    /// key as errors; callers with a full HF checkpoint must pass
1150    /// `strict=false`.
1151    pub fn load_hf_state_dict(
1152        &mut self,
1153        hf_state: &StateDict<T>,
1154        strict: bool,
1155    ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1156        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1157        let mut dropped: Vec<String> = Vec::new();
1158        for (k, v) in hf_state {
1159            // Strip the optional `text_model.` prefix.
1160            let after = k
1161                .strip_prefix("text_model.")
1162                .map_or_else(|| k.clone(), str::to_owned);
1163
1164            // `embeddings.position_ids` is a buffer — not a parameter on our
1165            // side. Drop in both modes; record so the pin script can audit.
1166            if after == "embeddings.position_ids" {
1167                dropped.push(k.clone());
1168                continue;
1169            }
1170
1171            let is_known = after.starts_with("embeddings.token_embedding.")
1172                || after.starts_with("embeddings.position_embedding.")
1173                || after.starts_with("encoder.")
1174                || after.starts_with("final_layer_norm.");
1175            if is_known {
1176                remapped.insert(after, v.clone());
1177                continue;
1178            }
1179
1180            if strict {
1181                return Err(FerrotorchError::InvalidArgument {
1182                    message: format!(
1183                        "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1184                         known CLIP text-tower parameter and strict mode is on. \
1185                         Pass strict=false to drop unknown keys."
1186                    ),
1187                });
1188            }
1189            dropped.push(k.clone());
1190        }
1191        dropped.sort();
1192        self.load_state_dict(&remapped, strict)?;
1193        Ok(crate::safetensors_loader::DropReport { dropped })
1194    }
1195}
1196
1197impl<T: Float> Module<T> for ClipTextEncoder<T> {
1198    /// `Module::forward` treats `input` as already-summed embeddings
1199    /// `[1, S, hidden]` and only runs the encoder + final LayerNorm.
1200    /// Real callers should use [`Self::forward_from_ids`] /
1201    /// [`Self::forward_from_id_tensor`].
1202    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1203        let h = self.encoder.forward(input)?;
1204        self.final_layer_norm.forward(&h)
1205    }
1206
1207    fn parameters(&self) -> Vec<&Parameter<T>> {
1208        let mut out = Vec::new();
1209        out.extend(self.embeddings.parameters());
1210        out.extend(self.encoder.parameters());
1211        out.extend(self.final_layer_norm.parameters());
1212        out
1213    }
1214
1215    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1216        let mut out = Vec::new();
1217        out.extend(self.embeddings.parameters_mut());
1218        out.extend(self.encoder.parameters_mut());
1219        out.extend(self.final_layer_norm.parameters_mut());
1220        out
1221    }
1222
1223    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1224        let mut out = Vec::new();
1225        for (n, p) in self.embeddings.named_parameters() {
1226            out.push((format!("embeddings.{n}"), p));
1227        }
1228        for (n, p) in self.encoder.named_parameters() {
1229            out.push((format!("encoder.{n}"), p));
1230        }
1231        for (n, p) in self.final_layer_norm.named_parameters() {
1232            out.push((format!("final_layer_norm.{n}"), p));
1233        }
1234        out
1235    }
1236
1237    fn train(&mut self) {
1238        self.training = true;
1239        self.embeddings.train();
1240        self.encoder.train();
1241        self.final_layer_norm.train();
1242    }
1243
1244    fn eval(&mut self) {
1245        self.training = false;
1246        self.embeddings.eval();
1247        self.encoder.eval();
1248        self.final_layer_norm.eval();
1249    }
1250
1251    fn is_training(&self) -> bool {
1252        self.training
1253    }
1254
1255    fn state_dict(&self) -> StateDict<T> {
1256        self.named_parameters()
1257            .into_iter()
1258            .map(|(n, p)| (n, p.tensor().clone()))
1259            .collect()
1260    }
1261
1262    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1263        let extract = |prefix: &str| -> StateDict<T> {
1264            let p = format!("{prefix}.");
1265            state
1266                .iter()
1267                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1268                .collect()
1269        };
1270        if strict {
1271            for k in state.keys() {
1272                if !(k.starts_with("embeddings.")
1273                    || k.starts_with("encoder.")
1274                    || k.starts_with("final_layer_norm."))
1275                {
1276                    return Err(FerrotorchError::InvalidArgument {
1277                        message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1278                    });
1279                }
1280            }
1281        }
1282        self.embeddings
1283            .load_state_dict(&extract("embeddings"), strict)?;
1284        self.encoder.load_state_dict(&extract("encoder"), strict)?;
1285        self.final_layer_norm
1286            .load_state_dict(&extract("final_layer_norm"), strict)?;
1287        Ok(())
1288    }
1289}
1290
1291// `mul` is re-exported above so downstream features (e.g. embedding
1292// scaling, never used for CLIP-ViT-L) can reach for it without a fresh
1293// import. Suppress the unused-import warning on a vanilla build.
1294#[allow(dead_code)]
1295fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1296    mul(a, b)
1297}
1298
1299// ---------------------------------------------------------------------------
1300// Tests
1301// ---------------------------------------------------------------------------
1302
1303#[cfg(test)]
1304mod tests {
1305    use super::*;
1306
1307    fn tiny_cfg() -> ClipTextConfig {
1308        // 2 heads × 4 dim/head = 8 hidden_size; 16 intermediate; 1 layer;
1309        // 6 positions; tiny vocab.
1310        ClipTextConfig {
1311            hidden_size: 8,
1312            intermediate_size: 16,
1313            num_attention_heads: 2,
1314            num_hidden_layers: 1,
1315            max_position_embeddings: 6,
1316            vocab_size: 32,
1317            layer_norm_eps: 1e-5,
1318        }
1319    }
1320
1321    #[test]
1322    fn sd_v1_5_config_is_canonical() {
1323        let c = ClipTextConfig::sd_v1_5();
1324        assert_eq!(c.hidden_size, 768);
1325        assert_eq!(c.intermediate_size, 3072);
1326        assert_eq!(c.num_attention_heads, 12);
1327        assert_eq!(c.num_hidden_layers, 12);
1328        assert_eq!(c.max_position_embeddings, 77);
1329        assert_eq!(c.vocab_size, 49408);
1330        assert_eq!(c.head_dim(), 64);
1331        c.validate().unwrap();
1332    }
1333
1334    #[test]
1335    fn validate_catches_bad_head_count() {
1336        let mut c = tiny_cfg();
1337        c.num_attention_heads = 3; // 8 % 3 != 0
1338        assert!(c.validate().is_err());
1339    }
1340
1341    #[test]
1342    fn from_json_str_round_trip() {
1343        let json = r#"{
1344            "hidden_size": 768,
1345            "intermediate_size": 3072,
1346            "num_attention_heads": 12,
1347            "num_hidden_layers": 12,
1348            "max_position_embeddings": 77,
1349            "vocab_size": 49408,
1350            "layer_norm_eps": 1e-5,
1351            "hidden_act": "quick_gelu"
1352        }"#;
1353        let c = ClipTextConfig::from_json_str(json).unwrap();
1354        assert_eq!(c.hidden_size, 768);
1355        assert_eq!(c.intermediate_size, 3072);
1356        assert_eq!(c.num_attention_heads, 12);
1357        assert_eq!(c.num_hidden_layers, 12);
1358        assert_eq!(c.max_position_embeddings, 77);
1359    }
1360
1361    #[test]
1362    fn embeddings_forward_shape() {
1363        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1364        let ids = [1u32, 5, 7, 9];
1365        let out = emb.forward_from_ids(&ids).unwrap();
1366        assert_eq!(out.shape(), &[1, 4, 8]);
1367        for &v in out.data().unwrap() {
1368            assert!(v.is_finite(), "embedding non-finite: {v}");
1369        }
1370    }
1371
1372    #[test]
1373    fn embeddings_reject_too_long_sequence() {
1374        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1375        let ids: Vec<u32> = (0..7).collect(); // > max_position 6
1376        assert!(emb.forward_from_ids(&ids).is_err());
1377    }
1378
1379    #[test]
1380    fn self_attention_forward_shape() {
1381        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1382        let x = Tensor::from_storage(
1383            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1384            vec![1, 5, 8],
1385            false,
1386        )
1387        .unwrap();
1388        let out = attn.forward(&x).unwrap();
1389        assert_eq!(out.shape(), &[1, 5, 8]);
1390        for &v in out.data().unwrap() {
1391            assert!(v.is_finite());
1392        }
1393    }
1394
1395    #[test]
1396    fn self_attention_is_actually_causal() {
1397        // Changing later tokens MUST NOT change earlier rows.
1398        // Build a tensor [1, 4, 8] with the first 2 rows fixed and the
1399        // last 2 rows perturbed across two runs. The first 2 output
1400        // rows must be bit-identical between the runs (within f32
1401        // round-off).
1402        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1403        let mut a = vec![0.1f32; 4 * 8];
1404        for i in 0..2 * 8 {
1405            a[i] = ((i + 1) as f32).sin();
1406        }
1407        let mut b = a.clone();
1408        // Perturb only rows 2 and 3.
1409        for i in (2 * 8)..(4 * 8) {
1410            b[i] = ((i + 11) as f32).sin();
1411        }
1412        let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1413        let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1414        let oa = attn.forward(&xa).unwrap();
1415        let ob = attn.forward(&xb).unwrap();
1416        let da = oa.data().unwrap();
1417        let db = ob.data().unwrap();
1418        for i in 0..2 * 8 {
1419            assert!(
1420                (da[i] - db[i]).abs() < 1e-5,
1421                "row {} ({}) differs between runs: {} vs {}",
1422                i / 8,
1423                i % 8,
1424                da[i],
1425                db[i]
1426            );
1427        }
1428    }
1429
1430    #[test]
1431    fn mlp_uses_quick_gelu() {
1432        // QuickGELU(x) = x * sigmoid(1.702 * x). Verify the FC1 + GELU
1433        // branch produces this for a known scalar input — we do so
1434        // indirectly by checking that the forward output remains finite
1435        // and the intermediate activation at zero input gives bias.
1436        let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1437        let x = Tensor::from_storage(
1438            TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1439            vec![1, 3, 8],
1440            false,
1441        )
1442        .unwrap();
1443        let out = mlp.forward(&x).unwrap();
1444        assert_eq!(out.shape(), &[1, 3, 8]);
1445        for &v in out.data().unwrap() {
1446            assert!(v.is_finite());
1447        }
1448    }
1449
1450    #[test]
1451    fn encoder_layer_forward_shape() {
1452        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1453        let x = Tensor::from_storage(
1454            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1455            vec![1, 5, 8],
1456            false,
1457        )
1458        .unwrap();
1459        let out = layer.forward(&x).unwrap();
1460        assert_eq!(out.shape(), &[1, 5, 8]);
1461        for &v in out.data().unwrap() {
1462            assert!(v.is_finite());
1463        }
1464    }
1465
1466    #[test]
1467    fn encoder_layer_named_parameters_use_hf_layout() {
1468        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1469        let names: Vec<String> = layer
1470            .named_parameters()
1471            .into_iter()
1472            .map(|(n, _)| n)
1473            .collect();
1474        for k in [
1475            "layer_norm1.weight",
1476            "layer_norm1.bias",
1477            "self_attn.q_proj.weight",
1478            "self_attn.q_proj.bias",
1479            "self_attn.k_proj.weight",
1480            "self_attn.v_proj.weight",
1481            "self_attn.out_proj.weight",
1482            "self_attn.out_proj.bias",
1483            "layer_norm2.weight",
1484            "mlp.fc1.weight",
1485            "mlp.fc1.bias",
1486            "mlp.fc2.weight",
1487            "mlp.fc2.bias",
1488        ] {
1489            assert!(
1490                names.iter().any(|n| n == k),
1491                "missing parameter key {k:?} in {names:?}"
1492            );
1493        }
1494    }
1495
1496    #[test]
1497    fn tiny_encoder_forward_from_ids_shape() {
1498        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1499        let ids = vec![1u32, 5, 7];
1500        let out = enc.forward_from_ids(&ids).unwrap();
1501        assert_eq!(out.shape(), &[1, 3, 8]);
1502        for &v in out.data().unwrap() {
1503            assert!(v.is_finite());
1504        }
1505    }
1506
1507    #[test]
1508    fn tiny_named_parameters_use_hf_layout() {
1509        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1510        let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1511        for k in [
1512            "embeddings.token_embedding.weight",
1513            "embeddings.position_embedding.weight",
1514            "encoder.layers.0.layer_norm1.weight",
1515            "encoder.layers.0.self_attn.q_proj.weight",
1516            "encoder.layers.0.self_attn.out_proj.bias",
1517            "encoder.layers.0.layer_norm2.bias",
1518            "encoder.layers.0.mlp.fc1.weight",
1519            "encoder.layers.0.mlp.fc2.bias",
1520            "final_layer_norm.weight",
1521            "final_layer_norm.bias",
1522        ] {
1523            assert!(
1524                names.iter().any(|n| n == k),
1525                "missing parameter key {k:?} in {names:?}"
1526            );
1527        }
1528    }
1529
1530    #[test]
1531    fn round_trip_state_dict() {
1532        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1533        let sd = src.state_dict();
1534        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1535        dst.load_state_dict(&sd, true).unwrap();
1536        let ids = vec![2u32, 4, 6];
1537        let a = src.forward_from_ids(&ids).unwrap();
1538        let b = dst.forward_from_ids(&ids).unwrap();
1539        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1540            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1541        }
1542    }
1543
1544    #[test]
1545    fn load_hf_state_dict_strips_text_model_prefix() {
1546        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1547        let bare = src.state_dict();
1548        let mut prefixed: StateDict<f32> = HashMap::new();
1549        for (k, v) in bare {
1550            prefixed.insert(format!("text_model.{k}"), v);
1551        }
1552        // Add the position_ids buffer — it should be dropped.
1553        prefixed.insert(
1554            "text_model.embeddings.position_ids".into(),
1555            ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1556        );
1557        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1558        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1559        assert_eq!(
1560            rep.dropped,
1561            vec!["text_model.embeddings.position_ids".to_string()]
1562        );
1563        let ids = vec![1u32, 2, 3];
1564        let a = src.forward_from_ids(&ids).unwrap();
1565        let b = dst.forward_from_ids(&ids).unwrap();
1566        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1567            assert!((x - y).abs() < 1e-5);
1568        }
1569    }
1570
1571    #[test]
1572    fn load_hf_state_dict_strict_rejects_unknown_key() {
1573        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1574        let mut sd: StateDict<f32> = HashMap::new();
1575        sd.insert(
1576            "mystery.key".into(),
1577            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1578        );
1579        assert!(dst.load_hf_state_dict(&sd, true).is_err());
1580    }
1581
1582    #[test]
1583    fn forward_from_id_tensor_matches_forward_from_ids() {
1584        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1585        let ids = vec![1u32, 5, 7];
1586        let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1587        let a = enc.forward_from_ids(&ids).unwrap();
1588        let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1589        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1590            assert!((x - y).abs() < 1e-5);
1591        }
1592    }
1593}