candle_coreml/config/generator/
shape_inference.rs1use super::schema_extractor::SchemaExtractor;
6use crate::config::model::{ComponentConfig, ShapeConfig};
7use anyhow::{anyhow, Result};
8use std::collections::HashMap;
9use tracing::debug;
10
11pub struct ShapeInference;
12
13impl Default for ShapeInference {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl ShapeInference {
20 pub fn new() -> Self {
21 Self
22 }
23
24 pub fn infer_shapes_with_schema_extractor(
27 &self,
28 components: &HashMap<String, ComponentConfig>,
29 schema_extractor: &SchemaExtractor,
30 ) -> Result<ShapeConfig, anyhow::Error> {
31 self.validate_tensor_metadata(components)?;
33
34 let batch_size = self.infer_batch_size(components);
35 let hidden_size = self.infer_hidden_size(components);
36 let context_length = self.infer_context_length(components);
37 let vocab_size = self.infer_vocab_size_with_chunking(components, schema_extractor);
38
39 debug!(
40 "📊 Inferred shapes (validated): batch={}, context={}, hidden={}, vocab={}",
41 batch_size, context_length, hidden_size, vocab_size
42 );
43
44 Ok(ShapeConfig {
45 batch_size,
46 context_length,
47 hidden_size,
48 vocab_size,
49 })
50 }
51
52 pub fn infer_shapes(
55 &self,
56 components: &HashMap<String, ComponentConfig>,
57 ) -> Result<ShapeConfig, anyhow::Error> {
58 self.validate_tensor_metadata(components)?;
60
61 let batch_size = self.infer_batch_size(components);
62 let hidden_size = self.infer_hidden_size(components);
63 let context_length = self.infer_context_length(components);
64 let vocab_size = self.infer_vocab_size(components);
65
66 debug!(
67 "📊 Inferred shapes (validated): batch={}, context={}, hidden={}, vocab={}",
68 batch_size, context_length, hidden_size, vocab_size
69 );
70
71 Ok(ShapeConfig {
72 batch_size,
73 context_length,
74 hidden_size,
75 vocab_size,
76 })
77 }
78
79 fn infer_batch_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
81 let mut batch_sizes = Vec::new();
82
83 for component in components.values() {
84 for tensor in component.inputs.values().chain(component.outputs.values()) {
85 if !tensor.shape.is_empty() {
86 batch_sizes.push(tensor.shape[0]);
87 }
88 }
89 }
90
91 batch_sizes.into_iter().min().unwrap_or(1)
92 }
93
94 fn infer_hidden_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
96 let mut from_hidden_states = Vec::new();
97 let mut heuristic = Vec::new();
98
99 for component in components.values() {
100 for (name, tensor) in component.inputs.iter().chain(component.outputs.iter()) {
102 let lname = name.to_lowercase();
104 let is_logits = lname.starts_with("logits");
105
106 if tensor.shape.len() >= 3 {
107 let feat = tensor.shape[2];
108 if name == "hidden_states" {
109 from_hidden_states.push(feat);
110 } else if !is_logits {
111 heuristic.push(feat);
112 }
113 } else if tensor.shape.len() == 2 {
114 let feat = tensor.shape[1];
116 if name == "hidden_states" {
117 from_hidden_states.push(feat);
118 } else if !is_logits && feat > 100 {
119 heuristic.push(feat);
120 }
121 }
122 }
123 }
124
125 if let Some(max_hidden) = from_hidden_states.into_iter().max() {
126 return max_hidden;
127 }
128 heuristic.into_iter().max().unwrap_or(1024)
129 }
130
131 fn infer_context_length(&self, components: &HashMap<String, ComponentConfig>) -> usize {
133 let mut seq_lengths = Vec::new();
134
135 for component in components.values() {
136 for tensor in component.inputs.values().chain(component.outputs.values()) {
137 if tensor.shape.len() >= 2 && tensor.shape[1] > 1 {
138 seq_lengths.push(tensor.shape[1]);
140 }
141 if tensor.shape.len() >= 4 {
143 seq_lengths.push(tensor.shape[3]);
144 }
145 }
146 }
147
148 seq_lengths.into_iter().max().unwrap_or(256)
149 }
150
151 fn infer_vocab_size_with_chunking(
153 &self,
154 components: &HashMap<String, ComponentConfig>,
155 schema_extractor: &SchemaExtractor,
156 ) -> usize {
157 for (name, component) in components {
159 if name == "lm_head" || name.contains("lm_head") {
160 if let Some(vocab_size) =
161 schema_extractor.calculate_vocab_size_from_logits(&component.outputs)
162 {
163 debug!(
164 "📊 Using chunked logits vocab size calculation: {}",
165 vocab_size
166 );
167 return vocab_size;
168 }
169 }
170 }
171
172 debug!("📊 Using legacy vocab size detection");
174 self.infer_vocab_size(components)
175 }
176
177 fn infer_vocab_size(&self, components: &HashMap<String, ComponentConfig>) -> usize {
179 let mut output_sizes = Vec::new();
180
181 for component in components.values() {
182 for tensor in component.outputs.values() {
183 if let Some(&last_dim) = tensor.shape.last() {
184 if last_dim > 1000 {
185 output_sizes.push(last_dim);
187 }
188 }
189 }
190 }
191
192 output_sizes.into_iter().max().unwrap_or(30000)
193 }
194
195 pub fn analyze_components(
197 &self,
198 components: &HashMap<String, ComponentConfig>,
199 ) -> ComponentAnalysis {
200 let mut analysis = ComponentAnalysis::default();
201
202 for (name, component) in components {
203 let comp_analysis = self.analyze_single_component(name, component);
204 analysis.components.insert(name.clone(), comp_analysis);
205 }
206
207 analysis.total_components = components.len();
208 analysis.function_based_components = components
209 .values()
210 .filter(|c| !c.functions.is_empty())
211 .count();
212 analysis.multi_function_components = components
213 .values()
214 .filter(|c| c.functions.len() > 1)
215 .count();
216
217 analysis
218 }
219
220 fn analyze_single_component(
221 &self,
222 name: &str,
223 component: &ComponentConfig,
224 ) -> SingleComponentAnalysis {
225 let input_shapes: Vec<Vec<usize>> =
226 component.inputs.values().map(|t| t.shape.clone()).collect();
227 let output_shapes: Vec<Vec<usize>> = component
228 .outputs
229 .values()
230 .map(|t| t.shape.clone())
231 .collect();
232
233 let max_input_dim = input_shapes.iter().flatten().max().copied().unwrap_or(0);
234 let max_output_dim = output_shapes.iter().flatten().max().copied().unwrap_or(0);
235
236 SingleComponentAnalysis {
237 name: name.to_string(),
238 input_count: component.inputs.len(),
239 output_count: component.outputs.len(),
240 function_count: component.functions.len(),
241 input_shapes,
242 output_shapes,
243 max_input_dimension: max_input_dim,
244 max_output_dimension: max_output_dim,
245 }
246 }
247
248 fn validate_tensor_metadata(
251 &self,
252 components: &HashMap<String, ComponentConfig>,
253 ) -> Result<()> {
254 let mut empty_components = Vec::new();
255 let mut components_with_issues = Vec::new();
256
257 for (name, component) in components {
258 if component.inputs.is_empty() && component.outputs.is_empty() {
260 empty_components.push(name.clone());
261 continue;
262 }
263
264 let has_valid_shapes = component
266 .inputs
267 .values()
268 .chain(component.outputs.values())
269 .any(|tensor| !tensor.shape.is_empty() && tensor.shape.iter().all(|&dim| dim > 0));
270
271 if !has_valid_shapes {
272 components_with_issues.push(name.clone());
273 }
274 }
275
276 if !empty_components.is_empty() {
278 return Err(anyhow!(
279 "Configuration generation failed: Components have empty tensor metadata.\n\
280 \n\
281 🔍 Components with empty tensor maps: {:?}\n\
282 \n\
283 💡 This typically indicates one of these issues:\n\
284 1. CoreML model files lack proper metadata (model.mlmodel missing or corrupt)\n\
285 2. Model packages are incomplete (.mlpackage structure is invalid)\n\
286 3. Metadata extraction failed during parsing\n\
287 \n\
288 🛠️ Possible solutions:\n\
289 1. Re-download the model from the original source\n\
290 2. Verify .mlpackage directory structure contains Data/com.apple.CoreML/model.mlmodel\n\
291 3. Check model compatibility with this version of candle-coreml\n\
292 4. For typo-fixer models: ensure using the correct coreml variant from HuggingFace\n\
293 \n\
294 📝 Expected tensor information:\n\
295 - Embeddings: input_ids → hidden_states\n\
296 - FFN: hidden_states + causal_mask → hidden_states\n\
297 - LM Head: hidden_states → logits (potentially chunked)",
298 empty_components
299 ));
300 }
301
302 if !components_with_issues.is_empty() {
303 return Err(anyhow!(
304 "Configuration generation failed: Components have invalid tensor shape information.\n\
305 \n\
306 🔍 Components with shape issues: {:?}\n\
307 \n\
308 💡 These components have tensor information but with invalid shapes (empty or zero dimensions).\n\
309 \n\
310 🛠️ This suggests corrupted model metadata. Try re-downloading the model.",
311 components_with_issues
312 ));
313 }
314
315 self.validate_model_specific_requirements(components)?;
317
318 debug!(
319 "✅ Tensor metadata validation passed for {} components",
320 components.len()
321 );
322 Ok(())
323 }
324
325 fn validate_model_specific_requirements(
327 &self,
328 components: &HashMap<String, ComponentConfig>,
329 ) -> Result<()> {
330 let looks_like_typo_fixer = components
332 .keys()
333 .any(|name| name.contains("typo") || name.contains("fixer"))
334 || components.values().any(|comp| {
335 comp.file_path
336 .as_ref()
337 .map(|path| path.contains("typo-fixer"))
338 .unwrap_or(false)
339 });
340
341 if looks_like_typo_fixer {
342 let lm_head = components.get("lm_head");
344 if let Some(lm_head) = lm_head {
345 let has_logits = lm_head.outputs.keys().any(|k| k.contains("logits"));
346 let logits_total_size: usize = lm_head
347 .outputs
348 .iter()
349 .filter(|(name, _)| name.contains("logits"))
350 .map(|(_, tensor)| tensor.shape.last().copied().unwrap_or(0))
351 .sum();
352
353 if !has_logits {
354 return Err(anyhow!(
355 "Typo-fixer model validation failed: LM head component lacks logits outputs.\n\
356 \n\
357 🔍 LM head outputs found: {:?}\n\
358 \n\
359 💡 Typo-fixer models require chunked logits outputs (logits_0, logits_1, etc.)\n\
360 🛠️ Ensure you're using the correct typo-fixer coreml model variant.",
361 lm_head.outputs.keys().collect::<Vec<_>>()
362 ));
363 }
364
365 if logits_total_size > 0 && logits_total_size < 100000 {
367 return Err(anyhow!(
368 "Typo-fixer model validation failed: Vocabulary size {} is too small.\n\
369 \n\
370 💡 Expected vocab size ≥ 100,000 for typo-fixer (typically 151,669)\n\
371 🔍 Detected logits total size: {}\n\
372 \n\
373 🛠️ This suggests model metadata extraction issues or wrong model variant.",
374 logits_total_size, logits_total_size
375 ));
376 }
377 }
378 }
379
380 Ok(())
381 }
382}
383
384#[derive(Debug, Default)]
385pub struct ComponentAnalysis {
386 pub total_components: usize,
387 pub function_based_components: usize,
388 pub multi_function_components: usize,
389 pub components: HashMap<String, SingleComponentAnalysis>,
390}
391
392#[derive(Debug)]
393pub struct SingleComponentAnalysis {
394 pub name: String,
395 pub input_count: usize,
396 pub output_count: usize,
397 pub function_count: usize,
398 pub input_shapes: Vec<Vec<usize>>,
399 pub output_shapes: Vec<Vec<usize>>,
400 pub max_input_dimension: usize,
401 pub max_output_dimension: usize,
402}