1use crate::{DinoV2Config, DinoV2PreprocessWeights, assemble_hidden, rgb_u8_to_imagenet_nchw};
17use anyhow::{Result, anyhow};
18use rlx_core::validate_standard_device;
19use rlx_flow::CompileProfile;
20use rlx_runtime::Device;
21use std::path::PathBuf;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum DinoV2Variant {
26 Small,
27 Base,
28 Large,
29}
30
31#[derive(Debug, Clone)]
33pub enum DinoV2Output {
34 Logits {
35 per_batch: Vec<Vec<f32>>,
36 num_classes: usize,
37 },
38 Tokens {
39 per_batch: Vec<Vec<f32>>,
40 seq: usize,
41 hidden: usize,
42 },
43}
44
45#[derive(Debug, Clone, Default)]
47pub struct DinoV2RunnerBuilder {
48 weights: Option<PathBuf>,
49 device: Option<Device>,
50 variant: Option<DinoV2Variant>,
51 img_size: Option<usize>,
52 batch: Option<usize>,
53 config: Option<DinoV2Config>,
54}
55
56impl DinoV2RunnerBuilder {
57 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
58 self.weights = Some(p.into());
59 self
60 }
61 pub fn device(mut self, d: Device) -> Self {
62 self.device = Some(d);
63 self
64 }
65 pub fn variant(mut self, v: DinoV2Variant) -> Self {
67 self.variant = Some(v);
68 self
69 }
70 pub fn img_size(mut self, n: usize) -> Self {
73 self.img_size = Some(n);
74 self
75 }
76 pub fn batch(mut self, n: usize) -> Self {
77 self.batch = Some(n);
78 self
79 }
80 pub fn config(mut self, cfg: DinoV2Config) -> Self {
83 self.config = Some(cfg);
84 self
85 }
86
87 pub fn build(self) -> Result<DinoV2Runner> {
88 use rlx_runtime::Session;
89
90 let weights_path = self
91 .weights
92 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
93 let device = self.device.unwrap_or(Device::Cpu);
94 validate_standard_device("dinov2", device)?;
95 let img_size = self.img_size.unwrap_or(518);
96 let batch = self.batch.unwrap_or(1);
97 let cfg = match (self.config, self.variant) {
98 (Some(c), _) => c,
99 (None, Some(DinoV2Variant::Small)) => DinoV2Config::vit_small(img_size),
100 (None, Some(DinoV2Variant::Large)) => DinoV2Config::vit_large(img_size),
101 (None, _) => DinoV2Config::vit_base(img_size),
103 };
104
105 let is_gguf = weights_path.extension().is_some_and(|e| e == "gguf");
106 if is_gguf {
107 rlx_core::gguf_validate_arch(&weights_path, rlx_core::DINOV2_GGUF_ARCHES)?;
108 }
109 let (mut wm, gguf_packed) =
110 if is_gguf && crate::packed_gguf::gguf_has_packed_linears(&weights_path)? {
111 eprintln!(
112 "[dinov2] loading GGUF with packed DequantMatMul {:?}",
113 weights_path
114 );
115 let (wm, packed) = crate::packed_gguf::load_dinov2_from_gguf(&weights_path)?;
116 (wm, Some(packed))
117 } else {
118 (
119 rlx_core::load_weight_map(&weights_path, rlx_core::DINOV2_GGUF_ARCHES)?,
120 None,
121 )
122 };
123 let built = super::flow::build_dinov2_built_with_packed(
124 &cfg,
125 &mut wm,
126 batch,
127 gguf_packed.as_ref(),
128 )?;
129 let typed = built.model.typed_params.clone();
130 let pre = built.preprocess;
131 let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
132 let opts =
133 rlx_core::flow_bridge::compile_options_for_profile(&CompileProfile::encoder(), device);
134 let mut compiled = Session::new(device).compile_with(graph, &opts);
135 rlx_core::flow_util::attach_built_params(&mut compiled, params, &typed);
136 Ok(DinoV2Runner {
137 compiled,
138 cfg,
139 preprocess: pre,
140 device,
141 batch,
142 })
143 }
144}
145
146pub struct DinoV2Runner {
148 compiled: rlx_runtime::CompiledGraph,
149 cfg: DinoV2Config,
150 preprocess: DinoV2PreprocessWeights,
151 device: Device,
152 batch: usize,
153}
154
155impl DinoV2Runner {
156 pub fn builder() -> DinoV2RunnerBuilder {
157 DinoV2RunnerBuilder::default()
158 }
159 pub fn config(&self) -> &DinoV2Config {
160 &self.cfg
161 }
162 pub fn device(&self) -> Device {
163 self.device
164 }
165
166 pub fn predict_image(&mut self, rgb: &[u8], h_in: usize, w_in: usize) -> Result<DinoV2Output> {
172 let img_size = self.cfg.img_size;
174 let mut nchw = rgb_u8_to_imagenet_nchw(rgb, h_in, w_in, img_size);
175 if self.batch > 1 {
177 let per = nchw.len();
178 let mut batched = Vec::with_capacity(per * self.batch);
179 for _ in 0..self.batch {
180 batched.extend_from_slice(&nchw);
181 }
182 nchw = batched;
183 }
184
185 let hidden = assemble_hidden(
187 &self.preprocess,
188 &nchw,
189 self.batch,
190 self.cfg.patch_size,
191 img_size,
192 )?;
193
194 let outputs = self.compiled.run(&[("hidden", hidden.as_slice())]);
196 let flat = outputs
197 .into_iter()
198 .next()
199 .ok_or_else(|| anyhow!("dinov2 forward returned no output"))?;
200
201 if self.cfg.num_classes > 0 {
203 let nc = self.cfg.num_classes;
204 let mut per_batch = Vec::with_capacity(self.batch);
205 for b in 0..self.batch {
206 per_batch.push(flat[b * nc..(b + 1) * nc].to_vec());
207 }
208 Ok(DinoV2Output::Logits {
209 per_batch,
210 num_classes: nc,
211 })
212 } else {
213 let seq = self.cfg.seq_len();
214 let hidden_dim = self.cfg.hidden_size;
215 let per = seq * hidden_dim;
216 let mut per_batch = Vec::with_capacity(self.batch);
217 for b in 0..self.batch {
218 per_batch.push(flat[b * per..(b + 1) * per].to_vec());
219 }
220 Ok(DinoV2Output::Tokens {
221 per_batch,
222 seq,
223 hidden: hidden_dim,
224 })
225 }
226 }
227}