Skip to main content

pineapple_neural/models/
vit_standard.rs

1// NOTE: This implemention of a vision transformer was sourced from huggingface candle github
2// repository. A direct copy was requried to modify the forward pass as methods were private.
3//
4// https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/vit.rs
5
6use candle_core::{IndexOp, Module, Result, Tensor, D};
7use candle_nn::{conv2d, linear, linear_no_bias, Conv2d, Linear};
8use candle_nn::{layer_norm, LayerNorm, VarBuilder};
9
10// https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/configuration_vit.py
11#[derive(Debug, Clone, serde::Deserialize)]
12pub struct Config {
13    pub hidden_size: usize,
14    pub num_hidden_layers: usize,
15    pub num_attention_heads: usize,
16    pub intermediate_size: usize,
17    pub hidden_act: candle_nn::Activation,
18    pub layer_norm_eps: f64,
19    pub image_size: usize,
20    pub patch_size: usize,
21    pub num_channels: usize,
22    pub qkv_bias: bool,
23}
24
25impl Config {
26    // https://huggingface.co/google/vit-base-patch16-224/blob/main/config.json
27    pub fn vit_base_patch16_224() -> Self {
28        Self {
29            hidden_size: 768,
30            num_hidden_layers: 12,
31            num_attention_heads: 12,
32            intermediate_size: 3072,
33            hidden_act: candle_nn::Activation::Gelu,
34            layer_norm_eps: 1e-12,
35            image_size: 224,
36            patch_size: 16,
37            num_channels: 3,
38            qkv_bias: true,
39        }
40    }
41
42    pub fn vit_base_subcell() -> Self {
43        Self {
44            hidden_size: 768,
45            num_hidden_layers: 12,
46            num_attention_heads: 12,
47            intermediate_size: 3072,
48            hidden_act: candle_nn::Activation::Gelu,
49            layer_norm_eps: 1e-12,
50            image_size: 448,
51            patch_size: 16,
52            num_channels: 3,
53            qkv_bias: true,
54        }
55    }
56
57    pub fn vit_base_scdino() -> Self {
58        Self {
59            hidden_size: 384,
60            num_hidden_layers: 12,
61            num_attention_heads: 6,
62            intermediate_size: 1536,
63            hidden_act: candle_nn::Activation::Gelu,
64            layer_norm_eps: 1e-6,
65            image_size: 224,
66            patch_size: 16,
67            num_channels: 3,
68            qkv_bias: true,
69        }
70    }
71}
72
73#[derive(Debug, Clone)]
74struct PatchEmbeddings {
75    num_patches: usize,
76    projection: Conv2d,
77}
78
79impl PatchEmbeddings {
80    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
81        let image_size = cfg.image_size;
82        let patch_size = cfg.patch_size;
83        let num_patches = (image_size / patch_size) * (image_size / patch_size);
84        let conv_cfg = candle_nn::Conv2dConfig {
85            stride: patch_size,
86            ..Default::default()
87        };
88        let projection = conv2d(
89            cfg.num_channels,
90            cfg.hidden_size,
91            patch_size,
92            conv_cfg,
93            vb.pp("projection"),
94        )?;
95        Ok(Self {
96            num_patches,
97            projection,
98        })
99    }
100}
101
102impl Module for PatchEmbeddings {
103    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
104        let (_b_size, _num_channels, _height, _width) = pixel_values.dims4()?;
105        self.projection
106            .forward(pixel_values)?
107            .flatten_from(2)?
108            .transpose(1, 2)
109    }
110}
111
112#[derive(Debug, Clone)]
113pub struct Embeddings {
114    cls_token: Tensor,
115    mask_token: Option<Tensor>,
116    patch_embeddings: PatchEmbeddings,
117    position_embeddings: Tensor,
118    hidden_size: usize,
119}
120
121impl Embeddings {
122    pub fn new(cfg: &Config, use_mask_token: bool, vb: VarBuilder) -> Result<Self> {
123        let hidden_size = cfg.hidden_size;
124        let cls_token = vb.get((1, 1, hidden_size), "cls_token")?;
125        let mask_token = if use_mask_token {
126            Some(vb.get((1, 1, hidden_size), "mask_token")?)
127        } else {
128            None
129        };
130        let patch_embeddings = PatchEmbeddings::new(cfg, vb.pp("patch_embeddings"))?;
131        let num_patches = patch_embeddings.num_patches;
132        let position_embeddings =
133            vb.get((1, num_patches + 1, hidden_size), "position_embeddings")?;
134        Ok(Self {
135            cls_token,
136            mask_token,
137            patch_embeddings,
138            position_embeddings,
139            hidden_size,
140        })
141    }
142
143    fn interpolate_pos_encoding(
144        &self,
145        _embeddings: &Tensor,
146        _height: usize,
147        _width: usize,
148    ) -> Result<Tensor> {
149        todo!()
150    }
151
152    pub fn forward(
153        &self,
154        pixel_values: &Tensor,
155        bool_masked_pos: Option<&Tensor>,
156        interpolate_pos_encoding: bool,
157    ) -> Result<Tensor> {
158        let (b_size, _num_channels, height, width) = pixel_values.dims4()?;
159        let embeddings = self.patch_embeddings.forward(pixel_values)?;
160        let embeddings = match (bool_masked_pos, &self.mask_token) {
161            (None, _) => embeddings,
162            (Some(_), None) => candle_core::bail!("bool_masked_pos set without mask_token"),
163            (Some(bool_masked_pos), Some(mask_tokens)) => {
164                let seq_len = embeddings.dim(1)?;
165                let mask_tokens = mask_tokens.broadcast_as((b_size, seq_len, self.hidden_size))?;
166                let mask = bool_masked_pos
167                    .unsqueeze(D::Minus1)?
168                    .to_dtype(mask_tokens.dtype())?;
169                ((mask_tokens * &mask)? - (embeddings * (mask - 1.)?)?)?
170            }
171        };
172        let cls_tokens = self.cls_token.broadcast_as((b_size, 1, self.hidden_size))?;
173        let embeddings = Tensor::cat(&[&cls_tokens, &embeddings], 1)?;
174        if interpolate_pos_encoding {
175            let pos = self.interpolate_pos_encoding(&embeddings, height, width)?;
176            embeddings.broadcast_add(&pos)
177        } else {
178            embeddings.broadcast_add(&self.position_embeddings)
179        }
180    }
181}
182
183#[derive(Debug, Clone)]
184struct SelfAttention {
185    query: Linear,
186    key: Linear,
187    value: Linear,
188    num_attention_heads: usize,
189    attention_head_size: usize,
190}
191
192impl SelfAttention {
193    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
194        let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
195        let num_attention_heads = cfg.num_attention_heads;
196        let all_head_size = num_attention_heads * attention_head_size;
197        let linear = |name| {
198            if cfg.qkv_bias {
199                linear(cfg.hidden_size, all_head_size, vb.pp(name))
200            } else {
201                linear_no_bias(cfg.hidden_size, all_head_size, vb.pp(name))
202            }
203        };
204        let query = linear("query")?;
205        let key = linear("key")?;
206        let value = linear("value")?;
207        Ok(Self {
208            query,
209            key,
210            value,
211            num_attention_heads,
212            attention_head_size,
213        })
214    }
215
216    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
217        let (b_size, seq_len, _) = xs.dims3()?;
218        xs.reshape((
219            b_size,
220            seq_len,
221            self.num_attention_heads,
222            self.attention_head_size,
223        ))?
224        .permute((0, 2, 1, 3))
225    }
226}
227
228impl Module for SelfAttention {
229    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
230        let query = self.query.forward(xs)?;
231        let key = self.key.forward(xs)?;
232        let value = self.value.forward(xs)?;
233
234        let query = self.transpose_for_scores(&query)?.contiguous()?;
235        let key = self.transpose_for_scores(&key)?.contiguous()?;
236        let value = self.transpose_for_scores(&value)?.contiguous()?;
237
238        let attention_scores =
239            (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
240        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
241        attention_probs
242            .matmul(&value)?
243            .permute((0, 2, 1, 3))?
244            .contiguous()?
245            .flatten_from(D::Minus2)
246    }
247}
248
249#[derive(Debug, Clone)]
250struct SelfOutput {
251    dense: Linear,
252}
253
254impl SelfOutput {
255    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
256        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
257        Ok(Self { dense })
258    }
259}
260
261impl Module for SelfOutput {
262    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
263        xs.apply(&self.dense)
264    }
265}
266
267#[derive(Debug, Clone)]
268struct Attention {
269    attention: SelfAttention,
270    output: SelfOutput,
271}
272
273impl Attention {
274    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
275        let attention = SelfAttention::new(cfg, vb.pp("attention"))?;
276        let output = SelfOutput::new(cfg, vb.pp("output"))?;
277        Ok(Self { attention, output })
278    }
279}
280
281impl Module for Attention {
282    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
283        xs.apply(&self.attention)?.apply(&self.output)
284    }
285}
286
287#[derive(Debug, Clone)]
288struct Intermediate {
289    dense: Linear,
290    intermediate_act_fn: candle_nn::Activation,
291}
292
293impl Intermediate {
294    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
295        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
296        Ok(Self {
297            dense,
298            intermediate_act_fn: cfg.hidden_act,
299        })
300    }
301}
302
303impl Module for Intermediate {
304    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
305        xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
306    }
307}
308
309#[derive(Debug, Clone)]
310struct Output {
311    dense: Linear,
312}
313
314impl Output {
315    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
316        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
317        Ok(Self { dense })
318    }
319
320    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
321        xs.apply(&self.dense)? + input_tensor
322    }
323}
324
325#[derive(Debug, Clone)]
326struct Layer {
327    attention: Attention,
328    intermediate: Intermediate,
329    output: Output,
330    layernorm_before: LayerNorm,
331    layernorm_after: LayerNorm,
332}
333
334impl Layer {
335    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
336        let attention = Attention::new(cfg, vb.pp("attention"))?;
337        let intermediate = Intermediate::new(cfg, vb.pp("intermediate"))?;
338        let output = Output::new(cfg, vb.pp("output"))?;
339        let h_sz = cfg.hidden_size;
340        let layernorm_before = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_before"))?;
341        let layernorm_after = layer_norm(h_sz, cfg.layer_norm_eps, vb.pp("layernorm_after"))?;
342        Ok(Self {
343            attention,
344            intermediate,
345            output,
346            layernorm_after,
347            layernorm_before,
348        })
349    }
350}
351
352impl Module for Layer {
353    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
354        let xs = (xs.apply(&self.layernorm_before)?.apply(&self.attention)? + xs)?;
355        let ys = xs.apply(&self.layernorm_after)?.apply(&self.intermediate)?;
356        self.output.forward(&ys, &xs)
357    }
358}
359
360#[derive(Debug, Clone)]
361pub struct Encoder {
362    layers: Vec<Layer>,
363}
364
365impl Encoder {
366    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
367        let vb = vb.pp("layer");
368        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
369        for i in 0..cfg.num_hidden_layers {
370            let layer = Layer::new(cfg, vb.pp(i))?;
371            layers.push(layer)
372        }
373        Ok(Self { layers })
374    }
375}
376
377impl Module for Encoder {
378    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
379        let mut xs = xs.clone();
380        for layer in self.layers.iter() {
381            xs = xs.apply(layer)?
382        }
383        Ok(xs)
384    }
385}
386
387#[derive(Debug, Clone)]
388pub struct StandardVisionTransformer {
389    embeddings: Embeddings,
390    encoder: Encoder,
391    layernorm: LayerNorm,
392}
393
394impl StandardVisionTransformer {
395    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
396        let vb_v = vb.pp("vit");
397        let embeddings = Embeddings::new(cfg, false, vb_v.pp("embeddings"))?;
398        let encoder = Encoder::new(cfg, vb_v.pp("encoder"))?;
399        let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_v.pp("layernorm"))?;
400        Ok(Self {
401            embeddings,
402            encoder,
403            layernorm,
404        })
405    }
406
407    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
408        let embedding_output = self.embeddings.forward(xs, None, false)?;
409        let encoder_outputs = self.encoder.forward(&embedding_output)?;
410        encoder_outputs.i((.., 0, ..))?.apply(&self.layernorm)
411    }
412}