Skip to main content

rlx_locateanything/
infer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level inference session — open once, ground many images with minimal code.
17//!
18//! Default image for quick runs: [`crate::fixtures::sample_image_path`].
19//! Full CLI and env documentation: `README.md` in this crate.
20
21use crate::config::LocateAnythingConfig;
22use crate::device::resolve_device;
23use crate::generation::{GenerationMode, SampleOpts};
24use crate::hub::{default_model_dir, resolve_weights_path};
25use crate::parse::{GroundingParse, parse_grounding};
26use crate::preprocess::{PreprocessedImage, preprocess_image};
27use crate::prompts;
28use crate::runner::{GenerateProfile, LocateAnythingRunner};
29use anyhow::{Context, Result};
30use image::{DynamicImage, GenericImageView};
31use rlx_runtime::Device;
32use std::path::Path;
33
34/// How user text + vision placeholders are assembled.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum PromptStyle {
37    /// HF `LocateAnythingProcessor` layout (recommended for boxes).
38    #[default]
39    Processor,
40    /// RLX minimal Qwen chat (`<|im_start|>user` + raw image slots).
41    Rlx,
42}
43
44/// Inference settings (device, decode, prompt, optional resize for speed).
45#[derive(Debug, Clone)]
46pub struct InferenceOptions {
47    pub device: Device,
48    pub generation_mode: GenerationMode,
49    pub max_new_tokens: usize,
50    pub temperature: f32,
51    pub repetition_penalty: f32,
52    pub top_p: f32,
53    pub prompt_style: PromptStyle,
54    /// Resize so the longest edge is at most this many pixels before patchify (faster).
55    pub max_image_side: Option<u32>,
56    /// Compile LM graphs on [`LocateAnythingSession::open`].
57    pub preload_language_model: bool,
58}
59
60impl InferenceOptions {
61    /// Greedy hybrid defaults suited to grounding (processor prompt).
62    pub fn for_grounding() -> Self {
63        Self {
64            device: resolve_device(None).unwrap_or(Device::Cpu),
65            generation_mode: GenerationMode::Hybrid,
66            max_new_tokens: 64,
67            temperature: 0.0,
68            repetition_penalty: 1.0,
69            top_p: 1.0,
70            prompt_style: PromptStyle::Processor,
71            max_image_side: None,
72            preload_language_model: false,
73        }
74    }
75
76    pub fn device(mut self, device: Device) -> Self {
77        self.device = device;
78        self
79    }
80
81    pub fn device_name(mut self, name: &str) -> Result<Self> {
82        self.device = resolve_device(Some(name))?;
83        Ok(self)
84    }
85
86    pub fn max_new_tokens(mut self, n: usize) -> Self {
87        self.max_new_tokens = n;
88        self
89    }
90
91    pub fn generation_mode(mut self, mode: GenerationMode) -> Self {
92        self.generation_mode = mode;
93        self
94    }
95
96    pub fn prompt_style(mut self, style: PromptStyle) -> Self {
97        self.prompt_style = style;
98        self
99    }
100
101    pub fn max_image_side(mut self, side: u32) -> Self {
102        self.max_image_side = Some(side);
103        self
104    }
105
106    pub fn preload_language_model(mut self, yes: bool) -> Self {
107        self.preload_language_model = yes;
108        self
109    }
110
111    fn sample_opts(&self) -> SampleOpts {
112        SampleOpts {
113            temperature: self.temperature,
114            top_p: self.top_p,
115            repetition_penalty: self.repetition_penalty,
116            max_new_tokens: self.max_new_tokens,
117            mode: self.generation_mode,
118        }
119    }
120}
121
122/// Parsed grounding output (boxes, points, ref labels).
123pub type GroundingResult = GroundingParse;
124
125/// Loaded model + compile caches for repeated grounding calls.
126pub struct LocateAnythingSession {
127    runner: LocateAnythingRunner,
128    cfg: LocateAnythingConfig,
129    options: InferenceOptions,
130    #[cfg(feature = "tokenizer")]
131    tokenizer: tokenizers::Tokenizer,
132}
133
134impl LocateAnythingSession {
135    /// Open from [`default_model_dir`] (HF cache, env, or local fetch layout).
136    pub fn open_default() -> Result<Self> {
137        Self::open(default_model_dir()?)
138    }
139
140    /// Open checkpoint at `model_dir` (path, `hf`, or Hub id `nvidia/LocateAnything-3B`).
141    pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142        Self::open_with_options(model_dir, InferenceOptions::for_grounding())
143    }
144
145    /// Open with full control over device, decode, and prompt layout.
146    pub fn open_with_options(
147        model_dir: impl AsRef<Path>,
148        options: InferenceOptions,
149    ) -> Result<Self> {
150        let dir = resolve_weights_path(model_dir.as_ref())?;
151        let cfg = LocateAnythingConfig::from_file(&dir.join("config.json"))
152            .with_context(|| format!("load config from {dir:?}"))?;
153        cfg.validate()?;
154
155        let sample = options.sample_opts();
156        let mut runner = LocateAnythingRunner::builder()
157            .weights(&dir)
158            .device(options.device)
159            .max_new_tokens(sample.max_new_tokens)
160            .generation_mode(sample.mode)
161            .temperature(sample.temperature)
162            .repetition_penalty(sample.repetition_penalty)
163            .build()?;
164
165        if options.preload_language_model {
166            runner.preload_language_model()?;
167        }
168
169        #[cfg(feature = "tokenizer")]
170        let tokenizer = crate::tokenizer::load_tokenizer(&dir)?;
171
172        Ok(Self {
173            runner,
174            cfg,
175            options,
176            #[cfg(feature = "tokenizer")]
177            tokenizer,
178        })
179    }
180
181    pub fn model_dir(&self) -> &Path {
182        self.runner.model_dir()
183    }
184
185    pub fn device(&self) -> Device {
186        self.options.device
187    }
188
189    pub fn options(&self) -> &InferenceOptions {
190        &self.options
191    }
192
193    pub fn runner(&self) -> &LocateAnythingRunner {
194        &self.runner
195    }
196
197    pub fn runner_mut(&mut self) -> &mut LocateAnythingRunner {
198        &mut self.runner
199    }
200
201    pub fn preprocess_dynamic(&self, img: &DynamicImage) -> Result<PreprocessedImage> {
202        let img = maybe_resize(img, self.options.max_image_side);
203        preprocess_image(&img, &self.cfg)
204    }
205
206    pub fn preprocess_file(&self, path: impl AsRef<Path>) -> Result<PreprocessedImage> {
207        let img = image::open(path.as_ref())?;
208        self.preprocess_dynamic(&img)
209    }
210
211    /// Compile MoonViT + projector + LM prefill for this image/prompt shape (amortize first real query).
212    pub fn warmup(&mut self, image: &PreprocessedImage, phrase: &str) -> Result<()> {
213        let prompt_ids = self.build_prompt_ids(image, phrase)?;
214        self.runner.warmup_compile(&prompt_ids, image)
215    }
216
217    /// Ground a single phrase in one image (returns boxes when the model emits them).
218    #[cfg(feature = "tokenizer")]
219    pub fn ground(&mut self, image: &PreprocessedImage, phrase: &str) -> Result<GroundingResult> {
220        self.ground_with_profile(image, phrase).map(|(r, _)| r)
221    }
222
223    /// Like [`Self::ground`] but also returns per-phase generation timings.
224    #[cfg(feature = "tokenizer")]
225    pub fn ground_with_profile(
226        &mut self,
227        image: &PreprocessedImage,
228        phrase: &str,
229    ) -> Result<(GroundingResult, GenerateProfile)> {
230        let (w, h) = (image.pixel_w, image.pixel_h);
231        let prompt_ids = self.build_prompt_ids(image, phrase)?;
232        let (tokens, profile) = self.runner.generate_with_profile(&prompt_ids, image)?;
233        let result = self.decode_grounding(&tokens, prompt_ids.len(), w, h)?;
234        Ok((result, profile))
235    }
236
237    #[cfg(feature = "tokenizer")]
238    pub fn ground_path(&mut self, path: impl AsRef<Path>, phrase: &str) -> Result<GroundingResult> {
239        let prep = self.preprocess_file(path)?;
240        self.ground(&prep, phrase)
241    }
242
243    #[cfg(feature = "tokenizer")]
244    pub fn ground_dynamic(&mut self, img: &DynamicImage, phrase: &str) -> Result<GroundingResult> {
245        let prep = self.preprocess_dynamic(img)?;
246        self.ground(&prep, phrase)
247    }
248
249    #[cfg(feature = "tokenizer")]
250    pub fn detect(
251        &mut self,
252        image: &PreprocessedImage,
253        categories: &[&str],
254    ) -> Result<GroundingResult> {
255        self.ground(image, &prompts::detect(categories))
256    }
257
258    #[cfg(feature = "tokenizer")]
259    fn build_prompt_ids(&self, image: &PreprocessedImage, user_text: &str) -> Result<Vec<u32>> {
260        let kh = self.cfg.vision_config.merge_kernel_size[0];
261        let kw = self.cfg.vision_config.merge_kernel_size[1];
262        let n_image = (image.grid_h / kh) * (image.grid_w / kw);
263        match self.options.prompt_style {
264            PromptStyle::Processor => {
265                let with_ph = if user_text.starts_with("<image-1>") {
266                    user_text.to_string()
267                } else {
268                    format!("<image-1>{user_text}")
269                };
270                crate::processor_prompt::build_processor_prompt_ids(
271                    self.runner.model_dir(),
272                    &self.cfg,
273                    &self.tokenizer,
274                    &with_ph,
275                    n_image,
276                )
277            }
278            PromptStyle::Rlx => crate::tokenizer::build_user_prompt_ids(
279                &self.cfg,
280                &self.tokenizer,
281                user_text,
282                n_image,
283            ),
284        }
285    }
286
287    #[cfg(feature = "tokenizer")]
288    fn decode_grounding(
289        &self,
290        tokens: &[u32],
291        prompt_len: usize,
292        width: u32,
293        height: u32,
294    ) -> Result<GroundingResult> {
295        let new = &tokens[prompt_len..];
296        let text = crate::tokenizer::decode(&self.tokenizer, new)?;
297        let raw = self
298            .tokenizer
299            .decode(new, false)
300            .unwrap_or_else(|_| text.clone());
301        let mut parsed = parse_grounding(&text, width, height);
302        if parsed.boxes.is_empty() && raw != text {
303            let from_raw = parse_grounding(&raw, width, height);
304            if !from_raw.boxes.is_empty() || !from_raw.refs.is_empty() {
305                parsed = from_raw;
306            }
307        }
308        parsed.text = text;
309        parsed.raw = raw;
310        parsed.new_tokens = new.len();
311        parsed.prompt_len = prompt_len;
312        Ok(parsed)
313    }
314}
315
316fn maybe_resize(img: &DynamicImage, max_side: Option<u32>) -> DynamicImage {
317    let Some(max_side) = max_side else {
318        return img.clone();
319    };
320    let (w, h) = img.dimensions();
321    let longest = w.max(h);
322    if longest <= max_side {
323        return img.clone();
324    }
325    let scale = max_side as f32 / longest as f32;
326    let nw = ((w as f32 * scale).round() as u32).max(1);
327    let nh = ((h as f32 * scale).round() as u32).max(1);
328    img.resize_exact(nw, nh, image::imageops::FilterType::Triangle)
329}