candle_coreml/
model_config.rs

1//! Model configuration system for candle-coreml
2//!
3//! This module provides a configuration system that replaces hardcoded model shapes
4//! with discoverable, flexible configurations. It supports loading configurations from
5//! JSON files generated by the shape discovery tool, as well as built-in configurations
6//! for known models.
7
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use tracing::debug;
13
14/// Complete model configuration including shapes, components, and naming patterns
15#[derive(Debug, Clone, Deserialize, Serialize)]
16pub struct ModelConfig {
17    pub model_info: ModelInfo,
18    pub shapes: ShapeConfig,
19    pub components: HashMap<String, ComponentConfig>,
20    pub naming: NamingConfig,
21    /// Execution mode for FFN: "unified" (single component/function) or "split" (separate prefill/infer components)
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub ffn_execution: Option<String>,
24}
25
26/// Model metadata and identification
27#[derive(Debug, Clone, Deserialize, Serialize)]
28pub struct ModelInfo {
29    #[serde(default)]
30    pub model_id: Option<String>,
31    pub path: Option<String>,
32    pub model_type: String,
33    pub discovered_at: Option<String>,
34}
35
36/// Overall model shape parameters
37#[derive(Debug, Clone, Deserialize, Serialize)]
38pub struct ShapeConfig {
39    pub batch_size: usize,
40    pub context_length: usize,
41    pub hidden_size: usize,
42    pub vocab_size: usize,
43}
44
45/// Configuration for a single model component (embeddings, FFN, LM head, etc.)
46#[derive(Debug, Clone, Deserialize, Serialize)]
47pub struct ComponentConfig {
48    pub file_path: Option<String>,
49    pub inputs: HashMap<String, TensorConfig>,
50    pub outputs: HashMap<String, TensorConfig>,
51    pub functions: Vec<String>,
52    /// Optional deterministic input order; if absent, caller must provide correct order.
53    #[serde(default)]
54    pub input_order: Option<Vec<String>>,
55}
56
57/// Configuration for a tensor (shape and data type)
58#[derive(Debug, Clone, Deserialize, Serialize)]
59pub struct TensorConfig {
60    pub name: String,
61    pub shape: Vec<usize>,
62    pub data_type: String,
63}
64
65/// Model file naming patterns for component discovery
66#[derive(Debug, Clone, Deserialize, Serialize)]
67pub struct NamingConfig {
68    // Deprecated: patterns removed; explicit file paths are required per component
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub embeddings_pattern: Option<String>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub ffn_prefill_pattern: Option<String>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub ffn_infer_pattern: Option<String>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub lm_head_pattern: Option<String>,
77}
78
79impl ModelConfig {
80    /// Load configuration from a JSON file
81    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
82        let path = path.as_ref();
83        let content = std::fs::read_to_string(path)
84            .with_context(|| format!("Failed to read config file: {}", path.display()))?;
85
86        let config: ModelConfig = serde_json::from_str(&content)
87            .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
88
89        Ok(config)
90    }
91
92    /// Save configuration to a JSON file
93    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
94        let path = path.as_ref();
95        let content =
96            serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
97
98        std::fs::write(path, content)
99            .with_context(|| format!("Failed to write config file: {}", path.display()))?;
100
101        Ok(())
102    }
103
104    /// Get built-in configuration for known model IDs
105    pub fn get_builtin_config(model_id: &str) -> Option<Self> {
106        crate::builtin_configs::get_builtin_config(model_id)
107    }
108
109    /// Create a default configuration with standard Qwen shapes
110    pub fn default_qwen() -> Self {
111        Self {
112            model_info: ModelInfo {
113                model_id: None,
114                path: None,
115                model_type: "qwen".to_string(),
116                discovered_at: None,
117            },
118            shapes: ShapeConfig {
119                batch_size: 1,
120                context_length: 512,
121                hidden_size: 1024,
122                vocab_size: 151936,
123            },
124            components: HashMap::new(),
125            naming: NamingConfig {
126                embeddings_pattern: None,
127                ffn_prefill_pattern: None,
128                ffn_infer_pattern: None,
129                lm_head_pattern: None,
130            },
131            ffn_execution: None,
132        }
133    }
134
135    /// Get the shape configuration for a specific component and tensor
136    pub fn get_tensor_shape(
137        &self,
138        component: &str,
139        tensor_name: &str,
140        is_input: bool,
141    ) -> Option<&Vec<usize>> {
142        let component_config = self.components.get(component)?;
143
144        let tensor_map = if is_input {
145            &component_config.inputs
146        } else {
147            &component_config.outputs
148        };
149
150        tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
151    }
152
153    /// Get the expected input shape for embeddings
154    pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
155        self.get_tensor_shape("embeddings", "input_ids", true)
156    }
157
158    /// Get the expected output shape for embeddings
159    pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
160        self.get_tensor_shape("embeddings", "hidden_states", false)
161    }
162
163    /// Get the expected input shape for FFN prefill
164    pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
165        self.get_tensor_shape("ffn_prefill", "hidden_states", true)
166    }
167
168    /// Get the expected input shape for LM head
169    pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
170        self.get_tensor_shape("lm_head", "hidden_states", true)
171    }
172
173    /// Check if this model supports multi-part logits output
174    pub fn has_multipart_logits(&self) -> bool {
175        if let Some(lm_head) = self.components.get("lm_head") {
176            // Check if there are multiple logits outputs (logits1, logits2, etc.)
177            let logits_outputs: Vec<_> = lm_head
178                .outputs
179                .keys()
180                .filter(|name| name.starts_with("logits") && name.len() > 6) // "logits" + number
181                .collect();
182            return logits_outputs.len() > 1;
183        }
184        false
185    }
186
187    /// Get the number of logits parts for multi-part output
188    pub fn logits_part_count(&self) -> usize {
189        if let Some(lm_head) = self.components.get("lm_head") {
190            let logits_outputs: Vec<_> = lm_head
191                .outputs
192                .keys()
193                .filter(|name| name.starts_with("logits"))
194                .collect();
195            if logits_outputs.is_empty() {
196                1 // Single logits output
197            } else {
198                logits_outputs.len()
199            }
200        } else {
201            1
202        }
203    }
204
205    /// Select the primary logits output name for the LM head.
206    /// Preference order: "logits1" (multipart), then "logits", otherwise the first available key.
207    pub fn lm_head_primary_output_name(&self) -> Option<String> {
208        let lm_head = self.components.get("lm_head")?;
209
210        // Prefer explicit multipart naming starting with logits1
211        if lm_head.outputs.contains_key("logits1") {
212            return Some("logits1".to_string());
213        }
214
215        // Common single output name
216        if lm_head.outputs.contains_key("logits") {
217            return Some("logits".to_string());
218        }
219
220        // Fallback to the first key if any
221        lm_head.outputs.keys().next().map(|k| k.to_string())
222    }
223
224    /// Validate the configuration for consistency
225    pub fn validate(&self) -> Result<()> {
226        // Check that required components exist
227        let required_components = ["embeddings", "lm_head"];
228        for component in required_components {
229            if !self.components.contains_key(component) {
230                return Err(anyhow::anyhow!("Missing required component: {}", component));
231            }
232        }
233
234        // Check shape consistency
235        if self.shapes.batch_size == 0 {
236            return Err(anyhow::anyhow!("batch_size must be greater than 0"));
237        }
238
239        if self.shapes.context_length == 0 {
240            return Err(anyhow::anyhow!("context_length must be greater than 0"));
241        }
242
243        if self.shapes.hidden_size == 0 {
244            return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
245        }
246
247        if self.shapes.vocab_size == 0 {
248            return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
249        }
250
251        // Validate tensor shapes make sense
252        for (component_name, component) in &self.components {
253            for (tensor_name, tensor) in &component.inputs {
254                if tensor.shape.is_empty() {
255                    return Err(anyhow::anyhow!(
256                        "Empty shape for {}.inputs.{}",
257                        component_name,
258                        tensor_name
259                    ));
260                }
261            }
262            for (tensor_name, tensor) in &component.outputs {
263                if tensor.shape.is_empty() {
264                    return Err(anyhow::anyhow!(
265                        "Empty shape for {}.outputs.{}",
266                        component_name,
267                        tensor_name
268                    ));
269                }
270            }
271        }
272
273        Ok(())
274    }
275
276    /// Validate internal wiring between components for basic shape compatibility.
277    /// Examples:
278    ///  - embeddings.outputs.hidden_states == ffn_prefill.inputs.hidden_states
279    ///  - ffn_infer.outputs.output_hidden_states == lm_head.inputs.hidden_states (when ffn_infer exists)
280    pub fn validate_internal_wiring(&self) -> Result<()> {
281        // Embeddings -> FFN prefill hidden_states flow
282        if let (Some(emb_out), Some(ffn_in_hidden)) = (
283            self.get_tensor_shape("embeddings", "hidden_states", false),
284            self.get_tensor_shape("ffn_prefill", "hidden_states", true),
285        ) {
286            if emb_out != ffn_in_hidden {
287                return Err(anyhow::anyhow!(
288                    "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
289                    emb_out, ffn_in_hidden
290                ));
291            }
292        }
293
294        // FFN infer -> LM head hidden_states flow
295        if self.components.contains_key("ffn_infer") {
296            if let (Some(ffn_out), Some(lm_in)) = (
297                self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
298                self.get_tensor_shape("lm_head", "hidden_states", true),
299            ) {
300                if ffn_out != lm_in {
301                    return Err(anyhow::anyhow!(
302                        "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
303                        ffn_out, lm_in
304                    ));
305                }
306            }
307        } else {
308            // If there's no separate ffn_infer, check ffn_prefill output shape matches lm_head (single-token path)
309            if let (Some(ffn_out), Some(lm_in)) = (
310                self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
311                self.get_tensor_shape("lm_head", "hidden_states", true),
312            ) {
313                if ffn_out != lm_in {
314                    return Err(anyhow::anyhow!(
315                        "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
316                        ffn_out, lm_in
317                    ));
318                }
319            }
320        }
321
322        Ok(())
323    }
324
325    /// Determine if FFN execution should be treated as split (separate infer component)
326    pub fn ffn_is_split(&self) -> bool {
327        if let Some(mode) = self.ffn_execution.as_deref() {
328            return mode == "split";
329        }
330        if let (Some(prefill), Some(infer)) = (
331            self.components.get("ffn_prefill"),
332            self.components.get("ffn_infer"),
333        ) {
334            match (&prefill.file_path, &infer.file_path) {
335                (Some(p), Some(i)) => p != i, // different files => split
336                _ => false,
337            }
338        } else {
339            false
340        }
341    }
342
343    /// Detect if prefill should run in single-token sequential mode based on configured shapes
344    pub fn prefill_is_single_token(&self) -> bool {
345        if let Some(prefill) = self.components.get("ffn_prefill") {
346            if let Some(hs) = prefill.inputs.get("hidden_states") {
347                let is_single = hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
348                debug!(
349                    "🔍 prefill_is_single_token: shape={:?}, len={}, dim[1]={:?}, result={}",
350                    hs.shape,
351                    hs.shape.len(),
352                    hs.shape.get(1),
353                    is_single
354                );
355                return is_single;
356            }
357        }
358        debug!(
359            "🔍 prefill_is_single_token: no ffn_prefill or hidden_states found, returning false"
360        );
361        false
362    }
363
364    /// Check if the model expects full sequence prefill (as opposed to single-token processing)
365    /// This is typically true for CoreML models with fixed-shape inputs like [1, 128, 1024]
366    pub fn expects_full_sequence_prefill(&self) -> bool {
367        if let Some(prefill) = self.components.get("ffn_prefill") {
368            if let Some(hs) = prefill.inputs.get("hidden_states") {
369                // If the model expects a fixed sequence length > 1, it needs full-sequence prefill
370                let expects_full =
371                    hs.shape.len() == 3 && hs.shape.get(1).is_some_and(|&seq_len| seq_len > 1);
372                debug!(
373                    "🔍 expects_full_sequence_prefill: shape={:?}, len={}, dim[1]={:?}, result={}",
374                    hs.shape,
375                    hs.shape.len(),
376                    hs.shape.get(1),
377                    expects_full
378                );
379                return expects_full;
380            }
381        }
382        debug!("🔍 expects_full_sequence_prefill: no ffn_prefill or hidden_states found, returning false");
383        false
384    }
385}
386
387impl Default for ModelConfig {
388    fn default() -> Self {
389        Self::default_qwen()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use tempfile::NamedTempFile;
397
398    fn create_test_config() -> ModelConfig {
399        let mut components = HashMap::new();
400
401        // Embeddings component
402        let mut embeddings_inputs = HashMap::new();
403        embeddings_inputs.insert(
404            "input_ids".to_string(),
405            TensorConfig {
406                name: "input_ids".to_string(),
407                shape: vec![1, 64],
408                data_type: "INT32".to_string(),
409            },
410        );
411
412        let mut embeddings_outputs = HashMap::new();
413        embeddings_outputs.insert(
414            "hidden_states".to_string(),
415            TensorConfig {
416                name: "hidden_states".to_string(),
417                shape: vec![1, 64, 1024],
418                data_type: "FLOAT16".to_string(),
419            },
420        );
421
422        components.insert(
423            "embeddings".to_string(),
424            ComponentConfig {
425                file_path: None,
426                inputs: embeddings_inputs,
427                outputs: embeddings_outputs,
428                functions: vec![],
429                input_order: None,
430            },
431        );
432
433        // LM Head component
434        let mut lm_head_inputs = HashMap::new();
435        lm_head_inputs.insert(
436            "hidden_states".to_string(),
437            TensorConfig {
438                name: "hidden_states".to_string(),
439                shape: vec![1, 1, 1024],
440                data_type: "FLOAT16".to_string(),
441            },
442        );
443
444        let mut lm_head_outputs = HashMap::new();
445        lm_head_outputs.insert(
446            "logits".to_string(),
447            TensorConfig {
448                name: "logits".to_string(),
449                shape: vec![1, 1, 151936],
450                data_type: "FLOAT32".to_string(),
451            },
452        );
453
454        components.insert(
455            "lm_head".to_string(),
456            ComponentConfig {
457                file_path: None,
458                inputs: lm_head_inputs,
459                outputs: lm_head_outputs,
460                functions: vec![],
461                input_order: None,
462            },
463        );
464
465        ModelConfig {
466            model_info: ModelInfo {
467                model_id: Some("test/model".to_string()),
468                path: Some("/test/path".to_string()),
469                model_type: "qwen".to_string(),
470                discovered_at: Some("2025-08-07T00:00:00".to_string()),
471            },
472            shapes: ShapeConfig {
473                batch_size: 1,
474                context_length: 512,
475                hidden_size: 1024,
476                vocab_size: 151936,
477            },
478            components,
479            naming: NamingConfig {
480                embeddings_pattern: None,
481                ffn_prefill_pattern: None,
482                ffn_infer_pattern: None,
483                lm_head_pattern: None,
484            },
485            ffn_execution: Some("unified".to_string()),
486        }
487    }
488
489    #[test]
490    fn test_config_serialization() {
491        let config = create_test_config();
492
493        // Test JSON serialization
494        let json = serde_json::to_string_pretty(&config).unwrap();
495        assert!(json.contains("test/model"));
496        assert!(json.contains("batch_size"));
497        assert!(json.contains("embeddings"));
498
499        // Test deserialization
500        let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
501        assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
502        assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
503        assert_eq!(parsed.components.len(), config.components.len());
504    }
505
506    #[test]
507    fn test_config_file_io() {
508        let config = create_test_config();
509        let temp_file = NamedTempFile::new().unwrap();
510
511        // Save configuration
512        config.save_to_file(temp_file.path()).unwrap();
513
514        // Load configuration
515        let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
516        assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
517        assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
518    }
519
520    #[test]
521    fn test_shape_accessors() {
522        let config = create_test_config();
523
524        // Test embeddings shapes
525        let embeddings_input = config.embeddings_input_shape().unwrap();
526        assert_eq!(embeddings_input, &vec![1, 64]);
527
528        let embeddings_output = config.embeddings_output_shape().unwrap();
529        assert_eq!(embeddings_output, &vec![1, 64, 1024]);
530
531        let lm_head_input = config.lm_head_input_shape().unwrap();
532        assert_eq!(lm_head_input, &vec![1, 1, 1024]);
533    }
534
535    #[test]
536    fn test_multipart_logits_detection() {
537        let config = create_test_config();
538        assert!(!config.has_multipart_logits()); // Single logits output
539
540        // Create config with multipart logits
541        let mut config_multipart = config;
542        let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
543        lm_head.outputs.clear();
544        lm_head.outputs.insert(
545            "logits1".to_string(),
546            TensorConfig {
547                name: "logits1".to_string(),
548                shape: vec![1, 1, 9480],
549                data_type: "FLOAT32".to_string(),
550            },
551        );
552        lm_head.outputs.insert(
553            "logits2".to_string(),
554            TensorConfig {
555                name: "logits2".to_string(),
556                shape: vec![1, 1, 9479],
557                data_type: "FLOAT32".to_string(),
558            },
559        );
560
561        assert!(config_multipart.has_multipart_logits());
562        assert_eq!(config_multipart.logits_part_count(), 2);
563    }
564
565    #[test]
566    fn test_config_validation() {
567        let config = create_test_config();
568        assert!(config.validate().is_ok());
569
570        // Internal wiring should be consistent in this synthetic setup
571        assert!(config.validate_internal_wiring().is_ok());
572
573        // Test missing component
574        let mut invalid_config = config.clone();
575        invalid_config.components.remove("embeddings");
576        assert!(invalid_config.validate().is_err());
577
578        // Test invalid shapes
579        let mut invalid_shapes = config;
580        invalid_shapes.shapes.batch_size = 0;
581        assert!(invalid_shapes.validate().is_err());
582    }
583
584    #[test]
585    fn test_default_config() {
586        let config = ModelConfig::default();
587        assert_eq!(config.model_info.model_type, "qwen");
588        assert_eq!(config.shapes.batch_size, 1);
589        assert_eq!(config.shapes.context_length, 512);
590        assert_eq!(config.shapes.hidden_size, 1024);
591        assert_eq!(config.shapes.vocab_size, 151936);
592    }
593}