Skip to main content

rlx_dinov2/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::{DinoV2Config, DinoV2PreprocessWeights, assemble_hidden, rgb_u8_to_imagenet_nchw};
17use anyhow::{Result, anyhow};
18use rlx_core::validate_standard_device;
19use rlx_flow::CompileProfile;
20use rlx_runtime::Device;
21use std::path::PathBuf;
22
23/// Which DINOv2 backbone size.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum DinoV2Variant {
26    Small,
27    Base,
28    Large,
29}
30
31/// Forward output: classifier logits or token features.
32#[derive(Debug, Clone)]
33pub enum DinoV2Output {
34    Logits {
35        per_batch: Vec<Vec<f32>>,
36        num_classes: usize,
37    },
38    Tokens {
39        per_batch: Vec<Vec<f32>>,
40        seq: usize,
41        hidden: usize,
42    },
43}
44
45/// Builder for [`DinoV2Runner`]. Mirrors the qwen3 / sam shape.
46#[derive(Debug, Clone, Default)]
47pub struct DinoV2RunnerBuilder {
48    weights: Option<PathBuf>,
49    device: Option<Device>,
50    variant: Option<DinoV2Variant>,
51    img_size: Option<usize>,
52    batch: Option<usize>,
53    config: Option<DinoV2Config>,
54}
55
56impl DinoV2RunnerBuilder {
57    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
58        self.weights = Some(p.into());
59        self
60    }
61    pub fn device(mut self, d: Device) -> Self {
62        self.device = Some(d);
63        self
64    }
65    /// One of the published HF presets. Default `Base` (vit-b/14).
66    pub fn variant(mut self, v: DinoV2Variant) -> Self {
67        self.variant = Some(v);
68        self
69    }
70    /// Image side length (square). Must be a multiple of the patch
71    /// size (14 for the standard DINOv2 checkpoints). Default 518.
72    pub fn img_size(mut self, n: usize) -> Self {
73        self.img_size = Some(n);
74        self
75    }
76    pub fn batch(mut self, n: usize) -> Self {
77        self.batch = Some(n);
78        self
79    }
80    /// Skip preset selection and use an explicit
81    /// [`DinoV2Config`].
82    pub fn config(mut self, cfg: DinoV2Config) -> Self {
83        self.config = Some(cfg);
84        self
85    }
86
87    pub fn build(self) -> Result<DinoV2Runner> {
88        use rlx_runtime::Session;
89
90        let weights_path = self
91            .weights
92            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
93        let device = self.device.unwrap_or(Device::Cpu);
94        validate_standard_device("dinov2", device)?;
95        let img_size = self.img_size.unwrap_or(518);
96        let batch = self.batch.unwrap_or(1);
97        let cfg = match (self.config, self.variant) {
98            (Some(c), _) => c,
99            (None, Some(DinoV2Variant::Small)) => DinoV2Config::vit_small(img_size),
100            (None, Some(DinoV2Variant::Large)) => DinoV2Config::vit_large(img_size),
101            // Default: vit_base.
102            (None, _) => DinoV2Config::vit_base(img_size),
103        };
104
105        let is_gguf = weights_path.extension().is_some_and(|e| e == "gguf");
106        if is_gguf {
107            rlx_core::gguf_validate_arch(&weights_path, rlx_core::DINOV2_GGUF_ARCHES)?;
108        }
109        let (mut wm, gguf_packed) =
110            if is_gguf && crate::packed_gguf::gguf_has_packed_linears(&weights_path)? {
111                eprintln!(
112                    "[dinov2] loading GGUF with packed DequantMatMul {:?}",
113                    weights_path
114                );
115                let (wm, packed) = crate::packed_gguf::load_dinov2_from_gguf(&weights_path)?;
116                (wm, Some(packed))
117            } else {
118                (
119                    rlx_core::load_weight_map(&weights_path, rlx_core::DINOV2_GGUF_ARCHES)?,
120                    None,
121                )
122            };
123        let built = super::flow::build_dinov2_built_with_packed(
124            &cfg,
125            &mut wm,
126            batch,
127            gguf_packed.as_ref(),
128        )?;
129        let typed = built.model.typed_params.clone();
130        let pre = built.preprocess;
131        let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
132        let opts =
133            rlx_core::flow_bridge::compile_options_for_profile(&CompileProfile::encoder(), device);
134        let mut compiled = Session::new(device).compile_with(graph, &opts);
135        rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
136        Ok(DinoV2Runner {
137            compiled,
138            cfg,
139            preprocess: pre,
140            device,
141            batch,
142        })
143    }
144}
145
146/// Resolved DINOv2 runner.
147pub struct DinoV2Runner {
148    compiled: rlx_runtime::CompiledGraph,
149    cfg: DinoV2Config,
150    preprocess: DinoV2PreprocessWeights,
151    device: Device,
152    batch: usize,
153}
154
155impl DinoV2Runner {
156    pub fn builder() -> DinoV2RunnerBuilder {
157        DinoV2RunnerBuilder::default()
158    }
159    pub fn config(&self) -> &DinoV2Config {
160        &self.cfg
161    }
162    pub fn device(&self) -> Device {
163        self.device
164    }
165
166    /// End-to-end forward on a single image. `rgb` is HWC u8 of any
167    /// resolution; will be resized + normalized to the configured
168    /// `img_size`. Returns logits when the loaded checkpoint
169    /// includes a classifier head, otherwise the post-LN feature
170    /// tokens.
171    pub fn predict_image(&mut self, rgb: &[u8], h_in: usize, w_in: usize) -> Result<DinoV2Output> {
172        // 1. resize + normalize
173        let img_size = self.cfg.img_size;
174        let mut nchw = rgb_u8_to_imagenet_nchw(rgb, h_in, w_in, img_size);
175        // Replicate across batch dim if batch > 1.
176        if self.batch > 1 {
177            let per = nchw.len();
178            let mut batched = Vec::with_capacity(per * self.batch);
179            for _ in 0..self.batch {
180                batched.extend_from_slice(&nchw);
181            }
182            nchw = batched;
183        }
184
185        // 2. host-side patchify + token assembly
186        let hidden = assemble_hidden(
187            &self.preprocess,
188            &nchw,
189            self.batch,
190            self.cfg.patch_size,
191            img_size,
192        )?;
193
194        // 3. forward through the compiled graph
195        let outputs = self.compiled.run(&[("hidden", hidden.as_slice())]);
196        let flat = outputs
197            .into_iter()
198            .next()
199            .ok_or_else(|| anyhow!("dinov2 forward returned no output"))?;
200
201        // 4. split the flat output back into per-batch slices.
202        if self.cfg.num_classes > 0 {
203            let nc = self.cfg.num_classes;
204            let mut per_batch = Vec::with_capacity(self.batch);
205            for b in 0..self.batch {
206                per_batch.push(flat[b * nc..(b + 1) * nc].to_vec());
207            }
208            Ok(DinoV2Output::Logits {
209                per_batch,
210                num_classes: nc,
211            })
212        } else {
213            let seq = self.cfg.seq_len();
214            let hidden_dim = self.cfg.hidden_size;
215            let per = seq * hidden_dim;
216            let mut per_batch = Vec::with_capacity(self.batch);
217            for b in 0..self.batch {
218                per_batch.push(flat[b * per..(b + 1) * per].to_vec());
219            }
220            Ok(DinoV2Output::Tokens {
221                per_batch,
222                seq,
223                hidden: hidden_dim,
224            })
225        }
226    }
227}