Skip to main content

rlx_dinov2/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! DINOv2 configuration. Mirrors Meta's reference configs.
17
18use serde::Deserialize;
19use std::path::Path;
20
21/// ImageNet-1k mean/std applied to RGB pixels in `[0, 1]`.
22/// Matches `candle-examples::imagenet::load_image*` and the original
23/// DINOv2 PyTorch preprocessing.
24pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
25pub const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225];
26
27/// DINOv2 model configuration. `vit_giant` (SwiGLU MLP) is not yet
28/// supported — vit_small / vit_base / vit_large are.
29#[derive(Debug, Clone, Deserialize)]
30pub struct DinoV2Config {
31    pub hidden_size: usize,
32    pub num_hidden_layers: usize,
33    pub num_attention_heads: usize,
34    pub img_size: usize,
35    pub patch_size: usize,
36    #[serde(default = "default_mlp_ratio")]
37    pub mlp_ratio: f64,
38    #[serde(default = "default_dinov2_ln_eps")]
39    pub layer_norm_eps: f64,
40    #[serde(default)]
41    pub num_register_tokens: usize,
42    /// Number of ImageNet classes for the optional classifier head.
43    /// Set to 0 to skip the head entirely (encoder-only output).
44    #[serde(default = "default_num_classes")]
45    pub num_classes: usize,
46}
47
48fn default_mlp_ratio() -> f64 {
49    4.0
50}
51fn default_dinov2_ln_eps() -> f64 {
52    1e-5
53}
54fn default_num_classes() -> usize {
55    1000
56}
57
58impl DinoV2Config {
59    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
60        let data = std::fs::read_to_string(path)?;
61        Ok(serde_json::from_str(&data)?)
62    }
63
64    pub fn new(
65        img_size: usize,
66        depth: usize,
67        embed_dim: usize,
68        num_heads: usize,
69        num_register_tokens: usize,
70    ) -> Self {
71        Self {
72            hidden_size: embed_dim,
73            num_hidden_layers: depth,
74            num_attention_heads: num_heads,
75            img_size,
76            patch_size: 14,
77            mlp_ratio: 4.0,
78            layer_norm_eps: 1e-5,
79            num_register_tokens,
80            num_classes: 1000,
81        }
82    }
83
84    pub fn intermediate_size(&self) -> usize {
85        (self.hidden_size as f64 * self.mlp_ratio) as usize
86    }
87    pub fn head_dim(&self) -> usize {
88        self.hidden_size / self.num_attention_heads
89    }
90    pub fn num_patches(&self) -> usize {
91        let n = self.img_size / self.patch_size;
92        n * n
93    }
94    pub fn seq_len(&self) -> usize {
95        1 + self.num_register_tokens + self.num_patches()
96    }
97    pub fn patch_dim(&self) -> usize {
98        3 * self.patch_size * self.patch_size
99    }
100
101    pub fn vit_small(img_size: usize) -> Self {
102        Self::new(img_size, 12, 384, 6, 0)
103    }
104    pub fn vit_base(img_size: usize) -> Self {
105        Self::new(img_size, 12, 768, 12, 0)
106    }
107    pub fn vit_large(img_size: usize) -> Self {
108        Self::new(img_size, 24, 1024, 16, 0)
109    }
110}