1use crate::cache::manager::CacheManager;
7use crate::config::model::{ComponentConfig, ModelConfig, NamingConfig};
8use anyhow::Result;
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11use tracing::{debug, info};
12
13pub mod caching;
14pub mod coreml_metadata;
15pub mod file_discovery;
16pub mod manifest_parser;
17pub mod schema_extractor;
18pub mod shape_inference;
19
20use caching::ConfigCaching;
21use file_discovery::FileDiscovery;
22use manifest_parser::ManifestParser;
23use schema_extractor::SchemaExtractor;
24use shape_inference::ShapeInference;
25
26pub use schema_extractor::ComponentRole;
28pub use coreml_metadata::CoreMLMetadataExtractor;
30
31pub struct ConfigGenerator {
33 file_discovery: FileDiscovery,
34 manifest_parser: ManifestParser,
35 schema_extractor: SchemaExtractor,
36 shape_inference: ShapeInference,
37 caching: ConfigCaching,
38 #[allow(dead_code)]
39 metadata_extractor: CoreMLMetadataExtractor,
40}
41
42impl ConfigGenerator {
43 pub fn new() -> Result<Self> {
45 let cache_manager = CacheManager::new()?;
46 let caching = ConfigCaching::new(cache_manager);
47
48 Ok(Self {
49 file_discovery: FileDiscovery::new(),
50 manifest_parser: ManifestParser::new(),
51 schema_extractor: SchemaExtractor::new(),
52 shape_inference: ShapeInference::new(),
53 caching,
54 metadata_extractor: coreml_metadata::CoreMLMetadataExtractor::new(),
55 })
56 }
57
58 pub fn generate_config_from_directory_enhanced(
64 &self,
65 model_dir: &Path,
66 model_id: &str,
67 model_type: &str,
68 ) -> Result<ModelConfig> {
69 info!("🔍 Generating config (enhanced) for model: {}", model_id);
70 debug!(" Model directory: {}", model_dir.display());
71 debug!(" Model type: {}", model_type);
72
73 self.file_discovery.validate_model_directory(model_dir)?;
75 let packages = self.file_discovery.find_coreml_packages(model_dir)?;
76
77 info!("📦 Found {} CoreML model files", packages.len());
78 for package in &packages {
79 debug!(
80 " • {}",
81 package.file_name().unwrap_or_default().to_string_lossy()
82 );
83 }
84
85 let mut components = HashMap::new();
87 for package_path in &packages {
88 self.process_package_with_metadata_detection(package_path, &mut components)?;
89 }
90
91 self.validate_required_components(&components)?;
93
94 let config = self
96 .build_model_config_enhanced(model_id, model_type, model_dir, components, &packages)?;
97
98 info!(
99 "✅ Generated enhanced config for {} with {} components",
100 model_id,
101 config.components.len()
102 );
103
104 self.caching.cache_config(model_id, &config)?;
106
107 Ok(config)
108 }
109
110 pub fn generate_config_from_directory(
115 &self,
116 model_dir: &Path,
117 model_id: &str,
118 model_type: &str,
119 ) -> Result<ModelConfig> {
120 info!("🔍 Generating config for model: {}", model_id);
121 debug!(" Model directory: {}", model_dir.display());
122 debug!(" Model type: {}", model_type);
123
124 self.file_discovery.validate_model_directory(model_dir)?;
126 let packages = self.file_discovery.find_coreml_packages(model_dir)?;
127
128 info!("📦 Found {} CoreML model files", packages.len());
129 for package in &packages {
130 debug!(
131 " • {}",
132 package.file_name().unwrap_or_default().to_string_lossy()
133 );
134 }
135
136 let mut components = HashMap::new();
138 for package_path in &packages {
139 self.process_package(package_path, &mut components)?;
140 }
141
142 let config =
144 self.build_model_config(model_id, model_type, model_dir, components, &packages)?;
145
146 info!(
147 "✅ Generated config for {} with {} components",
148 model_id,
149 config.components.len()
150 );
151
152 self.caching.cache_config(model_id, &config)?;
154
155 Ok(config)
156 }
157
158 pub fn load_cached_config(&self, model_id: &str) -> Result<Option<ModelConfig>> {
160 self.caching.load_cached_config(model_id)
161 }
162
163 pub fn has_cached_config(&self, model_id: &str) -> bool {
165 self.caching.has_cached_config(model_id)
166 }
167
168 pub fn clear_cached_config(&self, model_id: &str) -> Result<()> {
170 self.caching.clear_cached_config(model_id)
171 }
172
173 fn process_package(
176 &self,
177 package_path: &Path,
178 components: &mut HashMap<String, ComponentConfig>,
179 ) -> Result<()> {
180 let manifest = self.file_discovery.read_manifest(package_path)?;
181 let base_component_name = self.file_discovery.infer_component_name(package_path);
182
183 let parsed_components =
185 self.manifest_parser
186 .parse_package(package_path, &manifest, &base_component_name)?;
187
188 for (name, config) in parsed_components {
190 debug!(
191 "📋 Component '{}': inputs={:?} outputs={:?}",
192 name,
193 config.inputs.keys().collect::<Vec<_>>(),
194 config.outputs.keys().collect::<Vec<_>>()
195 );
196 components.insert(name, config);
197 }
198
199 Ok(())
200 }
201
202 fn process_package_with_metadata_detection(
203 &self,
204 package_path: &Path,
205 components: &mut HashMap<String, ComponentConfig>,
206 ) -> Result<()> {
207 let manifest_source = self.file_discovery.find_manifest_source(package_path)?;
209 let manifest = self.file_discovery.read_manifest(package_path)?;
210
211 debug!("🔍 Processing package with source: {:?}", manifest_source);
212
213 let parsed_components = self.manifest_parser.parse_package_enhanced(
215 package_path,
216 &manifest_source,
217 &manifest,
218 &self.schema_extractor,
219 )?;
220
221 for (name, config) in parsed_components {
223 debug!(
224 "📋 Enhanced component '{}': inputs={:?} outputs={:?}",
225 name,
226 config.inputs.keys().collect::<Vec<_>>(),
227 config.outputs.keys().collect::<Vec<_>>()
228 );
229 components.insert(name, config);
230 }
231
232 Ok(())
233 }
234
235 fn validate_required_components(
236 &self,
237 components: &HashMap<String, ComponentConfig>,
238 ) -> Result<()> {
239 if components.is_empty() {
245 return Err(anyhow::Error::msg("No components discovered"));
246 }
247
248 info!("✅ Component presence check passed (partial allowed in enhanced mode)");
249 Ok(())
250 }
251
252 fn build_model_config_enhanced(
253 &self,
254 model_id: &str,
255 model_type: &str,
256 model_dir: &Path,
257 components: HashMap<String, ComponentConfig>,
258 packages: &[PathBuf],
259 ) -> Result<ModelConfig> {
260 let is_pipeline = packages.len() >= 3; let shape_config = self
265 .shape_inference
266 .infer_shapes_with_schema_extractor(&components, &self.schema_extractor)?;
267
268 let naming_config = self.generate_naming_config(packages);
270
271 let component_list: Vec<(String, ComponentConfig)> = components.into_iter().collect();
273 let ffn_execution = self.manifest_parser.infer_execution_mode(&component_list);
274 info!("🔧 Detected execution mode: {}", ffn_execution);
275
276 let final_components: HashMap<String, ComponentConfig> =
277 component_list.into_iter().collect();
278
279 let model = ModelConfig {
280 model_info: crate::config::model::ModelInfo {
281 model_id: Some(model_id.to_string()),
282 path: Some(model_dir.to_string_lossy().to_string()),
283 model_type: model_type.to_string(),
284 discovered_at: Some(chrono::Utc::now().to_rfc3339()),
285 },
286 shapes: shape_config,
287 components: final_components,
288 naming: naming_config,
289 ffn_execution: Some(ffn_execution),
290 };
291
292 if is_pipeline {
294 let required = ["embeddings", "lm_head"];
295 let missing: Vec<_> = required
296 .iter()
297 .filter(|c| !model.components.contains_key(**c))
298 .collect();
299 if !missing.is_empty() {
300 return Err(anyhow::anyhow!(
301 "ModelConfig missing required components in pipeline: {:?}. Found: {:?}",
302 missing,
303 model.components.keys().collect::<Vec<_>>()
304 ));
305 }
306 }
307
308 Ok(model)
309 }
310
311 fn build_model_config(
312 &self,
313 model_id: &str,
314 model_type: &str,
315 model_dir: &Path,
316 components: HashMap<String, ComponentConfig>,
317 packages: &[PathBuf],
318 ) -> Result<ModelConfig> {
319 let shape_config = self.shape_inference.infer_shapes(&components)?;
321
322 let naming_config = self.generate_naming_config(packages);
324
325 let component_list: Vec<(String, ComponentConfig)> = components.into_iter().collect();
327 let ffn_execution = self.manifest_parser.infer_execution_mode(&component_list);
328 info!("🔧 Detected execution mode: {}", ffn_execution);
329
330 let final_components: HashMap<String, ComponentConfig> =
331 component_list.into_iter().collect();
332
333 Ok(ModelConfig {
334 model_info: crate::config::model::ModelInfo {
335 model_id: Some(model_id.to_string()),
336 path: Some(model_dir.to_string_lossy().to_string()),
337 model_type: model_type.to_string(),
338 discovered_at: Some(chrono::Utc::now().to_rfc3339()),
339 },
340 shapes: shape_config,
341 components: final_components,
342 naming: naming_config,
343 ffn_execution: Some(ffn_execution),
344 })
345 }
346
347 fn generate_naming_config(&self, _packages: &[PathBuf]) -> NamingConfig {
348 NamingConfig {
351 embeddings_pattern: None,
352 ffn_infer_pattern: None,
353 ffn_prefill_pattern: None,
354 lm_head_pattern: None,
355 }
356 }
357}
358
359impl ConfigGenerator {
361 pub fn find_mlpackage_files(&self, model_dir: &Path) -> Result<Vec<PathBuf>> {
363 self.file_discovery.find_coreml_packages(model_dir)
364 }
365
366 pub fn infer_component_name_from_file(&self, package_path: &Path) -> String {
368 self.file_discovery.infer_component_name(package_path)
369 }
370
371 pub fn analyze_mlpackage(&self, package_path: &Path) -> Result<ComponentConfig> {
373 let manifest = self.file_discovery.read_manifest(package_path)?;
374 let inputs = self.schema_extractor.extract_inputs(&manifest)?;
375 let outputs = self.schema_extractor.extract_outputs(&manifest)?;
376
377 Ok(ComponentConfig {
378 file_path: Some(package_path.to_string_lossy().to_string()),
379 inputs,
380 outputs,
381 functions: Vec::new(),
382 input_order: None,
383 })
384 }
385
386 pub fn extract_function_based_components(
388 &self,
389 package_path: &Path,
390 _base_config: &ComponentConfig,
391 ) -> Result<Option<HashMap<String, ComponentConfig>>> {
392 let manifest = self.file_discovery.read_manifest(package_path)?;
393 let base_component_name = self.file_discovery.infer_component_name(package_path);
394
395 let parsed_components =
396 self.manifest_parser
397 .parse_package(package_path, &manifest, &base_component_name)?;
398
399 if parsed_components.len() > 1 {
400 let function_components: HashMap<String, ComponentConfig> =
402 parsed_components.into_iter().collect();
403 Ok(Some(function_components))
404 } else {
405 Ok(None)
407 }
408 }
409
410 pub fn parse_tensor_configs_from_schema(
412 &self,
413 schema: &[serde_json::Value],
414 ) -> Result<HashMap<String, crate::model_config::TensorConfig>> {
415 self.schema_extractor.parse_tensor_configs(schema)
416 }
417
418 pub fn compute_shape_info_generic(
421 &self,
422 components: &HashMap<String, ComponentConfig>,
423 ) -> Result<crate::model_config::ShapeConfig> {
424 self.shape_inference.infer_shapes(components)
425 }
426
427 pub fn generate_naming_config_generic(&self, packages: &[PathBuf]) -> NamingConfig {
429 self.generate_naming_config(packages)
430 }
431
432 pub fn determine_execution_mode_generic(
434 &self,
435 components: &HashMap<String, ComponentConfig>,
436 ) -> String {
437 let component_list: Vec<(String, ComponentConfig)> = components
438 .iter()
439 .map(|(k, v)| (k.clone(), v.clone()))
440 .collect();
441 self.manifest_parser.infer_execution_mode(&component_list)
442 }
443
444 pub fn cache_generated_config(&self, model_id: &str, config: &ModelConfig) -> Result<()> {
446 self.caching.cache_config(model_id, config)
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use tempfile::TempDir;
454
455 fn create_mock_mlpackage(temp_dir: &Path, name: &str) -> Result<PathBuf> {
457 let package_path = temp_dir.join(format!("{name}.mlpackage"));
458 std::fs::create_dir_all(&package_path)?;
459
460 let manifest = serde_json::json!({
462 "fileFormatVersion": "1.0.0",
463 "itemInfoEntries": [
464 {
465 "path": "model.mlmodel",
466 "digestType": "SHA256"
467 }
468 ]
469 });
470
471 let manifest_path = package_path.join("Manifest.json");
472 std::fs::write(manifest_path, serde_json::to_string_pretty(&manifest)?)?;
473
474 let model_path = package_path.join("model.mlmodel");
476 std::fs::write(model_path, b"mock model data")?;
477
478 Ok(package_path)
479 }
480
481 #[test]
482 fn test_modular_config_generator_creation() -> Result<()> {
483 let generator = ConfigGenerator::new()?;
484
485 assert!(!generator.caching.has_cached_config("nonexistent")); Ok(())
489 }
490
491 #[test]
492 fn test_modular_file_discovery() -> Result<()> {
493 let temp_dir = TempDir::new()?;
494 let generator = ConfigGenerator::new()?;
495
496 create_mock_mlpackage(temp_dir.path(), "embeddings")?;
498 create_mock_mlpackage(temp_dir.path(), "transformer")?;
499 create_mock_mlpackage(temp_dir.path(), "head")?;
500
501 let packages = generator
502 .file_discovery
503 .find_coreml_packages(temp_dir.path())?;
504
505 assert_eq!(packages.len(), 3);
506 assert!(packages.iter().any(|p| p
507 .file_name()
508 .unwrap()
509 .to_string_lossy()
510 .contains("embeddings")));
511 assert!(packages.iter().any(|p| p
512 .file_name()
513 .unwrap()
514 .to_string_lossy()
515 .contains("transformer")));
516 assert!(packages
517 .iter()
518 .any(|p| p.file_name().unwrap().to_string_lossy().contains("head")));
519
520 Ok(())
521 }
522
523 #[test]
527 fn test_enhanced_parsing_falls_back_to_filename_when_metadata_empty() -> Result<()> {
528 let temp_dir = TempDir::new()?;
529 let generator = ConfigGenerator::new()?;
530
531 let pkg_path = create_mock_mlpackage(temp_dir.path(), "embeddings")?;
534
535 let mut components = std::collections::HashMap::new();
536 generator
537 .process_package_with_metadata_detection(&pkg_path, &mut components)
538 .expect("enhanced parsing should not fail");
539
540 assert!(
542 components.contains_key("embeddings"),
543 "expected 'embeddings' component, got: {:?}",
544 components.keys().collect::<Vec<_>>()
545 );
546 Ok(())
547 }
548}