candle_coreml/config/generator/
shape_inference.rs

1//! Generic shape inference from CoreML components
2//!
3//! Infers model dimensions (batch size, hidden size, etc.) from actual tensor configurations
4
5use super::schema_extractor::SchemaExtractor;
6use crate::config::model::{ComponentConfig, ShapeConfig};
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use tracing::debug;
10
11pub struct ShapeInference;
12
13impl Default for ShapeInference {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl ShapeInference {
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Compute overall shape configuration from components (enhanced with metadata-driven detection)
25    /// Returns an error if components have insufficient tensor metadata
26    pub fn infer_shapes_with_schema_extractor(
27        &self,
28        components: &HashMap<String, ComponentConfig>,
29        schema_extractor: &SchemaExtractor,
30    ) -> Result<ShapeConfig, anyhow::Error> {
31        // Strict validation: require sufficient tensor metadata
32        self.validate_tensor_metadata(components)?;
33
34        let batch_size = self.infer_batch_size(components);
35        let hidden_size = self.infer_hidden_size(components);
36        let context_length = self.infer_context_length(components);
37        let vocab_size = self.infer_vocab_size_with_chunking(components, schema_extractor);
38
39        debug!(
40            "📊 Inferred shapes (validated): batch={}, context={}, hidden={}, vocab={}",
41            batch_size, context_length, hidden_size, vocab_size
42        );
43
44        Ok(ShapeConfig {
45            batch_size,
46            context_length,
47            hidden_size,
48            vocab_size,
49        })
50    }
51
52    /// Compute overall shape configuration from components (legacy approach)
53    /// Returns an error if components have insufficient tensor metadata
54    pub fn infer_shapes(
55        &self,
56        components: &HashMap<String, ComponentConfig>,
57    ) -> Result<ShapeConfig, anyhow::Error> {
58        // Strict validation on legacy path as well
59        self.validate_tensor_metadata(components)?;
60
61        let batch_size = self.infer_batch_size(components);
62        let hidden_size = self.infer_hidden_size(components);
63        let context_length = self.infer_context_length(components);
64        let vocab_size = self.infer_vocab_size(components);
65
66        debug!(
67            "📊 Inferred shapes (validated): batch={}, context={}, hidden={}, vocab={}",
68            batch_size, context_length, hidden_size, vocab_size
69        );
70
71        Ok(ShapeConfig {
72            batch_size,
73            context_length,
74            hidden_size,
75            vocab_size,
76        })
77    }
78
79    /// Infer batch size from the smallest batch dimension across all components
80    fn infer_batch_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
81        let mut batch_sizes = Vec::new();
82
83        for component in components.values() {
84            for tensor in component.inputs.values().chain(component.outputs.values()) {
85                if !tensor.shape.is_empty() {
86                    batch_sizes.push(tensor.shape[0]);
87                }
88            }
89        }
90
91        batch_sizes.into_iter().min().unwrap_or(1)
92    }
93
94    /// Infer hidden size from hidden_states tensors, ignoring logits. Falls back to heuristic if needed.
95    fn infer_hidden_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
96        let mut from_hidden_states = Vec::new();
97        let mut heuristic = Vec::new();
98
99        for component in components.values() {
100            // Prefer tensors explicitly named "hidden_states"
101            for (name, tensor) in component.inputs.iter().chain(component.outputs.iter()) {
102                // Skip any logits-like tensors
103                let lname = name.to_lowercase();
104                let is_logits = lname.starts_with("logits");
105
106                if tensor.shape.len() >= 3 {
107                    let feat = tensor.shape[2];
108                    if name == "hidden_states" {
109                        from_hidden_states.push(feat);
110                    } else if !is_logits {
111                        heuristic.push(feat);
112                    }
113                } else if tensor.shape.len() == 2 {
114                    // Some manifests might flatten features as [1, hidden]
115                    let feat = tensor.shape[1];
116                    if name == "hidden_states" {
117                        from_hidden_states.push(feat);
118                    } else if !is_logits && feat > 100 {
119                        heuristic.push(feat);
120                    }
121                }
122            }
123        }
124
125        if let Some(max_hidden) = from_hidden_states.into_iter().max() {
126            return max_hidden;
127        }
128        heuristic.into_iter().max().unwrap_or(1024)
129    }
130
131    /// Infer context/sequence length from sequence dimensions
132    fn infer_context_length(&self, components: &HashMap<String, ComponentConfig>) -> usize {
133        let mut seq_lengths = Vec::new();
134
135        for component in components.values() {
136            for tensor in component.inputs.values().chain(component.outputs.values()) {
137                if tensor.shape.len() >= 2 && tensor.shape[1] > 1 {
138                    // 2D+ tensors: sequence dimension is usually index 1
139                    seq_lengths.push(tensor.shape[1]);
140                }
141                // Also check 4D tensors (e.g., attention masks)
142                if tensor.shape.len() >= 4 {
143                    seq_lengths.push(tensor.shape[3]);
144                }
145            }
146        }
147
148        seq_lengths.into_iter().max().unwrap_or(256)
149    }
150
151    /// Infer vocabulary size with chunked logits support (typo-fixer model pattern)
152    fn infer_vocab_size_with_chunking(
153        &self,
154        components: &HashMap<String, ComponentConfig>,
155        schema_extractor: &SchemaExtractor,
156    ) -> usize {
157        // First, try to find the lm_head component and use schema extractor's logic
158        for (name, component) in components {
159            if name == "lm_head" || name.contains("lm_head") {
160                if let Some(vocab_size) =
161                    schema_extractor.calculate_vocab_size_from_logits(&component.outputs)
162                {
163                    debug!(
164                        "📊 Using chunked logits vocab size calculation: {}",
165                        vocab_size
166                    );
167                    return vocab_size;
168                }
169            }
170        }
171
172        // Fallback to legacy logic
173        debug!("📊 Using legacy vocab size detection");
174        self.infer_vocab_size(components)
175    }
176
177    /// Infer vocabulary size from the largest output dimension (legacy)
178    fn infer_vocab_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
179        let mut output_sizes = Vec::new();
180
181        for component in components.values() {
182            for tensor in component.outputs.values() {
183                if let Some(&last_dim) = tensor.shape.last() {
184                    if last_dim > 1000 {
185                        // Likely a vocabulary or class dimension
186                        output_sizes.push(last_dim);
187                    }
188                }
189            }
190        }
191
192        output_sizes.into_iter().max().unwrap_or(30000)
193    }
194
195    /// Analyze component characteristics for debugging
196    pub fn analyze_components(
197        &self,
198        components: &HashMap<String, ComponentConfig>,
199    ) -> ComponentAnalysis {
200        let mut analysis = ComponentAnalysis::default();
201
202        for (name, component) in components {
203            let comp_analysis = self.analyze_single_component(name, component);
204            analysis.components.insert(name.clone(), comp_analysis);
205        }
206
207        analysis.total_components = components.len();
208        analysis.function_based_components = components
209            .values()
210            .filter(|c| !c.functions.is_empty())
211            .count();
212        analysis.multi_function_components = components
213            .values()
214            .filter(|c| c.functions.len() > 1)
215            .count();
216
217        analysis
218    }
219
220    fn analyze_single_component(
221        &self,
222        name: &str,
223        component: &ComponentConfig,
224    ) -> SingleComponentAnalysis {
225        let input_shapes: Vec<Vec<usize>> =
226            component.inputs.values().map(|t| t.shape.clone()).collect();
227        let output_shapes: Vec<Vec<usize>> = component
228            .outputs
229            .values()
230            .map(|t| t.shape.clone())
231            .collect();
232
233        let max_input_dim = input_shapes.iter().flatten().max().copied().unwrap_or(0);
234        let max_output_dim = output_shapes.iter().flatten().max().copied().unwrap_or(0);
235
236        SingleComponentAnalysis {
237            name: name.to_string(),
238            input_count: component.inputs.len(),
239            output_count: component.outputs.len(),
240            function_count: component.functions.len(),
241            input_shapes,
242            output_shapes,
243            max_input_dimension: max_input_dim,
244            max_output_dimension: max_output_dim,
245        }
246    }
247
248    /// Validate that components have sufficient tensor metadata for shape inference
249    /// Returns an error with actionable guidance if metadata is insufficient
250    fn validate_tensor_metadata(
251        &self,
252        components: &HashMap<String, ComponentConfig>,
253    ) -> Result<()> {
254        let mut empty_components = Vec::new();
255        let mut components_with_issues = Vec::new();
256
257        for (name, component) in components {
258            // Check for completely empty tensor maps
259            if component.inputs.is_empty() && component.outputs.is_empty() {
260                empty_components.push(name.clone());
261                continue;
262            }
263
264            // Check for components with tensors but no shape information
265            let has_valid_shapes = component
266                .inputs
267                .values()
268                .chain(component.outputs.values())
269                .any(|tensor| !tensor.shape.is_empty() && tensor.shape.iter().all(|&dim| dim > 0));
270
271            if !has_valid_shapes {
272                components_with_issues.push(name.clone());
273            }
274        }
275
276        // Fail fast with clear error messages
277        if !empty_components.is_empty() {
278            return Err(anyhow!(
279                "Configuration generation failed: Components have empty tensor metadata.\n\
280                 \n\
281                 🔍 Components with empty tensor maps: {:?}\n\
282                 \n\
283                 💡 This typically indicates one of these issues:\n\
284                    1. CoreML model files lack proper metadata (model.mlmodel missing or corrupt)\n\
285                    2. Model packages are incomplete (.mlpackage structure is invalid)\n\
286                    3. Metadata extraction failed during parsing\n\
287                 \n\
288                 🛠️  Possible solutions:\n\
289                    1. Re-download the model from the original source\n\
290                    2. Verify .mlpackage directory structure contains Data/com.apple.CoreML/model.mlmodel\n\
291                    3. Check model compatibility with this version of candle-coreml\n\
292                    4. For typo-fixer models: ensure using the correct coreml variant from HuggingFace\n\
293                 \n\
294                 📝 Expected tensor information:\n\
295                    - Embeddings: input_ids → hidden_states\n\
296                    - FFN: hidden_states + causal_mask → hidden_states\n\
297                    - LM Head: hidden_states → logits (potentially chunked)",
298                empty_components
299            ));
300        }
301
302        if !components_with_issues.is_empty() {
303            return Err(anyhow!(
304                "Configuration generation failed: Components have invalid tensor shape information.\n\
305                 \n\
306                 🔍 Components with shape issues: {:?}\n\
307                 \n\
308                 💡 These components have tensor information but with invalid shapes (empty or zero dimensions).\n\
309                 \n\
310                 🛠️  This suggests corrupted model metadata. Try re-downloading the model.",
311                components_with_issues
312            ));
313        }
314
315        // Additional validation for specific model patterns
316        self.validate_model_specific_requirements(components)?;
317
318        debug!(
319            "✅ Tensor metadata validation passed for {} components",
320            components.len()
321        );
322        Ok(())
323    }
324
325    /// Validate model-specific requirements (e.g., typo-fixer needs proper vocab size)
326    fn validate_model_specific_requirements(
327        &self,
328        components: &HashMap<String, ComponentConfig>,
329    ) -> Result<()> {
330        // Detect if this looks like a typo-fixer model based on filenames
331        let looks_like_typo_fixer = components
332            .keys()
333            .any(|name| name.contains("typo") || name.contains("fixer"))
334            || components.values().any(|comp| {
335                comp.file_path
336                    .as_ref()
337                    .map(|path| path.contains("typo-fixer"))
338                    .unwrap_or(false)
339            });
340
341        if looks_like_typo_fixer {
342            // For typo-fixer models, validate we can extract proper vocab size from LM head
343            let lm_head = components.get("lm_head");
344            if let Some(lm_head) = lm_head {
345                let has_logits = lm_head.outputs.keys().any(|k| k.contains("logits"));
346                let logits_total_size: usize = lm_head
347                    .outputs
348                    .iter()
349                    .filter(|(name, _)| name.contains("logits"))
350                    .map(|(_, tensor)| tensor.shape.last().copied().unwrap_or(0))
351                    .sum();
352
353                if !has_logits {
354                    return Err(anyhow!(
355                        "Typo-fixer model validation failed: LM head component lacks logits outputs.\n\
356                         \n\
357                         🔍 LM head outputs found: {:?}\n\
358                         \n\
359                         💡 Typo-fixer models require chunked logits outputs (logits_0, logits_1, etc.)\n\
360                         🛠️  Ensure you're using the correct typo-fixer coreml model variant.",
361                        lm_head.outputs.keys().collect::<Vec<_>>()
362                    ));
363                }
364
365                // Typo-fixer should have vocab size around 151,669
366                if logits_total_size > 0 && logits_total_size < 100000 {
367                    return Err(anyhow!(
368                        "Typo-fixer model validation failed: Vocabulary size {} is too small.\n\
369                         \n\
370                         💡 Expected vocab size ≥ 100,000 for typo-fixer (typically 151,669)\n\
371                         🔍 Detected logits total size: {}\n\
372                         \n\
373                         🛠️  This suggests model metadata extraction issues or wrong model variant.",
374                        logits_total_size, logits_total_size
375                    ));
376                }
377            }
378        }
379
380        Ok(())
381    }
382}
383
384#[derive(Debug, Default)]
385pub struct ComponentAnalysis {
386    pub total_components: usize,
387    pub function_based_components: usize,
388    pub multi_function_components: usize,
389    pub components: HashMap<String, SingleComponentAnalysis>,
390}
391
392#[derive(Debug)]
393pub struct SingleComponentAnalysis {
394    pub name: String,
395    pub input_count: usize,
396    pub output_count: usize,
397    pub function_count: usize,
398    pub input_shapes: Vec<Vec<usize>>,
399    pub output_shapes: Vec<Vec<usize>>,
400    pub max_input_dimension: usize,
401    pub max_output_dimension: usize,
402}