1use candle_core::{IndexOp, Module, Result, Tensor, D};
7use candle_nn::{conv2d, linear, linear_no_bias, Conv2d, Linear};
8use candle_nn::{layer_norm, LayerNorm, VarBuilder};
9
10#[derive(Debug, Clone, serde::Deserialize)]
12pub struct Config {
13 pub hidden_size: usize,
14 pub num_hidden_layers: usize,
15 pub num_attention_heads: usize,
16 pub intermediate_size: usize,
17 pub hidden_act: candle_nn::Activation,
18 pub layer_norm_eps: f64,
19 pub image_size: usize,
20 pub patch_size: usize,
21 pub num_channels: usize,
22 pub qkv_bias: bool,
23}
24
25impl Config {
26 pub fn vit_base_patch16_224() -> Self {
28 Self {
29 hidden_size: 768,
30 num_hidden_layers: 12,
31 num_attention_heads: 12,
32 intermediate_size: 3072,
33 hidden_act: candle_nn::Activation::Gelu,
34 layer_norm_eps: 1e-12,
35 image_size: 224,
36 patch_size: 16,
37 num_channels: 3,
38 qkv_bias: true,
39 }
40 }
41
42 pub fn vit_base_subcell() -> Self {
43 Self {
44 hidden_size: 768,
45 num_hidden_layers: 12,
46 num_attention_heads: 12,
47 intermediate_size: 3072,
48 hidden_act: candle_nn::Activation::Gelu,
49 layer_norm_eps: 1e-12,
50 image_size: 448,
51 patch_size: 16,
52 num_channels: 3,
53 qkv_bias: true,
54 }
55 }
56
57 pub fn vit_base_scdino() -> Self {
58 Self {
59 hidden_size: 384,
60 num_hidden_layers: 12,
61 num_attention_heads: 6,
62 intermediate_size: 1536,
63 hidden_act: candle_nn::Activation::Gelu,
64 layer_norm_eps: 1e-6,
65 image_size: 224,
66 patch_size: 16,
67 num_channels: 3,
68 qkv_bias: true,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
74struct PatchEmbeddings {
75 num_patches: usize,
76 projection: Conv2d,
77}
78
79impl PatchEmbeddings {
80 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
81 let image_size = cfg.image_size;
82 let patch_size = cfg.patch_size;
83 let num_patches = (image_size / patch_size) * (image_size / patch_size);
84 let conv_cfg = candle_nn::Conv2dConfig {
85 stride: patch_size,
86 ..Default::default()
87 };
88 let projection = conv2d(
89 cfg.num_channels,
90 cfg.hidden_size,
91 patch_size,
92 conv_cfg,
93 vb.pp("projection"),
94 )?;
95 Ok(Self {
96 num_patches,
97 projection,
98 })
99 }
100}
101
102impl Module for PatchEmbeddings {
103 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
104 let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
105 self.projection
106 .forward(pixel_values)?
107 .flatten_from(2)?
108 .transpose(1, 2)
109 }
110}
111
112#[derive(Debug, Clone)]
113pub struct Embeddings {
114 cls_token: Tensor,
115 mask_token: Option<Tensor>,
116 patch_embeddings: PatchEmbeddings,
117 position_embeddings: Tensor,
118 hidden_size: usize,
119}
120
121impl Embeddings {
122 pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
123 let hidden_size = cfg.hidden_size;
124 let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
125 let mask_token = if use_mask_token {
126 Some(vb.get((1, 1, hidden_size), "mask_token")?)
127 } else {
128 None
129 };
130 let patch_embeddings = PatchEmbeddings::new(cfg, vb.pp("patch_embeddings"))?;
131 let num_patches = patch_embeddings.num_patches;
132 let position_embeddings =
133 vb.get((1, num_patches + 1, hidden_size), "position_embeddings")?;
134 Ok(Self {
135 cls_token,
136 mask_token,
137 patch_embeddings,
138 position_embeddings,
139 hidden_size,
140 })
141 }
142
143 fn interpolate_pos_encoding(
144 &self,
145 _embeddings: &Tensor,
146 _height: usize,
147 _width: usize,
148 ) -> Result<Tensor> {
149 todo!()
150 }
151
152 pub fn forward(
153 &self,
154 pixel_values: &Tensor,
155 bool_masked_pos: Option<&Tensor>,
156 interpolate_pos_encoding: bool,
157 ) -> Result<Tensor> {
158 let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
159 let embeddings = self.patch_embeddings.forward(pixel_values)?;
160 let embeddings = match (bool_masked_pos, &self.mask_token) {
161 (None, _) => embeddings,
162 (Some(_), None) => candle_core::bail!("bool_masked_pos set without mask_token"),
163 (Some(bool_masked_pos), Some(mask_tokens)) => {
164 let seq_len = embeddings.dim(1)?;
165 let mask_tokens = mask_tokens.broadcast_as((b_size, seq_len, self.hidden_size))?;
166 let mask = bool_masked_pos
167 .unsqueeze(D::Minus1)?
168 .to_dtype(mask_tokens.dtype())?;
169 ((mask_tokens * &mask)? - (embeddings * (mask - 1.)?)?)?
170 }
171 };
172 let cls_tokens = self.cls_token.broadcast_as((b_size, 1, self.hidden_size))?;
173 let embeddings = Tensor::cat(&[&cls_tokens, &embeddings], 1)?;
174 if interpolate_pos_encoding {
175 let pos = self.interpolate_pos_encoding(&embeddings, height, width)?;
176 embeddings.broadcast_add(&pos)
177 } else {
178 embeddings.broadcast_add(&self.position_embeddings)
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
184struct SelfAttention {
185 query: Linear,
186 key: Linear,
187 value: Linear,
188 num_attention_heads: usize,
189 attention_head_size: usize,
190}
191
192impl SelfAttention {
193 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
194 let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
195 let num_attention_heads = cfg.num_attention_heads;
196 let all_head_size = num_attention_heads * attention_head_size;
197 let linear = |name| {
198 if cfg.qkv_bias {
199 linear(cfg.hidden_size, all_head_size, vb.pp(name))
200 } else {
201 linear_no_bias(cfg.hidden_size, all_head_size, vb.pp(name))
202 }
203 };
204 let query = linear("query")?;
205 let key = linear("key")?;
206 let value = linear("value")?;
207 Ok(Self {
208 query,
209 key,
210 value,
211 num_attention_heads,
212 attention_head_size,
213 })
214 }
215
216 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
217 let (b_size, seq_len, _) = xs.dims3()?;
218 xs.reshape((
219 b_size,
220 seq_len,
221 self.num_attention_heads,
222 self.attention_head_size,
223 ))?
224 .permute((0, 2, 1, 3))
225 }
226}
227
228impl Module for SelfAttention {
229 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
230 let query = self.query.forward(xs)?;
231 let key = self.key.forward(xs)?;
232 let value = self.value.forward(xs)?;
233
234 let query = self.transpose_for_scores(&query)?.contiguous()?;
235 let key = self.transpose_for_scores(&key)?.contiguous()?;
236 let value = self.transpose_for_scores(&value)?.contiguous()?;
237
238 let attention_scores =
239 (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
240 let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
241 attention_probs
242 .matmul(&value)?
243 .permute((0, 2, 1, 3))?
244 .contiguous()?
245 .flatten_from(D::Minus2)
246 }
247}
248
249#[derive(Debug, Clone)]
250struct SelfOutput {
251 dense: Linear,
252}
253
254impl SelfOutput {
255 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
256 let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
257 Ok(Self { dense })
258 }
259}
260
261impl Module for SelfOutput {
262 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
263 xs.apply(&self.dense)
264 }
265}
266
267#[derive(Debug, Clone)]
268struct Attention {
269 attention: SelfAttention,
270 output: SelfOutput,
271}
272
273impl Attention {
274 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
275 let attention = SelfAttention::new(cfg, vb.pp("attention"))?;
276 let output = SelfOutput::new(cfg, vb.pp("output"))?;
277 Ok(Self { attention, output })
278 }
279}
280
281impl Module for Attention {
282 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
283 xs.apply(&self.attention)?.apply(&self.output)
284 }
285}
286
287#[derive(Debug, Clone)]
288struct Intermediate {
289 dense: Linear,
290 intermediate_act_fn: candle_nn::Activation,
291}
292
293impl Intermediate {
294 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
295 let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
296 Ok(Self {
297 dense,
298 intermediate_act_fn: cfg.hidden_act,
299 })
300 }
301}
302
303impl Module for Intermediate {
304 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
305 xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
306 }
307}
308
309#[derive(Debug, Clone)]
310struct Output {
311 dense: Linear,
312}
313
314impl Output {
315 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
316 let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
317 Ok(Self { dense })
318 }
319
320 fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
321 xs.apply(&self.dense)? + input_tensor
322 }
323}
324
325#[derive(Debug, Clone)]
326struct Layer {
327 attention: Attention,
328 intermediate: Intermediate,
329 output: Output,
330 layernorm_before: LayerNorm,
331 layernorm_after: LayerNorm,
332}
333
334impl Layer {
335 fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
336 let attention = Attention::new(cfg, vb.pp("attention"))?;
337 let intermediate = Intermediate::new(cfg, vb.pp("intermediate"))?;
338 let output = Output::new(cfg, vb.pp("output"))?;
339 let h_sz = cfg.hidden_size;
340 let layernorm_before = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_before"))?;
341 let layernorm_after = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_after"))?;
342 Ok(Self {
343 attention,
344 intermediate,
345 output,
346 layernorm_after,
347 layernorm_before,
348 })
349 }
350}
351
352impl Module for Layer {
353 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
354 let xs = (xs.apply(&self.layernorm_before)?.apply(&self.attention)? + xs)?;
355 let ys = xs.apply(&self.layernorm_after)?.apply(&self.intermediate)?;
356 self.output.forward(&ys, &xs)
357 }
358}
359
360#[derive(Debug, Clone)]
361pub struct Encoder {
362 layers: Vec<Layer>,
363}
364
365impl Encoder {
366 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
367 let vb = vb.pp("layer");
368 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
369 for i in 0..cfg.num_hidden_layers {
370 let layer = Layer::new(cfg, vb.pp(i))?;
371 layers.push(layer)
372 }
373 Ok(Self { layers })
374 }
375}
376
377impl Module for Encoder {
378 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
379 let mut xs = xs.clone();
380 for layer in self.layers.iter() {
381 xs = xs.apply(layer)?
382 }
383 Ok(xs)
384 }
385}
386
387#[derive(Debug, Clone)]
388pub struct StandardVisionTransformer {
389 embeddings: Embeddings,
390 encoder: Encoder,
391 layernorm: LayerNorm,
392}
393
394impl StandardVisionTransformer {
395 pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
396 let vb_v = vb.pp("vit");
397 let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
398 let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
399 let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
400 Ok(Self {
401 embeddings,
402 encoder,
403 layernorm,
404 })
405 }
406
407 pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
408 let embedding_output = self.embeddings.forward(xs, None, false)?;
409 let encoder_outputs = self.encoder.forward(&embedding_output)?;
410 encoder_outputs.i((.., 0, ..))?.apply(&self.layernorm)
411 }
412}