1use anyhow::{Context, Result, ensure};
19use serde::Deserialize;
20use std::path::Path;
21
22#[derive(Debug, Clone, Deserialize)]
24pub struct MoonVitConfig {
25 pub model_type: String,
26 pub hidden_size: usize,
27 pub intermediate_size: usize,
28 pub num_attention_heads: usize,
29 pub num_hidden_layers: usize,
30 pub patch_size: usize,
31 pub merge_kernel_size: [usize; 2],
32 pub init_pos_emb_height: usize,
33 pub init_pos_emb_width: usize,
34}
35
36impl MoonVitConfig {
37 pub fn head_dim(&self) -> usize {
38 self.hidden_size / self.num_attention_heads
39 }
40
41 pub fn merge_area(&self) -> usize {
42 self.merge_kernel_size[0] * self.merge_kernel_size[1]
43 }
44}
45
46#[derive(Debug, Clone, Deserialize)]
48pub struct LocateAnythingTextConfig {
49 pub vocab_size: usize,
50 pub hidden_size: usize,
51 pub intermediate_size: usize,
52 pub num_hidden_layers: usize,
53 pub num_attention_heads: usize,
54 pub num_key_value_heads: usize,
55 pub max_position_embeddings: usize,
56 #[serde(default = "default_rms_norm_eps")]
57 pub rms_norm_eps: f64,
58 #[serde(default = "default_rope_theta")]
59 pub rope_theta: f64,
60 #[serde(default = "default_hidden_act")]
61 pub hidden_act: String,
62 #[serde(default)]
63 pub tie_word_embeddings: bool,
64 #[serde(default = "default_block_size")]
66 pub block_size: usize,
67 #[serde(default)]
68 pub causal_attn: bool,
69 pub bos_token_id: u32,
70 pub eos_token_id: u32,
71 #[serde(default)]
72 pub null_token_id: Option<u32>,
73 #[serde(default)]
74 pub switch_token_id: Option<u32>,
75 #[serde(default)]
76 pub text_mask_token_id: Option<u32>,
77}
78
79fn default_rms_norm_eps() -> f64 {
80 1e-6
81}
82fn default_rope_theta() -> f64 {
83 1_000_000.0
84}
85fn default_hidden_act() -> String {
86 "silu".into()
87}
88fn default_block_size() -> usize {
89 6
90}
91
92impl LocateAnythingTextConfig {
93 pub fn head_dim(&self) -> usize {
94 self.hidden_size / self.num_attention_heads
95 }
96
97 pub fn kv_group_size(&self) -> usize {
98 self.num_attention_heads / self.num_key_value_heads
99 }
100
101 pub fn to_qwen3_config(&self) -> rlx_qwen3::Qwen3Config {
103 rlx_qwen3::Qwen3Config {
104 vocab_size: self.vocab_size,
105 hidden_size: self.hidden_size,
106 intermediate_size: self.intermediate_size,
107 num_hidden_layers: self.num_hidden_layers,
108 num_attention_heads: self.num_attention_heads,
109 num_key_value_heads: self.num_key_value_heads,
110 head_dim: self.head_dim(),
111 max_position_embeddings: self.max_position_embeddings,
112 rms_norm_eps: self.rms_norm_eps,
113 rope_theta: self.rope_theta,
114 hidden_act: self.hidden_act.clone(),
115 tie_word_embeddings: self.tie_word_embeddings,
116 attention_bias: true,
117 qk_norm: false,
118 sliding_window: None,
119 max_window_layers: usize::MAX,
120 use_sliding_window: false,
121 num_experts: 0,
122 num_experts_used: 0,
123 expert_ffn_size: 0,
124 shared_expert_ffn_size: 0,
125 expert_weights_scale: 1.0,
126 }
127 }
128}
129
130#[derive(Debug, Clone, Deserialize)]
132pub struct LocateAnythingPreprocessorConfig {
133 #[serde(default = "default_in_token_limit")]
134 pub in_token_limit: usize,
135 #[serde(default = "default_image_mean")]
136 pub image_mean: [f32; 3],
137 #[serde(default = "default_image_std")]
138 pub image_std: [f32; 3],
139}
140
141fn default_in_token_limit() -> usize {
142 25_600
143}
144
145fn default_image_mean() -> [f32; 3] {
146 [0.5, 0.5, 0.5]
147}
148
149fn default_image_std() -> [f32; 3] {
150 [0.5, 0.5, 0.5]
151}
152
153impl LocateAnythingPreprocessorConfig {
154 pub fn from_file(path: &Path) -> Result<Self> {
155 let data = std::fs::read_to_string(path)
156 .with_context(|| format!("read preprocessor config {path:?}"))?;
157 serde_json::from_str(&data).with_context(|| format!("parse preprocessor config {path:?}"))
158 }
159}
160
161#[derive(Debug, Clone, Deserialize)]
163pub struct LocateAnythingConfig {
164 pub model_type: String,
165 pub image_token_index: u32,
166 pub box_start_token_id: u32,
167 pub box_end_token_id: u32,
168 pub coord_start_token_id: u32,
169 pub coord_end_token_id: u32,
170 pub ref_start_token_id: u32,
171 pub ref_end_token_id: u32,
172 pub none_token_id: u32,
173 #[serde(default = "default_mlp_connector_layers")]
174 pub mlp_connector_layers: usize,
175 #[serde(default)]
176 pub mlp_checkpoint: bool,
177 pub text_config: LocateAnythingTextConfig,
178 pub vision_config: MoonVitConfig,
179 #[serde(skip)]
181 pub preprocessor: LocateAnythingPreprocessorConfig,
182}
183
184fn default_mlp_connector_layers() -> usize {
185 2
186}
187
188impl Default for LocateAnythingPreprocessorConfig {
189 fn default() -> Self {
190 Self {
191 in_token_limit: default_in_token_limit(),
192 image_mean: default_image_mean(),
193 image_std: default_image_std(),
194 }
195 }
196}
197
198impl LocateAnythingConfig {
199 pub const HF_MODEL_ID: &'static str = "nvidia/LocateAnything-3B";
200
201 pub fn from_file(path: &Path) -> Result<Self> {
202 let data = std::fs::read_to_string(path)
203 .with_context(|| format!("read LocateAnything config {path:?}"))?;
204 let mut cfg: Self = serde_json::from_str(&data)
205 .with_context(|| format!("parse LocateAnything config {path:?}"))?;
206 let dir = path.parent().unwrap_or(Path::new("."));
207 cfg.preprocessor =
208 LocateAnythingPreprocessorConfig::from_file(&dir.join("preprocessor_config.json"))
209 .unwrap_or_default();
210 Ok(cfg)
211 }
212
213 pub fn from_model_dir(dir: &Path) -> Result<Self> {
214 Self::from_file(&dir.join("config.json"))
215 }
216
217 pub fn validate(&self) -> Result<()> {
218 ensure!(
219 self.model_type == "locateanything",
220 "model_type must be locateanything, got {}",
221 self.model_type
222 );
223 ensure!(
224 self.vision_config.model_type == "moonvit",
225 "vision_config.model_type must be moonvit, got {}",
226 self.vision_config.model_type
227 );
228 ensure!(self.text_config.num_hidden_layers > 0, "text layers");
229 ensure!(self.vision_config.num_hidden_layers > 0, "vision layers");
230 ensure!(self.mlp_connector_layers == 2, "mlp_connector_layers");
231 Ok(())
232 }
233
234 pub fn projector_input_dim(&self) -> usize {
236 self.vision_config.hidden_size * self.vision_config.merge_area()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn moonvit_merge_area() {
246 let cfg = MoonVitConfig {
247 model_type: "moonvit".into(),
248 hidden_size: 1152,
249 intermediate_size: 4304,
250 num_attention_heads: 16,
251 num_hidden_layers: 27,
252 patch_size: 14,
253 merge_kernel_size: [2, 2],
254 init_pos_emb_height: 64,
255 init_pos_emb_width: 64,
256 };
257 assert_eq!(cfg.merge_area(), 4);
258 }
259}