Skip to main content

rlx_dinov2/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! DINOv2 — Meta's self-supervised ViT (with optional registers).
17//!
18//! Public entry points:
19//!   - [`DinoV2Config`] — model dimensions + variant factories
20//!     (`vit_small`, `vit_base`, `vit_large`)
21//!   - [`build_dinov2_graph_sized`] — emits the IR graph and the
22//!     host-side [`DinoV2PreprocessWeights`].
23//!   - [`assemble_hidden`] / [`rgb_u8_to_imagenet_nchw`] — host-side
24//!     image → encoder-input plumbing.
25//!
26//! Weight keys match Meta / candle's safetensors so checkpoints from
27//! the HF Hub (e.g. `lmz/candle-dino-v2`) load with no remapping.
28
29pub mod builder;
30pub mod cli;
31pub mod config;
32pub mod flow;
33pub mod packed_gguf;
34pub mod preprocess;
35pub mod runner;
36
37pub use builder::build_dinov2_graph_sized;
38pub use config::{DinoV2Config, IMAGENET_MEAN, IMAGENET_STD};
39pub use flow::{DinoV2Built, DinoV2Flow, build_dinov2_built, build_dinov2_built_with_packed};
40pub use packed_gguf::{gguf_has_packed_linears, load_dinov2_from_gguf};
41pub use preprocess::{DinoV2PreprocessWeights, assemble_hidden, rgb_u8_to_imagenet_nchw};
42pub use runner::{DinoV2Output, DinoV2Runner, DinoV2RunnerBuilder, DinoV2Variant};
43
44#[cfg(test)]
45mod tests {
46    use super::*;
47    use rlx_core::weight_map::WeightMap;
48    use std::collections::HashMap;
49
50    fn synthetic_weights(cfg: &DinoV2Config) -> WeightMap {
51        let h = cfg.hidden_size;
52        let int_dim = cfg.intermediate_size();
53        let ps = cfg.patch_size;
54        let pd = cfg.patch_dim();
55        let seq = cfg.seq_len();
56
57        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
58        let z = |n: usize| vec![0.0f32; n];
59
60        t.insert(
61            "patch_embed.proj.weight".into(),
62            (z(h * pd), vec![h, 3, ps, ps]),
63        );
64        t.insert("patch_embed.proj.bias".into(), (z(h), vec![h]));
65        t.insert("cls_token".into(), (z(h), vec![1, 1, h]));
66        t.insert("pos_embed".into(), (z(seq * h), vec![1, seq, h]));
67        if cfg.num_register_tokens > 0 {
68            t.insert(
69                "register_tokens".into(),
70                (
71                    z(cfg.num_register_tokens * h),
72                    vec![1, cfg.num_register_tokens, h],
73                ),
74            );
75        }
76        for i in 0..cfg.num_hidden_layers {
77            let lp = format!("blocks.{i}");
78            t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
79            t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
80            t.insert(
81                format!("{lp}.attn.qkv.weight"),
82                (z(h * 3 * h), vec![3 * h, h]),
83            );
84            t.insert(format!("{lp}.attn.qkv.bias"), (z(3 * h), vec![3 * h]));
85            t.insert(format!("{lp}.attn.proj.weight"), (z(h * h), vec![h, h]));
86            t.insert(format!("{lp}.attn.proj.bias"), (z(h), vec![h]));
87            t.insert(format!("{lp}.ls1.gamma"), (z(h), vec![h]));
88            t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
89            t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
90            t.insert(
91                format!("{lp}.mlp.fc1.weight"),
92                (z(int_dim * h), vec![int_dim, h]),
93            );
94            t.insert(format!("{lp}.mlp.fc1.bias"), (z(int_dim), vec![int_dim]));
95            t.insert(
96                format!("{lp}.mlp.fc2.weight"),
97                (z(h * int_dim), vec![h, int_dim]),
98            );
99            t.insert(format!("{lp}.mlp.fc2.bias"), (z(h), vec![h]));
100            t.insert(format!("{lp}.ls2.gamma"), (z(h), vec![h]));
101        }
102        t.insert("norm.weight".into(), (z(h), vec![h]));
103        t.insert("norm.bias".into(), (z(h), vec![h]));
104        if cfg.num_classes > 0 {
105            t.insert(
106                "head.weight".into(),
107                (z(cfg.num_classes * 2 * h), vec![cfg.num_classes, 2 * h]),
108            );
109            t.insert(
110                "head.bias".into(),
111                (z(cfg.num_classes), vec![cfg.num_classes]),
112            );
113        }
114        WeightMap::from_tensors(t)
115    }
116
117    #[test]
118    fn encoder_only_graph_builds() {
119        let mut cfg = DinoV2Config::vit_small(28);
120        cfg.num_classes = 0; // encoder-only
121        let mut wm = synthetic_weights(&cfg);
122        let (g, _params, pre) = build_dinov2_graph_sized(&cfg, &mut wm, 1).unwrap();
123        assert_eq!(g.outputs.len(), 1);
124        assert_eq!(pre.embed_dim, cfg.hidden_size);
125        assert_eq!(wm.len(), 0);
126    }
127
128    #[test]
129    fn classifier_graph_builds() {
130        let cfg = DinoV2Config::vit_small(28); // num_classes defaults to 1000
131        let mut wm = synthetic_weights(&cfg);
132        let (g, _, _) = build_dinov2_graph_sized(&cfg, &mut wm, 1).unwrap();
133        assert_eq!(g.outputs.len(), 1);
134        // Final output should be [B, num_classes].
135        let out_id = g.outputs[0];
136        let s = g.shape(out_id);
137        let dims: Vec<usize> = s.dims().iter().map(|d| d.unwrap_static()).collect();
138        assert_eq!(dims, vec![1, cfg.num_classes]);
139    }
140
141    #[test]
142    fn with_register_tokens() {
143        let mut cfg = DinoV2Config::vit_small(28);
144        cfg.num_register_tokens = 4;
145        let mut wm = synthetic_weights(&cfg);
146        let (_g, _, pre) = build_dinov2_graph_sized(&cfg, &mut wm, 1).unwrap();
147        assert_eq!(pre.register_tokens.len(), 4 * cfg.hidden_size);
148        assert_eq!(pre.seq, 1 + 4 + cfg.num_patches());
149    }
150
151    /// Build a WeightMap like `synthetic_weights` but with a callback
152    /// to override the data for specific keys (preserving shape).
153    fn synthetic_weights_with<F: Fn(&str, &mut Vec<f32>)>(
154        cfg: &DinoV2Config,
155        edit: F,
156    ) -> WeightMap {
157        let mut wm = synthetic_weights(cfg);
158        let keys: Vec<String> = wm.keys().map(|s| s.to_string()).collect();
159        let mut all: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
160        for k in keys {
161            let (mut d, s) = wm.take(&k).unwrap();
162            edit(&k, &mut d);
163            all.insert(k, (d, s));
164        }
165        WeightMap::from_tensors(all)
166    }
167
168    #[test]
169    fn assemble_hidden_zero_image_yields_pos_embed_plus_bias_plus_cls() {
170        // With zero pixels, the patch projection contributes only its
171        // bias; the assembled hidden then equals pos_embed broadcast +
172        // [cls; proj_b…; proj_b…] per row.
173        let mut cfg = DinoV2Config::vit_small(28);
174        cfg.num_classes = 0;
175        let h = cfg.hidden_size;
176        let seq = cfg.seq_len();
177        let np = cfg.num_patches();
178
179        let pos: Vec<f32> = (0..seq * h).map(|i| i as f32 * 1e-3).collect();
180        let cls: Vec<f32> = (0..h).map(|i| 100.0 + i as f32).collect();
181        let bias: Vec<f32> = (0..h).map(|i| -1.0 - (i as f32) * 0.1).collect();
182        let pos_clone = pos.clone();
183        let cls_clone = cls.clone();
184        let bias_clone = bias.clone();
185
186        let mut wm = synthetic_weights_with(&cfg, |k, d| match k {
187            "pos_embed" => d.copy_from_slice(&pos_clone),
188            "cls_token" => d.copy_from_slice(&cls_clone),
189            "patch_embed.proj.bias" => d.copy_from_slice(&bias_clone),
190            _ => {}
191        });
192
193        let (_g, _p, pre) = build_dinov2_graph_sized(&cfg, &mut wm, 1).unwrap();
194        let image = vec![0f32; 3 * cfg.img_size * cfg.img_size];
195        let hidden = assemble_hidden(&pre, &image, 1, cfg.patch_size, cfg.img_size).unwrap();
196        assert_eq!(hidden.len(), seq * h);
197
198        // Row 0 = CLS + pos_embed[0]; rows 1..1+np = bias + pos_embed[row]
199        for k in 0..h {
200            let exp = cls[k] + pos[k];
201            assert!(
202                (hidden[k] - exp).abs() < 1e-5,
203                "cls col {k}: {} vs {}",
204                hidden[k],
205                exp
206            );
207        }
208        for row in 1..(1 + np) {
209            for k in 0..h {
210                let exp = bias[k] + pos[row * h + k];
211                let got = hidden[row * h + k];
212                assert!(
213                    (got - exp).abs() < 1e-5,
214                    "row {row} col {k}: {got} vs {exp}"
215                );
216            }
217        }
218    }
219}