candle_coreml/config/
basic.rs

1//! Configuration types for CoreML models
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for CoreML models
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8    /// Input tensor names in order (e.g., ["input_ids", "token_type_ids", "attention_mask"])
9    pub input_names: Vec<String>,
10    /// Output tensor name (e.g., "logits")
11    pub output_name: String,
12    /// Maximum sequence length
13    pub max_sequence_length: usize,
14    /// Vocabulary size
15    pub vocab_size: usize,
16    /// Model architecture name
17    pub model_type: String,
18}
19
20impl Default for Config {
21    fn default() -> Self {
22        Self {
23            input_names: vec!["input_ids".to_string()],
24            output_name: "logits".to_string(),
25            max_sequence_length: 128,
26            vocab_size: 32000,
27            model_type: "coreml".to_string(),
28        }
29    }
30}
31
32impl Config {
33    /// Create BERT-style config with input_ids, token_type_ids, and attention_mask
34    pub fn bert_config(output_name: &str, max_seq_len: usize, vocab_size: usize) -> Self {
35        Self {
36            input_names: vec![
37                "input_ids".to_string(),
38                "token_type_ids".to_string(),
39                "attention_mask".to_string(),
40            ],
41            output_name: output_name.to_string(),
42            max_sequence_length: max_seq_len,
43            vocab_size,
44            model_type: "bert".to_string(),
45        }
46    }
47}