1use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct ModelConfig {
16 pub model_info: ModelInfo,
17 pub shapes: ShapeConfig,
18 pub components: HashMap<String, ComponentConfig>,
19 pub naming: NamingConfig,
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub ffn_execution: Option<String>,
23}
24
25#[derive(Debug, Clone, Deserialize, Serialize)]
27pub struct ModelInfo {
28 #[serde(default)]
29 pub model_id: Option<String>,
30 pub path: Option<String>,
31 pub model_type: String,
32 pub discovered_at: Option<String>,
33}
34
35#[derive(Debug, Clone, Deserialize, Serialize)]
37pub struct ShapeConfig {
38 pub batch_size: usize,
39 pub context_length: usize,
40 pub hidden_size: usize,
41 pub vocab_size: usize,
42}
43
44#[derive(Debug, Clone, Deserialize, Serialize)]
46pub struct ComponentConfig {
47 pub file_path: Option<String>,
48 pub inputs: HashMap<String, TensorConfig>,
49 pub outputs: HashMap<String, TensorConfig>,
50 pub functions: Vec<String>,
51 #[serde(default)]
53 pub input_order: Option<Vec<String>>,
54}
55
56#[derive(Debug, Clone, Deserialize, Serialize)]
58pub struct TensorConfig {
59 pub name: String,
60 pub shape: Vec<usize>,
61 pub data_type: String,
62}
63
64#[derive(Debug, Clone, Deserialize, Serialize)]
66pub struct NamingConfig {
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub embeddings_pattern: Option<String>,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub ffn_prefill_pattern: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub ffn_infer_pattern: Option<String>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub lm_head_pattern: Option<String>,
76}
77
78impl ModelConfig {
79 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
81 let path = path.as_ref();
82 let content = std::fs::read_to_string(path)
83 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
84
85 let config: ModelConfig = serde_json::from_str(&content)
86 .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
87
88 Ok(config)
89 }
90
91 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
93 let path = path.as_ref();
94 let content =
95 serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
96
97 std::fs::write(path, content)
98 .with_context(|| format!("Failed to write config file: {}", path.display()))?;
99
100 Ok(())
101 }
102
103 pub fn get_builtin_config(model_id: &str) -> Option<Self> {
105 crate::builtin_configs::get_builtin_config(model_id)
106 }
107
108 pub fn default_qwen() -> Self {
110 Self {
111 model_info: ModelInfo {
112 model_id: None,
113 path: None,
114 model_type: "qwen".to_string(),
115 discovered_at: None,
116 },
117 shapes: ShapeConfig {
118 batch_size: 1,
119 context_length: 512,
120 hidden_size: 1024,
121 vocab_size: 151936,
122 },
123 components: HashMap::new(),
124 naming: NamingConfig {
125 embeddings_pattern: None,
126 ffn_prefill_pattern: None,
127 ffn_infer_pattern: None,
128 lm_head_pattern: None,
129 },
130 ffn_execution: None,
131 }
132 }
133
134 pub fn get_tensor_shape(
136 &self,
137 component: &str,
138 tensor_name: &str,
139 is_input: bool,
140 ) -> Option<&Vec<usize>> {
141 let component_config = self.components.get(component)?;
142
143 let tensor_map = if is_input {
144 &component_config.inputs
145 } else {
146 &component_config.outputs
147 };
148
149 tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
150 }
151
152 pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
154 self.get_tensor_shape("embeddings", "input_ids", true)
155 }
156
157 pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
159 self.get_tensor_shape("embeddings", "hidden_states", false)
160 }
161
162 pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
164 self.get_tensor_shape("ffn_prefill", "hidden_states", true)
165 }
166
167 pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
169 self.get_tensor_shape("lm_head", "hidden_states", true)
170 }
171
172 pub fn has_multipart_logits(&self) -> bool {
174 if let Some(lm_head) = self.components.get("lm_head") {
175 let logits_outputs: Vec<_> = lm_head
177 .outputs
178 .keys()
179 .filter(|name| name.starts_with("logits") && name.len() > 6) .collect();
181 return logits_outputs.len() > 1;
182 }
183 false
184 }
185
186 pub fn logits_part_count(&self) -> usize {
188 if let Some(lm_head) = self.components.get("lm_head") {
189 let logits_outputs: Vec<_> = lm_head
190 .outputs
191 .keys()
192 .filter(|name| name.starts_with("logits"))
193 .collect();
194 if logits_outputs.is_empty() {
195 1 } else {
197 logits_outputs.len()
198 }
199 } else {
200 1
201 }
202 }
203
204 pub fn lm_head_primary_output_name(&self) -> Option<String> {
207 let lm_head = self.components.get("lm_head")?;
208
209 if lm_head.outputs.contains_key("logits1") {
211 return Some("logits1".to_string());
212 }
213
214 if lm_head.outputs.contains_key("logits") {
216 return Some("logits".to_string());
217 }
218
219 lm_head.outputs.keys().next().map(|k| k.to_string())
221 }
222
223 pub fn validate(&self) -> Result<()> {
225 let required_components = ["embeddings", "lm_head"];
227 for component in required_components {
228 if !self.components.contains_key(component) {
229 return Err(anyhow::anyhow!("Missing required component: {}", component));
230 }
231 }
232
233 if self.shapes.batch_size == 0 {
235 return Err(anyhow::anyhow!("batch_size must be greater than 0"));
236 }
237
238 if self.shapes.context_length == 0 {
239 return Err(anyhow::anyhow!("context_length must be greater than 0"));
240 }
241
242 if self.shapes.hidden_size == 0 {
243 return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
244 }
245
246 if self.shapes.vocab_size == 0 {
247 return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
248 }
249
250 for (component_name, component) in &self.components {
252 for (tensor_name, tensor) in &component.inputs {
253 if tensor.shape.is_empty() {
254 return Err(anyhow::anyhow!(
255 "Empty shape for {}.inputs.{}",
256 component_name,
257 tensor_name
258 ));
259 }
260 }
261 for (tensor_name, tensor) in &component.outputs {
262 if tensor.shape.is_empty() {
263 return Err(anyhow::anyhow!(
264 "Empty shape for {}.outputs.{}",
265 component_name,
266 tensor_name
267 ));
268 }
269 }
270 }
271
272 Ok(())
273 }
274
275 pub fn validate_internal_wiring(&self) -> Result<()> {
280 if let (Some(emb_out), Some(ffn_in_hidden)) = (
282 self.get_tensor_shape("embeddings", "hidden_states", false),
283 self.get_tensor_shape("ffn_prefill", "hidden_states", true),
284 ) {
285 if emb_out != ffn_in_hidden {
286 return Err(anyhow::anyhow!(
287 "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
288 emb_out, ffn_in_hidden
289 ));
290 }
291 }
292
293 if self.components.contains_key("ffn_infer") {
295 if let (Some(ffn_out), Some(lm_in)) = (
296 self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
297 self.get_tensor_shape("lm_head", "hidden_states", true),
298 ) {
299 if ffn_out != lm_in {
300 return Err(anyhow::anyhow!(
301 "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
302 ffn_out, lm_in
303 ));
304 }
305 }
306 } else {
307 if let (Some(ffn_out), Some(lm_in)) = (
309 self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
310 self.get_tensor_shape("lm_head", "hidden_states", true),
311 ) {
312 if ffn_out != lm_in {
313 return Err(anyhow::anyhow!(
314 "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
315 ffn_out, lm_in
316 ));
317 }
318 }
319 }
320
321 Ok(())
322 }
323
324 pub fn ffn_is_split(&self) -> bool {
326 if let Some(mode) = self.ffn_execution.as_deref() {
327 return mode == "split";
328 }
329 if let (Some(prefill), Some(infer)) = (
330 self.components.get("ffn_prefill"),
331 self.components.get("ffn_infer"),
332 ) {
333 match (&prefill.file_path, &infer.file_path) {
334 (Some(p), Some(i)) => p != i, _ => false,
336 }
337 } else {
338 false
339 }
340 }
341
342 pub fn prefill_is_single_token(&self) -> bool {
344 if let Some(prefill) = self.components.get("ffn_prefill") {
345 if let Some(hs) = prefill.inputs.get("hidden_states") {
346 return hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
347 }
348 }
349 false
350 }
351}
352
353impl Default for ModelConfig {
354 fn default() -> Self {
355 Self::default_qwen()
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use tempfile::NamedTempFile;
363
364 fn create_test_config() -> ModelConfig {
365 let mut components = HashMap::new();
366
367 let mut embeddings_inputs = HashMap::new();
369 embeddings_inputs.insert(
370 "input_ids".to_string(),
371 TensorConfig {
372 name: "input_ids".to_string(),
373 shape: vec![1, 64],
374 data_type: "INT32".to_string(),
375 },
376 );
377
378 let mut embeddings_outputs = HashMap::new();
379 embeddings_outputs.insert(
380 "hidden_states".to_string(),
381 TensorConfig {
382 name: "hidden_states".to_string(),
383 shape: vec![1, 64, 1024],
384 data_type: "FLOAT16".to_string(),
385 },
386 );
387
388 components.insert(
389 "embeddings".to_string(),
390 ComponentConfig {
391 file_path: None,
392 inputs: embeddings_inputs,
393 outputs: embeddings_outputs,
394 functions: vec![],
395 input_order: None,
396 },
397 );
398
399 let mut lm_head_inputs = HashMap::new();
401 lm_head_inputs.insert(
402 "hidden_states".to_string(),
403 TensorConfig {
404 name: "hidden_states".to_string(),
405 shape: vec![1, 1, 1024],
406 data_type: "FLOAT16".to_string(),
407 },
408 );
409
410 let mut lm_head_outputs = HashMap::new();
411 lm_head_outputs.insert(
412 "logits".to_string(),
413 TensorConfig {
414 name: "logits".to_string(),
415 shape: vec![1, 1, 151936],
416 data_type: "FLOAT32".to_string(),
417 },
418 );
419
420 components.insert(
421 "lm_head".to_string(),
422 ComponentConfig {
423 file_path: None,
424 inputs: lm_head_inputs,
425 outputs: lm_head_outputs,
426 functions: vec![],
427 input_order: None,
428 },
429 );
430
431 ModelConfig {
432 model_info: ModelInfo {
433 model_id: Some("test/model".to_string()),
434 path: Some("/test/path".to_string()),
435 model_type: "qwen".to_string(),
436 discovered_at: Some("2025-08-07T00:00:00".to_string()),
437 },
438 shapes: ShapeConfig {
439 batch_size: 1,
440 context_length: 512,
441 hidden_size: 1024,
442 vocab_size: 151936,
443 },
444 components,
445 naming: NamingConfig {
446 embeddings_pattern: None,
447 ffn_prefill_pattern: None,
448 ffn_infer_pattern: None,
449 lm_head_pattern: None,
450 },
451 ffn_execution: Some("unified".to_string()),
452 }
453 }
454
455 #[test]
456 fn test_config_serialization() {
457 let config = create_test_config();
458
459 let json = serde_json::to_string_pretty(&config).unwrap();
461 assert!(json.contains("test/model"));
462 assert!(json.contains("batch_size"));
463 assert!(json.contains("embeddings"));
464
465 let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
467 assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
468 assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
469 assert_eq!(parsed.components.len(), config.components.len());
470 }
471
472 #[test]
473 fn test_config_file_io() {
474 let config = create_test_config();
475 let temp_file = NamedTempFile::new().unwrap();
476
477 config.save_to_file(temp_file.path()).unwrap();
479
480 let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
482 assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
483 assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
484 }
485
486 #[test]
487 fn test_shape_accessors() {
488 let config = create_test_config();
489
490 let embeddings_input = config.embeddings_input_shape().unwrap();
492 assert_eq!(embeddings_input, &vec![1, 64]);
493
494 let embeddings_output = config.embeddings_output_shape().unwrap();
495 assert_eq!(embeddings_output, &vec![1, 64, 1024]);
496
497 let lm_head_input = config.lm_head_input_shape().unwrap();
498 assert_eq!(lm_head_input, &vec![1, 1, 1024]);
499 }
500
501 #[test]
502 fn test_multipart_logits_detection() {
503 let config = create_test_config();
504 assert!(!config.has_multipart_logits()); let mut config_multipart = config;
508 let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
509 lm_head.outputs.clear();
510 lm_head.outputs.insert(
511 "logits1".to_string(),
512 TensorConfig {
513 name: "logits1".to_string(),
514 shape: vec![1, 1, 9480],
515 data_type: "FLOAT32".to_string(),
516 },
517 );
518 lm_head.outputs.insert(
519 "logits2".to_string(),
520 TensorConfig {
521 name: "logits2".to_string(),
522 shape: vec![1, 1, 9479],
523 data_type: "FLOAT32".to_string(),
524 },
525 );
526
527 assert!(config_multipart.has_multipart_logits());
528 assert_eq!(config_multipart.logits_part_count(), 2);
529 }
530
531 #[test]
532 fn test_config_validation() {
533 let config = create_test_config();
534 assert!(config.validate().is_ok());
535
536 assert!(config.validate_internal_wiring().is_ok());
538
539 let mut invalid_config = config.clone();
541 invalid_config.components.remove("embeddings");
542 assert!(invalid_config.validate().is_err());
543
544 let mut invalid_shapes = config;
546 invalid_shapes.shapes.batch_size = 0;
547 assert!(invalid_shapes.validate().is_err());
548 }
549
550 #[test]
551 fn test_default_config() {
552 let config = ModelConfig::default();
553 assert_eq!(config.model_info.model_type, "qwen");
554 assert_eq!(config.shapes.batch_size, 1);
555 assert_eq!(config.shapes.context_length, 512);
556 assert_eq!(config.shapes.hidden_size, 1024);
557 assert_eq!(config.shapes.vocab_size, 151936);
558 }
559}