burn_dinov3 0.1.2

DINOv3 with burn & Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
use burn::{
    config::Config,
    module::{Module, Param},
    nn::{
        Dropout, DropoutConfig, Gelu, Initializer, LayerNorm, LayerNormConfig, Linear,
        LinearConfig,
        conv::{Conv2d, Conv2dConfig},
    },
    tensor::{Bool, FloatDType, Tensor, backend::Backend, module, ops::AttentionModuleOptions},
};

#[derive(Module, Debug)]
pub struct PatchEmbed<B: Backend> {
    pub proj: Conv2d<B>,
    pub patch_size: usize,
    pub embed_dim: usize,
}

impl<B: Backend> PatchEmbed<B> {
    pub fn new(in_chans: usize, embed_dim: usize, patch_size: usize, device: &B::Device) -> Self {
        let proj = Conv2dConfig::new([in_chans, embed_dim], [patch_size, patch_size])
            .with_stride([patch_size, patch_size])
            .init(device);
        Self {
            proj,
            patch_size,
            embed_dim,
        }
    }

    pub fn forward(&self, x: Tensor<B, 4>) -> (Tensor<B, 3>, usize, usize) {
        let x = self.proj.forward(x);
        let [_, _, height, width] = x.dims();
        let x = x.flatten(2, -1);
        (x.swap_dims(1, 2), height, width)
    }
}

#[derive(Module, Debug)]
pub struct RopePositionEmbedding<B: Backend> {
    // Facebook engineers save periods in its pth model.
    pub periods: Param<Tensor<B, 1>>,
    pub d_head: usize,
}

impl<B: Backend> RopePositionEmbedding<B> {
    pub fn new(embed_dim: usize, num_heads: usize, base: f32, device: &B::Device) -> Self {
        let d_head = embed_dim / num_heads;

        let periods = Param::from_tensor(
            Tensor::from_floats([base], device)
                .powf(Tensor::arange_step(0..d_head as i64, 4, device).float() / d_head as f32),
        )
        .no_grad(); // do not update

        Self { periods, d_head }
    }

    pub fn forward(&self, height: usize, width: usize) -> (Tensor<B, 2>, Tensor<B, 2>) {
        let device = self.periods.device();

        let coords_h = (Tensor::arange(0..height as i64, &device).float() + 0.5) / (height as f32);
        let coords_w = (Tensor::arange(0..width as i64, &device).float() + 0.5) / (width as f32);

        let ch = coords_h
            .unsqueeze_dim::<2>(1) // [h, 1]
            .repeat_dim(1, width) // [h, w]
            .reshape([-1, 1]); // [hw, 1]
        let cw = coords_w
            .unsqueeze::<2>() // [1, w]
            .repeat_dim(0, height) // [h, w]
            .reshape([-1, 1]); // [hw, 1]
        let mut coords = Tensor::cat(vec![ch, cw], 1); // [hw, 2]
        coords = coords * 2.0 - 1.0;

        // [hw, 2, 1] / [1, 1, d_head/4] -> [hw, 2, d_head/4]
        let angles = coords.unsqueeze_dim::<3>(2) * std::f64::consts::PI * 2.0
            / self
                .periods
                .val()
                .cast(FloatDType::F32) // After loaded from facebook pth, it's BF16
                .unsqueeze::<3>();
        let angles = angles.flatten(1, 2); // [hw, d_head/2]
        let angles_tiled = Tensor::cat(vec![angles.clone(), angles], 1); // [hw, d_head]

        let sin = angles_tiled.clone().sin();
        let cos = angles_tiled.cos();

        (sin, cos)
    }
}

#[derive(Module, Debug)]
pub struct LinearKMaskedBias<B: Backend> {
    pub linear: Linear<B>,
    pub bias_mask: Param<Tensor<B, 1>>,
}

impl<B: Backend> LinearKMaskedBias<B> {
    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
        let masked_bias = self
            .linear
            .bias
            .as_ref()
            .map(|b| b.val() * self.bias_mask.val());

        module::linear(input, self.linear.weight.val(), masked_bias)
    }
}

#[derive(Module, Debug)]
pub struct LayerScale<B: Backend> {
    pub gamma: Param<Tensor<B, 1>>,
}

impl<B: Backend> LayerScale<B> {
    pub fn new(dim: usize, init_values: f32, device: &B::Device) -> Self {
        let gamma = Param::from_tensor(Tensor::ones([dim], device) * init_values);
        Self { gamma }
    }

    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let dim = x.shape()[2];
        x * self.gamma.val().reshape([1, 1, dim])
    }
}

#[derive(Module, Debug)]
pub struct LoRA<B: Backend> {
    pub a: Param<Tensor<B, 2>>,
    pub b: Param<Tensor<B, 2>>,
}

#[derive(Config, Debug)]
pub struct LoRAConfig {
    pub dim: usize,
    pub rank: usize,
    #[config(default = "Initializer::Zeros")]
    pub a_initializer: Initializer,
    #[config(default = "Initializer::KaimingUniform{gain:1.0/3.0f64.sqrt(), fan_out_only:false}")]
    pub b_initializer: Initializer,
}

