Skip to main content

ferrotorch_diffusion/
clip_text_encoder.rs

1//! Stable-Diffusion 1.5 CLIP text encoder
2//! (`openai/clip-vit-large-patch14` — the text tower of CLIP-ViT-L/14).
3//!
4//! Phase B.3c of real-artifact-driven development — third and final SD
5//! sub-model. Together with the VAE decoder (Phase B.3a, #1150) and the
6//! UNet (Phase B.3b, #1151) this completes the bit-perfect SD-1.5
7//! inference pipeline:
8//!
9//! ```text
10//! CLIP text encoder  →  UNet (denoise)  →  VAE decoder
11//! [B, S=77, 768]        [B, 4, 64, 64]    [B, 3, 512, 512]
12//! ```
13//!
14//! Mirrors `transformers.CLIPTextModel` for the SD-1.5 config exactly:
15//!
16//! ```text
17//! hidden_size        = 768
18//! intermediate_size  = 3072
19//! num_attention_heads = 12
20//! num_hidden_layers  = 12
21//! max_position_embeddings = 77
22//! vocab_size         = 49408
23//! hidden_act         = "quick_gelu"     # x * sigmoid(1.702 * x)
24//! layer_norm_eps     = 1e-5
25//! ```
26//!
27//! Architecture (state-dict prefix in parens):
28//!
29//! ```text
30//! CLIPTextModel
31//! └── text_model
32//!     ├── embeddings
33//!     │   ├── token_embedding.weight    [49408, 768]
34//!     │   └── position_embedding.weight [77, 768]
35//!     ├── encoder
36//!     │   └── layers.{0..11}.
37//!     │       ├── layer_norm1.{weight,bias}    [768], [768]
38//!     │       ├── self_attn.
39//!     │       │   ├── q_proj.{weight,bias}    [768,768], [768]
40//!     │       │   ├── k_proj.{weight,bias}    [768,768], [768]
41//!     │       │   ├── v_proj.{weight,bias}    [768,768], [768]
42//!     │       │   └── out_proj.{weight,bias}  [768,768], [768]
43//!     │       ├── layer_norm2.{weight,bias}    [768], [768]
44//!     │       └── mlp.
45//!     │           ├── fc1.{weight,bias}        [3072,768], [3072]
46//!     │           └── fc2.{weight,bias}        [768,3072], [768]
47//!     └── final_layer_norm.{weight,bias}        [768], [768]
48//! ```
49//!
50//! Forward pass (per layer is pre-LayerNorm + residual):
51//!
52//! ```text
53//! h = token_embedding(input_ids) + position_embedding(arange(S))
54//! for layer in encoder.layers:
55//!     residual = h
56//!     h = layer_norm1(h)
57//!     h = causal_self_attn(h, h, h)            # ← causal mask is critical
58//!     h = residual + h
59//!     residual = h
60//!     h = layer_norm2(h)
61//!     h = fc2(quick_gelu(fc1(h)))
62//!     h = residual + h
63//! h = final_layer_norm(h)
64//! return h                                      # last_hidden_state [B, S, 768]
65//! ```
66//!
67//! ## Critical correctness gotchas
68//!
69//! 1. **Causal mask**. Despite the "encoder" name, CLIP text-side
70//!    self-attention is causal — position `i` attends only to
71//!    `0..=i`. Omitting this would still pass shape checks but break
72//!    parity vs `transformers`.
73//! 2. **QuickGELU**, not standard GELU. CLIP-ViT-L/14 uses the fast
74//!    sigmoid approximation `x * sigmoid(1.702 * x)`, not the erf-based
75//!    or tanh-based GELU. We pin this via
76//!    `GELU::with_approximate(GeluApproximate::Sigmoid)`.
77//! 3. **Position embedding is *learned* and absolute** — full table of
78//!    77 entries. Position ids are `[0, 1, ..., S-1]` for every forward.
79//! 4. **All four self-attention projections have bias** (unlike SD's
80//!    UNet `Attention` which has `bias=False` on q/k/v).
81//! 5. **SD uses `last_hidden_state` directly**, not the EOS-pooled
82//!    output. We return `[B, S, hidden_size]`.
83
84use std::collections::HashMap;
85
86use ferrotorch_core::grad_fns::arithmetic::{add, mul};
87use ferrotorch_core::{
88    FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
89};
90use ferrotorch_nn::module::{Module, StateDict};
91use ferrotorch_nn::parameter::Parameter;
92use ferrotorch_nn::{
93    Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
94    transpose_heads_to_2d,
95};
96
97// ---------------------------------------------------------------------------
98// Config
99// ---------------------------------------------------------------------------
100
101/// Configuration for the SD-1.5 CLIP text encoder
102/// (`runwayml/stable-diffusion-v1-5/text_encoder/config.json`).
103#[derive(Debug, Clone)]
104pub struct ClipTextConfig {
105    /// Hidden width. SD-1.5: 768.
106    pub hidden_size: usize,
107    /// FFN expansion width. SD-1.5: 3072.
108    pub intermediate_size: usize,
109    /// Number of attention heads per layer. SD-1.5: 12. Must divide
110    /// `hidden_size` evenly.
111    pub num_attention_heads: usize,
112    /// Number of transformer layers. SD-1.5: 12.
113    pub num_hidden_layers: usize,
114    /// Maximum sequence length. SD-1.5: 77.
115    pub max_position_embeddings: usize,
116    /// Token vocabulary size. SD-1.5: 49408.
117    pub vocab_size: usize,
118    /// LayerNorm epsilon. SD-1.5: 1e-5.
119    pub layer_norm_eps: f64,
120}
121
122impl Default for ClipTextConfig {
123    fn default() -> Self {
124        Self::sd_v1_5()
125    }
126}
127
128impl ClipTextConfig {
129    /// SD-1.5 CLIP text encoder defaults (CLIP-ViT-L/14 text tower).
130    pub fn sd_v1_5() -> Self {
131        Self {
132            hidden_size: 768,
133            intermediate_size: 3072,
134            num_attention_heads: 12,
135            num_hidden_layers: 12,
136            max_position_embeddings: 77,
137            vocab_size: 49408,
138            layer_norm_eps: 1e-5,
139        }
140    }
141
142    /// Per-head dimension.
143    #[inline]
144    #[must_use]
145    pub fn head_dim(&self) -> usize {
146        self.hidden_size / self.num_attention_heads
147    }
148
149    /// Validate field bounds.
150    ///
151    /// # Errors
152    ///
153    /// Returns [`FerrotorchError::InvalidArgument`] for any out-of-bounds
154    /// or arithmetic-incompatible field.
155    pub fn validate(&self) -> FerrotorchResult<()> {
156        if self.hidden_size == 0
157            || self.intermediate_size == 0
158            || self.num_attention_heads == 0
159            || self.num_hidden_layers == 0
160            || self.max_position_embeddings == 0
161            || self.vocab_size == 0
162        {
163            return Err(FerrotorchError::InvalidArgument {
164                message: "ClipTextConfig: all size fields must be > 0".into(),
165            });
166        }
167        if self.hidden_size % self.num_attention_heads != 0 {
168            return Err(FerrotorchError::InvalidArgument {
169                message: format!(
170                    "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
171                    self.hidden_size, self.num_attention_heads,
172                ),
173            });
174        }
175        if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
176            return Err(FerrotorchError::InvalidArgument {
177                message: format!(
178                    "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
179                    self.layer_norm_eps,
180                ),
181            });
182        }
183        Ok(())
184    }
185
186    /// Parse a `text_encoder/config.json` document into a [`ClipTextConfig`].
187    ///
188    /// Recognised keys (all optional — anything missing falls back to the
189    /// SD-1.5 defaults): `hidden_size`, `intermediate_size`,
190    /// `num_attention_heads`, `num_hidden_layers`,
191    /// `max_position_embeddings`, `vocab_size`, `layer_norm_eps`.
192    ///
193    /// # Errors
194    ///
195    /// Returns [`FerrotorchError::InvalidArgument`] on malformed JSON or
196    /// invalid field values.
197    pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
198        let v: serde_json::Value =
199            serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
200                message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
201            })?;
202        let mut cfg = Self::default();
203        if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
204            cfg.hidden_size = x as usize;
205        }
206        if let Some(x) = v.get("intermediate_size").and_then(serde_json::Value::as_u64) {
207            cfg.intermediate_size = x as usize;
208        }
209        if let Some(x) = v.get("num_attention_heads").and_then(serde_json::Value::as_u64) {
210            cfg.num_attention_heads = x as usize;
211        }
212        if let Some(x) = v.get("num_hidden_layers").and_then(serde_json::Value::as_u64) {
213            cfg.num_hidden_layers = x as usize;
214        }
215        if let Some(x) = v
216            .get("max_position_embeddings")
217            .and_then(serde_json::Value::as_u64)
218        {
219            cfg.max_position_embeddings = x as usize;
220        }
221        if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
222            cfg.vocab_size = x as usize;
223        }
224        if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
225            cfg.layer_norm_eps = x;
226        }
227        cfg.validate()?;
228        Ok(cfg)
229    }
230
231    /// Parse a `text_encoder/config.json` file from disk.
232    ///
233    /// # Errors
234    ///
235    /// Returns [`FerrotorchError::InvalidArgument`] for I/O or parse
236    /// failures.
237    pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
238        let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
239            message: format!(
240                "ClipTextConfig::from_file: failed to read {}: {e}",
241                path.display(),
242            ),
243        })?;
244        Self::from_json_str(&s)
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Helper: reshape utilities (own-the-data so the per-row buffers are
250// always contiguous and ready for the BMM-style attention call).
251// ---------------------------------------------------------------------------
252
253fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
254    let prod: usize = shape.iter().product();
255    if prod != t.numel() {
256        return Err(FerrotorchError::ShapeMismatch {
257            message: format!(
258                "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
259                 match source numel {}",
260                t.numel()
261            ),
262        });
263    }
264    let data = t.data_vec()?;
265    Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
266}
267
268/// Build a 1-D float-encoded index tensor from u32 ids (matches the
269/// trick `BertEmbeddings::float_index_tensor` uses).
270fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
271    let data: Vec<T> = ids
272        .iter()
273        .map(|&i| numeric_cast::cast::<u32, T>(i))
274        .collect::<FerrotorchResult<Vec<T>>>()?;
275    let n = data.len();
276    Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
277}
278
279// ---------------------------------------------------------------------------
280// CLIPTextEmbeddings
281// ---------------------------------------------------------------------------
282
283/// Token embedding + learned absolute position embedding. The two
284/// lookups are summed. Mirrors `CLIPTextEmbeddings` in transformers.
285///
286/// Note: there is NO LayerNorm at the embedding level (unlike BERT).
287/// The first per-layer `layer_norm1` handles normalisation downstream.
288#[derive(Debug)]
289pub struct ClipTextEmbeddings<T: Float> {
290    /// Token lookup — `[vocab_size, hidden_size]`.
291    pub token_embedding: Embedding<T>,
292    /// Learned position lookup — `[max_position_embeddings, hidden_size]`.
293    pub position_embedding: Embedding<T>,
294    hidden_size: usize,
295    max_position_embeddings: usize,
296    training: bool,
297}
298
299impl<T: Float> ClipTextEmbeddings<T> {
300    /// Build randomly-initialized embeddings for the given config.
301    ///
302    /// # Errors
303    ///
304    /// Returns [`FerrotorchError`] from the underlying [`Embedding`]
305    /// constructors.
306    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
307        cfg.validate()?;
308        Ok(Self {
309            token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
310            position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
311            hidden_size: cfg.hidden_size,
312            max_position_embeddings: cfg.max_position_embeddings,
313            training: false,
314        })
315    }
316
317    /// Run the embedding sum on a sequence of token ids.
318    ///
319    /// `input_ids` is the verbatim token-id vector (length `S`). The
320    /// output has shape `[1, S, hidden]`.
321    ///
322    /// # Errors
323    ///
324    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty or
325    ///   exceeds `max_position_embeddings`.
326    /// * Propagates downstream embedding-lookup errors.
327    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
328        if input_ids.is_empty() {
329            return Err(FerrotorchError::InvalidArgument {
330                message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
331            });
332        }
333        let seq_len = input_ids.len();
334        if seq_len > self.max_position_embeddings {
335            return Err(FerrotorchError::InvalidArgument {
336                message: format!(
337                    "ClipTextEmbeddings: sequence length {seq_len} exceeds \
338                     max_position_embeddings {}",
339                    self.max_position_embeddings,
340                ),
341            });
342        }
343
344        let word_idx = float_index_tensor::<T>(input_ids)?;
345        let word_2d = self.token_embedding.forward(&word_idx)?; // [S, hidden]
346
347        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
348        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
349        let pos_2d = self.position_embedding.forward(&pos_idx)?; // [S, hidden]
350
351        let summed = add(&word_2d, &pos_2d)?;
352        // Promote to [1, S, hidden] so downstream uses 3-D ranks.
353        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
354    }
355}
356
357impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
358    /// When called via `Module::forward` we treat `input` as a
359    /// 1-D float-index tensor (same convention as the inner
360    /// `Embedding` modules). Real callers should use
361    /// [`Self::forward_from_ids`].
362    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
363        let word_2d = self.token_embedding.forward(input)?;
364        let seq_len = input.numel();
365        let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
366        let pos_idx = float_index_tensor::<T>(&pos_ids)?;
367        let pos_2d = self.position_embedding.forward(&pos_idx)?;
368        let summed = add(&word_2d, &pos_2d)?;
369        reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
370    }
371
372    fn parameters(&self) -> Vec<&Parameter<T>> {
373        let mut out = Vec::new();
374        out.extend(self.token_embedding.parameters());
375        out.extend(self.position_embedding.parameters());
376        out
377    }
378
379    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
380        let mut out = Vec::new();
381        out.extend(self.token_embedding.parameters_mut());
382        out.extend(self.position_embedding.parameters_mut());
383        out
384    }
385
386    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387        let mut out = Vec::new();
388        for (n, p) in self.token_embedding.named_parameters() {
389            out.push((format!("token_embedding.{n}"), p));
390        }
391        for (n, p) in self.position_embedding.named_parameters() {
392            out.push((format!("position_embedding.{n}"), p));
393        }
394        out
395    }
396
397    fn train(&mut self) {
398        self.training = true;
399    }
400
401    fn eval(&mut self) {
402        self.training = false;
403    }
404
405    fn is_training(&self) -> bool {
406        self.training
407    }
408
409    fn state_dict(&self) -> StateDict<T> {
410        self.named_parameters()
411            .into_iter()
412            .map(|(n, p)| (n, p.tensor().clone()))
413            .collect()
414    }
415
416    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
417        let extract = |prefix: &str| -> StateDict<T> {
418            let p = format!("{prefix}.");
419            state
420                .iter()
421                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
422                .collect()
423        };
424        if strict {
425            let prefixes = ["token_embedding", "position_embedding"];
426            for k in state.keys() {
427                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
428                    return Err(FerrotorchError::InvalidArgument {
429                        message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
430                    });
431                }
432            }
433        }
434        self.token_embedding
435            .load_state_dict(&extract("token_embedding"), strict)?;
436        self.position_embedding
437            .load_state_dict(&extract("position_embedding"), strict)?;
438        Ok(())
439    }
440}
441
442// ---------------------------------------------------------------------------
443// CLIP self-attention
444// ---------------------------------------------------------------------------
445
446/// Multi-head self-attention with a *causal* mask. Mirrors
447/// `CLIPAttention` (the text-side variant) in transformers — all four
448/// projections (q/k/v/out) carry bias, the scale is `1/sqrt(head_dim)`,
449/// and the attention is causal (each position attends to itself and
450/// earlier positions only).
451///
452/// State-dict layout:
453///
454/// ```text
455/// q_proj.{weight,bias}    [hidden, hidden], [hidden]
456/// k_proj.{weight,bias}    [hidden, hidden], [hidden]
457/// v_proj.{weight,bias}    [hidden, hidden], [hidden]
458/// out_proj.{weight,bias}  [hidden, hidden], [hidden]
459/// ```
460#[derive(Debug)]
461pub struct ClipSelfAttention<T: Float> {
462    /// Query projection — `[hidden, hidden]`, with bias.
463    pub q_proj: Linear<T>,
464    /// Key projection — `[hidden, hidden]`, with bias.
465    pub k_proj: Linear<T>,
466    /// Value projection — `[hidden, hidden]`, with bias.
467    pub v_proj: Linear<T>,
468    /// Output projection — `[hidden, hidden]`, with bias.
469    pub out_proj: Linear<T>,
470    num_heads: usize,
471    head_dim: usize,
472    hidden: usize,
473    training: bool,
474}
475
476impl<T: Float> ClipSelfAttention<T> {
477    /// Build randomly-initialized self-attention projections.
478    ///
479    /// # Errors
480    ///
481    /// Returns the underlying [`FerrotorchError`] on bad config dims.
482    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
483        cfg.validate()?;
484        Ok(Self {
485            q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
486            k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
487            v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
488            out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
489            num_heads: cfg.num_attention_heads,
490            head_dim: cfg.head_dim(),
491            hidden: cfg.hidden_size,
492            training: false,
493        })
494    }
495}
496
497impl<T: Float> Module<T> for ClipSelfAttention<T> {
498    /// Forward — input `[1, S, hidden]`, output `[1, S, hidden]`.
499    ///
500    /// The attention is causal: position `i` cannot attend to position
501    /// `j > i`. This matches `transformers.CLIPAttention`'s use of
502    /// `_create_4d_causal_attention_mask`.
503    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
504        let shape = input.shape();
505        if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
506            return Err(FerrotorchError::ShapeMismatch {
507                message: format!(
508                    "ClipSelfAttention expects [1, S, {}], got {:?}",
509                    self.hidden, shape,
510                ),
511            });
512        }
513        let seq_len = shape[1];
514
515        // Projections — Linear handles any rank; output is [1, S, hidden].
516        let q = self.q_proj.forward(input)?;
517        let k = self.k_proj.forward(input)?;
518        let v = self.v_proj.forward(input)?;
519
520        // Drop the batch=1 leading dim so reshape_to_heads / the
521        // attention helper can treat the rows as [S, H*d].
522        let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
523        let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
524        let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
525
526        // [S, H*d] → [H, S, d] (batch-first heads).
527        let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
528        let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
529        let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
530
531        // Scaled dot-product attention with causal mask. `standard_attention`
532        // applies `1/sqrt(head_dim)` scaling and `-inf` upper-triangular
533        // mask, then softmax + value mix.
534        let ctx = standard_attention(&q_h, &k_h, &v_h, /* causal = */ true)?;
535
536        // [H, S, d] → [S, H*d] → [1, S, hidden].
537        let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
538        let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
539
540        // Output projection (with bias).
541        self.out_proj.forward(&ctx3)
542    }
543
544    fn parameters(&self) -> Vec<&Parameter<T>> {
545        let mut out = Vec::new();
546        out.extend(self.q_proj.parameters());
547        out.extend(self.k_proj.parameters());
548        out.extend(self.v_proj.parameters());
549        out.extend(self.out_proj.parameters());
550        out
551    }
552
553    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
554        let mut out = Vec::new();
555        out.extend(self.q_proj.parameters_mut());
556        out.extend(self.k_proj.parameters_mut());
557        out.extend(self.v_proj.parameters_mut());
558        out.extend(self.out_proj.parameters_mut());
559        out
560    }
561
562    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
563        let mut out = Vec::new();
564        for (n, p) in self.q_proj.named_parameters() {
565            out.push((format!("q_proj.{n}"), p));
566        }
567        for (n, p) in self.k_proj.named_parameters() {
568            out.push((format!("k_proj.{n}"), p));
569        }
570        for (n, p) in self.v_proj.named_parameters() {
571            out.push((format!("v_proj.{n}"), p));
572        }
573        for (n, p) in self.out_proj.named_parameters() {
574            out.push((format!("out_proj.{n}"), p));
575        }
576        out
577    }
578
579    fn train(&mut self) {
580        self.training = true;
581    }
582
583    fn eval(&mut self) {
584        self.training = false;
585    }
586
587    fn is_training(&self) -> bool {
588        self.training
589    }
590
591    fn state_dict(&self) -> StateDict<T> {
592        self.named_parameters()
593            .into_iter()
594            .map(|(n, p)| (n, p.tensor().clone()))
595            .collect()
596    }
597
598    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599        let extract = |prefix: &str| -> StateDict<T> {
600            let p = format!("{prefix}.");
601            state
602                .iter()
603                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
604                .collect()
605        };
606        if strict {
607            let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
608            for k in state.keys() {
609                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
610                    return Err(FerrotorchError::InvalidArgument {
611                        message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
612                    });
613                }
614            }
615        }
616        self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
617        self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
618        self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
619        self.out_proj
620            .load_state_dict(&extract("out_proj"), strict)?;
621        Ok(())
622    }
623}
624
625// ---------------------------------------------------------------------------
626// CLIPMLP
627// ---------------------------------------------------------------------------
628
629/// CLIP MLP: `fc2(quick_gelu(fc1(x)))`.
630///
631/// QuickGELU is `x * sigmoid(1.702 * x)` — the fast sigmoid
632/// approximation. NOT the standard `0.5 * x * (1 + erf(x/sqrt(2)))`
633/// kernel. This is the published `hidden_act = "quick_gelu"` in
634/// CLIP-ViT-L/14's config.
635#[derive(Debug)]
636pub struct ClipMlp<T: Float> {
637    /// Expansion projection — `[hidden, intermediate]`, with bias.
638    pub fc1: Linear<T>,
639    /// Reduction projection — `[intermediate, hidden]`, with bias.
640    pub fc2: Linear<T>,
641    activation: GELU,
642    training: bool,
643}
644
645impl<T: Float> ClipMlp<T> {
646    /// Build randomly-initialized MLP for the given config.
647    ///
648    /// # Errors
649    ///
650    /// Returns the underlying [`FerrotorchError`] on bad config dims.
651    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
652        cfg.validate()?;
653        Ok(Self {
654            fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
655            fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
656            // QuickGELU: `x * sigmoid(1.702 * x)`. See module doc.
657            activation: GELU::with_approximate(GeluApproximate::Sigmoid),
658            training: false,
659        })
660    }
661}
662
663impl<T: Float> Module<T> for ClipMlp<T> {
664    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
665        let h = self.fc1.forward(input)?;
666        let h = self.activation.forward(&h)?;
667        self.fc2.forward(&h)
668    }
669
670    fn parameters(&self) -> Vec<&Parameter<T>> {
671        let mut out = Vec::new();
672        out.extend(self.fc1.parameters());
673        out.extend(self.fc2.parameters());
674        out
675    }
676
677    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
678        let mut out = Vec::new();
679        out.extend(self.fc1.parameters_mut());
680        out.extend(self.fc2.parameters_mut());
681        out
682    }
683
684    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
685        let mut out = Vec::new();
686        for (n, p) in self.fc1.named_parameters() {
687            out.push((format!("fc1.{n}"), p));
688        }
689        for (n, p) in self.fc2.named_parameters() {
690            out.push((format!("fc2.{n}"), p));
691        }
692        out
693    }
694
695    fn train(&mut self) {
696        self.training = true;
697    }
698
699    fn eval(&mut self) {
700        self.training = false;
701    }
702
703    fn is_training(&self) -> bool {
704        self.training
705    }
706
707    fn state_dict(&self) -> StateDict<T> {
708        self.named_parameters()
709            .into_iter()
710            .map(|(n, p)| (n, p.tensor().clone()))
711            .collect()
712    }
713
714    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
715        let extract = |prefix: &str| -> StateDict<T> {
716            let p = format!("{prefix}.");
717            state
718                .iter()
719                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
720                .collect()
721        };
722        if strict {
723            for k in state.keys() {
724                if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
725                    return Err(FerrotorchError::InvalidArgument {
726                        message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
727                    });
728                }
729            }
730        }
731        self.fc1.load_state_dict(&extract("fc1"), strict)?;
732        self.fc2.load_state_dict(&extract("fc2"), strict)?;
733        Ok(())
734    }
735}
736
737// ---------------------------------------------------------------------------
738// CLIPEncoderLayer
739// ---------------------------------------------------------------------------
740
741/// One CLIP text encoder layer.
742///
743/// Pre-LayerNorm + residual stack:
744///
745/// ```text
746/// h = x + self_attn(layer_norm1(x))
747/// h = h + mlp(layer_norm2(h))
748/// ```
749#[derive(Debug)]
750pub struct ClipEncoderLayer<T: Float> {
751    /// Pre-attention LayerNorm.
752    pub layer_norm1: LayerNorm<T>,
753    /// Causal self-attention (q/k/v/out, all biased).
754    pub self_attn: ClipSelfAttention<T>,
755    /// Pre-FFN LayerNorm.
756    pub layer_norm2: LayerNorm<T>,
757    /// Two-layer MLP with QuickGELU activation.
758    pub mlp: ClipMlp<T>,
759    training: bool,
760}
761
762impl<T: Float> ClipEncoderLayer<T> {
763    /// Build a randomly-initialized encoder layer.
764    ///
765    /// # Errors
766    ///
767    /// Returns the underlying [`FerrotorchError`] on bad config dims.
768    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
769        Ok(Self {
770            layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
771            self_attn: ClipSelfAttention::new(cfg)?,
772            layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
773            mlp: ClipMlp::new(cfg)?,
774            training: false,
775        })
776    }
777}
778
779impl<T: Float> Module<T> for ClipEncoderLayer<T> {
780    /// Forward — `[1, S, hidden]` → `[1, S, hidden]`.
781    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
782        // Self-attention sub-block (pre-norm).
783        let normed = self.layer_norm1.forward(input)?;
784        let attn_out = self.self_attn.forward(&normed)?;
785        let after_attn = add(input, &attn_out)?;
786
787        // MLP sub-block (pre-norm).
788        let normed_ffn = self.layer_norm2.forward(&after_attn)?;
789        let mlp_out = self.mlp.forward(&normed_ffn)?;
790        add(&after_attn, &mlp_out)
791    }
792
793    fn parameters(&self) -> Vec<&Parameter<T>> {
794        let mut out = Vec::new();
795        out.extend(self.layer_norm1.parameters());
796        out.extend(self.self_attn.parameters());
797        out.extend(self.layer_norm2.parameters());
798        out.extend(self.mlp.parameters());
799        out
800    }
801
802    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
803        let mut out = Vec::new();
804        out.extend(self.layer_norm1.parameters_mut());
805        out.extend(self.self_attn.parameters_mut());
806        out.extend(self.layer_norm2.parameters_mut());
807        out.extend(self.mlp.parameters_mut());
808        out
809    }
810
811    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
812        let mut out = Vec::new();
813        for (n, p) in self.layer_norm1.named_parameters() {
814            out.push((format!("layer_norm1.{n}"), p));
815        }
816        for (n, p) in self.self_attn.named_parameters() {
817            out.push((format!("self_attn.{n}"), p));
818        }
819        for (n, p) in self.layer_norm2.named_parameters() {
820            out.push((format!("layer_norm2.{n}"), p));
821        }
822        for (n, p) in self.mlp.named_parameters() {
823            out.push((format!("mlp.{n}"), p));
824        }
825        out
826    }
827
828    fn train(&mut self) {
829        self.training = true;
830    }
831
832    fn eval(&mut self) {
833        self.training = false;
834    }
835
836    fn is_training(&self) -> bool {
837        self.training
838    }
839
840    fn state_dict(&self) -> StateDict<T> {
841        self.named_parameters()
842            .into_iter()
843            .map(|(n, p)| (n, p.tensor().clone()))
844            .collect()
845    }
846
847    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
848        let extract = |prefix: &str| -> StateDict<T> {
849            let p = format!("{prefix}.");
850            state
851                .iter()
852                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
853                .collect()
854        };
855        if strict {
856            let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
857            for k in state.keys() {
858                if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
859                    return Err(FerrotorchError::InvalidArgument {
860                        message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
861                    });
862                }
863            }
864        }
865        self.layer_norm1
866            .load_state_dict(&extract("layer_norm1"), strict)?;
867        self.self_attn
868            .load_state_dict(&extract("self_attn"), strict)?;
869        self.layer_norm2
870            .load_state_dict(&extract("layer_norm2"), strict)?;
871        self.mlp.load_state_dict(&extract("mlp"), strict)?;
872        Ok(())
873    }
874}
875
876// ---------------------------------------------------------------------------
877// CLIPEncoder
878// ---------------------------------------------------------------------------
879
880/// Stack of `num_hidden_layers` [`ClipEncoderLayer`]s applied in order.
881#[derive(Debug)]
882pub struct ClipEncoder<T: Float> {
883    /// One layer per `num_hidden_layers`.
884    pub layers: Vec<ClipEncoderLayer<T>>,
885    training: bool,
886}
887
888impl<T: Float> ClipEncoder<T> {
889    /// Build a randomly-initialized encoder stack.
890    ///
891    /// # Errors
892    ///
893    /// Returns the underlying [`FerrotorchError`] on bad config dims.
894    pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
895        cfg.validate()?;
896        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
897        for _ in 0..cfg.num_hidden_layers {
898            layers.push(ClipEncoderLayer::new(cfg)?);
899        }
900        Ok(Self {
901            layers,
902            training: false,
903        })
904    }
905}
906
907impl<T: Float> Module<T> for ClipEncoder<T> {
908    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
909        let mut h = input.clone();
910        for l in &self.layers {
911            h = l.forward(&h)?;
912        }
913        Ok(h)
914    }
915
916    fn parameters(&self) -> Vec<&Parameter<T>> {
917        let mut out = Vec::new();
918        for l in &self.layers {
919            out.extend(l.parameters());
920        }
921        out
922    }
923
924    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
925        let mut out = Vec::new();
926        for l in &mut self.layers {
927            out.extend(l.parameters_mut());
928        }
929        out
930    }
931
932    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
933        let mut out = Vec::new();
934        for (i, l) in self.layers.iter().enumerate() {
935            for (n, p) in l.named_parameters() {
936                out.push((format!("layers.{i}.{n}"), p));
937            }
938        }
939        out
940    }
941
942    fn train(&mut self) {
943        self.training = true;
944        for l in &mut self.layers {
945            l.train();
946        }
947    }
948
949    fn eval(&mut self) {
950        self.training = false;
951        for l in &mut self.layers {
952            l.eval();
953        }
954    }
955
956    fn is_training(&self) -> bool {
957        self.training
958    }
959
960    fn state_dict(&self) -> StateDict<T> {
961        self.named_parameters()
962            .into_iter()
963            .map(|(n, p)| (n, p.tensor().clone()))
964            .collect()
965    }
966
967    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
968        let extract = |prefix: &str| -> StateDict<T> {
969            let p = format!("{prefix}.");
970            state
971                .iter()
972                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
973                .collect()
974        };
975        if strict {
976            for k in state.keys() {
977                if !k.starts_with("layers.") {
978                    return Err(FerrotorchError::InvalidArgument {
979                        message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
980                    });
981                }
982            }
983        }
984        for (i, l) in self.layers.iter_mut().enumerate() {
985            l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
986        }
987        Ok(())
988    }
989}
990
991// ---------------------------------------------------------------------------
992// CLIPTextTransformer / ClipTextEncoder
993// ---------------------------------------------------------------------------
994
995/// The full SD-1.5 CLIP text encoder. Wraps [`ClipTextEmbeddings`] +
996/// [`ClipEncoder`] + a final [`LayerNorm`] (`final_layer_norm`).
997///
998/// Mirrors `CLIPTextTransformer` in transformers. The HF
999/// `CLIPTextModel` wrapper sits one prefix above
1000/// (`text_model.embeddings.*`, `text_model.encoder.*`,
1001/// `text_model.final_layer_norm.*`) — see
1002/// [`crate::safetensors_loader::load_clip_text_encoder`] for the
1003/// `text_model.` strip.
1004///
1005/// Output is the per-token `last_hidden_state` `[B, S, hidden_size]`.
1006/// SD-1.5 consumes this directly as `encoder_hidden_states` for the
1007/// UNet's cross-attention (no pooling).
1008#[derive(Debug)]
1009pub struct ClipTextEncoder<T: Float> {
1010    /// Token + position embedding sum.
1011    pub embeddings: ClipTextEmbeddings<T>,
1012    /// 12 × [`ClipEncoderLayer`] for SD-1.5.
1013    pub encoder: ClipEncoder<T>,
1014    /// Final LayerNorm over the last hidden state.
1015    pub final_layer_norm: LayerNorm<T>,
1016    /// Frozen copy of the configuration used to build the module.
1017    pub config: ClipTextConfig,
1018    training: bool,
1019}
1020
1021impl<T: Float> ClipTextEncoder<T> {
1022    /// Build a randomly-initialized text encoder.
1023    ///
1024    /// # Errors
1025    ///
1026    /// Returns the underlying [`FerrotorchError`] from any sub-module
1027    /// constructor.
1028    pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1029        cfg.validate()?;
1030        let embeddings = ClipTextEmbeddings::new(&cfg)?;
1031        let encoder = ClipEncoder::new(&cfg)?;
1032        let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1033        Ok(Self {
1034            embeddings,
1035            encoder,
1036            final_layer_norm,
1037            config: cfg,
1038            training: false,
1039        })
1040    }
1041
1042    /// Run the encoder on a token-id sequence and return the per-token
1043    /// `last_hidden_state` `[1, S, hidden_size]`.
1044    ///
1045    /// `input_ids` is the verbatim CLIP-BPE token-id vector (length
1046    /// `S`). For SD-1.5 the canonical inference call is `S = 77`
1047    /// (already-padded with EOS to the max length).
1048    ///
1049    /// # Errors
1050    ///
1051    /// * [`FerrotorchError::InvalidArgument`] if `input_ids` is empty
1052    ///   or longer than `max_position_embeddings`.
1053    /// * Propagates downstream Embedding / LayerNorm errors.
1054    pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1055        let h = self.embeddings.forward_from_ids(input_ids)?;
1056        let h = self.encoder.forward(&h)?;
1057        self.final_layer_norm.forward(&h)
1058    }
1059
1060    /// Run the encoder on a pre-built float-encoded token-id tensor of
1061    /// shape `[S]`. Returns `[1, S, hidden_size]`.
1062    ///
1063    /// `ids` carries u32 token ids losslessly cast to `T`
1064    /// (`numeric_cast::cast::<u32, T>`). Mirrors what the dump example
1065    /// reads off disk.
1066    ///
1067    /// # Errors
1068    ///
1069    /// Propagates downstream lookup / LayerNorm errors and converts
1070    /// invalid (negative / NaN / overflow) ids to
1071    /// [`FerrotorchError::InvalidArgument`].
1072    pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1073        // Convert to u32 ids — same defensive cast the underlying
1074        // `Embedding::forward` does.
1075        if ids.ndim() != 1 {
1076            return Err(FerrotorchError::ShapeMismatch {
1077                message: format!(
1078                    "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1079                    ids.shape()
1080                ),
1081            });
1082        }
1083        let data = ids.data_vec()?;
1084        let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1085        for (i, v) in data.iter().enumerate() {
1086            let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1087                FerrotorchError::InvalidArgument {
1088                    message: format!(
1089                        "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1090                         not representable as f64"
1091                    ),
1092                }
1093            })?;
1094            if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1095                return Err(FerrotorchError::InvalidArgument {
1096                    message: format!(
1097                        "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1098                         is not a non-negative integer"
1099                    ),
1100                });
1101            }
1102            u32_ids.push(f as u32);
1103        }
1104        self.forward_from_ids(&u32_ids)
1105    }
1106
1107    /// Load a HuggingFace `CLIPTextModel` state dict into this module.
1108    ///
1109    /// Accepts both:
1110    ///   - bare-`text_model` layout (no prefix; what the pin script
1111    ///     normalises to).
1112    ///   - full `text_model.<rest>` prefix (what the upstream HF
1113    ///     checkpoint ships).
1114    ///
1115    /// The HF safetensors also ships a non-parameter buffer we
1116    /// explicitly drop:
1117    ///
1118    /// * `text_model.embeddings.position_ids` — a `[1, max_pos]`
1119    ///   `arange(max_pos)` buffer regenerated each forward pass. Recorded
1120    ///   in the [`crate::safetensors_loader::DropReport`].
1121    ///
1122    /// # Errors
1123    ///
1124    /// Forwards whatever each sub-module's `load_state_dict` returns
1125    /// (shape mismatch / strict-mode missing key). Strict mode will
1126    /// surface `text_model.embeddings.position_ids` and any unknown
1127    /// key as errors; callers with a full HF checkpoint must pass
1128    /// `strict=false`.
1129    pub fn load_hf_state_dict(
1130        &mut self,
1131        hf_state: &StateDict<T>,
1132        strict: bool,
1133    ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1134        let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1135        let mut dropped: Vec<String> = Vec::new();
1136        for (k, v) in hf_state {
1137            // Strip the optional `text_model.` prefix.
1138            let after = k.strip_prefix("text_model.").map_or_else(|| k.clone(), str::to_owned);
1139
1140            // `embeddings.position_ids` is a buffer — not a parameter on our
1141            // side. Drop in both modes; record so the pin script can audit.
1142            if after == "embeddings.position_ids" {
1143                dropped.push(k.clone());
1144                continue;
1145            }
1146
1147            let is_known = after.starts_with("embeddings.token_embedding.")
1148                || after.starts_with("embeddings.position_embedding.")
1149                || after.starts_with("encoder.")
1150                || after.starts_with("final_layer_norm.");
1151            if is_known {
1152                remapped.insert(after, v.clone());
1153                continue;
1154            }
1155
1156            if strict {
1157                return Err(FerrotorchError::InvalidArgument {
1158                    message: format!(
1159                        "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1160                         known CLIP text-tower parameter and strict mode is on. \
1161                         Pass strict=false to drop unknown keys."
1162                    ),
1163                });
1164            }
1165            dropped.push(k.clone());
1166        }
1167        dropped.sort();
1168        self.load_state_dict(&remapped, strict)?;
1169        Ok(crate::safetensors_loader::DropReport { dropped })
1170    }
1171}
1172
1173impl<T: Float> Module<T> for ClipTextEncoder<T> {
1174    /// `Module::forward` treats `input` as already-summed embeddings
1175    /// `[1, S, hidden]` and only runs the encoder + final LayerNorm.
1176    /// Real callers should use [`Self::forward_from_ids`] /
1177    /// [`Self::forward_from_id_tensor`].
1178    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1179        let h = self.encoder.forward(input)?;
1180        self.final_layer_norm.forward(&h)
1181    }
1182
1183    fn parameters(&self) -> Vec<&Parameter<T>> {
1184        let mut out = Vec::new();
1185        out.extend(self.embeddings.parameters());
1186        out.extend(self.encoder.parameters());
1187        out.extend(self.final_layer_norm.parameters());
1188        out
1189    }
1190
1191    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1192        let mut out = Vec::new();
1193        out.extend(self.embeddings.parameters_mut());
1194        out.extend(self.encoder.parameters_mut());
1195        out.extend(self.final_layer_norm.parameters_mut());
1196        out
1197    }
1198
1199    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1200        let mut out = Vec::new();
1201        for (n, p) in self.embeddings.named_parameters() {
1202            out.push((format!("embeddings.{n}"), p));
1203        }
1204        for (n, p) in self.encoder.named_parameters() {
1205            out.push((format!("encoder.{n}"), p));
1206        }
1207        for (n, p) in self.final_layer_norm.named_parameters() {
1208            out.push((format!("final_layer_norm.{n}"), p));
1209        }
1210        out
1211    }
1212
1213    fn train(&mut self) {
1214        self.training = true;
1215        self.embeddings.train();
1216        self.encoder.train();
1217        self.final_layer_norm.train();
1218    }
1219
1220    fn eval(&mut self) {
1221        self.training = false;
1222        self.embeddings.eval();
1223        self.encoder.eval();
1224        self.final_layer_norm.eval();
1225    }
1226
1227    fn is_training(&self) -> bool {
1228        self.training
1229    }
1230
1231    fn state_dict(&self) -> StateDict<T> {
1232        self.named_parameters()
1233            .into_iter()
1234            .map(|(n, p)| (n, p.tensor().clone()))
1235            .collect()
1236    }
1237
1238    fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1239        let extract = |prefix: &str| -> StateDict<T> {
1240            let p = format!("{prefix}.");
1241            state
1242                .iter()
1243                .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1244                .collect()
1245        };
1246        if strict {
1247            for k in state.keys() {
1248                if !(k.starts_with("embeddings.")
1249                    || k.starts_with("encoder.")
1250                    || k.starts_with("final_layer_norm."))
1251                {
1252                    return Err(FerrotorchError::InvalidArgument {
1253                        message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1254                    });
1255                }
1256            }
1257        }
1258        self.embeddings
1259            .load_state_dict(&extract("embeddings"), strict)?;
1260        self.encoder
1261            .load_state_dict(&extract("encoder"), strict)?;
1262        self.final_layer_norm
1263            .load_state_dict(&extract("final_layer_norm"), strict)?;
1264        Ok(())
1265    }
1266}
1267
1268// `mul` is re-exported above so downstream features (e.g. embedding
1269// scaling, never used for CLIP-ViT-L) can reach for it without a fresh
1270// import. Suppress the unused-import warning on a vanilla build.
1271#[allow(dead_code)]
1272fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1273    mul(a, b)
1274}
1275
1276// ---------------------------------------------------------------------------
1277// Tests
1278// ---------------------------------------------------------------------------
1279
1280#[cfg(test)]
1281mod tests {
1282    use super::*;
1283
1284    fn tiny_cfg() -> ClipTextConfig {
1285        // 2 heads × 4 dim/head = 8 hidden_size; 16 intermediate; 1 layer;
1286        // 6 positions; tiny vocab.
1287        ClipTextConfig {
1288            hidden_size: 8,
1289            intermediate_size: 16,
1290            num_attention_heads: 2,
1291            num_hidden_layers: 1,
1292            max_position_embeddings: 6,
1293            vocab_size: 32,
1294            layer_norm_eps: 1e-5,
1295        }
1296    }
1297
1298    #[test]
1299    fn sd_v1_5_config_is_canonical() {
1300        let c = ClipTextConfig::sd_v1_5();
1301        assert_eq!(c.hidden_size, 768);
1302        assert_eq!(c.intermediate_size, 3072);
1303        assert_eq!(c.num_attention_heads, 12);
1304        assert_eq!(c.num_hidden_layers, 12);
1305        assert_eq!(c.max_position_embeddings, 77);
1306        assert_eq!(c.vocab_size, 49408);
1307        assert_eq!(c.head_dim(), 64);
1308        c.validate().unwrap();
1309    }
1310
1311    #[test]
1312    fn validate_catches_bad_head_count() {
1313        let mut c = tiny_cfg();
1314        c.num_attention_heads = 3; // 8 % 3 != 0
1315        assert!(c.validate().is_err());
1316    }
1317
1318    #[test]
1319    fn from_json_str_round_trip() {
1320        let json = r#"{
1321            "hidden_size": 768,
1322            "intermediate_size": 3072,
1323            "num_attention_heads": 12,
1324            "num_hidden_layers": 12,
1325            "max_position_embeddings": 77,
1326            "vocab_size": 49408,
1327            "layer_norm_eps": 1e-5,
1328            "hidden_act": "quick_gelu"
1329        }"#;
1330        let c = ClipTextConfig::from_json_str(json).unwrap();
1331        assert_eq!(c.hidden_size, 768);
1332        assert_eq!(c.intermediate_size, 3072);
1333        assert_eq!(c.num_attention_heads, 12);
1334        assert_eq!(c.num_hidden_layers, 12);
1335        assert_eq!(c.max_position_embeddings, 77);
1336    }
1337
1338    #[test]
1339    fn embeddings_forward_shape() {
1340        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1341        let ids = [1u32, 5, 7, 9];
1342        let out = emb.forward_from_ids(&ids).unwrap();
1343        assert_eq!(out.shape(), &[1, 4, 8]);
1344        for &v in out.data().unwrap() {
1345            assert!(v.is_finite(), "embedding non-finite: {v}");
1346        }
1347    }
1348
1349    #[test]
1350    fn embeddings_reject_too_long_sequence() {
1351        let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1352        let ids: Vec<u32> = (0..7).collect(); // > max_position 6
1353        assert!(emb.forward_from_ids(&ids).is_err());
1354    }
1355
1356    #[test]
1357    fn self_attention_forward_shape() {
1358        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1359        let x = Tensor::from_storage(
1360            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1361            vec![1, 5, 8],
1362            false,
1363        )
1364        .unwrap();
1365        let out = attn.forward(&x).unwrap();
1366        assert_eq!(out.shape(), &[1, 5, 8]);
1367        for &v in out.data().unwrap() {
1368            assert!(v.is_finite());
1369        }
1370    }
1371
1372    #[test]
1373    fn self_attention_is_actually_causal() {
1374        // Changing later tokens MUST NOT change earlier rows.
1375        // Build a tensor [1, 4, 8] with the first 2 rows fixed and the
1376        // last 2 rows perturbed across two runs. The first 2 output
1377        // rows must be bit-identical between the runs (within f32
1378        // round-off).
1379        let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1380        let mut a = vec![0.1f32; 4 * 8];
1381        for i in 0..2 * 8 {
1382            a[i] = ((i + 1) as f32).sin();
1383        }
1384        let mut b = a.clone();
1385        // Perturb only rows 2 and 3.
1386        for i in (2 * 8)..(4 * 8) {
1387            b[i] = ((i + 11) as f32).sin();
1388        }
1389        let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1390        let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1391        let oa = attn.forward(&xa).unwrap();
1392        let ob = attn.forward(&xb).unwrap();
1393        let da = oa.data().unwrap();
1394        let db = ob.data().unwrap();
1395        for i in 0..2 * 8 {
1396            assert!(
1397                (da[i] - db[i]).abs() < 1e-5,
1398                "row {} ({}) differs between runs: {} vs {}",
1399                i / 8,
1400                i % 8,
1401                da[i],
1402                db[i]
1403            );
1404        }
1405    }
1406
1407    #[test]
1408    fn mlp_uses_quick_gelu() {
1409        // QuickGELU(x) = x * sigmoid(1.702 * x). Verify the FC1 + GELU
1410        // branch produces this for a known scalar input — we do so
1411        // indirectly by checking that the forward output remains finite
1412        // and the intermediate activation at zero input gives bias.
1413        let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1414        let x = Tensor::from_storage(
1415            TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1416            vec![1, 3, 8],
1417            false,
1418        )
1419        .unwrap();
1420        let out = mlp.forward(&x).unwrap();
1421        assert_eq!(out.shape(), &[1, 3, 8]);
1422        for &v in out.data().unwrap() {
1423            assert!(v.is_finite());
1424        }
1425    }
1426
1427    #[test]
1428    fn encoder_layer_forward_shape() {
1429        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1430        let x = Tensor::from_storage(
1431            TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1432            vec![1, 5, 8],
1433            false,
1434        )
1435        .unwrap();
1436        let out = layer.forward(&x).unwrap();
1437        assert_eq!(out.shape(), &[1, 5, 8]);
1438        for &v in out.data().unwrap() {
1439            assert!(v.is_finite());
1440        }
1441    }
1442
1443    #[test]
1444    fn encoder_layer_named_parameters_use_hf_layout() {
1445        let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1446        let names: Vec<String> = layer.named_parameters().into_iter().map(|(n, _)| n).collect();
1447        for k in [
1448            "layer_norm1.weight",
1449            "layer_norm1.bias",
1450            "self_attn.q_proj.weight",
1451            "self_attn.q_proj.bias",
1452            "self_attn.k_proj.weight",
1453            "self_attn.v_proj.weight",
1454            "self_attn.out_proj.weight",
1455            "self_attn.out_proj.bias",
1456            "layer_norm2.weight",
1457            "mlp.fc1.weight",
1458            "mlp.fc1.bias",
1459            "mlp.fc2.weight",
1460            "mlp.fc2.bias",
1461        ] {
1462            assert!(
1463                names.iter().any(|n| n == k),
1464                "missing parameter key {k:?} in {names:?}"
1465            );
1466        }
1467    }
1468
1469    #[test]
1470    fn tiny_encoder_forward_from_ids_shape() {
1471        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1472        let ids = vec![1u32, 5, 7];
1473        let out = enc.forward_from_ids(&ids).unwrap();
1474        assert_eq!(out.shape(), &[1, 3, 8]);
1475        for &v in out.data().unwrap() {
1476            assert!(v.is_finite());
1477        }
1478    }
1479
1480    #[test]
1481    fn tiny_named_parameters_use_hf_layout() {
1482        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1483        let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1484        for k in [
1485            "embeddings.token_embedding.weight",
1486            "embeddings.position_embedding.weight",
1487            "encoder.layers.0.layer_norm1.weight",
1488            "encoder.layers.0.self_attn.q_proj.weight",
1489            "encoder.layers.0.self_attn.out_proj.bias",
1490            "encoder.layers.0.layer_norm2.bias",
1491            "encoder.layers.0.mlp.fc1.weight",
1492            "encoder.layers.0.mlp.fc2.bias",
1493            "final_layer_norm.weight",
1494            "final_layer_norm.bias",
1495        ] {
1496            assert!(
1497                names.iter().any(|n| n == k),
1498                "missing parameter key {k:?} in {names:?}"
1499            );
1500        }
1501    }
1502
1503    #[test]
1504    fn round_trip_state_dict() {
1505        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1506        let sd = src.state_dict();
1507        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1508        dst.load_state_dict(&sd, true).unwrap();
1509        let ids = vec![2u32, 4, 6];
1510        let a = src.forward_from_ids(&ids).unwrap();
1511        let b = dst.forward_from_ids(&ids).unwrap();
1512        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1513            assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1514        }
1515    }
1516
1517    #[test]
1518    fn load_hf_state_dict_strips_text_model_prefix() {
1519        let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1520        let bare = src.state_dict();
1521        let mut prefixed: StateDict<f32> = HashMap::new();
1522        for (k, v) in bare {
1523            prefixed.insert(format!("text_model.{k}"), v);
1524        }
1525        // Add the position_ids buffer — it should be dropped.
1526        prefixed.insert(
1527            "text_model.embeddings.position_ids".into(),
1528            ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1529        );
1530        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1531        let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1532        assert_eq!(rep.dropped, vec!["text_model.embeddings.position_ids".to_string()]);
1533        let ids = vec![1u32, 2, 3];
1534        let a = src.forward_from_ids(&ids).unwrap();
1535        let b = dst.forward_from_ids(&ids).unwrap();
1536        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1537            assert!((x - y).abs() < 1e-5);
1538        }
1539    }
1540
1541    #[test]
1542    fn load_hf_state_dict_strict_rejects_unknown_key() {
1543        let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1544        let mut sd: StateDict<f32> = HashMap::new();
1545        sd.insert(
1546            "mystery.key".into(),
1547            ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1548        );
1549        assert!(dst.load_hf_state_dict(&sd, true).is_err());
1550    }
1551
1552    #[test]
1553    fn forward_from_id_tensor_matches_forward_from_ids() {
1554        let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1555        let ids = vec![1u32, 5, 7];
1556        let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1557        let a = enc.forward_from_ids(&ids).unwrap();
1558        let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1559        for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1560            assert!((x - y).abs() < 1e-5);
1561        }
1562    }
1563}