Skip to main content

rlx_locateanything/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! LocateAnything configuration — HuggingFace `config.json` (`nvidia/LocateAnything-3B`).
17
18use anyhow::{Context, Result, ensure};
19use serde::Deserialize;
20use std::path::Path;
21
22/// MoonViT-SO-400M vision tower (`vision_config` in HF JSON).
23#[derive(Debug, Clone, Deserialize)]
24pub struct MoonVitConfig {
25    pub model_type: String,
26    pub hidden_size: usize,
27    pub intermediate_size: usize,
28    pub num_attention_heads: usize,
29    pub num_hidden_layers: usize,
30    pub patch_size: usize,
31    pub merge_kernel_size: [usize; 2],
32    pub init_pos_emb_height: usize,
33    pub init_pos_emb_width: usize,
34}
35
36impl MoonVitConfig {
37    pub fn head_dim(&self) -> usize {
38        self.hidden_size / self.num_attention_heads
39    }
40
41    pub fn merge_area(&self) -> usize {
42        self.merge_kernel_size[0] * self.merge_kernel_size[1]
43    }
44}
45
46/// Qwen2.5-3B text trunk (`text_config` in HF JSON).
47#[derive(Debug, Clone, Deserialize)]
48pub struct LocateAnythingTextConfig {
49    pub vocab_size: usize,
50    pub hidden_size: usize,
51    pub intermediate_size: usize,
52    pub num_hidden_layers: usize,
53    pub num_attention_heads: usize,
54    pub num_key_value_heads: usize,
55    pub max_position_embeddings: usize,
56    #[serde(default = "default_rms_norm_eps")]
57    pub rms_norm_eps: f64,
58    #[serde(default = "default_rope_theta")]
59    pub rope_theta: f64,
60    #[serde(default = "default_hidden_act")]
61    pub hidden_act: String,
62    #[serde(default)]
63    pub tie_word_embeddings: bool,
64    /// MTP / parallel box block size (HF default 6).
65    #[serde(default = "default_block_size")]
66    pub block_size: usize,
67    #[serde(default)]
68    pub causal_attn: bool,
69    pub bos_token_id: u32,
70    pub eos_token_id: u32,
71    #[serde(default)]
72    pub null_token_id: Option<u32>,
73    #[serde(default)]
74    pub switch_token_id: Option<u32>,
75    #[serde(default)]
76    pub text_mask_token_id: Option<u32>,
77}
78
79fn default_rms_norm_eps() -> f64 {
80    1e-6
81}
82fn default_rope_theta() -> f64 {
83    1_000_000.0
84}
85fn default_hidden_act() -> String {
86    "silu".into()
87}
88fn default_block_size() -> usize {
89    6
90}
91
92impl LocateAnythingTextConfig {
93    pub fn head_dim(&self) -> usize {
94        self.hidden_size / self.num_attention_heads
95    }
96
97    pub fn kv_group_size(&self) -> usize {
98        self.num_attention_heads / self.num_key_value_heads
99    }
100
101    /// Map HF text trunk to [`rlx_qwen3::Qwen3Config`] (Qwen2.5 — biases on, no QK-norm).
102    pub fn to_qwen3_config(&self) -> rlx_qwen3::Qwen3Config {
103        rlx_qwen3::Qwen3Config {
104            vocab_size: self.vocab_size,
105            hidden_size: self.hidden_size,
106            intermediate_size: self.intermediate_size,
107            num_hidden_layers: self.num_hidden_layers,
108            num_attention_heads: self.num_attention_heads,
109            num_key_value_heads: self.num_key_value_heads,
110            head_dim: self.head_dim(),
111            max_position_embeddings: self.max_position_embeddings,
112            rms_norm_eps: self.rms_norm_eps,
113            rope_theta: self.rope_theta,
114            hidden_act: self.hidden_act.clone(),
115            tie_word_embeddings: self.tie_word_embeddings,
116            attention_bias: true,
117            qk_norm: false,
118            sliding_window: None,
119            max_window_layers: usize::MAX,
120            use_sliding_window: false,
121            num_experts: 0,
122            num_experts_used: 0,
123            expert_ffn_size: 0,
124            shared_expert_ffn_size: 0,
125            expert_weights_scale: 1.0,
126        }
127    }
128}
129
130/// `preprocessor_config.json` — native-resolution rescale limits (HF image processor).
131#[derive(Debug, Clone, Deserialize)]
132pub struct LocateAnythingPreprocessorConfig {
133    #[serde(default = "default_in_token_limit")]
134    pub in_token_limit: usize,
135    #[serde(default = "default_image_mean")]
136    pub image_mean: [f32; 3],
137    #[serde(default = "default_image_std")]
138    pub image_std: [f32; 3],
139}
140
141fn default_in_token_limit() -> usize {
142    25_600
143}
144
145fn default_image_mean() -> [f32; 3] {
146    [0.5, 0.5, 0.5]
147}
148
149fn default_image_std() -> [f32; 3] {
150    [0.5, 0.5, 0.5]
151}
152
153impl LocateAnythingPreprocessorConfig {
154    pub fn from_file(path: &Path) -> Result<Self> {
155        let data = std::fs::read_to_string(path)
156            .with_context(|| format!("read preprocessor config {path:?}"))?;
157        serde_json::from_str(&data).with_context(|| format!("parse preprocessor config {path:?}"))
158    }
159}
160
161/// Top-level LocateAnything checkpoint config.
162#[derive(Debug, Clone, Deserialize)]
163pub struct LocateAnythingConfig {
164    pub model_type: String,
165    pub image_token_index: u32,
166    pub box_start_token_id: u32,
167    pub box_end_token_id: u32,
168    pub coord_start_token_id: u32,
169    pub coord_end_token_id: u32,
170    pub ref_start_token_id: u32,
171    pub ref_end_token_id: u32,
172    pub none_token_id: u32,
173    #[serde(default = "default_mlp_connector_layers")]
174    pub mlp_connector_layers: usize,
175    #[serde(default)]
176    pub mlp_checkpoint: bool,
177    pub text_config: LocateAnythingTextConfig,
178    pub vision_config: MoonVitConfig,
179    /// Loaded from `preprocessor_config.json` via [`Self::from_model_dir`].
180    #[serde(skip)]
181    pub preprocessor: LocateAnythingPreprocessorConfig,
182}
183
184fn default_mlp_connector_layers() -> usize {
185    2
186}
187
188impl Default for LocateAnythingPreprocessorConfig {
189    fn default() -> Self {
190        Self {
191            in_token_limit: default_in_token_limit(),
192            image_mean: default_image_mean(),
193            image_std: default_image_std(),
194        }
195    }
196}
197
198impl LocateAnythingConfig {
199    pub const HF_MODEL_ID: &'static str = "nvidia/LocateAnything-3B";
200
201    pub fn from_file(path: &Path) -> Result<Self> {
202        let data = std::fs::read_to_string(path)
203            .with_context(|| format!("read LocateAnything config {path:?}"))?;
204        let mut cfg: Self = serde_json::from_str(&data)
205            .with_context(|| format!("parse LocateAnything config {path:?}"))?;
206        let dir = path.parent().unwrap_or(Path::new("."));
207        cfg.preprocessor =
208            LocateAnythingPreprocessorConfig::from_file(&dir.join("preprocessor_config.json"))
209                .unwrap_or_default();
210        Ok(cfg)
211    }
212
213    pub fn from_model_dir(dir: &Path) -> Result<Self> {
214        Self::from_file(&dir.join("config.json"))
215    }
216
217    pub fn validate(&self) -> Result<()> {
218        ensure!(
219            self.model_type == "locateanything",
220            "model_type must be locateanything, got {}",
221            self.model_type
222        );
223        ensure!(
224            self.vision_config.model_type == "moonvit",
225            "vision_config.model_type must be moonvit, got {}",
226            self.vision_config.model_type
227        );
228        ensure!(self.text_config.num_hidden_layers > 0, "text layers");
229        ensure!(self.vision_config.num_hidden_layers > 0, "vision layers");
230        ensure!(self.mlp_connector_layers == 2, "mlp_connector_layers");
231        Ok(())
232    }
233
234    /// Projector input width: MoonViT hidden × merge kernel area (2×2).
235    pub fn projector_input_dim(&self) -> usize {
236        self.vision_config.hidden_size * self.vision_config.merge_area()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn moonvit_merge_area() {
246        let cfg = MoonVitConfig {
247            model_type: "moonvit".into(),
248            hidden_size: 1152,
249            intermediate_size: 4304,
250            num_attention_heads: 16,
251            num_hidden_layers: 27,
252            patch_size: 14,
253            merge_kernel_size: [2, 2],
254            init_pos_emb_height: 64,
255            init_pos_emb_width: 64,
256        };
257        assert_eq!(cfg.merge_area(), 4);
258    }
259}