candle_coreml/
model_config.rs

1//! Model configuration system for candle-coreml
2//!
3//! This module provides a configuration system that replaces hardcoded model shapes
4//! with discoverable, flexible configurations. It supports loading configurations from
5//! JSON files generated by the shape discovery tool, as well as built-in configurations
6//! for known models.
7
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12
13/// Complete model configuration including shapes, components, and naming patterns
14#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct ModelConfig {
16    pub model_info: ModelInfo,
17    pub shapes: ShapeConfig,
18    pub components: HashMap<String, ComponentConfig>,
19    pub naming: NamingConfig,
20    /// Execution mode for FFN: "unified" (single component/function) or "split" (separate prefill/infer components)
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub ffn_execution: Option<String>,
23}
24
25/// Model metadata and identification
26#[derive(Debug, Clone, Deserialize, Serialize)]
27pub struct ModelInfo {
28    #[serde(default)]
29    pub model_id: Option<String>,
30    pub path: Option<String>,
31    pub model_type: String,
32    pub discovered_at: Option<String>,
33}
34
35/// Overall model shape parameters
36#[derive(Debug, Clone, Deserialize, Serialize)]
37pub struct ShapeConfig {
38    pub batch_size: usize,
39    pub context_length: usize,
40    pub hidden_size: usize,
41    pub vocab_size: usize,
42}
43
44/// Configuration for a single model component (embeddings, FFN, LM head, etc.)
45#[derive(Debug, Clone, Deserialize, Serialize)]
46pub struct ComponentConfig {
47    pub file_path: Option<String>,
48    pub inputs: HashMap<String, TensorConfig>,
49    pub outputs: HashMap<String, TensorConfig>,
50    pub functions: Vec<String>,
51    /// Optional deterministic input order; if absent, caller must provide correct order.
52    #[serde(default)]
53    pub input_order: Option<Vec<String>>,
54}
55
56/// Configuration for a tensor (shape and data type)
57#[derive(Debug, Clone, Deserialize, Serialize)]
58pub struct TensorConfig {
59    pub name: String,
60    pub shape: Vec<usize>,
61    pub data_type: String,
62}
63
64/// Model file naming patterns for component discovery
65#[derive(Debug, Clone, Deserialize, Serialize)]
66pub struct NamingConfig {
67    // Deprecated: patterns removed; explicit file paths are required per component
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub embeddings_pattern: Option<String>,
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub ffn_prefill_pattern: Option<String>,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub ffn_infer_pattern: Option<String>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub lm_head_pattern: Option<String>,
76}
77
78impl ModelConfig {
79    /// Load configuration from a JSON file
80    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
81        let path = path.as_ref();
82        let content = std::fs::read_to_string(path)
83            .with_context(|| format!("Failed to read config file: {}", path.display()))?;
84
85        let config: ModelConfig = serde_json::from_str(&content)
86            .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
87
88        Ok(config)
89    }
90
91    /// Save configuration to a JSON file
92    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
93        let path = path.as_ref();
94        let content =
95            serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
96
97        std::fs::write(path, content)
98            .with_context(|| format!("Failed to write config file: {}", path.display()))?;
99
100        Ok(())
101    }
102
103    /// Get built-in configuration for known model IDs
104    pub fn get_builtin_config(model_id: &str) -> Option<Self> {
105        crate::builtin_configs::get_builtin_config(model_id)
106    }
107
108    /// Create a default configuration with standard Qwen shapes
109    pub fn default_qwen() -> Self {
110        Self {
111            model_info: ModelInfo {
112                model_id: None,
113                path: None,
114                model_type: "qwen".to_string(),
115                discovered_at: None,
116            },
117            shapes: ShapeConfig {
118                batch_size: 1,
119                context_length: 512,
120                hidden_size: 1024,
121                vocab_size: 151936,
122            },
123            components: HashMap::new(),
124            naming: NamingConfig {
125                embeddings_pattern: None,
126                ffn_prefill_pattern: None,
127                ffn_infer_pattern: None,
128                lm_head_pattern: None,
129            },
130            ffn_execution: None,
131        }
132    }
133
134    /// Get the shape configuration for a specific component and tensor
135    pub fn get_tensor_shape(
136        &self,
137        component: &str,
138        tensor_name: &str,
139        is_input: bool,
140    ) -> Option<&Vec<usize>> {
141        let component_config = self.components.get(component)?;
142
143        let tensor_map = if is_input {
144            &component_config.inputs
145        } else {
146            &component_config.outputs
147        };
148
149        tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
150    }
151
152    /// Get the expected input shape for embeddings
153    pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
154        self.get_tensor_shape("embeddings", "input_ids", true)
155    }
156
157    /// Get the expected output shape for embeddings
158    pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
159        self.get_tensor_shape("embeddings", "hidden_states", false)
160    }
161
162    /// Get the expected input shape for FFN prefill
163    pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
164        self.get_tensor_shape("ffn_prefill", "hidden_states", true)
165    }
166
167    /// Get the expected input shape for LM head
168    pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
169        self.get_tensor_shape("lm_head", "hidden_states", true)
170    }
171
172    /// Check if this model supports multi-part logits output
173    pub fn has_multipart_logits(&self) -> bool {
174        if let Some(lm_head) = self.components.get("lm_head") {
175            // Check if there are multiple logits outputs (logits1, logits2, etc.)
176            let logits_outputs: Vec<_> = lm_head
177                .outputs
178                .keys()
179                .filter(|name| name.starts_with("logits") && name.len() > 6) // "logits" + number
180                .collect();
181            return logits_outputs.len() > 1;
182        }
183        false
184    }
185
186    /// Get the number of logits parts for multi-part output
187    pub fn logits_part_count(&self) -> usize {
188        if let Some(lm_head) = self.components.get("lm_head") {
189            let logits_outputs: Vec<_> = lm_head
190                .outputs
191                .keys()
192                .filter(|name| name.starts_with("logits"))
193                .collect();
194            if logits_outputs.is_empty() {
195                1 // Single logits output
196            } else {
197                logits_outputs.len()
198            }
199        } else {
200            1
201        }
202    }
203
204    /// Select the primary logits output name for the LM head.
205    /// Preference order: "logits1" (multipart), then "logits", otherwise the first available key.
206    pub fn lm_head_primary_output_name(&self) -> Option<String> {
207        let lm_head = self.components.get("lm_head")?;
208
209        // Prefer explicit multipart naming starting with logits1
210        if lm_head.outputs.contains_key("logits1") {
211            return Some("logits1".to_string());
212        }
213
214        // Common single output name
215        if lm_head.outputs.contains_key("logits") {
216            return Some("logits".to_string());
217        }
218
219        // Fallback to the first key if any
220        lm_head.outputs.keys().next().map(|k| k.to_string())
221    }
222
223    /// Validate the configuration for consistency
224    pub fn validate(&self) -> Result<()> {
225        // Check that required components exist
226        let required_components = ["embeddings", "lm_head"];
227        for component in required_components {
228            if !self.components.contains_key(component) {
229                return Err(anyhow::anyhow!("Missing required component: {}", component));
230            }
231        }
232
233        // Check shape consistency
234        if self.shapes.batch_size == 0 {
235            return Err(anyhow::anyhow!("batch_size must be greater than 0"));
236        }
237
238        if self.shapes.context_length == 0 {
239            return Err(anyhow::anyhow!("context_length must be greater than 0"));
240        }
241
242        if self.shapes.hidden_size == 0 {
243            return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
244        }
245
246        if self.shapes.vocab_size == 0 {
247            return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
248        }
249
250        // Validate tensor shapes make sense
251        for (component_name, component) in &self.components {
252            for (tensor_name, tensor) in &component.inputs {
253                if tensor.shape.is_empty() {
254                    return Err(anyhow::anyhow!(
255                        "Empty shape for {}.inputs.{}",
256                        component_name,
257                        tensor_name
258                    ));
259                }
260            }
261            for (tensor_name, tensor) in &component.outputs {
262                if tensor.shape.is_empty() {
263                    return Err(anyhow::anyhow!(
264                        "Empty shape for {}.outputs.{}",
265                        component_name,
266                        tensor_name
267                    ));
268                }
269            }
270        }
271
272        Ok(())
273    }
274
275    /// Validate internal wiring between components for basic shape compatibility.
276    /// Examples:
277    ///  - embeddings.outputs.hidden_states == ffn_prefill.inputs.hidden_states
278    ///  - ffn_infer.outputs.output_hidden_states == lm_head.inputs.hidden_states (when ffn_infer exists)
279    pub fn validate_internal_wiring(&self) -> Result<()> {
280        // Embeddings -> FFN prefill hidden_states flow
281        if let (Some(emb_out), Some(ffn_in_hidden)) = (
282            self.get_tensor_shape("embeddings", "hidden_states", false),
283            self.get_tensor_shape("ffn_prefill", "hidden_states", true),
284        ) {
285            if emb_out != ffn_in_hidden {
286                return Err(anyhow::anyhow!(
287                    "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
288                    emb_out, ffn_in_hidden
289                ));
290            }
291        }
292
293        // FFN infer -> LM head hidden_states flow
294        if self.components.contains_key("ffn_infer") {
295            if let (Some(ffn_out), Some(lm_in)) = (
296                self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
297                self.get_tensor_shape("lm_head", "hidden_states", true),
298            ) {
299                if ffn_out != lm_in {
300                    return Err(anyhow::anyhow!(
301                        "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
302                        ffn_out, lm_in
303                    ));
304                }
305            }
306        } else {
307            // If there's no separate ffn_infer, check ffn_prefill output shape matches lm_head (single-token path)
308            if let (Some(ffn_out), Some(lm_in)) = (
309                self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
310                self.get_tensor_shape("lm_head", "hidden_states", true),
311            ) {
312                if ffn_out != lm_in {
313                    return Err(anyhow::anyhow!(
314                        "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
315                        ffn_out, lm_in
316                    ));
317                }
318            }
319        }
320
321        Ok(())
322    }
323
324    /// Determine if FFN execution should be treated as split (separate infer component)
325    pub fn ffn_is_split(&self) -> bool {
326        if let Some(mode) = self.ffn_execution.as_deref() {
327            return mode == "split";
328        }
329        if let (Some(prefill), Some(infer)) = (
330            self.components.get("ffn_prefill"),
331            self.components.get("ffn_infer"),
332        ) {
333            match (&prefill.file_path, &infer.file_path) {
334                (Some(p), Some(i)) => p != i, // different files => split
335                _ => false,
336            }
337        } else {
338            false
339        }
340    }
341
342    /// Detect if prefill should run in single-token sequential mode based on configured shapes
343    pub fn prefill_is_single_token(&self) -> bool {
344        if let Some(prefill) = self.components.get("ffn_prefill") {
345            if let Some(hs) = prefill.inputs.get("hidden_states") {
346                return hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
347            }
348        }
349        false
350    }
351}
352
353impl Default for ModelConfig {
354    fn default() -> Self {
355        Self::default_qwen()
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use tempfile::NamedTempFile;
363
364    fn create_test_config() -> ModelConfig {
365        let mut components = HashMap::new();
366
367        // Embeddings component
368        let mut embeddings_inputs = HashMap::new();
369        embeddings_inputs.insert(
370            "input_ids".to_string(),
371            TensorConfig {
372                name: "input_ids".to_string(),
373                shape: vec![1, 64],
374                data_type: "INT32".to_string(),
375            },
376        );
377
378        let mut embeddings_outputs = HashMap::new();
379        embeddings_outputs.insert(
380            "hidden_states".to_string(),
381            TensorConfig {
382                name: "hidden_states".to_string(),
383                shape: vec![1, 64, 1024],
384                data_type: "FLOAT16".to_string(),
385            },
386        );
387
388        components.insert(
389            "embeddings".to_string(),
390            ComponentConfig {
391                file_path: None,
392                inputs: embeddings_inputs,
393                outputs: embeddings_outputs,
394                functions: vec![],
395                input_order: None,
396            },
397        );
398
399        // LM Head component
400        let mut lm_head_inputs = HashMap::new();
401        lm_head_inputs.insert(
402            "hidden_states".to_string(),
403            TensorConfig {
404                name: "hidden_states".to_string(),
405                shape: vec![1, 1, 1024],
406                data_type: "FLOAT16".to_string(),
407            },
408        );
409
410        let mut lm_head_outputs = HashMap::new();
411        lm_head_outputs.insert(
412            "logits".to_string(),
413            TensorConfig {
414                name: "logits".to_string(),
415                shape: vec![1, 1, 151936],
416                data_type: "FLOAT32".to_string(),
417            },
418        );
419
420        components.insert(
421            "lm_head".to_string(),
422            ComponentConfig {
423                file_path: None,
424                inputs: lm_head_inputs,
425                outputs: lm_head_outputs,
426                functions: vec![],
427                input_order: None,
428            },
429        );
430
431        ModelConfig {
432            model_info: ModelInfo {
433                model_id: Some("test/model".to_string()),
434                path: Some("/test/path".to_string()),
435                model_type: "qwen".to_string(),
436                discovered_at: Some("2025-08-07T00:00:00".to_string()),
437            },
438            shapes: ShapeConfig {
439                batch_size: 1,
440                context_length: 512,
441                hidden_size: 1024,
442                vocab_size: 151936,
443            },
444            components,
445            naming: NamingConfig {
446                embeddings_pattern: None,
447                ffn_prefill_pattern: None,
448                ffn_infer_pattern: None,
449                lm_head_pattern: None,
450            },
451            ffn_execution: Some("unified".to_string()),
452        }
453    }
454
455    #[test]
456    fn test_config_serialization() {
457        let config = create_test_config();
458
459        // Test JSON serialization
460        let json = serde_json::to_string_pretty(&config).unwrap();
461        assert!(json.contains("test/model"));
462        assert!(json.contains("batch_size"));
463        assert!(json.contains("embeddings"));
464
465        // Test deserialization
466        let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
467        assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
468        assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
469        assert_eq!(parsed.components.len(), config.components.len());
470    }
471
472    #[test]
473    fn test_config_file_io() {
474        let config = create_test_config();
475        let temp_file = NamedTempFile::new().unwrap();
476
477        // Save configuration
478        config.save_to_file(temp_file.path()).unwrap();
479
480        // Load configuration
481        let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
482        assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
483        assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
484    }
485
486    #[test]
487    fn test_shape_accessors() {
488        let config = create_test_config();
489
490        // Test embeddings shapes
491        let embeddings_input = config.embeddings_input_shape().unwrap();
492        assert_eq!(embeddings_input, &vec![1, 64]);
493
494        let embeddings_output = config.embeddings_output_shape().unwrap();
495        assert_eq!(embeddings_output, &vec![1, 64, 1024]);
496
497        let lm_head_input = config.lm_head_input_shape().unwrap();
498        assert_eq!(lm_head_input, &vec![1, 1, 1024]);
499    }
500
501    #[test]
502    fn test_multipart_logits_detection() {
503        let config = create_test_config();
504        assert!(!config.has_multipart_logits()); // Single logits output
505
506        // Create config with multipart logits
507        let mut config_multipart = config;
508        let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
509        lm_head.outputs.clear();
510        lm_head.outputs.insert(
511            "logits1".to_string(),
512            TensorConfig {
513                name: "logits1".to_string(),
514                shape: vec![1, 1, 9480],
515                data_type: "FLOAT32".to_string(),
516            },
517        );
518        lm_head.outputs.insert(
519            "logits2".to_string(),
520            TensorConfig {
521                name: "logits2".to_string(),
522                shape: vec![1, 1, 9479],
523                data_type: "FLOAT32".to_string(),
524            },
525        );
526
527        assert!(config_multipart.has_multipart_logits());
528        assert_eq!(config_multipart.logits_part_count(), 2);
529    }
530
531    #[test]
532    fn test_config_validation() {
533        let config = create_test_config();
534        assert!(config.validate().is_ok());
535
536        // Internal wiring should be consistent in this synthetic setup
537        assert!(config.validate_internal_wiring().is_ok());
538
539        // Test missing component
540        let mut invalid_config = config.clone();
541        invalid_config.components.remove("embeddings");
542        assert!(invalid_config.validate().is_err());
543
544        // Test invalid shapes
545        let mut invalid_shapes = config;
546        invalid_shapes.shapes.batch_size = 0;
547        assert!(invalid_shapes.validate().is_err());
548    }
549
550    #[test]
551    fn test_default_config() {
552        let config = ModelConfig::default();
553        assert_eq!(config.model_info.model_type, "qwen");
554        assert_eq!(config.shapes.batch_size, 1);
555        assert_eq!(config.shapes.context_length, 512);
556        assert_eq!(config.shapes.hidden_size, 1024);
557        assert_eq!(config.shapes.vocab_size, 151936);
558    }
559}