Skip to main content

rlx_dinov2/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 DINOv2 flow — native [`ModelFlow`] ViT assembly.
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, GgufPackedParams, ModelFlow};
20use rlx_ir::hir::HirMut;
21use rlx_ir::{DType, HirGraphExt, Shape};
22
23use super::config::DinoV2Config;
24use super::preprocess::DinoV2PreprocessWeights;
25use rlx_core::flow_util::WeightMapSource;
26use rlx_core::weight_map::WeightMap;
27
28#[derive(Debug, Clone)]
29pub struct DinoV2Flow<'a> {
30    cfg: &'a DinoV2Config,
31    batch: usize,
32}
33
34impl<'a> DinoV2Flow<'a> {
35    pub fn new(cfg: &'a DinoV2Config, batch: usize) -> Self {
36        Self { cfg, batch }
37    }
38
39    pub fn build(self, weights: &mut WeightMap) -> Result<DinoV2Built> {
40        build_dinov2_built(self.cfg, weights, self.batch)
41    }
42}
43
44pub struct DinoV2Built {
45    pub model: BuiltModel,
46    pub preprocess: DinoV2PreprocessWeights,
47}
48
49pub fn build_dinov2_built(
50    cfg: &DinoV2Config,
51    weights: &mut WeightMap,
52    batch: usize,
53) -> Result<DinoV2Built> {
54    build_dinov2_built_with_packed(cfg, weights, batch, None)
55}
56
57pub fn build_dinov2_built_with_packed(
58    cfg: &DinoV2Config,
59    weights: &mut WeightMap,
60    batch: usize,
61    gguf_packed: Option<&GgufPackedParams>,
62) -> Result<DinoV2Built> {
63    let preprocess = super::preprocess::extract_preprocess_weights(weights, cfg)?;
64    let h = cfg.hidden_size;
65    let nh = cfg.num_attention_heads;
66    let eps = cfg.layer_norm_eps as f32;
67    let seq = cfg.seq_len();
68    let f = DType::F32;
69
70    let mut flow = ModelFlow::new("dinov2")
71        .with_profile(CompileProfile::encoder())
72        .input("hidden", Shape::new(&[batch, seq, h], f))
73        .attn_mask_ones(batch, seq)
74        .repeat_dinov2_layers(cfg.num_hidden_layers, h, nh, eps)
75        .layer_norm("norm.weight", "norm.bias", eps);
76
77    if cfg.num_classes > 0 {
78        let patch_start = 1 + cfg.num_register_tokens;
79        let num_patches = cfg.num_patches();
80        let num_classes = cfg.num_classes;
81        flow = flow.plugin_named("dinov2.head", move |emit, hidden| {
82            let encoded = hidden.ok_or_else(|| anyhow::anyhow!("dinov2 head requires hidden"))?;
83            let head_w = emit.load_param("head.weight", true)?;
84            let head_b = emit.load_param("head.bias", false)?;
85            let mut gb = HirMut::new(emit.hir());
86            let cls_slice = gb.narrow_(encoded.hir_id(), 1, 0, 1);
87            let cls_flat = gb.reshape_(cls_slice, vec![batch as i64, h as i64]);
88            let patch_tokens = gb.narrow_(encoded.hir_id(), 1, patch_start, num_patches);
89            let mean_patches = gb.mean(patch_tokens, vec![1], false);
90            let features = gb.concat_(vec![cls_flat, mean_patches], 1);
91            let logits_mm = gb.mm(features, head_w);
92            let logits = gb.add(logits_mm, head_b);
93            Ok(Some(emit.wrap(
94                logits,
95                Shape::new(&[batch, num_classes], DType::F32),
96            )))
97        });
98        flow = flow.output("logits");
99    } else {
100        flow = flow.output("hidden");
101    }
102
103    Ok(DinoV2Built {
104        model: flow.build_with(&mut WeightMapSource(weights), gguf_packed)?,
105        preprocess,
106    })
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use std::collections::HashMap;
113
114    fn tiny_cfg() -> DinoV2Config {
115        DinoV2Config {
116            hidden_size: 16,
117            num_hidden_layers: 1,
118            num_attention_heads: 4,
119            img_size: 32,
120            patch_size: 16,
121            mlp_ratio: 4.0,
122            layer_norm_eps: 1e-5,
123            num_register_tokens: 0,
124            num_classes: 0,
125        }
126    }
127
128    fn synth_weights(cfg: &DinoV2Config) -> WeightMap {
129        let h = cfg.hidden_size;
130        let int_dim = (h as f64 * cfg.mlp_ratio) as usize;
131        let patch_dim = cfg.patch_dim();
132        let seq = cfg.seq_len();
133        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
134        let z = |n: usize| vec![0.0f32; n];
135        t.insert(
136            "patch_embed.proj.weight".into(),
137            (z(h * patch_dim), vec![h, 3, cfg.patch_size, cfg.patch_size]),
138        );
139        t.insert("patch_embed.proj.bias".into(), (z(h), vec![h]));
140        t.insert("cls_token".into(), (z(h), vec![1, 1, h]));
141        t.insert("pos_embed".into(), (z(seq * h), vec![1, seq, h]));
142        let lp = "blocks.0";
143        t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
144        t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
145        t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
146        t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
147        t.insert(
148            format!("{lp}.attn.qkv.weight"),
149            (z(3 * h * h), vec![3 * h, h]),
150        );
151        t.insert(format!("{lp}.attn.qkv.bias"), (z(3 * h), vec![3 * h]));
152        t.insert(format!("{lp}.attn.proj.weight"), (z(h * h), vec![h, h]));
153        t.insert(format!("{lp}.attn.proj.bias"), (z(h), vec![h]));
154        t.insert(format!("{lp}.ls1.gamma"), (z(h), vec![h]));
155        t.insert(format!("{lp}.ls2.gamma"), (z(h), vec![h]));
156        t.insert(
157            format!("{lp}.mlp.fc1.weight"),
158            (z(int_dim * h), vec![int_dim, h]),
159        );
160        t.insert(format!("{lp}.mlp.fc1.bias"), (z(int_dim), vec![int_dim]));
161        t.insert(
162            format!("{lp}.mlp.fc2.weight"),
163            (z(h * int_dim), vec![h, int_dim]),
164        );
165        t.insert(format!("{lp}.mlp.fc2.bias"), (z(h), vec![h]));
166        t.insert("norm.weight".into(), (z(h), vec![h]));
167        t.insert("norm.bias".into(), (z(h), vec![h]));
168        WeightMap::from_tensors(t)
169    }
170
171    #[test]
172    fn dinov2_flow_builds() {
173        let cfg = tiny_cfg();
174        let mut wm = synth_weights(&cfg);
175        let built = DinoV2Flow::new(&cfg, 1).build(&mut wm).unwrap();
176        assert_eq!(built.model.primary_shape().rank(), 3);
177    }
178}