1use anyhow::{Context, Result};
9use candle_core::{Device, Error as CandleError, Tensor};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use tracing::{debug, trace};
14
15#[derive(Debug, Clone, Deserialize, Serialize)]
17pub struct ModelConfig {
18 pub model_info: ModelInfo,
19 pub shapes: ShapeConfig,
20 pub components: HashMap<String, ComponentConfig>,
21 pub naming: NamingConfig,
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub ffn_execution: Option<String>,
25}
26
27#[derive(Debug, Clone, Deserialize, Serialize)]
29pub struct ModelInfo {
30 #[serde(default)]
31 pub model_id: Option<String>,
32 pub path: Option<String>,
33 pub model_type: String,
34 pub discovered_at: Option<String>,
35}
36
37#[derive(Debug, Clone, Deserialize, Serialize)]
39pub struct ShapeConfig {
40 pub batch_size: usize,
41 pub context_length: usize,
42 pub hidden_size: usize,
43 pub vocab_size: usize,
44}
45
46#[derive(Debug, Clone, Deserialize, Serialize)]
48pub struct ComponentConfig {
49 pub file_path: Option<String>,
50 pub inputs: HashMap<String, TensorConfig>,
51 pub outputs: HashMap<String, TensorConfig>,
52 pub functions: Vec<String>,
53 #[serde(default)]
55 pub input_order: Option<Vec<String>>,
56}
57
58#[derive(Debug, Clone, Deserialize, Serialize)]
60pub struct TensorConfig {
61 pub name: String,
62 pub shape: Vec<usize>,
63 pub data_type: String,
64}
65
66#[derive(Debug, Clone, Deserialize, Serialize)]
68pub struct NamingConfig {
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub embeddings_pattern: Option<String>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub ffn_prefill_pattern: Option<String>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub ffn_infer_pattern: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub lm_head_pattern: Option<String>,
78}
79
80impl ModelConfig {
81 pub fn default_qwen() -> Self {
83 Self {
84 model_info: ModelInfo {
85 model_id: Some("default/qwen".to_string()),
86 path: None,
87 model_type: "qwen".to_string(),
88 discovered_at: None,
89 },
90 shapes: ShapeConfig {
91 batch_size: 1,
92 context_length: 512,
93 hidden_size: 1024,
94 vocab_size: 151_936,
95 },
96 components: HashMap::new(),
97 naming: NamingConfig {
98 embeddings_pattern: None,
99 ffn_prefill_pattern: None,
100 ffn_infer_pattern: None,
101 lm_head_pattern: None,
102 },
103 ffn_execution: None,
104 }
105 }
106 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
108 let path = path.as_ref();
109 let content = std::fs::read_to_string(path)
110 .with_context(|| format!("Failed to read config file: {}", path.display()))?;
111
112 let config: ModelConfig = serde_json::from_str(&content)
113 .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
114
115 Ok(config)
116 }
117
118 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
120 let path = path.as_ref();
121 let content =
122 serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
123
124 std::fs::write(path, content)
125 .with_context(|| format!("Failed to write config file: {}", path.display()))?;
126
127 Ok(())
128 }
129
130 pub fn get_tensor_shape(
132 &self,
133 component: &str,
134 tensor_name: &str,
135 is_input: bool,
136 ) -> Option<&Vec<usize>> {
137 let component_config = self.components.get(component)?;
138
139 let tensor_map = if is_input {
140 &component_config.inputs
141 } else {
142 &component_config.outputs
143 };
144
145 tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
146 }
147
148 pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
150 self.get_tensor_shape("embeddings", "input_ids", true)
151 }
152
153 pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
155 self.get_tensor_shape("embeddings", "hidden_states", false)
156 }
157
158 pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
160 self.get_tensor_shape("ffn_prefill", "hidden_states", true)
161 }
162
163 pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
165 self.get_tensor_shape("lm_head", "hidden_states", true)
166 }
167
168 pub fn has_multipart_logits(&self) -> bool {
170 if let Some(lm_head) = self.components.get("lm_head") {
171 let logits_outputs: Vec<_> = lm_head
173 .outputs
174 .keys()
175 .filter(|name| name.starts_with("logits") && name.len() > 6) .collect();
177 return logits_outputs.len() > 1;
178 }
179 false
180 }
181
182 pub fn logits_part_count(&self) -> usize {
184 if let Some(lm_head) = self.components.get("lm_head") {
185 let logits_outputs: Vec<_> = lm_head
186 .outputs
187 .keys()
188 .filter(|name| name.starts_with("logits"))
189 .collect();
190 if logits_outputs.is_empty() {
191 1 } else {
193 logits_outputs.len()
194 }
195 } else {
196 1
197 }
198 }
199
200 pub fn lm_head_primary_output_name(&self) -> Option<String> {
203 let lm_head = self.components.get("lm_head")?;
204
205 if lm_head.outputs.contains_key("logits1") {
207 return Some("logits1".to_string());
208 }
209
210 if lm_head.outputs.contains_key("logits") {
212 return Some("logits".to_string());
213 }
214
215 lm_head.outputs.keys().next().map(|k| k.to_string())
217 }
218
219 pub fn validate(&self) -> Result<()> {
221 let required_components = ["embeddings", "lm_head"];
223 for component in required_components {
224 if !self.components.contains_key(component) {
225 return Err(anyhow::anyhow!("Missing required component: {}", component));
226 }
227 }
228
229 if self.shapes.batch_size == 0 {
231 return Err(anyhow::anyhow!("batch_size must be greater than 0"));
232 }
233
234 if self.shapes.context_length == 0 {
235 return Err(anyhow::anyhow!("context_length must be greater than 0"));
236 }
237
238 if self.shapes.hidden_size == 0 {
239 return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
240 }
241
242 if self.shapes.vocab_size == 0 {
243 return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
244 }
245
246 for (component_name, component) in &self.components {
248 for (tensor_name, tensor) in &component.inputs {
249 if tensor.shape.is_empty() {
250 return Err(anyhow::anyhow!(
251 "Empty shape for {}.inputs.{}",
252 component_name,
253 tensor_name
254 ));
255 }
256 }
257 for (tensor_name, tensor) in &component.outputs {
258 if tensor.shape.is_empty() {
259 return Err(anyhow::anyhow!(
260 "Empty shape for {}.outputs.{}",
261 component_name,
262 tensor_name
263 ));
264 }
265 }
266 }
267
268 Ok(())
269 }
270
271 pub fn validate_internal_wiring(&self) -> Result<()> {
276 if let (Some(emb_out), Some(ffn_in_hidden)) = (
278 self.get_tensor_shape("embeddings", "hidden_states", false),
279 self.get_tensor_shape("ffn_prefill", "hidden_states", true),
280 ) {
281 if emb_out != ffn_in_hidden {
282 return Err(anyhow::anyhow!(
283 "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
284 emb_out, ffn_in_hidden
285 ));
286 }
287 }
288
289 if self.components.contains_key("ffn_infer") {
291 if let (Some(ffn_out), Some(lm_in)) = (
292 self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
293 self.get_tensor_shape("lm_head", "hidden_states", true),
294 ) {
295 if ffn_out != lm_in {
296 return Err(anyhow::anyhow!(
297 "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
298 ffn_out, lm_in
299 ));
300 }
301 }
302 } else {
303 if let (Some(ffn_out), Some(lm_in)) = (
305 self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
306 self.get_tensor_shape("lm_head", "hidden_states", true),
307 ) {
308 if ffn_out != lm_in {
309 return Err(anyhow::anyhow!(
310 "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
311 ffn_out, lm_in
312 ));
313 }
314 }
315 }
316
317 Ok(())
318 }
319
320 pub fn ffn_is_split(&self) -> bool {
322 if let Some(mode) = self.ffn_execution.as_deref() {
323 return mode == "split";
324 }
325 if let (Some(prefill), Some(infer)) = (
326 self.components.get("ffn_prefill"),
327 self.components.get("ffn_infer"),
328 ) {
329 match (&prefill.file_path, &infer.file_path) {
330 (Some(p), Some(i)) => p != i, _ => false,
332 }
333 } else {
334 false
335 }
336 }
337
338 pub fn prefill_is_single_token(&self) -> bool {
340 if let Some(prefill) = self.components.get("ffn_prefill") {
341 if let Some(hs) = prefill.inputs.get("hidden_states") {
342 let is_single = hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
343 debug!(
344 "🔍 prefill_is_single_token: shape={:?}, len={}, dim[1]={:?}, result={}",
345 hs.shape,
346 hs.shape.len(),
347 hs.shape.get(1),
348 is_single
349 );
350 return is_single;
351 }
352 }
353 debug!(
354 "🔍 prefill_is_single_token: no ffn_prefill or hidden_states found, returning false"
355 );
356 false
357 }
358
359 pub fn expects_full_sequence_prefill(&self) -> bool {
362 if let Some(prefill) = self.components.get("ffn_prefill") {
363 if let Some(hs) = prefill.inputs.get("hidden_states") {
364 let expects_full =
366 hs.shape.len() == 3 && hs.shape.get(1).is_some_and(|&seq_len| seq_len > 1);
367 trace!(
368 "🔍 expects_full_sequence_prefill: shape={:?}, len={}, dim[1]={:?}, result={}",
369 hs.shape,
370 hs.shape.len(),
371 hs.shape.get(1),
372 expects_full
373 );
374 return expects_full;
375 }
376 }
377 trace!("🔍 expects_full_sequence_prefill: no ffn_prefill or hidden_states found, returning false");
378 false
379 }
380
381 pub fn create_embeddings_input_tensor(
385 &self,
386 tokens: &[i64],
387 device: &Device,
388 ) -> Result<Tensor, CandleError> {
389 let expected_shape = self
390 .embeddings_input_shape()
391 .ok_or_else(|| CandleError::Msg("No embeddings input shape found".to_string()))?;
392 let expected_len = expected_shape[1]; let mut padded_tokens = tokens.to_vec();
396 padded_tokens.resize(expected_len, 0); Tensor::from_vec(
399 padded_tokens,
400 (expected_shape[0], expected_shape[1]),
401 device,
402 )
403 }
404
405 pub fn create_ffn_position_ids_tensor(
407 &self,
408 positions: &[i64],
409 device: &Device,
410 ) -> Result<Tensor, CandleError> {
411 let expected_shape = self
412 .get_tensor_shape("ffn_prefill", "position_ids", true)
413 .ok_or_else(|| {
414 CandleError::Msg("No FFN prefill position_ids shape found".to_string())
415 })?;
416
417 let mut expected_len = expected_shape[0];
420 if expected_len == 1 {
421 if let Some(hs_shape) = self.get_tensor_shape("ffn_prefill", "hidden_states", true) {
423 if hs_shape.len() == 3 && hs_shape[1] > 1 {
424 expected_len = hs_shape[1];
425 }
426 }
427 if expected_len == 1 {
429 if let Some(emb) = self.embeddings_input_shape() {
430 if emb.len() == 2 && emb[1] > 1 {
431 expected_len = emb[1];
432 }
433 }
434 }
435 }
437
438 let mut position_ids = Vec::with_capacity(expected_len);
440 for i in 0..expected_len {
441 if i < positions.len() {
442 position_ids.push(positions[i]);
443 } else {
444 position_ids.push(0); }
446 }
447
448 Tensor::from_vec(position_ids, (expected_len,), device)
449 }
450
451 pub fn create_ffn_causal_mask_tensor(
453 &self,
454 _batch_size: usize,
455 _context_length: usize,
456 device: &Device,
457 ) -> Result<Tensor, CandleError> {
458 let expected_shape_vec =
460 if let Some(shape) = self.get_tensor_shape("ffn_prefill", "causal_mask", true) {
461 shape.clone()
462 } else {
463 let mut seq_len = 0usize;
465 if let Some(hs) = self.get_tensor_shape("ffn_prefill", "hidden_states", true) {
466 if hs.len() == 3 && hs[1] > 0 {
467 seq_len = hs[1];
468 }
469 }
470 if seq_len == 0 {
471 if let Some(emb) = self.embeddings_input_shape() {
472 if emb.len() == 2 && emb[1] > 0 {
473 seq_len = emb[1];
474 }
475 }
476 }
477 if seq_len == 0 {
478 seq_len = self.shapes.context_length;
479 }
480 vec![1, 1, seq_len, seq_len]
481 };
482
483 let mask_rows = expected_shape_vec[2];
484 let mask_context_length = expected_shape_vec[3];
485
486 let mut mask_data = vec![f32::NEG_INFINITY; mask_rows * mask_context_length];
488 for i in 0..mask_rows {
489 for j in 0..=i.min(mask_context_length - 1) {
490 mask_data[i * mask_context_length + j] = 0.0;
491 }
492 }
493
494 Tensor::from_vec(
495 mask_data,
496 (
497 expected_shape_vec[0],
498 expected_shape_vec[1],
499 expected_shape_vec[2],
500 expected_shape_vec[3],
501 ),
502 device,
503 )
504 }
505
506 pub fn create_single_token_hidden_states(
508 &self,
509 _tokens: &[i64],
510 device: &Device,
511 ) -> Result<Tensor, CandleError> {
512 let expected_shape = self
513 .get_tensor_shape("lm_head", "hidden_states", true)
514 .ok_or_else(|| CandleError::Msg("No LM head hidden_states shape found".to_string()))?;
515
516 let tensor_data = vec![0.0f32; expected_shape.iter().product()];
518 let shape = (expected_shape[0], expected_shape[1], expected_shape[2]);
519
520 Tensor::from_vec(tensor_data, shape, device)
521 }
522
523 pub fn create_infer_position_ids_tensor(
525 &self,
526 position: i64,
527 device: &Device,
528 ) -> Result<Tensor, CandleError> {
529 if let Some(infer_shape) = self.get_tensor_shape("ffn_infer", "position_ids", true) {
531 if infer_shape.len() == 1 {
533 Tensor::from_vec(vec![position], (infer_shape[0],), device)
534 } else {
535 let size = infer_shape.iter().product();
536 let mut data = vec![0i64; size];
537 data[0] = position;
538 Tensor::from_vec(data, infer_shape.as_slice(), device)
539 }
540 } else {
541 Tensor::from_vec(vec![position], (1,), device)
543 }
544 }
545
546 pub fn create_current_pos_tensor(
548 &self,
549 position: i64,
550 device: &Device,
551 ) -> Result<Tensor, CandleError> {
552 Tensor::from_vec(vec![position], (1,), device)
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use tempfile::NamedTempFile;
561
562 fn create_test_config() -> ModelConfig {
563 let mut components = HashMap::new();
564
565 let mut embeddings_inputs = HashMap::new();
567 embeddings_inputs.insert(
568 "input_ids".to_string(),
569 TensorConfig {
570 name: "input_ids".to_string(),
571 shape: vec![1, 64],
572 data_type: "INT32".to_string(),
573 },
574 );
575
576 let mut embeddings_outputs = HashMap::new();
577 embeddings_outputs.insert(
578 "hidden_states".to_string(),
579 TensorConfig {
580 name: "hidden_states".to_string(),
581 shape: vec![1, 64, 1024],
582 data_type: "FLOAT16".to_string(),
583 },
584 );
585
586 components.insert(
587 "embeddings".to_string(),
588 ComponentConfig {
589 file_path: None,
590 inputs: embeddings_inputs,
591 outputs: embeddings_outputs,
592 functions: vec![],
593 input_order: None,
594 },
595 );
596
597 let mut lm_head_inputs = HashMap::new();
599 lm_head_inputs.insert(
600 "hidden_states".to_string(),
601 TensorConfig {
602 name: "hidden_states".to_string(),
603 shape: vec![1, 1, 1024],
604 data_type: "FLOAT16".to_string(),
605 },
606 );
607
608 let mut lm_head_outputs = HashMap::new();
609 lm_head_outputs.insert(
610 "logits".to_string(),
611 TensorConfig {
612 name: "logits".to_string(),
613 shape: vec![1, 1, 151936],
614 data_type: "FLOAT32".to_string(),
615 },
616 );
617
618 components.insert(
619 "lm_head".to_string(),
620 ComponentConfig {
621 file_path: None,
622 inputs: lm_head_inputs,
623 outputs: lm_head_outputs,
624 functions: vec![],
625 input_order: None,
626 },
627 );
628
629 ModelConfig {
630 model_info: ModelInfo {
631 model_id: Some("test/model".to_string()),
632 path: Some("/test/path".to_string()),
633 model_type: "qwen".to_string(),
634 discovered_at: Some("2025-08-07T00:00:00".to_string()),
635 },
636 shapes: ShapeConfig {
637 batch_size: 1,
638 context_length: 512,
639 hidden_size: 1024,
640 vocab_size: 151936,
641 },
642 components,
643 naming: NamingConfig {
644 embeddings_pattern: None,
645 ffn_prefill_pattern: None,
646 ffn_infer_pattern: None,
647 lm_head_pattern: None,
648 },
649 ffn_execution: Some("unified".to_string()),
650 }
651 }
652
653 #[test]
654 fn test_config_serialization() {
655 let config = create_test_config();
656
657 let json = serde_json::to_string_pretty(&config).unwrap();
659 assert!(json.contains("test/model"));
660 assert!(json.contains("batch_size"));
661 assert!(json.contains("embeddings"));
662
663 let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
665 assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
666 assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
667 assert_eq!(parsed.components.len(), config.components.len());
668 }
669
670 #[test]
671 fn test_config_file_io() {
672 let config = create_test_config();
673 let temp_file = NamedTempFile::new().unwrap();
674
675 config.save_to_file(temp_file.path()).unwrap();
677
678 let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
680 assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
681 assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
682 }
683
684 #[test]
685 fn test_shape_accessors() {
686 let config = create_test_config();
687
688 let embeddings_input = config.embeddings_input_shape().unwrap();
690 assert_eq!(embeddings_input, &vec![1, 64]);
691
692 let embeddings_output = config.embeddings_output_shape().unwrap();
693 assert_eq!(embeddings_output, &vec![1, 64, 1024]);
694
695 let lm_head_input = config.lm_head_input_shape().unwrap();
696 assert_eq!(lm_head_input, &vec![1, 1, 1024]);
697 }
698
699 #[test]
700 fn test_multipart_logits_detection() {
701 let config = create_test_config();
702 assert!(!config.has_multipart_logits()); let mut config_multipart = config;
706 let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
707 lm_head.outputs.clear();
708 lm_head.outputs.insert(
709 "logits1".to_string(),
710 TensorConfig {
711 name: "logits1".to_string(),
712 shape: vec![1, 1, 9480],
713 data_type: "FLOAT32".to_string(),
714 },
715 );
716 lm_head.outputs.insert(
717 "logits2".to_string(),
718 TensorConfig {
719 name: "logits2".to_string(),
720 shape: vec![1, 1, 9479],
721 data_type: "FLOAT32".to_string(),
722 },
723 );
724
725 assert!(config_multipart.has_multipart_logits());
726 assert_eq!(config_multipart.logits_part_count(), 2);
727 }
728
729 #[test]
730 fn test_config_validation() {
731 let config = create_test_config();
732 assert!(config.validate().is_ok());
733
734 assert!(config.validate_internal_wiring().is_ok());
736
737 let mut invalid_config = config.clone();
739 invalid_config.components.remove("embeddings");
740 assert!(invalid_config.validate().is_err());
741
742 let mut invalid_shapes = config;
744 invalid_shapes.shapes.batch_size = 0;
745 assert!(invalid_shapes.validate().is_err());
746 }
747}