candle_coreml/config/
model.rs

1//! Model configuration system for candle-coreml
2//!
3//! This module provides a configuration system that replaces hardcoded model shapes
4//! with discoverable, flexible configurations. It supports loading configurations from
5//! JSON files generated by the shape discovery tool, as well as built-in configurations
6//! for known models.
7
8use anyhow::{Context, Result};
9use candle_core::{Device, Error as CandleError, Tensor};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::path::Path;
13use tracing::{debug, trace};
14
15/// Complete model configuration including shapes, components, and naming patterns
16#[derive(Debug, Clone, Deserialize, Serialize)]
17pub struct ModelConfig {
18    pub model_info: ModelInfo,
19    pub shapes: ShapeConfig,
20    pub components: HashMap<String, ComponentConfig>,
21    pub naming: NamingConfig,
22    /// Execution mode for FFN: "unified" (single component/function) or "split" (separate prefill/infer components)
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub ffn_execution: Option<String>,
25}
26
27/// Model metadata and identification
28#[derive(Debug, Clone, Deserialize, Serialize)]
29pub struct ModelInfo {
30    #[serde(default)]
31    pub model_id: Option<String>,
32    pub path: Option<String>,
33    pub model_type: String,
34    pub discovered_at: Option<String>,
35}
36
37/// Overall model shape parameters
38#[derive(Debug, Clone, Deserialize, Serialize)]
39pub struct ShapeConfig {
40    pub batch_size: usize,
41    pub context_length: usize,
42    pub hidden_size: usize,
43    pub vocab_size: usize,
44}
45
46/// Configuration for a single model component (embeddings, FFN, LM head, etc.)
47#[derive(Debug, Clone, Deserialize, Serialize)]
48pub struct ComponentConfig {
49    pub file_path: Option<String>,
50    pub inputs: HashMap<String, TensorConfig>,
51    pub outputs: HashMap<String, TensorConfig>,
52    pub functions: Vec<String>,
53    /// Optional deterministic input order; if absent, caller must provide correct order.
54    #[serde(default)]
55    pub input_order: Option<Vec<String>>,
56}
57
58/// Configuration for a tensor (shape and data type)
59#[derive(Debug, Clone, Deserialize, Serialize)]
60pub struct TensorConfig {
61    pub name: String,
62    pub shape: Vec<usize>,
63    pub data_type: String,
64}
65
66/// Model file naming patterns for component discovery
67#[derive(Debug, Clone, Deserialize, Serialize)]
68pub struct NamingConfig {
69    // Deprecated: patterns removed; explicit file paths are required per component
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub embeddings_pattern: Option<String>,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub ffn_prefill_pattern: Option<String>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub ffn_infer_pattern: Option<String>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub lm_head_pattern: Option<String>,
78}
79
80impl ModelConfig {
81    /// Create a minimal default Qwen ModelConfig (no components). Useful for tests and fallbacks.
82    pub fn default_qwen() -> Self {
83        Self {
84            model_info: ModelInfo {
85                model_id: Some("default/qwen".to_string()),
86                path: None,
87                model_type: "qwen".to_string(),
88                discovered_at: None,
89            },
90            shapes: ShapeConfig {
91                batch_size: 1,
92                context_length: 512,
93                hidden_size: 1024,
94                vocab_size: 151_936,
95            },
96            components: HashMap::new(),
97            naming: NamingConfig {
98                embeddings_pattern: None,
99                ffn_prefill_pattern: None,
100                ffn_infer_pattern: None,
101                lm_head_pattern: None,
102            },
103            ffn_execution: None,
104        }
105    }
106    /// Load configuration from a JSON file
107    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
108        let path = path.as_ref();
109        let content = std::fs::read_to_string(path)
110            .with_context(|| format!("Failed to read config file: {}", path.display()))?;
111
112        let config: ModelConfig = serde_json::from_str(&content)
113            .with_context(|| format!("Failed to parse config file: {}", path.display()))?;
114
115        Ok(config)
116    }
117
118    /// Save configuration to a JSON file
119    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
120        let path = path.as_ref();
121        let content =
122            serde_json::to_string_pretty(self).context("Failed to serialize configuration")?;
123
124        std::fs::write(path, content)
125            .with_context(|| format!("Failed to write config file: {}", path.display()))?;
126
127        Ok(())
128    }
129
130    /// Get the shape configuration for a specific component and tensor
131    pub fn get_tensor_shape(
132        &self,
133        component: &str,
134        tensor_name: &str,
135        is_input: bool,
136    ) -> Option<&Vec<usize>> {
137        let component_config = self.components.get(component)?;
138
139        let tensor_map = if is_input {
140            &component_config.inputs
141        } else {
142            &component_config.outputs
143        };
144
145        tensor_map.get(tensor_name).map(|tensor| &tensor.shape)
146    }
147
148    /// Get the expected input shape for embeddings
149    pub fn embeddings_input_shape(&self) -> Option<&Vec<usize>> {
150        self.get_tensor_shape("embeddings", "input_ids", true)
151    }
152
153    /// Get the expected output shape for embeddings
154    pub fn embeddings_output_shape(&self) -> Option<&Vec<usize>> {
155        self.get_tensor_shape("embeddings", "hidden_states", false)
156    }
157
158    /// Get the expected input shape for FFN prefill
159    pub fn ffn_prefill_input_shape(&self) -> Option<&Vec<usize>> {
160        self.get_tensor_shape("ffn_prefill", "hidden_states", true)
161    }
162
163    /// Get the expected input shape for LM head
164    pub fn lm_head_input_shape(&self) -> Option<&Vec<usize>> {
165        self.get_tensor_shape("lm_head", "hidden_states", true)
166    }
167
168    /// Check if this model supports multi-part logits output
169    pub fn has_multipart_logits(&self) -> bool {
170        if let Some(lm_head) = self.components.get("lm_head") {
171            // Check if there are multiple logits outputs (logits1, logits2, etc.)
172            let logits_outputs: Vec<_> = lm_head
173                .outputs
174                .keys()
175                .filter(|name| name.starts_with("logits") && name.len() > 6) // "logits" + number
176                .collect();
177            return logits_outputs.len() > 1;
178        }
179        false
180    }
181
182    /// Get the number of logits parts for multi-part output
183    pub fn logits_part_count(&self) -> usize {
184        if let Some(lm_head) = self.components.get("lm_head") {
185            let logits_outputs: Vec<_> = lm_head
186                .outputs
187                .keys()
188                .filter(|name| name.starts_with("logits"))
189                .collect();
190            if logits_outputs.is_empty() {
191                1 // Single logits output
192            } else {
193                logits_outputs.len()
194            }
195        } else {
196            1
197        }
198    }
199
200    /// Select the primary logits output name for the LM head.
201    /// Preference order: "logits1" (multipart), then "logits", otherwise the first available key.
202    pub fn lm_head_primary_output_name(&self) -> Option<String> {
203        let lm_head = self.components.get("lm_head")?;
204
205        // Prefer explicit multipart naming starting with logits1
206        if lm_head.outputs.contains_key("logits1") {
207            return Some("logits1".to_string());
208        }
209
210        // Common single output name
211        if lm_head.outputs.contains_key("logits") {
212            return Some("logits".to_string());
213        }
214
215        // Fallback to the first key if any
216        lm_head.outputs.keys().next().map(|k| k.to_string())
217    }
218
219    /// Validate the configuration for consistency
220    pub fn validate(&self) -> Result<()> {
221        // Check that required components exist
222        let required_components = ["embeddings", "lm_head"];
223        for component in required_components {
224            if !self.components.contains_key(component) {
225                return Err(anyhow::anyhow!("Missing required component: {}", component));
226            }
227        }
228
229        // Check shape consistency
230        if self.shapes.batch_size == 0 {
231            return Err(anyhow::anyhow!("batch_size must be greater than 0"));
232        }
233
234        if self.shapes.context_length == 0 {
235            return Err(anyhow::anyhow!("context_length must be greater than 0"));
236        }
237
238        if self.shapes.hidden_size == 0 {
239            return Err(anyhow::anyhow!("hidden_size must be greater than 0"));
240        }
241
242        if self.shapes.vocab_size == 0 {
243            return Err(anyhow::anyhow!("vocab_size must be greater than 0"));
244        }
245
246        // Validate tensor shapes make sense
247        for (component_name, component) in &self.components {
248            for (tensor_name, tensor) in &component.inputs {
249                if tensor.shape.is_empty() {
250                    return Err(anyhow::anyhow!(
251                        "Empty shape for {}.inputs.{}",
252                        component_name,
253                        tensor_name
254                    ));
255                }
256            }
257            for (tensor_name, tensor) in &component.outputs {
258                if tensor.shape.is_empty() {
259                    return Err(anyhow::anyhow!(
260                        "Empty shape for {}.outputs.{}",
261                        component_name,
262                        tensor_name
263                    ));
264                }
265            }
266        }
267
268        Ok(())
269    }
270
271    /// Validate internal wiring between components for basic shape compatibility.
272    /// Examples:
273    ///  - embeddings.outputs.hidden_states == ffn_prefill.inputs.hidden_states
274    ///  - ffn_infer.outputs.output_hidden_states == lm_head.inputs.hidden_states (when ffn_infer exists)
275    pub fn validate_internal_wiring(&self) -> Result<()> {
276        // Embeddings -> FFN prefill hidden_states flow
277        if let (Some(emb_out), Some(ffn_in_hidden)) = (
278            self.get_tensor_shape("embeddings", "hidden_states", false),
279            self.get_tensor_shape("ffn_prefill", "hidden_states", true),
280        ) {
281            if emb_out != ffn_in_hidden {
282                return Err(anyhow::anyhow!(
283                    "Shape mismatch: embeddings.hidden_states {:?} != ffn_prefill.hidden_states {:?}",
284                    emb_out, ffn_in_hidden
285                ));
286            }
287        }
288
289        // FFN infer -> LM head hidden_states flow
290        if self.components.contains_key("ffn_infer") {
291            if let (Some(ffn_out), Some(lm_in)) = (
292                self.get_tensor_shape("ffn_infer", "output_hidden_states", false),
293                self.get_tensor_shape("lm_head", "hidden_states", true),
294            ) {
295                if ffn_out != lm_in {
296                    return Err(anyhow::anyhow!(
297                        "Shape mismatch: ffn_infer.output_hidden_states {:?} != lm_head.hidden_states {:?}",
298                        ffn_out, lm_in
299                    ));
300                }
301            }
302        } else {
303            // If there's no separate ffn_infer, check ffn_prefill output shape matches lm_head (single-token path)
304            if let (Some(ffn_out), Some(lm_in)) = (
305                self.get_tensor_shape("ffn_prefill", "output_hidden_states", false),
306                self.get_tensor_shape("lm_head", "hidden_states", true),
307            ) {
308                if ffn_out != lm_in {
309                    return Err(anyhow::anyhow!(
310                        "Shape mismatch: ffn_prefill.output_hidden_states {:?} != lm_head.hidden_states {:?}",
311                        ffn_out, lm_in
312                    ));
313                }
314            }
315        }
316
317        Ok(())
318    }
319
320    /// Determine if FFN execution should be treated as split (separate infer component)
321    pub fn ffn_is_split(&self) -> bool {
322        if let Some(mode) = self.ffn_execution.as_deref() {
323            return mode == "split";
324        }
325        if let (Some(prefill), Some(infer)) = (
326            self.components.get("ffn_prefill"),
327            self.components.get("ffn_infer"),
328        ) {
329            match (&prefill.file_path, &infer.file_path) {
330                (Some(p), Some(i)) => p != i, // different files => split
331                _ => false,
332            }
333        } else {
334            false
335        }
336    }
337
338    /// Detect if prefill should run in single-token sequential mode based on configured shapes
339    pub fn prefill_is_single_token(&self) -> bool {
340        if let Some(prefill) = self.components.get("ffn_prefill") {
341            if let Some(hs) = prefill.inputs.get("hidden_states") {
342                let is_single = hs.shape.len() == 3 && hs.shape.get(1) == Some(&1);
343                debug!(
344                    "🔍 prefill_is_single_token: shape={:?}, len={}, dim[1]={:?}, result={}",
345                    hs.shape,
346                    hs.shape.len(),
347                    hs.shape.get(1),
348                    is_single
349                );
350                return is_single;
351            }
352        }
353        debug!(
354            "🔍 prefill_is_single_token: no ffn_prefill or hidden_states found, returning false"
355        );
356        false
357    }
358
359    /// Check if the model expects full sequence prefill (as opposed to single-token processing)
360    /// This is typically true for CoreML models with fixed-shape inputs like [1, 128, 1024]
361    pub fn expects_full_sequence_prefill(&self) -> bool {
362        if let Some(prefill) = self.components.get("ffn_prefill") {
363            if let Some(hs) = prefill.inputs.get("hidden_states") {
364                // If the model expects a fixed sequence length > 1, it needs full-sequence prefill
365                let expects_full =
366                    hs.shape.len() == 3 && hs.shape.get(1).is_some_and(|&seq_len| seq_len > 1);
367                trace!(
368                    "🔍 expects_full_sequence_prefill: shape={:?}, len={}, dim[1]={:?}, result={}",
369                    hs.shape,
370                    hs.shape.len(),
371                    hs.shape.get(1),
372                    expects_full
373                );
374                return expects_full;
375            }
376        }
377        trace!("🔍 expects_full_sequence_prefill: no ffn_prefill or hidden_states found, returning false");
378        false
379    }
380
381    // Tensor Creation Methods (moved from QwenConfig for consolidation)
382
383    /// Create embeddings input tensor with proper shape from configuration
384    pub fn create_embeddings_input_tensor(
385        &self,
386        tokens: &[i64],
387        device: &Device,
388    ) -> Result<Tensor, CandleError> {
389        let expected_shape = self
390            .embeddings_input_shape()
391            .ok_or_else(|| CandleError::Msg("No embeddings input shape found".to_string()))?;
392        let expected_len = expected_shape[1]; // [batch, seq_len] -> seq_len
393
394        // Pad or truncate tokens to match expected length
395        let mut padded_tokens = tokens.to_vec();
396        padded_tokens.resize(expected_len, 0); // Pad with 0s
397
398        Tensor::from_vec(
399            padded_tokens,
400            (expected_shape[0], expected_shape[1]),
401            device,
402        )
403    }
404
405    /// Create position IDs tensor for FFN prefill with proper shape
406    pub fn create_ffn_position_ids_tensor(
407        &self,
408        positions: &[i64],
409        device: &Device,
410    ) -> Result<Tensor, CandleError> {
411        let expected_shape = self
412            .get_tensor_shape("ffn_prefill", "position_ids", true)
413            .ok_or_else(|| {
414                CandleError::Msg("No FFN prefill position_ids shape found".to_string())
415            })?;
416
417        // Heuristic: some manifests report position_ids length as [1] even for prefill.
418        // When that happens, derive the true sequence length from other known shapes.
419        let mut expected_len = expected_shape[0];
420        if expected_len == 1 {
421            // Prefer prefill hidden_states seq_len if available and > 1
422            if let Some(hs_shape) = self.get_tensor_shape("ffn_prefill", "hidden_states", true) {
423                if hs_shape.len() == 3 && hs_shape[1] > 1 {
424                    expected_len = hs_shape[1];
425                }
426            }
427            // Or embeddings input seq_len if available and > 1
428            if expected_len == 1 {
429                if let Some(emb) = self.embeddings_input_shape() {
430                    if emb.len() == 2 && emb[1] > 1 {
431                        expected_len = emb[1];
432                    }
433                }
434            }
435            // As a final fallback, keep 1 (some exotic models may actually expect [1])
436        }
437
438        // Create position sequence up to expected length
439        let mut position_ids = Vec::with_capacity(expected_len);
440        for i in 0..expected_len {
441            if i < positions.len() {
442                position_ids.push(positions[i]);
443            } else {
444                position_ids.push(0); // Pad with 0s
445            }
446        }
447
448        Tensor::from_vec(position_ids, (expected_len,), device)
449    }
450
451    /// Create causal mask tensor for FFN with proper shape
452    pub fn create_ffn_causal_mask_tensor(
453        &self,
454        _batch_size: usize,
455        _context_length: usize,
456        device: &Device,
457    ) -> Result<Tensor, CandleError> {
458        // Prefer explicit shape from config; otherwise synthesize a reasonable default
459        let expected_shape_vec =
460            if let Some(shape) = self.get_tensor_shape("ffn_prefill", "causal_mask", true) {
461                shape.clone()
462            } else {
463                // Derive a default square mask [1,1,seq_len,seq_len]
464                let mut seq_len = 0usize;
465                if let Some(hs) = self.get_tensor_shape("ffn_prefill", "hidden_states", true) {
466                    if hs.len() == 3 && hs[1] > 0 {
467                        seq_len = hs[1];
468                    }
469                }
470                if seq_len == 0 {
471                    if let Some(emb) = self.embeddings_input_shape() {
472                        if emb.len() == 2 && emb[1] > 0 {
473                            seq_len = emb[1];
474                        }
475                    }
476                }
477                if seq_len == 0 {
478                    seq_len = self.shapes.context_length;
479                }
480                vec![1, 1, seq_len, seq_len]
481            };
482
483        let mask_rows = expected_shape_vec[2];
484        let mask_context_length = expected_shape_vec[3];
485
486        // Create causal mask data
487        let mut mask_data = vec![f32::NEG_INFINITY; mask_rows * mask_context_length];
488        for i in 0..mask_rows {
489            for j in 0..=i.min(mask_context_length - 1) {
490                mask_data[i * mask_context_length + j] = 0.0;
491            }
492        }
493
494        Tensor::from_vec(
495            mask_data,
496            (
497                expected_shape_vec[0],
498                expected_shape_vec[1],
499                expected_shape_vec[2],
500                expected_shape_vec[3],
501            ),
502            device,
503        )
504    }
505
506    /// Create single token hidden states tensor for LM head
507    pub fn create_single_token_hidden_states(
508        &self,
509        _tokens: &[i64],
510        device: &Device,
511    ) -> Result<Tensor, CandleError> {
512        let expected_shape = self
513            .get_tensor_shape("lm_head", "hidden_states", true)
514            .ok_or_else(|| CandleError::Msg("No LM head hidden_states shape found".to_string()))?;
515
516        // Create dummy tensor with correct shape (would be filled by actual embeddings)
517        let tensor_data = vec![0.0f32; expected_shape.iter().product()];
518        let shape = (expected_shape[0], expected_shape[1], expected_shape[2]);
519
520        Tensor::from_vec(tensor_data, shape, device)
521    }
522
523    /// Create position IDs tensor for inference (single position)
524    pub fn create_infer_position_ids_tensor(
525        &self,
526        position: i64,
527        device: &Device,
528    ) -> Result<Tensor, CandleError> {
529        // Check if we have a dedicated ffn_infer component with specific shape
530        if let Some(infer_shape) = self.get_tensor_shape("ffn_infer", "position_ids", true) {
531            // Use the infer-specific shape
532            if infer_shape.len() == 1 {
533                Tensor::from_vec(vec![position], (infer_shape[0],), device)
534            } else {
535                let size = infer_shape.iter().product();
536                let mut data = vec![0i64; size];
537                data[0] = position;
538                Tensor::from_vec(data, infer_shape.as_slice(), device)
539            }
540        } else {
541            // No dedicated infer component - use single position for inference (original QwenConfig behavior)
542            Tensor::from_vec(vec![position], (1,), device)
543        }
544    }
545
546    /// Create current position tensor for FFN
547    pub fn create_current_pos_tensor(
548        &self,
549        position: i64,
550        device: &Device,
551    ) -> Result<Tensor, CandleError> {
552        // Most models expect [1] shape for current_pos
553        Tensor::from_vec(vec![position], (1,), device)
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use tempfile::NamedTempFile;
561
562    fn create_test_config() -> ModelConfig {
563        let mut components = HashMap::new();
564
565        // Embeddings component
566        let mut embeddings_inputs = HashMap::new();
567        embeddings_inputs.insert(
568            "input_ids".to_string(),
569            TensorConfig {
570                name: "input_ids".to_string(),
571                shape: vec![1, 64],
572                data_type: "INT32".to_string(),
573            },
574        );
575
576        let mut embeddings_outputs = HashMap::new();
577        embeddings_outputs.insert(
578            "hidden_states".to_string(),
579            TensorConfig {
580                name: "hidden_states".to_string(),
581                shape: vec![1, 64, 1024],
582                data_type: "FLOAT16".to_string(),
583            },
584        );
585
586        components.insert(
587            "embeddings".to_string(),
588            ComponentConfig {
589                file_path: None,
590                inputs: embeddings_inputs,
591                outputs: embeddings_outputs,
592                functions: vec![],
593                input_order: None,
594            },
595        );
596
597        // LM Head component
598        let mut lm_head_inputs = HashMap::new();
599        lm_head_inputs.insert(
600            "hidden_states".to_string(),
601            TensorConfig {
602                name: "hidden_states".to_string(),
603                shape: vec![1, 1, 1024],
604                data_type: "FLOAT16".to_string(),
605            },
606        );
607
608        let mut lm_head_outputs = HashMap::new();
609        lm_head_outputs.insert(
610            "logits".to_string(),
611            TensorConfig {
612                name: "logits".to_string(),
613                shape: vec![1, 1, 151936],
614                data_type: "FLOAT32".to_string(),
615            },
616        );
617
618        components.insert(
619            "lm_head".to_string(),
620            ComponentConfig {
621                file_path: None,
622                inputs: lm_head_inputs,
623                outputs: lm_head_outputs,
624                functions: vec![],
625                input_order: None,
626            },
627        );
628
629        ModelConfig {
630            model_info: ModelInfo {
631                model_id: Some("test/model".to_string()),
632                path: Some("/test/path".to_string()),
633                model_type: "qwen".to_string(),
634                discovered_at: Some("2025-08-07T00:00:00".to_string()),
635            },
636            shapes: ShapeConfig {
637                batch_size: 1,
638                context_length: 512,
639                hidden_size: 1024,
640                vocab_size: 151936,
641            },
642            components,
643            naming: NamingConfig {
644                embeddings_pattern: None,
645                ffn_prefill_pattern: None,
646                ffn_infer_pattern: None,
647                lm_head_pattern: None,
648            },
649            ffn_execution: Some("unified".to_string()),
650        }
651    }
652
653    #[test]
654    fn test_config_serialization() {
655        let config = create_test_config();
656
657        // Test JSON serialization
658        let json = serde_json::to_string_pretty(&config).unwrap();
659        assert!(json.contains("test/model"));
660        assert!(json.contains("batch_size"));
661        assert!(json.contains("embeddings"));
662
663        // Test deserialization
664        let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
665        assert_eq!(parsed.model_info.model_id, config.model_info.model_id);
666        assert_eq!(parsed.shapes.batch_size, config.shapes.batch_size);
667        assert_eq!(parsed.components.len(), config.components.len());
668    }
669
670    #[test]
671    fn test_config_file_io() {
672        let config = create_test_config();
673        let temp_file = NamedTempFile::new().unwrap();
674
675        // Save configuration
676        config.save_to_file(temp_file.path()).unwrap();
677
678        // Load configuration
679        let loaded = ModelConfig::load_from_file(temp_file.path()).unwrap();
680        assert_eq!(loaded.model_info.model_id, config.model_info.model_id);
681        assert_eq!(loaded.shapes.hidden_size, config.shapes.hidden_size);
682    }
683
684    #[test]
685    fn test_shape_accessors() {
686        let config = create_test_config();
687
688        // Test embeddings shapes
689        let embeddings_input = config.embeddings_input_shape().unwrap();
690        assert_eq!(embeddings_input, &vec![1, 64]);
691
692        let embeddings_output = config.embeddings_output_shape().unwrap();
693        assert_eq!(embeddings_output, &vec![1, 64, 1024]);
694
695        let lm_head_input = config.lm_head_input_shape().unwrap();
696        assert_eq!(lm_head_input, &vec![1, 1, 1024]);
697    }
698
699    #[test]
700    fn test_multipart_logits_detection() {
701        let config = create_test_config();
702        assert!(!config.has_multipart_logits()); // Single logits output
703
704        // Create config with multipart logits
705        let mut config_multipart = config;
706        let lm_head = config_multipart.components.get_mut("lm_head").unwrap();
707        lm_head.outputs.clear();
708        lm_head.outputs.insert(
709            "logits1".to_string(),
710            TensorConfig {
711                name: "logits1".to_string(),
712                shape: vec![1, 1, 9480],
713                data_type: "FLOAT32".to_string(),
714            },
715        );
716        lm_head.outputs.insert(
717            "logits2".to_string(),
718            TensorConfig {
719                name: "logits2".to_string(),
720                shape: vec![1, 1, 9479],
721                data_type: "FLOAT32".to_string(),
722            },
723        );
724
725        assert!(config_multipart.has_multipart_logits());
726        assert_eq!(config_multipart.logits_part_count(), 2);
727    }
728
729    #[test]
730    fn test_config_validation() {
731        let config = create_test_config();
732        assert!(config.validate().is_ok());
733
734        // Internal wiring should be consistent in this synthetic setup
735        assert!(config.validate_internal_wiring().is_ok());
736
737        // Test missing component
738        let mut invalid_config = config.clone();
739        invalid_config.components.remove("embeddings");
740        assert!(invalid_config.validate().is_err());
741
742        // Test invalid shapes
743        let mut invalid_shapes = config;
744        invalid_shapes.shapes.batch_size = 0;
745        assert!(invalid_shapes.validate().is_err());
746    }
747}