impl LoRAConfig {
    pub fn init<B: Backend>(&self, device: &B::Device) -> LoRA<B> {
        LoRA {
            a: self.a_initializer.init([self.dim, self.rank], device),
            b: self.b_initializer.init([self.rank, self.dim], device),
        }
    }
}

impl<B: Backend> LoRA<B> {
    /// x: [batch_size, seq, dim]
    /// out: [batch_size, seq, dim]
    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        // [b, seq, dim] @ [1, dim, r] -> [b, seq, r]
        // [b, seq, r] @ [1, r, dim] -> [b, seq, dim]
        x.matmul(self.a.val().unsqueeze())
            .matmul(self.b.val().unsqueeze())
    }
}

#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
    pub qkv: LinearKMaskedBias<B>,
    pub proj: Linear<B>,
    pub drop_out: Dropout,
    pub lora: Option<LoRA<B>>,
    pub num_heads: usize,
}

impl<B: Backend> Attention<B> {
    pub fn new(dim: usize, num_heads: usize, device: &B::Device) -> Self {
        Self {
            qkv: LinearKMaskedBias {
                linear: LinearConfig::new(dim, dim * 3).with_bias(true).init(device),
                bias_mask: Param::from_tensor(Tensor::zeros([dim * 3], device)),
            },
            proj: LinearConfig::new(dim, dim).with_bias(true).init(device),
            drop_out: DropoutConfig::new(0.0).init(), // did not see any config other than 0 in facebook repo
            lora: None,
            num_heads,
        }
    }

    pub fn forward(
        &self,
        x: Tensor<B, 3>,
        repo: Option<&(Tensor<B, 2>, Tensor<B, 2>)>,
    ) -> Tensor<B, 3> {
        let [batch_size, seq_len, dim] = x.dims();
        // [b, seq, dim] @ [1, dim, dim * 3] -> [b, seq, dim * 3]
        let qkv = self.qkv.forward(x.clone());

        let qkv = qkv.reshape([batch_size, seq_len, 3, self.num_heads, dim / self.num_heads]);

        let [mut q, mut k, mut v]: [Tensor<B, 4>; 3] = qkv
            .chunk(3, 2)
            .into_iter()
            .map(|tensor| tensor.squeeze_dim::<4>(2).swap_dims(1, 2))
            .collect::<Vec<_>>() // [b, nh, s, dh]
            .try_into()
            .unwrap();

        if let Some(lora) = self.lora.as_ref() {
            q = q + lora
                .forward(x.clone())
                .reshape([batch_size, seq_len, self.num_heads, dim / self.num_heads])
                .swap_dims(1, 2);
            v = v + lora
                .forward(x)
                .reshape([batch_size, seq_len, self.num_heads, dim / self.num_heads])
                .swap_dims(1, 2);
        }

        if let Some((sin, cos)) = repo {
            q = Self::apply_rope(q, sin, cos);
            k = Self::apply_rope(k, sin, cos);
        }

        let out = module::attention(q, k, v, None, None, AttentionModuleOptions::default());
        let out = out.swap_dims(1, 2).reshape([batch_size, seq_len, dim]);

        self.drop_out.forward(self.proj.forward(out))
    }

    fn apply_rope(x: Tensor<B, 4>, sin: &Tensor<B, 2>, cos: &Tensor<B, 2>) -> Tensor<B, 4> {
        let [_, _, seq, head_dim] = x.dims();
        let [rope_seq, _h_dim] = sin.dims();
        let num_cls_and_storage_tokens = seq - rope_seq;

        let [prefix, mut rope] = x
            .split_with_sizes(vec![num_cls_and_storage_tokens, rope_seq], 2)
            .try_into()
            .unwrap();

        let half_head_dim = head_dim / 2;
        let [x1, x2] = rope.clone().split(half_head_dim, 3).try_into().unwrap();

        let x_half = Tensor::cat(vec![x2.mul_scalar(-1.0), x1], 3);
        rope = (rope * cos.clone().reshape([1, 1, rope_seq, head_dim]))
            + (x_half * sin.clone().reshape([1, 1, rope_seq, head_dim]));

        Tensor::cat(vec![prefix, rope], 2)
    }
}

#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
    pub fc1: Linear<B>,
    pub act: Gelu,
    pub fc2: Linear<B>,
}

impl<B: Backend> Mlp<B> {
    pub fn new(in_features: usize, hidden_features: usize, device: &B::Device) -> Self {
        Self {
            fc1: LinearConfig::new(in_features, hidden_features).init(device),
            act: Gelu::new(),
            fc2: LinearConfig::new(hidden_features, in_features).init(device),
        }
    }

    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
        let x = self.fc1.forward(x);
        let x = self.act.forward(x);
        self.fc2.forward(x)
    }
}

