candle_coreml/config/
basic.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct Config {
8 pub input_names: Vec<String>,
10 pub output_name: String,
12 pub max_sequence_length: usize,
14 pub vocab_size: usize,
16 pub model_type: String,
18}
19
20impl Default for Config {
21 fn default() -> Self {
22 Self {
23 input_names: vec!["input_ids".to_string()],
24 output_name: "logits".to_string(),
25 max_sequence_length: 128,
26 vocab_size: 32000,
27 model_type: "coreml".to_string(),
28 }
29 }
30}
31
32impl Config {
33 pub fn bert_config(output_name: &str, max_seq_len: usize, vocab_size: usize) -> Self {
35 Self {
36 input_names: vec![
37 "input_ids".to_string(),
38 "token_type_ids".to_string(),
39 "attention_mask".to_string(),
40 ],
41 output_name: output_name.to_string(),
42 max_sequence_length: max_seq_len,
43 vocab_size,
44 model_type: "bert".to_string(),
45 }
46 }
47}