1use candle::{DType, Device, Result, Tensor, D};
8use candle_nn as nn;
9use candle_nn::Module;
10
11#[derive(Debug, Clone, Copy)]
12pub enum Activation {
13 QuickGelu,
14 Gelu,
15 GeluErf,
16}
17
18impl Module for Activation {
19 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
20 match self {
21 Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
22 Activation::Gelu => xs.gelu(),
23 Activation::GeluErf => xs.gelu_erf(),
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
29pub struct Config {
30 vocab_size: usize,
31 embed_dim: usize, activation: Activation, intermediate_size: usize,
34 pub max_position_embeddings: usize,
35 pub pad_with: Option<String>,
37 num_hidden_layers: usize,
38 num_attention_heads: usize,
39 #[allow(dead_code)]
40 projection_dim: usize,
41}
42
43impl Config {
44 pub fn v1_5() -> Self {
47 Self {
48 vocab_size: 49408,
49 embed_dim: 768,
50 intermediate_size: 3072,
51 max_position_embeddings: 77,
52 pad_with: None,
53 num_hidden_layers: 12,
54 num_attention_heads: 12,
55 projection_dim: 768,
56 activation: Activation::QuickGelu,
57 }
58 }
59
60 pub fn v2_1() -> Self {
62 Self {
63 vocab_size: 49408,
64 embed_dim: 1024,
65 intermediate_size: 4096,
66 max_position_embeddings: 77,
67 pad_with: Some("!".to_string()),
68 num_hidden_layers: 23,
69 num_attention_heads: 16,
70 projection_dim: 512,
71 activation: Activation::Gelu,
72 }
73 }
74
75 pub fn sdxl() -> Self {
77 Self {
78 vocab_size: 49408,
79 embed_dim: 768,
80 intermediate_size: 3072,
81 max_position_embeddings: 77,
82 pad_with: Some("!".to_string()),
83 num_hidden_layers: 12,
84 num_attention_heads: 12,
85 projection_dim: 768,
86 activation: Activation::QuickGelu,
87 }
88 }
89
90 pub fn sdxl2() -> Self {
92 Self {
93 vocab_size: 49408,
94 embed_dim: 1280,
95 intermediate_size: 5120,
96 max_position_embeddings: 77,
97 pad_with: Some("!".to_string()),
98 num_hidden_layers: 32,
99 num_attention_heads: 20,
100 projection_dim: 1280,
101 activation: Activation::Gelu,
102 }
103 }
104
105 pub fn ssd1b() -> Self {
106 Self::sdxl()
107 }
108
109 pub fn ssd1b2() -> Self {
110 Self::sdxl2()
111 }
112
113 pub fn wuerstchen() -> Self {
115 Self {
116 vocab_size: 49408,
117 embed_dim: 1024,
118 intermediate_size: 4096,
119 max_position_embeddings: 77,
120 pad_with: None,
121 num_hidden_layers: 24,
122 num_attention_heads: 16,
123 projection_dim: 1024,
124 activation: Activation::GeluErf,
125 }
126 }
127
128 pub fn wuerstchen_prior() -> Self {
130 Self {
131 vocab_size: 49408,
132 embed_dim: 1280,
133 intermediate_size: 5120,
134 max_position_embeddings: 77,
135 pad_with: None,
136 num_hidden_layers: 32,
137 num_attention_heads: 20,
138 projection_dim: 512,
139 activation: Activation::GeluErf,
140 }
141 }
142}
143
144#[derive(Debug)]
147struct ClipTextEmbeddings {
148 token_embedding: candle_nn::Embedding,
149 position_embedding: candle_nn::Embedding,
150 position_ids: Tensor,
151}
152
153impl ClipTextEmbeddings {
154 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
155 let token_embedding =
156 candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
157 let position_embedding = candle_nn::embedding(
158 c.max_position_embeddings,
159 c.embed_dim,
160 vs.pp("position_embedding"),
161 )?;
162 let position_ids =
163 Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
164 Ok(ClipTextEmbeddings {
165 token_embedding,
166 position_embedding,
167 position_ids,
168 })
169 }
170}
171
172impl Module for ClipTextEmbeddings {
173 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
174 let token_embedding = self.token_embedding.forward(xs)?;
175 let position_embedding = self.position_embedding.forward(&self.position_ids)?;
176 token_embedding.broadcast_add(&position_embedding)
177 }
178}
179
180#[derive(Debug)]
181struct ClipAttention {
182 k_proj: candle_nn::Linear,
183 v_proj: candle_nn::Linear,
184 q_proj: candle_nn::Linear,
185 out_proj: candle_nn::Linear,
186 head_dim: usize,
187 scale: f64,
188 num_attention_heads: usize,
189}
190
191impl ClipAttention {
192 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
193 let embed_dim = c.embed_dim;
194 let num_attention_heads = c.num_attention_heads;
195 let k_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
196 let v_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
197 let q_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
198 let out_proj = candle_nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
199 let head_dim = embed_dim / num_attention_heads;
200 let scale = (head_dim as f64).powf(-0.5);
201 Ok(ClipAttention {
202 k_proj,
203 v_proj,
204 q_proj,
205 out_proj,
206 head_dim,
207 scale,
208 num_attention_heads,
209 })
210 }
211
212 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
213 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
214 .transpose(1, 2)?
215 .contiguous()
216 }
217
218 fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
219 let in_dtype = xs.dtype();
220 let (bsz, seq_len, embed_dim) = xs.dims3()?;
221 let query_states = (self.q_proj.forward(xs)? * self.scale)?;
222 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
223 let query_states = self
224 .shape(&query_states, seq_len, bsz)?
225 .reshape(proj_shape)?
226 .to_dtype(DType::F32)?;
227 let key_states = self
228 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
229 .reshape(proj_shape)?
230 .to_dtype(DType::F32)?;
231 let value_states = self
232 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
233 .reshape(proj_shape)?
234 .to_dtype(DType::F32)?;
235 let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
236
237 let src_len = key_states.dim(1)?;
238 let attn_weights = attn_weights
239 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
240 .broadcast_add(causal_attention_mask)?;
241 let attn_weights =
242 attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
243 let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
244
245 let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?;
246 let attn_output = attn_output
247 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
248 .transpose(1, 2)?
249 .reshape((bsz, seq_len, embed_dim))?;
250 self.out_proj.forward(&attn_output)
251 }
252}
253
254#[derive(Debug)]
255struct ClipMlp {
256 fc1: candle_nn::Linear,
257 fc2: candle_nn::Linear,
258 activation: Activation,
259}
260
261impl ClipMlp {
262 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
263 let fc1 = candle_nn::linear(c.embed_dim, c.intermediate_size, vs.pp("fc1"))?;
264 let fc2 = candle_nn::linear(c.intermediate_size, c.embed_dim, vs.pp("fc2"))?;
265 Ok(ClipMlp {
266 fc1,
267 fc2,
268 activation: c.activation,
269 })
270 }
271}
272
273impl ClipMlp {
274 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
275 let xs = self.fc1.forward(xs)?;
276 self.fc2.forward(&self.activation.forward(&xs)?)
277 }
278}
279
280#[derive(Debug)]
281struct ClipEncoderLayer {
282 self_attn: ClipAttention,
283 layer_norm1: candle_nn::LayerNorm,
284 mlp: ClipMlp,
285 layer_norm2: candle_nn::LayerNorm,
286}
287
288impl ClipEncoderLayer {
289 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
290 let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
291 let layer_norm1 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm1"))?;
292 let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
293 let layer_norm2 = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("layer_norm2"))?;
294 Ok(ClipEncoderLayer {
295 self_attn,
296 layer_norm1,
297 mlp,
298 layer_norm2,
299 })
300 }
301
302 fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
303 let residual = xs;
304 let xs = self.layer_norm1.forward(xs)?;
305 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
306 let xs = (xs + residual)?;
307
308 let residual = &xs;
309 let xs = self.layer_norm2.forward(&xs)?;
310 let xs = self.mlp.forward(&xs)?;
311 xs + residual
312 }
313}
314
315#[derive(Debug)]
316struct ClipEncoder {
317 layers: Vec<ClipEncoderLayer>,
318}
319
320impl ClipEncoder {
321 fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
322 let vs = vs.pp("layers");
323 let mut layers: Vec<ClipEncoderLayer> = Vec::new();
324 for index in 0..c.num_hidden_layers {
325 let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
326 layers.push(layer)
327 }
328 Ok(ClipEncoder { layers })
329 }
330
331 fn forward(&self, xs: &Tensor, causal_attention_mask: &Tensor) -> Result<Tensor> {
332 let mut xs = xs.clone();
333 for layer in self.layers.iter() {
334 xs = layer.forward(&xs, causal_attention_mask)?;
335 }
336 Ok(xs)
337 }
338}
339
340#[derive(Debug)]
342pub struct ClipTextTransformer {
343 embeddings: ClipTextEmbeddings,
344 encoder: ClipEncoder,
345 final_layer_norm: candle_nn::LayerNorm,
346}
347
348impl ClipTextTransformer {
349 pub fn new(vs: candle_nn::VarBuilder, c: &Config) -> Result<Self> {
350 let vs = vs.pp("text_model");
351 let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
352 let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
353 let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
354 Ok(ClipTextTransformer {
355 embeddings,
356 encoder,
357 final_layer_norm,
358 })
359 }
360
361 fn build_causal_attention_mask(
363 bsz: usize,
364 seq_len: usize,
365 mask_after: usize,
366 device: &Device,
367 ) -> Result<Tensor> {
368 let mask: Vec<_> = (0..seq_len)
369 .flat_map(|i| {
370 (0..seq_len).map(move |j| {
371 if j > i || j > mask_after {
372 f32::MIN
373 } else {
374 0.
375 }
376 })
377 })
378 .collect();
379 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
380 mask.broadcast_as((bsz, seq_len, seq_len))
381 }
382
383 pub fn forward_with_mask(&self, xs: &Tensor, mask_after: usize) -> Result<Tensor> {
384 let (bsz, seq_len) = xs.dims2()?;
385 let xs = self.embeddings.forward(xs)?;
386 let causal_attention_mask =
387 Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
388 let xs = self.encoder.forward(&xs, &causal_attention_mask)?;
389 self.final_layer_norm.forward(&xs)
390 }
391
392 pub fn forward_until_encoder_layer(
393 &self,
394 xs: &Tensor,
395 mask_after: usize,
396 until_layer: isize,
397 ) -> Result<(Tensor, Tensor)> {
398 let (bsz, seq_len) = xs.dims2()?;
399 let xs = self.embeddings.forward(xs)?;
400 let causal_attention_mask =
401 Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?;
402
403 let mut xs = xs.clone();
404 let mut intermediate = xs.clone();
405
406 let until_layer = if until_layer < 0 {
408 self.encoder.layers.len() as isize + until_layer
409 } else {
410 until_layer
411 } as usize;
412
413 for (layer_id, layer) in self.encoder.layers.iter().enumerate() {
414 xs = layer.forward(&xs, &causal_attention_mask)?;
415 if layer_id == until_layer {
416 intermediate = xs.clone();
417 }
418 }
419
420 Ok((self.final_layer_norm.forward(&xs)?, intermediate))
421 }
422}
423
424impl Module for ClipTextTransformer {
425 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
426 self.forward_with_mask(xs, usize::MAX)
427 }
428}