#[derive(Module, Debug)]
pub struct Block<B: Backend> {
    pub norm1: LayerNorm<B>,
    pub attn: Attention<B>,
    pub ls1: LayerScale<B>,
    pub norm2: LayerNorm<B>,
    pub mlp: Mlp<B>,
    pub ls2: LayerScale<B>,
}

impl<B: Backend> Block<B> {
    pub fn new(dim: usize, num_heads: usize, ffn_ratio: f64, device: &B::Device) -> Self {
        let hidden_dim = (dim as f64 * ffn_ratio) as usize;
        Self {
            norm1: LayerNormConfig::new(dim).with_bias(true).init(device),
            attn: Attention::new(dim, num_heads, device),
            ls1: LayerScale::new(dim, 1e-5, device),
            norm2: LayerNormConfig::new(dim).with_bias(true).init(device),
            mlp: Mlp::new(dim, hidden_dim, device),
            ls2: LayerScale::new(dim, 1e-5, device),
        }
    }

    pub fn forward(
        &self,
        x: Tensor<B, 3>,
        repo: Option<&(Tensor<B, 2>, Tensor<B, 2>)>,
    ) -> Tensor<B, 3> {
        let x = self
            .ls1
            .forward(self.attn.forward(self.norm1.forward(x.clone()), repo))
            + x; // ResNet

        self.ls2
            .forward(self.mlp.forward(self.norm2.forward(x.clone())))
            + x
    }
}

#[derive(Module, Debug)]
pub struct DinoVisionTransformer<B: Backend> {
    pub patch_embed: PatchEmbed<B>,
    pub cls_token: Param<Tensor<B, 3>>,
    pub storage_tokens: Param<Tensor<B, 3>>,
    pub rope_embed: RopePositionEmbedding<B>,
    pub blocks: Vec<Block<B>>,
    pub norm: LayerNorm<B>,
    pub mask_token: Option<Param<Tensor<B, 2>>>,
}

impl<B: Backend> DinoVisionTransformer<B> {
    pub fn new(
        patch_size: usize,
        embed_dim: usize,
        num_storage_tokens: usize,
        depth: usize,
        num_heads: usize,
        ffn_ratio: f64,
        device: &B::Device,
    ) -> Self {
        let patch_embed = PatchEmbed::new(3, embed_dim, patch_size, device);
        let cls_token = Param::from_tensor(Tensor::zeros([1, 1, embed_dim], device));
        let storage_tokens =
            Param::from_tensor(Tensor::zeros([1, num_storage_tokens, embed_dim], device));
        let rope_embed = RopePositionEmbedding::new(embed_dim, num_heads, 100.0, device).no_grad();
        let blocks = vec![Block::new(embed_dim, num_heads, ffn_ratio, device); depth];
        let norm = LayerNormConfig::new(embed_dim).init(device);
        let mask_token = Param::from_tensor(Tensor::zeros([1, embed_dim], device));
        Self {
            patch_embed,
            cls_token,
            storage_tokens,
            rope_embed,
            blocks,
            norm,
            mask_token: Some(mask_token),
        }
    }

    pub fn forward(&self, x: Tensor<B, 4>, masks: Option<&Tensor<B, 2, Bool>>) -> Tensor<B, 3> {
        let (mut x, height, width) = self.patch_embed.forward(x);
        let [batch_size, seq, dim] = x.dims();

        if let Some(masks) = masks
            && let Some(mask_token) = self.mask_token.as_ref()
        {
            x = x.mask_where(
                masks
                    .clone()
                    .reshape([1, seq, dim])
                    .repeat_dim(0, batch_size),
                mask_token
                    .val()
                    .reshape([1, 1, dim])
                    .repeat_dim(1, seq)
                    .repeat_dim(0, batch_size),
            );
        }

        // rotary position encoding 2d
        let repo = self.rope_embed.forward(height, width);

        let cls_token_batch = self.cls_token.val().repeat_dim(0, batch_size);
        let storage_tokens_batch = self.storage_tokens.val().repeat_dim(0, batch_size);
        x = Tensor::cat(vec![cls_token_batch, storage_tokens_batch, x], 1);

        for block in &self.blocks {
            x = block.forward(x, Some(&repo));
        }

        self.norm.forward(x)
    }
}

pub fn vit_small<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 384, 4, 12, 6, 4.0, device)
}

pub fn vit_base<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 768, 4, 12, 12, 4.0, device)
}

pub fn vit_large<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 1024, 4, 24, 16, 4.0, device)
}

pub fn vit_so400m<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 1152, 4, 27, 18, 3.777777778, device)
}

pub fn vit_huge2<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 1280, 4, 32, 20, 4.0, device)
}

pub fn vit_giant2<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 1536, 4, 40, 24, 4.0, device)
}

pub fn vit_7b<B: Backend>(patch_size: usize, device: &B::Device) -> DinoVisionTransformer<B> {
    DinoVisionTransformer::new(patch_size, 4096, 4, 40, 32, 3.0, device)
}