Skip to main content

candle_transformers/models/stable_diffusion/
clip.rs

1//! Contrastive Language-Image Pre-Training
2//!
3//! Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
4//! pairs of images with related texts.
5//!
6//! - [CLIP](https://github.com/openai/CLIP)
7use candle::{DType, Device, Result, Tensor, D};
8use candle_nn as nn;
9use candle_nn::Module;
10
11#[derive(Debug, Clone, Copy)]
12pub enum Activation {
13    QuickGelu,
14    Gelu,
15    GeluErf,
16}
17
18impl Module for Activation {
19    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
20        match self {
21            Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
22            Activation::Gelu => xs.gelu(),
23            Activation::GeluErf => xs.gelu_erf(),
24        }
25    }
26}
27
28#[derive(Debug, Clone)]
29pub struct Config {
30    vocab_size: usize,
31    embed_dim: usize,       // aka config.hidden_size
32    activation: Activation, // aka config.hidden_act
33    intermediate_size: usize,
34    pub max_position_embeddings: usize,
35    // The character to use for padding, use EOS when not set.
36    pub pad_with: Option<String>,
37    num_hidden_layers: usize,
38    num_attention_heads: usize,
39    #[allow(dead_code)]
40    projection_dim: usize,
41}
42
43impl Config {
44    // The config details can be found in the "text_config" section of this json file:
45    // https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
46    pub fn v1_5() -> Self {
47        Self {
48            vocab_size: 49408,
49            embed_dim: 768,
50            intermediate_size: 3072,
51            max_position_embeddings: 77,
52            pad_with: None,
53            num_hidden_layers: 12,
54            num_attention_heads: 12,
55            projection_dim: 768,
56            activation: Activation::QuickGelu,
57        }
58    }
59
60    // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/text_encoder/config.json
61    pub fn v2_1() -> Self {
62        Self {
63            vocab_size: 49408,
64            embed_dim: 1024,
65            intermediate_size: 4096,
66            max_position_embeddings: 77,
67            pad_with: Some("!".to_string()),
68            num_hidden_layers: 23,
69            num_attention_heads: 16,
70            projection_dim: 512,
71            activation: Activation::Gelu,
72        }
73    }
74
75    // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
76    pub fn sdxl() -> Self {
77        Self {
78            vocab_size: 49408,
79            embed_dim: 768,
80            intermediate_size: 3072,
81            max_position_embeddings: 77,
82            pad_with: Some("!".to_string()),
83            num_hidden_layers: 12,
84            num_attention_heads: 12,
85            projection_dim: 768,
86            activation: Activation::QuickGelu,
87        }
88    }
89
90    // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
91    pub fn sdxl2() -> Self {
92        Self {
93            vocab_size: 49408,
94            embed_dim: 1280,
95            intermediate_size: 5120,
96            max_position_embeddings: 77,
97            pad_with: Some("!".to_string()),
98            num_hidden_layers: 32,
99            num_attention_heads: 20,
100            projection_dim: 1280,
101            activation: Activation::Gelu,
102        }
103    }
104
105    pub fn ssd1b() -> Self {
106        Self::sdxl()
107    }
108
109    pub fn ssd1b2() -> Self {
110        Self::sdxl2()
111    }
112
113    // https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
114    pub fn wuerstchen() -> Self {
115        Self {
116            vocab_size: 49408,
117            embed_dim: 1024,
118            intermediate_size: 4096,
119            max_position_embeddings: 77,
120            pad_with: None,
121            num_hidden_layers: 24,
122            num_attention_heads: 16,
123            projection_dim: 1024,
124            activation: Activation::GeluErf,
125        }
126    }
127
128    // https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/text_encoder/config.json
129    pub fn wuerstchen_prior() -> Self {
130        Self {
131            vocab_size: 49408,
132            embed_dim: 1280,
133            intermediate_size: 5120,
134            max_position_embeddings: 77,
135            pad_with: None,
136            num_hidden_layers: 32,
137            num_attention_heads: 20,
138            projection_dim: 512,
139            activation: Activation::GeluErf,
140        }
141    }
142}
143
144// CLIP Text Model
145// https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py
146#[derive(Debug)]
147struct ClipTextEmbeddings {
148    token_embedding: candle_nn::Embedding,
149    position_embedding: candle_nn::Embedding,
150    position_ids: Tensor,
151}
152
153impl ClipTextEmbeddings {
154    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
155        let token_embedding =
156            candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
157        let position_embedding = candle_nn::embedding(
158            c.max_position_embeddings,
159            c.embed_dim,
160            vs.pp("position_embedding"),
161        )?;
162        let position_ids =
163            Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
164        Ok(ClipTextEmbeddings {
165            token_embedding,
166            position_embedding,
167            position_ids,
168        })
169    }
170}
171
172impl Module for ClipTextEmbeddings {
173    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
174        let token_embedding = self.token_embedding.forward(xs)?;
175        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
176        token_embedding.broadcast_add(&position_embedding)
177    }
178}
179
180#[derive(Debug)]
181struct ClipAttention {
182    k_proj: candle_nn::Linear,
183    v_proj: candle_nn::Linear,
184    q_proj: candle_nn::Linear,
185    out_proj: candle_nn::Linear,
186    head_dim: usize,
187    scale: f64,
188    num_attention_heads: usize,
189}
190
191impl ClipAttention {
192    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
193        let embed_dim = c.embed_dim;
194        let num_attention_heads = c.num_attention_heads;
195        let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
196        let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
197        let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
198        let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
199        let head_dim = embed_dim / num_attention_heads;
200        let scale = (head_dim as f64).powf(-0.5);
201        Ok(ClipAttention {
202            k_proj,
203            v_proj,
204            q_proj,
205            out_proj,
206            head_dim,
207            scale,
208            num_attention_heads,
209        })
210    }
211
212    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
213        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
214            .transpose(1, 2)?
215            .contiguous()
216    }
217
218    fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
219        let in_dtype = xs.dtype();
220        let (bsz, seq_len, embed_dim) = xs.dims3()?;
221        let query_states = (self.q_proj.forward(xs)? * self.scale)?;
222        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
223        let query_states = self
224            .shape(&query_states, seq_len, bsz)?
225            .reshape(proj_shape)?
226            .to_dtype(DType::F32)?;
227        let key_states = self
228            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
229            .reshape(proj_shape)?
230            .to_dtype(DType::F32)?;
231        let value_states = self
232            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
233            .reshape(proj_shape)?
234            .to_dtype(DType::F32)?;
235        let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
236
237        let src_len = key_states.dim(1)?;
238        let attn_weights = attn_weights
239            .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
240            .broadcast_add(causal_attention_mask)?;
241        let attn_weights =
242            attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
243        let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
244
245        let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
246        let attn_output = attn_output
247            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
248            .transpose(1, 2)?
249            .reshape((bsz, seq_len, embed_dim))?;
250        self.out_proj.forward(&attn_output)
251    }
252}
253
254#[derive(Debug)]
255struct ClipMlp {
256    fc1: candle_nn::Linear,
257    fc2: candle_nn::Linear,
258    activation: Activation,
259}
260
261impl ClipMlp {
262    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
263        let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
264        let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
265        Ok(ClipMlp {
266            fc1,
267            fc2,
268            activation: c.activation,
269        })
270    }
271}
272
273impl ClipMlp {
274    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275        let xs = self.fc1.forward(xs)?;
276        self.fc2.forward(&self.activation.forward(&xs)?)
277    }
278}
279
280#[derive(Debug)]
281struct ClipEncoderLayer {
282    self_attn: ClipAttention,
283    layer_norm1: candle_nn::LayerNorm,
284    mlp: ClipMlp,
285    layer_norm2: candle_nn::LayerNorm,
286}
287
288impl ClipEncoderLayer {
289    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
290        let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
291        let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
292        let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
293        let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
294        Ok(ClipEncoderLayer {
295            self_attn,
296            layer_norm1,
297            mlp,
298            layer_norm2,
299        })
300    }
301
302    fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
303        let residual = xs;
304        let xs = self.layer_norm1.forward(xs)?;
305        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
306        let xs = (xs + residual)?;
307
308        let residual = &xs;
309        let xs = self.layer_norm2.forward(&xs)?;
310        let xs = self.mlp.forward(&xs)?;
311        xs + residual
312    }
313}
314
315#[derive(Debug)]
316struct ClipEncoder {
317    layers: Vec<ClipEncoderLayer>,
318}
319
320impl ClipEncoder {
321    fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
322        let vs = vs.pp("layers");
323        let mut layers: Vec<ClipEncoderLayer> = Vec::new();
324        for index in 0..c.num_hidden_layers {
325            let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
326            layers.push(layer)
327        }
328        Ok(ClipEncoder { layers })
329    }
330
331    fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
332        let mut xs = xs.clone();
333        for layer in self.layers.iter() {
334            xs = layer.forward(&xs, causal_attention_mask)?;
335        }
336        Ok(xs)
337    }
338}
339
340/// A CLIP transformer based model.
341#[derive(Debug)]
342pub struct ClipTextTransformer {
343    embeddings: ClipTextEmbeddings,
344    encoder: ClipEncoder,
345    final_layer_norm: candle_nn::LayerNorm,
346}
347
348impl ClipTextTransformer {
349    pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
350        let vs = vs.pp("text_model");
351        let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
352        let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
353        let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
354        Ok(ClipTextTransformer {
355            embeddings,
356            encoder,
357            final_layer_norm,
358        })
359    }
360
361    // https://github.com/huggingface/transformers/blob/674f750a57431222fa2832503a108df3badf1564/src/transformers/models/clip/modeling_clip.py#L678
362    fn build_causal_attention_mask(
363        bsz: usize,
364        seq_len: usize,
365        mask_after: usize,
366        device: &Device,
367    ) -> Result<Tensor> {
368        let mask: Vec<_> = (0..seq_len)
369            .flat_map(|i| {
370                (0..seq_len).map(move |j| {
371                    if j > i || j > mask_after {
372                        f32::MIN
373                    } else {
374                        0.
375                    }
376                })
377            })
378            .collect();
379        let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
380        mask.broadcast_as((bsz, seq_len, seq_len))
381    }
382
383    pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
384        let (bsz, seq_len) = xs.dims2()?;
385        let xs = self.embeddings.forward(xs)?;
386        let causal_attention_mask =
387            Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
388        let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
389        self.final_layer_norm.forward(&xs)
390    }
391
392    pub fn forward_until_encoder_layer(
393        &self,
394        xs: &Tensor,
395        mask_after: usize,
396        until_layer: isize,
397    ) -> Result<(Tensor, Tensor)> {
398        let (bsz, seq_len) = xs.dims2()?;
399        let xs = self.embeddings.forward(xs)?;
400        let causal_attention_mask =
401            Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
402
403        let mut xs = xs.clone();
404        let mut intermediate = xs.clone();
405
406        // Modified encoder.forward that returns the intermediate tensor along with final output.
407        let until_layer = if until_layer < 0 {
408            self.encoder.layers.len() as isize + until_layer
409        } else {
410            until_layer
411        } as usize;
412
413        for (layer_id, layer) in self.encoder.layers.iter().enumerate() {
414            xs = layer.forward(&xs, &causal_attention_mask)?;
415            if layer_id == until_layer {
416                intermediate = xs.clone();
417            }
418        }
419
420        Ok((self.final_layer_norm.forward(&xs)?, intermediate))
421    }
422}
423
424impl Module for ClipTextTransformer {
425    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426        self.forward_with_mask(xs, usize::MAX)
427    }
428}