1use crate::config::LocateAnythingConfig;
22use crate::device::resolve_device;
23use crate::generation::{GenerationMode, SampleOpts};
24use crate::hub::{default_model_dir, resolve_weights_path};
25use crate::parse::{GroundingParse, parse_grounding};
26use crate::preprocess::{PreprocessedImage, preprocess_image};
27use crate::prompts;
28use crate::runner::{GenerateProfile, LocateAnythingRunner};
29use anyhow::{Context, Result};
30use image::{DynamicImage, GenericImageView};
31use rlx_runtime::Device;
32use std::path::Path;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum PromptStyle {
37 #[default]
39 Processor,
40 Rlx,
42}
43
44#[derive(Debug, Clone)]
46pub struct InferenceOptions {
47 pub device: Device,
48 pub generation_mode: GenerationMode,
49 pub max_new_tokens: usize,
50 pub temperature: f32,
51 pub repetition_penalty: f32,
52 pub top_p: f32,
53 pub prompt_style: PromptStyle,
54 pub max_image_side: Option<u32>,
56 pub preload_language_model: bool,
58}
59
60impl InferenceOptions {
61 pub fn for_grounding() -> Self {
63 Self {
64 device: resolve_device(None).unwrap_or(Device::Cpu),
65 generation_mode: GenerationMode::Hybrid,
66 max_new_tokens: 64,
67 temperature: 0.0,
68 repetition_penalty: 1.0,
69 top_p: 1.0,
70 prompt_style: PromptStyle::Processor,
71 max_image_side: None,
72 preload_language_model: false,
73 }
74 }
75
76 pub fn device(mut self, device: Device) -> Self {
77 self.device = device;
78 self
79 }
80
81 pub fn device_name(mut self, name: &str) -> Result<Self> {
82 self.device = resolve_device(Some(name))?;
83 Ok(self)
84 }
85
86 pub fn max_new_tokens(mut self, n: usize) -> Self {
87 self.max_new_tokens = n;
88 self
89 }
90
91 pub fn generation_mode(mut self, mode: GenerationMode) -> Self {
92 self.generation_mode = mode;
93 self
94 }
95
96 pub fn prompt_style(mut self, style: PromptStyle) -> Self {
97 self.prompt_style = style;
98 self
99 }
100
101 pub fn max_image_side(mut self, side: u32) -> Self {
102 self.max_image_side = Some(side);
103 self
104 }
105
106 pub fn preload_language_model(mut self, yes: bool) -> Self {
107 self.preload_language_model = yes;
108 self
109 }
110
111 fn sample_opts(&self) -> SampleOpts {
112 SampleOpts {
113 temperature: self.temperature,
114 top_p: self.top_p,
115 repetition_penalty: self.repetition_penalty,
116 max_new_tokens: self.max_new_tokens,
117 mode: self.generation_mode,
118 }
119 }
120}
121
122pub type GroundingResult = GroundingParse;
124
125pub struct LocateAnythingSession {
127 runner: LocateAnythingRunner,
128 cfg: LocateAnythingConfig,
129 options: InferenceOptions,
130 #[cfg(feature = "tokenizer")]
131 tokenizer: tokenizers::Tokenizer,
132}
133
134impl LocateAnythingSession {
135 pub fn open_default() -> Result<Self> {
137 Self::open(default_model_dir()?)
138 }
139
140 pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142 Self::open_with_options(model_dir, InferenceOptions::for_grounding())
143 }
144
145 pub fn open_with_options(
147 model_dir: impl AsRef<Path>,
148 options: InferenceOptions,
149 ) -> Result<Self> {
150 let dir = resolve_weights_path(model_dir.as_ref())?;
151 let cfg = LocateAnythingConfig::from_file(&dir.join("config.json"))
152 .with_context(|| format!("load config from {dir:?}"))?;
153 cfg.validate()?;
154
155 let sample = options.sample_opts();
156 let mut runner = LocateAnythingRunner::builder()
157 .weights(&dir)
158 .device(options.device)
159 .max_new_tokens(sample.max_new_tokens)
160 .generation_mode(sample.mode)
161 .temperature(sample.temperature)
162 .repetition_penalty(sample.repetition_penalty)
163 .build()?;
164
165 if options.preload_language_model {
166 runner.preload_language_model()?;
167 }
168
169 #[cfg(feature = "tokenizer")]
170 let tokenizer = crate::tokenizer::load_tokenizer(&dir)?;
171
172 Ok(Self {
173 runner,
174 cfg,
175 options,
176 #[cfg(feature = "tokenizer")]
177 tokenizer,
178 })
179 }
180
181 pub fn model_dir(&self) -> &Path {
182 self.runner.model_dir()
183 }
184
185 pub fn device(&self) -> Device {
186 self.options.device
187 }
188
189 pub fn options(&self) -> &InferenceOptions {
190 &self.options
191 }
192
193 pub fn runner(&self) -> &LocateAnythingRunner {
194 &self.runner
195 }
196
197 pub fn runner_mut(&mut self) -> &mut LocateAnythingRunner {
198 &mut self.runner
199 }
200
201 pub fn preprocess_dynamic(&self, img: &DynamicImage) -> Result<PreprocessedImage> {
202 let img = maybe_resize(img, self.options.max_image_side);
203 preprocess_image(&img, &self.cfg)
204 }
205
206 pub fn preprocess_file(&self, path: impl AsRef<Path>) -> Result<PreprocessedImage> {
207 let img = image::open(path.as_ref())?;
208 self.preprocess_dynamic(&img)
209 }
210
211 pub fn warmup(&mut self, image: &PreprocessedImage, phrase: &str) -> Result<()> {
213 let prompt_ids = self.build_prompt_ids(image, phrase)?;
214 self.runner.warmup_compile(&prompt_ids, image)
215 }
216
217 #[cfg(feature = "tokenizer")]
219 pub fn ground(&mut self, image: &PreprocessedImage, phrase: &str) -> Result<GroundingResult> {
220 self.ground_with_profile(image, phrase).map(|(r, _)| r)
221 }
222
223 #[cfg(feature = "tokenizer")]
225 pub fn ground_with_profile(
226 &mut self,
227 image: &PreprocessedImage,
228 phrase: &str,
229 ) -> Result<(GroundingResult, GenerateProfile)> {
230 let (w, h) = (image.pixel_w, image.pixel_h);
231 let prompt_ids = self.build_prompt_ids(image, phrase)?;
232 let (tokens, profile) = self.runner.generate_with_profile(&prompt_ids, image)?;
233 let result = self.decode_grounding(&tokens, prompt_ids.len(), w, h)?;
234 Ok((result, profile))
235 }
236
237 #[cfg(feature = "tokenizer")]
238 pub fn ground_path(&mut self, path: impl AsRef<Path>, phrase: &str) -> Result<GroundingResult> {
239 let prep = self.preprocess_file(path)?;
240 self.ground(&prep, phrase)
241 }
242
243 #[cfg(feature = "tokenizer")]
244 pub fn ground_dynamic(&mut self, img: &DynamicImage, phrase: &str) -> Result<GroundingResult> {
245 let prep = self.preprocess_dynamic(img)?;
246 self.ground(&prep, phrase)
247 }
248
249 #[cfg(feature = "tokenizer")]
250 pub fn detect(
251 &mut self,
252 image: &PreprocessedImage,
253 categories: &[&str],
254 ) -> Result<GroundingResult> {
255 self.ground(image, &prompts::detect(categories))
256 }
257
258 #[cfg(feature = "tokenizer")]
259 fn build_prompt_ids(&self, image: &PreprocessedImage, user_text: &str) -> Result<Vec<u32>> {
260 let kh = self.cfg.vision_config.merge_kernel_size[0];
261 let kw = self.cfg.vision_config.merge_kernel_size[1];
262 let n_image = (image.grid_h / kh) * (image.grid_w / kw);
263 match self.options.prompt_style {
264 PromptStyle::Processor => {
265 let with_ph = if user_text.starts_with("<image-1>") {
266 user_text.to_string()
267 } else {
268 format!("<image-1>{user_text}")
269 };
270 crate::processor_prompt::build_processor_prompt_ids(
271 self.runner.model_dir(),
272 &self.cfg,
273 &self.tokenizer,
274 &with_ph,
275 n_image,
276 )
277 }
278 PromptStyle::Rlx => crate::tokenizer::build_user_prompt_ids(
279 &self.cfg,
280 &self.tokenizer,
281 user_text,
282 n_image,
283 ),
284 }
285 }
286
287 #[cfg(feature = "tokenizer")]
288 fn decode_grounding(
289 &self,
290 tokens: &[u32],
291 prompt_len: usize,
292 width: u32,
293 height: u32,
294 ) -> Result<GroundingResult> {
295 let new = &tokens[prompt_len..];
296 let text = crate::tokenizer::decode(&self.tokenizer, new)?;
297 let raw = self
298 .tokenizer
299 .decode(new, false)
300 .unwrap_or_else(|_| text.clone());
301 let mut parsed = parse_grounding(&text, width, height);
302 if parsed.boxes.is_empty() && raw != text {
303 let from_raw = parse_grounding(&raw, width, height);
304 if !from_raw.boxes.is_empty() || !from_raw.refs.is_empty() {
305 parsed = from_raw;
306 }
307 }
308 parsed.text = text;
309 parsed.raw = raw;
310 parsed.new_tokens = new.len();
311 parsed.prompt_len = prompt_len;
312 Ok(parsed)
313 }
314}
315
316fn maybe_resize(img: &DynamicImage, max_side: Option<u32>) -> DynamicImage {
317 let Some(max_side) = max_side else {
318 return img.clone();
319 };
320 let (w, h) = img.dimensions();
321 let longest = w.max(h);
322 if longest <= max_side {
323 return img.clone();
324 }
325 let scale = max_side as f32 / longest as f32;
326 let nw = ((w as f32 * scale).round() as u32).max(1);
327 let nh = ((h as f32 * scale).round() as u32).max(1);
328 img.resize_exact(nw, nh, image::imageops::FilterType::Triangle)
329}