1use candle_core::{Result, Tensor, D};
8use candle_nn as nn;
9use candle_nn::Module;
10
11use crate::config::DiffusionConfig;
12
13#[derive(Debug)]
19struct ClipAttention {
20 q_proj: nn::Linear,
21 k_proj: nn::Linear,
22 v_proj: nn::Linear,
23 out_proj: nn::Linear,
24 num_heads: usize,
25 head_dim: usize,
26 scale: f64,
27}
28
29impl ClipAttention {
30 fn new(vs: nn::VarBuilder, embed_dim: usize, num_heads: usize) -> Result<Self> {
31 let head_dim = embed_dim / num_heads;
32 let scale = 1.0 / (head_dim as f64).sqrt();
33 let q_proj = nn::linear(embed_dim, embed_dim, vs.pp("q_proj"))?;
34 let k_proj = nn::linear(embed_dim, embed_dim, vs.pp("k_proj"))?;
35 let v_proj = nn::linear(embed_dim, embed_dim, vs.pp("v_proj"))?;
36 let out_proj = nn::linear(embed_dim, embed_dim, vs.pp("out_proj"))?;
37 Ok(Self {
38 q_proj,
39 k_proj,
40 v_proj,
41 out_proj,
42 num_heads,
43 head_dim,
44 scale,
45 })
46 }
47
48 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
49 let (b, seq_len, _) = xs.dims3()?;
50 let q = self.q_proj.forward(xs)?;
51 let k = self.k_proj.forward(xs)?;
52 let v = self.v_proj.forward(xs)?;
53
54 let reshape = |t: Tensor| -> Result<Tensor> {
55 t.reshape((b, seq_len, self.num_heads, self.head_dim))?
56 .transpose(1, 2)
57 };
58
59 let q = reshape(q)?;
60 let k = reshape(k)?;
61 let v = reshape(v)?;
62
63 let attn = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? * self.scale)?;
64 let attn = nn::ops::softmax_last_dim(&attn)?;
65 let out = attn.matmul(&v)?;
66
67 let out = out.transpose(1, 2)?.reshape((b, seq_len, ()))?;
68 self.out_proj.forward(&out)
69 }
70}
71
72#[derive(Debug)]
74struct ClipEncoderLayer {
75 layer_norm1: nn::LayerNorm,
76 self_attn: ClipAttention,
77 layer_norm2: nn::LayerNorm,
78 fc1: nn::Linear,
79 fc2: nn::Linear,
80}
81
82impl ClipEncoderLayer {
83 fn new(
84 vs: nn::VarBuilder,
85 embed_dim: usize,
86 num_heads: usize,
87 intermediate_size: usize,
88 ) -> Result<Self> {
89 let layer_norm1 = nn::layer_norm(embed_dim, 1e-5, vs.pp("layer_norm1"))?;
90 let self_attn = ClipAttention::new(vs.pp("self_attn"), embed_dim, num_heads)?;
91 let layer_norm2 = nn::layer_norm(embed_dim, 1e-5, vs.pp("layer_norm2"))?;
92 let fc1 = nn::linear(embed_dim, intermediate_size, vs.pp("mlp.fc1"))?;
93 let fc2 = nn::linear(intermediate_size, embed_dim, vs.pp("mlp.fc2"))?;
94 Ok(Self {
95 layer_norm1,
96 self_attn,
97 layer_norm2,
98 fc1,
99 fc2,
100 })
101 }
102}
103
104impl Module for ClipEncoderLayer {
105 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
106 let residual = xs;
107 let xs = self.layer_norm1.forward(xs)?;
108 let xs = self.self_attn.forward(&xs)?;
109 let xs = (xs + residual)?;
110
111 let residual = &xs;
112 let h = self.layer_norm2.forward(&xs)?;
113 let h = self.fc1.forward(&h)?.gelu()?;
114 let h = self.fc2.forward(&h)?;
115 h + residual
116 }
117}
118
119#[derive(Debug, Clone)]
125pub struct ClipVisionConfig {
126 pub embed_dim: usize,
127 pub num_heads: usize,
128 pub num_layers: usize,
129 pub intermediate_size: usize,
130 pub image_size: usize,
131 pub patch_size: usize,
132}
133
134impl Default for ClipVisionConfig {
135 fn default() -> Self {
136 Self {
137 embed_dim: 1280,
138 num_heads: 16,
139 num_layers: 32,
140 intermediate_size: 5120,
141 image_size: 224,
142 patch_size: 14,
143 }
144 }
145}
146
147impl ClipVisionConfig {
148 pub fn num_patches(&self) -> usize {
149 (self.image_size / self.patch_size).pow(2)
150 }
151}
152
153#[derive(Debug)]
157pub struct ClipImageEncoder {
158 patch_embedding: nn::Conv2d,
159 position_embedding: nn::Embedding,
160 class_embedding: Tensor,
161 pre_layernorm: nn::LayerNorm,
162 encoder_layers: Vec<ClipEncoderLayer>,
163 post_layernorm: nn::LayerNorm,
164 ip_projection: Option<nn::Linear>,
166 config: ClipVisionConfig,
167}
168
169impl ClipImageEncoder {
170 pub fn new(
172 vs: nn::VarBuilder,
173 clip_config: &ClipVisionConfig,
174 project_to: Option<usize>,
175 ) -> Result<Self> {
176 let embed_dim = clip_config.embed_dim;
177
178 let patch_embedding = nn::conv2d(
179 3,
180 embed_dim,
181 clip_config.patch_size,
182 nn::Conv2dConfig {
183 stride: clip_config.patch_size,
184 ..Default::default()
185 },
186 vs.pp("embeddings.patch_embedding"),
187 )?;
188
189 let num_positions = clip_config.num_patches() + 1; let position_embedding = nn::embedding(
191 num_positions,
192 embed_dim,
193 vs.pp("embeddings.position_embedding"),
194 )?;
195
196 let class_embedding = vs.get((1, 1, embed_dim), "embeddings.class_embedding")?;
197
198 let pre_layernorm = nn::layer_norm(embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
199
200 let vs_layers = vs.pp("encoder.layers");
201 let mut encoder_layers = Vec::with_capacity(clip_config.num_layers);
202 for i in 0..clip_config.num_layers {
203 encoder_layers.push(ClipEncoderLayer::new(
204 vs_layers.pp(i.to_string()),
205 embed_dim,
206 clip_config.num_heads,
207 clip_config.intermediate_size,
208 )?);
209 }
210
211 let post_layernorm = nn::layer_norm(embed_dim, 1e-5, vs.pp("post_layernorm"))?;
212
213 let ip_projection = if let Some(target_dim) = project_to {
214 Some(nn::linear(embed_dim, target_dim, vs.pp("ip_projection"))?)
215 } else {
216 None
217 };
218
219 Ok(Self {
220 patch_embedding,
221 position_embedding,
222 class_embedding,
223 pre_layernorm,
224 encoder_layers,
225 post_layernorm,
226 ip_projection,
227 config: clip_config.clone(),
228 })
229 }
230
231 pub fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
237 let batch_size = pixel_values.dim(0)?;
238 let device = pixel_values.device();
239 let dtype = pixel_values.dtype();
240
241 let patches = self.patch_embedding.forward(pixel_values)?;
243 let (_, _, h, w) = patches.dims4()?;
244 let num_patches = h * w;
245
246 let patches = patches.flatten(2, 3)?.transpose(1, 2)?;
248
249 let cls = self
251 .class_embedding
252 .broadcast_as((batch_size, 1, self.config.embed_dim))?;
253 let embeddings = Tensor::cat(&[cls.to_dtype(dtype)?, patches], 1)?;
254
255 let position_ids = Tensor::arange(0u32, (num_patches + 1) as u32, device)?;
257 let pos_embeds = self.position_embedding.forward(&position_ids)?;
258 let embeddings = (embeddings + pos_embeds.unsqueeze(0)?)?;
259
260 let mut hidden = self.pre_layernorm.forward(&embeddings)?;
262
263 for layer in &self.encoder_layers {
265 hidden = layer.forward(&hidden)?;
266 }
267
268 hidden = self.post_layernorm.forward(&hidden)?;
270
271 if let Some(ref proj) = self.ip_projection {
273 hidden = proj.forward(&hidden)?;
274 }
275
276 Ok(hidden)
277 }
278}
279
280pub fn build_clip_encoder(
282 vs: nn::VarBuilder,
283 config: &DiffusionConfig,
284) -> Result<ClipImageEncoder> {
285 let clip_config = ClipVisionConfig::default();
286 ClipImageEncoder::new(vs, &clip_config, Some(config.cross_attention_dim))
287}