1use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use tracing::debug;
13
14#[derive(Debug, Clone, Deserialize, Serialize)]
16pub struct ModelConfig {
17 pub model_info: ModelInfo,
18 pub shapes: ShapeConfig,
19 pub components: HashMap<String, ComponentConfig>,
20 pub naming: NamingConfig,
21 #[serde(skip_serializing_if = "Option::is_none")]
23 pub ffn_execution: Option<String>,
24}
25
26#[derive(Debug, Clone, Deserialize, Serialize)]
28pub struct ModelInfo {
29 #[serde(default)]
30 pub model_id: Option<String>,
31 pub path: Option<String>,
32 pub model_type: String,
33 pub discovered_at: Option<String>,
34}
35
36#[derive(Debug, Clone, Deserialize, Serialize)]
38pub struct ShapeConfig {
39 pub batch_size: usize,
40 pub context_length: usize,
41 pub hidden_size: usize,
42 pub vocab_size: usize,
43}
44
45#[derive(Debug, Clone, Deserialize, Serialize)]
47pub struct ComponentConfig {
48 pub file_path: Option<String>,
49 pub inputs: HashMap<String, TensorConfig>,
50 pub outputs: HashMap<String, TensorConfig>,
51 pub functions: Vec<String>,
52 #[serde(default)]
54 pub input_order: Option<Vec<String>>,
55}
56
57#[derive(Debug, Clone, Deserialize, Serialize)]
59pub struct TensorConfig {
60 pub name: String,
61 pub shape: Vec<usize>,
62 pub data_type: String,
63}
64
65#[derive(Debug, Clone, Deserialize, Serialize)]
67pub struct NamingConfig {
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub embeddings_pattern: Option<String>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub ffn_prefill_pattern: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub ffn_infer_pattern: Option<String>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 pub lm_head_pattern: Option<String>,
77}
78
79impl ModelConfig {
80 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
82 let path = path.as_ref();
83 let content = std::fs::read_to_string(path)
84 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
85
86 let config: ModelConfig = serde_json::from_str(&content)
87 .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
88
89 Ok(config)
90 }
91
92 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
94 let path = path.as_ref();
95 let content =
96 serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
97
98 std::fs::write(path, content)
99 .with_context(|| format!("Failed to write config file: {}", path.display()))?;
100
101 Ok(())
102 }
103
104 pub fn get_builtin_config(model_id: &str) -> Option<Self> {
106 crate::builtin_configs::get_builtin_config(model_id)
107 }
108
109 pub fn default_qwen() -> Self {
111 Self {
112 model_info: ModelInfo {
113 model_id: None,
114 path: None,
115 model_type: "qwen".to_string(),
116 discovered_at: None,
117 },
118 shapes: ShapeConfig {
119 batch_size: 1,
120 context_length: 512,
121 hidden_size: 1024,
122 vocab_size: 151936,
123 },
124 components: HashMap::new(),
125 naming: NamingConfig {
126 embeddings_pattern: None,
127 ffn_prefill_pattern: None,
128 ffn_infer_pattern: None,
129 lm_head_pattern: None,
130 },
131 ffn_execution: None,
132 }
133 }
134
135 pub fn get_tensor_shape(
137 &self,
138 component: &str,
139 tensor_name: &str,
140 is_input: bool,
141 ) -> Option<&Vec<usize>> {
142 let component_config = self.components.get(component)?;
143
144 let tensor_map = if is_input {
145 &component_config.inputs
146 } else {
147 &component_config.outputs
148 };
149
150 tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
151 }
152
153 pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
155 self.get_tensor_shape("embeddings", "input_ids", true)
156 }
157
158 pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
160 self.get_tensor_shape("embeddings", "hidden_states", false)
161 }
162
163 pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
165 self.get_tensor_shape("ffn_prefill", "hidden_states", true)
166 }
167
168 pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
170 self.get_tensor_shape("lm_head", "hidden_states", true)
171 }
172
173 pub fn has_multipart_logits(&self) -> bool {
175 if let Some(lm_head) = self.components.get("lm_head") {
176 let logits_outputs: Vec<_> = lm_head
178 .outputs
179 .keys()
180 .filter(|name| name.starts_with("logits") && name.len() > 6) .collect();
182 return logits_outputs.len() > 1;
183 }
184 false
185 }
186
187 pub fn logits_part_count(&self) -> usize {
189 if let Some(lm_head) = self.components.get("lm_head") {
190 let logits_outputs: Vec<_> = lm_head
191 .outputs
192 .keys()
193 .filter(|name| name.starts_with("logits"))
194 .collect();
195 if logits_outputs.is_empty() {
196 1 } else {
198 logits_outputs.len()
199 }
200 } else {
201 1
202 }
203 }
204
205 pub fn lm_head_primary_output_name(&self) -> Option<String> {
208 let lm_head = self.components.get("lm_head")?;
209
210 if lm_head.outputs.contains_key("logits1") {
212 return Some("logits1".to_string());
213 }
214
215 if lm_head.outputs.contains_key("logits") {
217 return Some("logits".to_string());
218 }
219
220 lm_head.outputs.keys().next().map(|k| k.to_string())
222 }
223
224 pub fn validate(&self) -> Result<()> {
226 let required_components = ["embeddings", "lm_head"];
228 for component in required_components {
229 if !self.components.contains_key(component) {
230 return Err(anyhow::anyhow!("Missing required component: {}", component));
231 }
232 }
233
234 if self.shapes.batch_size == 0 {
236 return Err(anyhow::anyhow!("batch_size must be greater than 0"));
237 }
238
239 if self.shapes.context_length == 0 {
240 return Err(anyhow::anyhow!("context_length must be greater than 0"));
241 }
242
243 if self.shapes.hidden_size == 0 {
244 return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
245 }
246
247 if self.shapes.vocab_size == 0 {
248 return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
249 }
250
251 for (component_name, component) in &self.components {
253 for (tensor_name, tensor) in &component.inputs {
254 if tensor.shape.is_empty() {
255 return Err(anyhow::anyhow!(
256 "Empty shape for {}.inputs.{}",
257 component_name,
258 tensor_name
259 ));
260 }
261 }
262 for (tensor_name, tensor) in &component.outputs {
263 if tensor.shape.is_empty() {
264 return Err(anyhow::anyhow!(
265 "Empty shape for {}.outputs.{}",
266 component_name,
267 tensor_name
268 ));
269 }
270 }
271 }
272
273 Ok(())
274 }
275
276 pub fn validate_internal_wiring(&self) -> Result<()> {
281 if let (Some(emb_out), Some(ffn_in_hidden)) = (
283 self.get_tensor_shape("embeddings", "hidden_states", false),
284 self.get_tensor_shape("ffn_prefill", "hidden_states", true),
285 ) {
286 if emb_out != ffn_in_hidden {
287 return Err(anyhow::anyhow!(
288 "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
289 emb_out, ffn_in_hidden
290 ));
291 }
292 }
293
294 if self.components.contains_key("ffn_infer") {
296 if let (Some(ffn_out), Some(lm_in)) = (
297 self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
298 self.get_tensor_shape("lm_head", "hidden_states", true),
299 ) {
300 if ffn_out != lm_in {
301 return Err(anyhow::anyhow!(
302 "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
303 ffn_out, lm_in
304 ));
305 }
306 }
307 } else {
308 if let (Some(ffn_out), Some(lm_in)) = (
310 self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
311 self.get_tensor_shape("lm_head", "hidden_states", true),
312 ) {
313 if ffn_out != lm_in {
314 return Err(anyhow::anyhow!(
315 "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
316 ffn_out, lm_in
317 ));
318 }
319 }
320 }
321
322 Ok(())
323 }
324
325 pub fn ffn_is_split(&self) -> bool {
327 if let Some(mode) = self.ffn_execution.as_deref() {
328 return mode == "split";
329 }
330 if let (Some(prefill), Some(infer)) = (
331 self.components.get("ffn_prefill"),
332 self.components.get("ffn_infer"),
333 ) {
334 match (&prefill.file_path, &infer.file_path) {
335 (Some(p), Some(i)) => p != i, _ => false,
337 }
338 } else {
339 false
340 }
341 }
342
343 pub fn prefill_is_single_token(&self) -> bool {
345 if let Some(prefill) = self.components.get("ffn_prefill") {
346 if let Some(hs) = prefill.inputs.get("hidden_states") {
347 let is_single = hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
348 debug!(
349 "🔍 prefill_is_single_token: shape={:?}, len={}, dim[1]={:?}, result={}",
350 hs.shape,
351 hs.shape.len(),
352 hs.shape.get(1),
353 is_single
354 );
355 return is_single;
356 }
357 }
358 debug!(
359 "🔍 prefill_is_single_token: no ffn_prefill or hidden_states found, returning false"
360 );
361 false
362 }
363
364 pub fn expects_full_sequence_prefill(&self) -> bool {
367 if let Some(prefill) = self.components.get("ffn_prefill") {
368 if let Some(hs) = prefill.inputs.get("hidden_states") {
369 let expects_full =
371 hs.shape.len() == 3 && hs.shape.get(1).is_some_and(|&seq_len| seq_len > 1);
372 debug!(
373 "🔍 expects_full_sequence_prefill: shape={:?}, len={}, dim[1]={:?}, result={}",
374 hs.shape,
375 hs.shape.len(),
376 hs.shape.get(1),
377 expects_full
378 );
379 return expects_full;
380 }
381 }
382 debug!("🔍 expects_full_sequence_prefill: no ffn_prefill or hidden_states found, returning false");
383 false
384 }
385}
386
387impl Default for ModelConfig {
388 fn default() -> Self {
389 Self::default_qwen()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use tempfile::NamedTempFile;
397
398 fn create_test_config() -> ModelConfig {
399 let mut components = HashMap::new();
400
401 let mut embeddings_inputs = HashMap::new();
403 embeddings_inputs.insert(
404 "input_ids".to_string(),
405 TensorConfig {
406 name: "input_ids".to_string(),
407 shape: vec![1, 64],
408 data_type: "INT32".to_string(),
409 },
410 );
411
412 let mut embeddings_outputs = HashMap::new();
413 embeddings_outputs.insert(
414 "hidden_states".to_string(),
415 TensorConfig {
416 name: "hidden_states".to_string(),
417 shape: vec![1, 64, 1024],
418 data_type: "FLOAT16".to_string(),
419 },
420 );
421
422 components.insert(
423 "embeddings".to_string(),
424 ComponentConfig {
425 file_path: None,
426 inputs: embeddings_inputs,
427 outputs: embeddings_outputs,
428 functions: vec![],
429 input_order: None,
430 },
431 );
432
433 let mut lm_head_inputs = HashMap::new();
435 lm_head_inputs.insert(
436 "hidden_states".to_string(),
437 TensorConfig {
438 name: "hidden_states".to_string(),
439 shape: vec![1, 1, 1024],
440 data_type: "FLOAT16".to_string(),
441 },
442 );
443
444 let mut lm_head_outputs = HashMap::new();
445 lm_head_outputs.insert(
446 "logits".to_string(),
447 TensorConfig {
448 name: "logits".to_string(),
449 shape: vec![1, 1, 151936],
450 data_type: "FLOAT32".to_string(),
451 },
452 );
453
454 components.insert(
455 "lm_head".to_string(),
456 ComponentConfig {
457 file_path: None,
458 inputs: lm_head_inputs,
459 outputs: lm_head_outputs,
460 functions: vec![],
461 input_order: None,
462 },
463 );
464
465 ModelConfig {
466 model_info: ModelInfo {
467 model_id: Some("test/model".to_string()),
468 path: Some("/test/path".to_string()),
469 model_type: "qwen".to_string(),
470 discovered_at: Some("2025-08-07T00:00:00".to_string()),
471 },
472 shapes: ShapeConfig {
473 batch_size: 1,
474 context_length: 512,
475 hidden_size: 1024,
476 vocab_size: 151936,
477 },
478 components,
479 naming: NamingConfig {
480 embeddings_pattern: None,
481 ffn_prefill_pattern: None,
482 ffn_infer_pattern: None,
483 lm_head_pattern: None,
484 },
485 ffn_execution: Some("unified".to_string()),
486 }
487 }
488
489 #[test]
490 fn test_config_serialization() {
491 let config = create_test_config();
492
493 let json = serde_json::to_string_pretty(&config).unwrap();
495 assert!(json.contains("test/model"));
496 assert!(json.contains("batch_size"));
497 assert!(json.contains("embeddings"));
498
499 let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
501 assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
502 assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
503 assert_eq!(parsed.components.len(), config.components.len());
504 }
505
506 #[test]
507 fn test_config_file_io() {
508 let config = create_test_config();
509 let temp_file = NamedTempFile::new().unwrap();
510
511 config.save_to_file(temp_file.path()).unwrap();
513
514 let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
516 assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
517 assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
518 }
519
520 #[test]
521 fn test_shape_accessors() {
522 let config = create_test_config();
523
524 let embeddings_input = config.embeddings_input_shape().unwrap();
526 assert_eq!(embeddings_input, &vec![1, 64]);
527
528 let embeddings_output = config.embeddings_output_shape().unwrap();
529 assert_eq!(embeddings_output, &vec![1, 64, 1024]);
530
531 let lm_head_input = config.lm_head_input_shape().unwrap();
532 assert_eq!(lm_head_input, &vec![1, 1, 1024]);
533 }
534
535 #[test]
536 fn test_multipart_logits_detection() {
537 let config = create_test_config();
538 assert!(!config.has_multipart_logits()); let mut config_multipart = config;
542 let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
543 lm_head.outputs.clear();
544 lm_head.outputs.insert(
545 "logits1".to_string(),
546 TensorConfig {
547 name: "logits1".to_string(),
548 shape: vec![1, 1, 9480],
549 data_type: "FLOAT32".to_string(),
550 },
551 );
552 lm_head.outputs.insert(
553 "logits2".to_string(),
554 TensorConfig {
555 name: "logits2".to_string(),
556 shape: vec![1, 1, 9479],
557 data_type: "FLOAT32".to_string(),
558 },
559 );
560
561 assert!(config_multipart.has_multipart_logits());
562 assert_eq!(config_multipart.logits_part_count(), 2);
563 }
564
565 #[test]
566 fn test_config_validation() {
567 let config = create_test_config();
568 assert!(config.validate().is_ok());
569
570 assert!(config.validate_internal_wiring().is_ok());
572
573 let mut invalid_config = config.clone();
575 invalid_config.components.remove("embeddings");
576 assert!(invalid_config.validate().is_err());
577
578 let mut invalid_shapes = config;
580 invalid_shapes.shapes.batch_size = 0;
581 assert!(invalid_shapes.validate().is_err());
582 }
583
584 #[test]
585 fn test_default_config() {
586 let config = ModelConfig::default();
587 assert_eq!(config.model_info.model_type, "qwen");
588 assert_eq!(config.shapes.batch_size, 1);
589 assert_eq!(config.shapes.context_length, 512);
590 assert_eq!(config.shapes.hidden_size, 1024);
591 assert_eq!(config.shapes.vocab_size, 151936);
592 }
593}