candle_coreml/config/generator/
mod.rs

1//! Modular configuration generator for CoreML models
2//!
3//! This module provides automatic configuration generation from CoreML .mlpackage files
4//! with a clean, modular architecture that's truly model-agnostic.
5
6use crate::cache::manager::CacheManager;
7use crate::config::model::{ComponentConfig, ModelConfig, NamingConfig};
8use anyhow::Result;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use tracing::{debug, info};
12
13pub mod caching;
14pub mod coreml_metadata;
15pub mod file_discovery;
16pub mod manifest_parser;
17pub mod schema_extractor;
18pub mod shape_inference;
19
20use caching::ConfigCaching;
21use file_discovery::FileDiscovery;
22use manifest_parser::ManifestParser;
23use schema_extractor::SchemaExtractor;
24use shape_inference::ShapeInference;
25
26// Re-export ComponentRole for external use
27pub use schema_extractor::ComponentRole;
28// Re-export CoreML metadata extractor for testing
29pub use coreml_metadata::CoreMLMetadataExtractor;
30
31/// Modular configuration generator for auto-detecting model parameters
32pub struct ConfigGenerator {
33    file_discovery: FileDiscovery,
34    manifest_parser: ManifestParser,
35    schema_extractor: SchemaExtractor,
36    shape_inference: ShapeInference,
37    caching: ConfigCaching,
38    #[allow(dead_code)]
39    metadata_extractor: CoreMLMetadataExtractor,
40}
41
42impl ConfigGenerator {
43    /// Create a new config generator with all modules initialized
44    pub fn new() -> Result<Self> {
45        let cache_manager = CacheManager::new()?;
46        let caching = ConfigCaching::new(cache_manager);
47
48        Ok(Self {
49            file_discovery: FileDiscovery::new(),
50            manifest_parser: ManifestParser::new(),
51            schema_extractor: SchemaExtractor::new(),
52            shape_inference: ShapeInference::new(),
53            caching,
54            metadata_extractor: coreml_metadata::CoreMLMetadataExtractor::new(),
55        })
56    }
57
58    /// Generate a config from a downloaded model directory using enhanced metadata-driven detection
59    ///
60    /// This function inspects .mlpackage files in a directory and generates a complete
61    /// ModelConfig with proper shapes and component configurations, using metadata-driven
62    /// component role detection to support both unified and split FFN architectures.
63    pub fn generate_config_from_directory_enhanced(
64        &self,
65        model_dir: &Path,
66        model_id: &str,
67        model_type: &str,
68    ) -> Result<ModelConfig> {
69        info!("🔍 Generating config (enhanced) for model: {}", model_id);
70        debug!("   Model directory: {}", model_dir.display());
71        debug!("   Model type: {}", model_type);
72
73        // Validate and discover CoreML packages
74        self.file_discovery.validate_model_directory(model_dir)?;
75        let packages = self.file_discovery.find_coreml_packages(model_dir)?;
76
77        info!("📦 Found {} CoreML model files", packages.len());
78        for package in &packages {
79            debug!(
80                "   • {}",
81                package.file_name().unwrap_or_default().to_string_lossy()
82            );
83        }
84
85        // Analyze and parse each package using metadata-driven detection
86        let mut components = HashMap::new();
87        for package_path in &packages {
88            self.process_package_with_metadata_detection(package_path, &mut components)?;
89        }
90
91        // Check for required components
92        self.validate_required_components(&components)?;
93
94        // Generate final configuration with enhanced shape inference
95        let config = self
96            .build_model_config_enhanced(model_id, model_type, model_dir, components, &packages)?;
97
98        info!(
99            "✅ Generated enhanced config for {} with {} components",
100            model_id,
101            config.components.len()
102        );
103
104        // Cache the generated config
105        self.caching.cache_config(model_id, &config)?;
106
107        Ok(config)
108    }
109
110    /// Generate a config from a downloaded model directory (legacy method)
111    ///
112    /// This function inspects .mlpackage files in a directory and generates
113    /// a complete ModelConfig with proper shapes and component configurations.
114    pub fn generate_config_from_directory(
115        &self,
116        model_dir: &Path,
117        model_id: &str,
118        model_type: &str,
119    ) -> Result<ModelConfig> {
120        info!("🔍 Generating config for model: {}", model_id);
121        debug!("   Model directory: {}", model_dir.display());
122        debug!("   Model type: {}", model_type);
123
124        // Validate and discover CoreML packages
125        self.file_discovery.validate_model_directory(model_dir)?;
126        let packages = self.file_discovery.find_coreml_packages(model_dir)?;
127
128        info!("📦 Found {} CoreML model files", packages.len());
129        for package in &packages {
130            debug!(
131                "   • {}",
132                package.file_name().unwrap_or_default().to_string_lossy()
133            );
134        }
135
136        // Analyze and parse each package
137        let mut components = HashMap::new();
138        for package_path in &packages {
139            self.process_package(package_path, &mut components)?;
140        }
141
142        // Generate final configuration
143        let config =
144            self.build_model_config(model_id, model_type, model_dir, components, &packages)?;
145
146        info!(
147            "✅ Generated config for {} with {} components",
148            model_id,
149            config.components.len()
150        );
151
152        // Cache the generated config
153        self.caching.cache_config(model_id, &config)?;
154
155        Ok(config)
156    }
157
158    /// Load a cached configuration if available
159    pub fn load_cached_config(&self, model_id: &str) -> Result<Option<ModelConfig>> {
160        self.caching.load_cached_config(model_id)
161    }
162
163    /// Check if a cached configuration exists
164    pub fn has_cached_config(&self, model_id: &str) -> bool {
165        self.caching.has_cached_config(model_id)
166    }
167
168    /// Clear cached configuration for a model
169    pub fn clear_cached_config(&self, model_id: &str) -> Result<()> {
170        self.caching.clear_cached_config(model_id)
171    }
172
173    // Private implementation methods
174
175    fn process_package(
176        &self,
177        package_path: &Path,
178        components: &mut HashMap<String, ComponentConfig>,
179    ) -> Result<()> {
180        let manifest = self.file_discovery.read_manifest(package_path)?;
181        let base_component_name = self.file_discovery.infer_component_name(package_path);
182
183        // Parse package into component configurations
184        let parsed_components =
185            self.manifest_parser
186                .parse_package(package_path, &manifest, &base_component_name)?;
187
188        // Add all components to the collection
189        for (name, config) in parsed_components {
190            debug!(
191                "📋 Component '{}': inputs={:?} outputs={:?}",
192                name,
193                config.inputs.keys().collect::<Vec<_>>(),
194                config.outputs.keys().collect::<Vec<_>>()
195            );
196            components.insert(name, config);
197        }
198
199        Ok(())
200    }
201
202    fn process_package_with_metadata_detection(
203        &self,
204        package_path: &Path,
205        components: &mut HashMap<String, ComponentConfig>,
206    ) -> Result<()> {
207        // Get manifest source to determine parsing strategy
208        let manifest_source = self.file_discovery.find_manifest_source(package_path)?;
209        let manifest = self.file_discovery.read_manifest(package_path)?;
210
211        debug!("🔍 Processing package with source: {:?}", manifest_source);
212
213        // Parse package using enhanced method that handles all source types
214        let parsed_components = self.manifest_parser.parse_package_enhanced(
215            package_path,
216            &manifest_source,
217            &manifest,
218            &self.schema_extractor,
219        )?;
220
221        // Add all components to the collection
222        for (name, config) in parsed_components {
223            debug!(
224                "📋 Enhanced component '{}': inputs={:?} outputs={:?}",
225                name,
226                config.inputs.keys().collect::<Vec<_>>(),
227                config.outputs.keys().collect::<Vec<_>>()
228            );
229            components.insert(name, config);
230        }
231
232        Ok(())
233    }
234
235    fn validate_required_components(
236        &self,
237        components: &HashMap<String, ComponentConfig>,
238    ) -> Result<()> {
239        // During incremental parsing (tests creating one package at a time),
240        // allow partial configurations. Full-pipeline validation is enforced
241        // later when all components are present in the directory.
242
243        // Require at least one component with non-empty tensors to proceed.
244        if components.is_empty() {
245            return Err(anyhow::Error::msg("No components discovered"));
246        }
247
248        info!("✅ Component presence check passed (partial allowed in enhanced mode)");
249        Ok(())
250    }
251
252    fn build_model_config_enhanced(
253        &self,
254        model_id: &str,
255        model_type: &str,
256        model_dir: &Path,
257        components: HashMap<String, ComponentConfig>,
258        packages: &[PathBuf],
259    ) -> Result<ModelConfig> {
260        // If multiple packages are present (pipeline), ensure we end up with the core components
261        let is_pipeline = packages.len() >= 3; // embeddings, ffn*, lm_head
262
263        // Compute shape configuration using enhanced inference (fails fast on empty tensor metadata)
264        let shape_config = self
265            .shape_inference
266            .infer_shapes_with_schema_extractor(&components, &self.schema_extractor)?;
267
268        // Generate naming patterns (generic approach)
269        let naming_config = self.generate_naming_config(packages);
270
271        // Determine execution mode using enhanced detection
272        let component_list: Vec<(String, ComponentConfig)> = components.into_iter().collect();
273        let ffn_execution = self.manifest_parser.infer_execution_mode(&component_list);
274        info!("🔧 Detected execution mode: {}", ffn_execution);
275
276        let final_components: HashMap<String, ComponentConfig> =
277            component_list.into_iter().collect();
278
279        let model = ModelConfig {
280            model_info: crate::config::model::ModelInfo {
281                model_id: Some(model_id.to_string()),
282                path: Some(model_dir.to_string_lossy().to_string()),
283                model_type: model_type.to_string(),
284                discovered_at: Some(chrono::Utc::now().to_rfc3339()),
285            },
286            shapes: shape_config,
287            components: final_components,
288            naming: naming_config,
289            ffn_execution: Some(ffn_execution),
290        };
291
292        // For pipeline runs, require the minimal set of components
293        if is_pipeline {
294            let required = ["embeddings", "lm_head"];
295            let missing: Vec<_> = required
296                .iter()
297                .filter(|c| !model.components.contains_key(**c))
298                .collect();
299            if !missing.is_empty() {
300                return Err(anyhow::anyhow!(
301                    "ModelConfig missing required components in pipeline: {:?}. Found: {:?}",
302                    missing,
303                    model.components.keys().collect::<Vec<_>>()
304                ));
305            }
306        }
307
308        Ok(model)
309    }
310
311    fn build_model_config(
312        &self,
313        model_id: &str,
314        model_type: &str,
315        model_dir: &Path,
316        components: HashMap<String, ComponentConfig>,
317        packages: &[PathBuf],
318    ) -> Result<ModelConfig> {
319        // Compute shape configuration from discovered components (fails fast on empty tensor metadata)
320        let shape_config = self.shape_inference.infer_shapes(&components)?;
321
322        // Generate naming patterns (generic approach - mostly empty for truly generic models)
323        let naming_config = self.generate_naming_config(packages);
324
325        // Determine execution mode
326        let component_list: Vec<(String, ComponentConfig)> = components.into_iter().collect();
327        let ffn_execution = self.manifest_parser.infer_execution_mode(&component_list);
328        info!("🔧 Detected execution mode: {}", ffn_execution);
329
330        let final_components: HashMap<String, ComponentConfig> =
331            component_list.into_iter().collect();
332
333        Ok(ModelConfig {
334            model_info: crate::config::model::ModelInfo {
335                model_id: Some(model_id.to_string()),
336                path: Some(model_dir.to_string_lossy().to_string()),
337                model_type: model_type.to_string(),
338                discovered_at: Some(chrono::Utc::now().to_rfc3339()),
339            },
340            shapes: shape_config,
341            components: final_components,
342            naming: naming_config,
343            ffn_execution: Some(ffn_execution),
344        })
345    }
346
347    fn generate_naming_config(&self, _packages: &[PathBuf]) -> NamingConfig {
348        // For truly generic models, we don't assume specific naming patterns
349        // Just return empty patterns since we're being model-agnostic
350        NamingConfig {
351            embeddings_pattern: None,
352            ffn_infer_pattern: None,
353            ffn_prefill_pattern: None,
354            lm_head_pattern: None,
355        }
356    }
357}
358
359// Re-export the old interface for backward compatibility
360impl ConfigGenerator {
361    /// Find all .mlpackage files in a directory (legacy interface)
362    pub fn find_mlpackage_files(&self, model_dir: &Path) -> Result<Vec<PathBuf>> {
363        self.file_discovery.find_coreml_packages(model_dir)
364    }
365
366    /// Infer component name from package filename (legacy interface)
367    pub fn infer_component_name_from_file(&self, package_path: &Path) -> String {
368        self.file_discovery.infer_component_name(package_path)
369    }
370
371    /// Analyze a single .mlpackage file (legacy interface)
372    pub fn analyze_mlpackage(&self, package_path: &Path) -> Result<ComponentConfig> {
373        let manifest = self.file_discovery.read_manifest(package_path)?;
374        let inputs = self.schema_extractor.extract_inputs(&manifest)?;
375        let outputs = self.schema_extractor.extract_outputs(&manifest)?;
376
377        Ok(ComponentConfig {
378            file_path: Some(package_path.to_string_lossy().to_string()),
379            inputs,
380            outputs,
381            functions: Vec::new(),
382            input_order: None,
383        })
384    }
385
386    /// Extract function-based components (legacy interface)
387    pub fn extract_function_based_components(
388        &self,
389        package_path: &Path,
390        _base_config: &ComponentConfig,
391    ) -> Result<Option<HashMap<String, ComponentConfig>>> {
392        let manifest = self.file_discovery.read_manifest(package_path)?;
393        let base_component_name = self.file_discovery.infer_component_name(package_path);
394
395        let parsed_components =
396            self.manifest_parser
397                .parse_package(package_path, &manifest, &base_component_name)?;
398
399        if parsed_components.len() > 1 {
400            // Has multiple function-based components
401            let function_components: HashMap<String, ComponentConfig> =
402                parsed_components.into_iter().collect();
403            Ok(Some(function_components))
404        } else {
405            // Single component, no functions
406            Ok(None)
407        }
408    }
409
410    /// Parse tensor configurations from schema (legacy interface)
411    pub fn parse_tensor_configs_from_schema(
412        &self,
413        schema: &[serde_json::Value],
414    ) -> Result<HashMap<String, crate::model_config::TensorConfig>> {
415        self.schema_extractor.parse_tensor_configs(schema)
416    }
417
418    /// Compute shape info (legacy interface)
419    /// Returns an error if components have insufficient tensor metadata
420    pub fn compute_shape_info_generic(
421        &self,
422        components: &HashMap<String, ComponentConfig>,
423    ) -> Result<crate::model_config::ShapeConfig> {
424        self.shape_inference.infer_shapes(components)
425    }
426
427    /// Generate naming config (legacy interface)  
428    pub fn generate_naming_config_generic(&self, packages: &[PathBuf]) -> NamingConfig {
429        self.generate_naming_config(packages)
430    }
431
432    /// Determine execution mode (legacy interface)
433    pub fn determine_execution_mode_generic(
434        &self,
435        components: &HashMap<String, ComponentConfig>,
436    ) -> String {
437        let component_list: Vec<(String, ComponentConfig)> = components
438            .iter()
439            .map(|(k, v)| (k.clone(), v.clone()))
440            .collect();
441        self.manifest_parser.infer_execution_mode(&component_list)
442    }
443
444    /// Cache generated config (legacy interface)
445    pub fn cache_generated_config(&self, model_id: &str, config: &ModelConfig) -> Result<()> {
446        self.caching.cache_config(model_id, config)
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use tempfile::TempDir;
454
455    /// Create a mock .mlpackage directory structure for testing
456    fn create_mock_mlpackage(temp_dir: &Path, name: &str) -> Result<PathBuf> {
457        let package_path = temp_dir.join(format!("{name}.mlpackage"));
458        std::fs::create_dir_all(&package_path)?;
459
460        // Create a minimal manifest.json
461        let manifest = serde_json::json!({
462            "fileFormatVersion": "1.0.0",
463            "itemInfoEntries": [
464                {
465                    "path": "model.mlmodel",
466                    "digestType": "SHA256"
467                }
468            ]
469        });
470
471        let manifest_path = package_path.join("Manifest.json");
472        std::fs::write(manifest_path, serde_json::to_string_pretty(&manifest)?)?;
473
474        // Create a minimal model file (empty for testing)
475        let model_path = package_path.join("model.mlmodel");
476        std::fs::write(model_path, b"mock model data")?;
477
478        Ok(package_path)
479    }
480
481    #[test]
482    fn test_modular_config_generator_creation() -> Result<()> {
483        let generator = ConfigGenerator::new()?;
484
485        // Should have all modules initialized
486        assert!(!generator.caching.has_cached_config("nonexistent")); // Should return false but not crash
487
488        Ok(())
489    }
490
491    #[test]
492    fn test_modular_file_discovery() -> Result<()> {
493        let temp_dir = TempDir::new()?;
494        let generator = ConfigGenerator::new()?;
495
496        // Create some mock packages
497        create_mock_mlpackage(temp_dir.path(), "embeddings")?;
498        create_mock_mlpackage(temp_dir.path(), "transformer")?;
499        create_mock_mlpackage(temp_dir.path(), "head")?;
500
501        let packages = generator
502            .file_discovery
503            .find_coreml_packages(temp_dir.path())?;
504
505        assert_eq!(packages.len(), 3);
506        assert!(packages.iter().any(|p| p
507            .file_name()
508            .unwrap()
509            .to_string_lossy()
510            .contains("embeddings")));
511        assert!(packages.iter().any(|p| p
512            .file_name()
513            .unwrap()
514            .to_string_lossy()
515            .contains("transformer")));
516        assert!(packages
517            .iter()
518            .any(|p| p.file_name().unwrap().to_string_lossy().contains("head")));
519
520        Ok(())
521    }
522
523    // Note: Legacy generation on mock/placeholder packages intentionally fails under strict validation.
524    // Tests that assumed permissive behavior have been removed to keep assertions strict.
525
526    #[test]
527    fn test_enhanced_parsing_falls_back_to_filename_when_metadata_empty() -> Result<()> {
528        let temp_dir = TempDir::new()?;
529        let generator = ConfigGenerator::new()?;
530
531        // Create a mock package named "embeddings" with Manifest.json and an empty model file.
532        // This simulates environments where metadata extraction yields empty inputs/outputs.
533        let pkg_path = create_mock_mlpackage(temp_dir.path(), "embeddings")?;
534
535        let mut components = std::collections::HashMap::new();
536        generator
537            .process_package_with_metadata_detection(&pkg_path, &mut components)
538            .expect("enhanced parsing should not fail");
539
540        // Expect that we classified the component as "embeddings" via filename fallback
541        assert!(
542            components.contains_key("embeddings"),
543            "expected 'embeddings' component, got: {:?}",
544            components.keys().collect::<Vec<_>>()
545        );
546        Ok(())
547    }
548}