cake_core/models/common/
config.rs1use anyhow::Result;
2use serde::de::{self, Deserializer};
3
4#[derive(Debug, Clone)]
6pub enum EosTokenId {
7 Single(u32),
8 Multiple(Vec<u32>),
9}
10
11impl EosTokenId {
12 pub fn is_eos(&self, token_id: u32) -> bool {
14 match self {
15 EosTokenId::Single(id) => *id == token_id,
16 EosTokenId::Multiple(ids) => ids.contains(&token_id),
17 }
18 }
19}
20
21impl<'de> serde::Deserialize<'de> for EosTokenId {
22 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
23 where
24 D: Deserializer<'de>,
25 {
26 let value = serde_json::Value::deserialize(deserializer)?;
27 match value {
28 serde_json::Value::Number(n) => n
29 .as_u64()
30 .map(|v| EosTokenId::Single(v as u32))
31 .ok_or_else(|| de::Error::custom("expected u32 for eos_token_id")),
32 serde_json::Value::Array(arr) => {
33 let ids: std::result::Result<Vec<u32>, _> = arr
34 .iter()
35 .map(|v| {
36 v.as_u64()
37 .map(|n| n as u32)
38 .ok_or_else(|| de::Error::custom("expected u32 in eos_token_id array"))
39 })
40 .collect();
41 ids.map(EosTokenId::Multiple)
42 }
43 _ => Err(de::Error::custom("expected u32 or array for eos_token_id")),
44 }
45 }
46}
47
48#[derive(Debug, Clone, serde::Deserialize)]
50pub struct RopeScaling {
51 #[serde(default)]
53 pub factor: f32,
54 #[serde(default)]
56 pub high_freq_factor: f32,
57 #[serde(default)]
59 pub low_freq_factor: f32,
60 #[serde(default)]
62 pub original_max_position_embeddings: usize,
63 #[serde(default)]
65 pub rope_type: Option<String>,
66}
67
68#[derive(Debug, Clone)]
70pub struct LinearAttnConfig {
71 pub layer_types: Vec<String>,
73 pub conv_kernel_dim: usize,
75 pub num_key_heads: usize,
77 pub key_head_dim: usize,
79 pub num_value_heads: usize,
81 pub value_head_dim: usize,
83}
84
85#[derive(Debug, Clone)]
87pub struct Config {
88 pub hidden_size: usize,
89 pub intermediate_size: usize,
90 pub vocab_size: usize,
91 pub num_hidden_layers: usize,
92 pub num_attention_heads: usize,
93 pub num_key_value_heads: usize,
94 pub rms_norm_eps: f64,
95 pub rope_theta: f32,
96 pub bos_token_id: Option<u32>,
97 pub eos_token_id: Option<EosTokenId>,
98 pub rope_scaling: Option<RopeScaling>,
99 pub tie_word_embeddings: bool,
100 pub max_seq_len: usize,
101 pub use_qkv_bias: bool,
103 pub model_prefix: String,
105 pub head_dim: Option<usize>,
107 pub partial_rotary_factor: f32,
109 pub linear_attn: Option<LinearAttnConfig>,
111 pub residual_rms_norm: bool,
114}
115
116pub fn load_rms_norm(
119 size: usize,
120 eps: f64,
121 residual: bool,
122 vb: candle_nn::VarBuilder,
123) -> candle_core::Result<candle_nn::RmsNorm> {
124 let weight = vb.get(size, "weight")?;
125 let weight = if residual {
126 (weight + 1.0)?
127 } else {
128 weight
129 };
130 Ok(candle_nn::RmsNorm::new(weight, eps))
131}
132
133pub fn detect_text_model_arch(config_path: &std::path::Path) -> Result<String> {
135 let data = std::fs::read(config_path)
136 .map_err(|e| anyhow!("can't read {}: {:?}", config_path.display(), e))?;
137 let json: serde_json::Value = serde_json::from_slice(&data)
138 .map_err(|e| anyhow!("can't parse {}: {:?}", config_path.display(), e))?;
139
140 if let Some(archs) = json.get("architectures").and_then(|v| v.as_array()) {
141 for arch in archs {
142 if let Some(s) = arch.as_str() {
143 return Ok(s.to_string());
144 }
145 }
146 }
147
148 Ok(String::new())
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn test_eos_token_id_single() {
157 let json = r#"128001"#;
158 let eos: EosTokenId = serde_json::from_str(json).unwrap();
159 assert!(eos.is_eos(128001));
160 assert!(!eos.is_eos(128008));
161 }
162
163 #[test]
164 fn test_eos_token_id_array() {
165 let json = r#"[128001, 128008, 128009]"#;
166 let eos: EosTokenId = serde_json::from_str(json).unwrap();
167 assert!(eos.is_eos(128001));
168 assert!(eos.is_eos(128008));
169 assert!(eos.is_eos(128009));
170 assert!(!eos.is_eos(0));
171 }
172
173 #[test]
174 fn test_rope_scaling_deserialization() {
175 let json = r#"{
176 "factor": 8.0,
177 "high_freq_factor": 4.0,
178 "low_freq_factor": 1.0,
179 "original_max_position_embeddings": 8192,
180 "rope_type": "llama3"
181 }"#;
182 let scaling: RopeScaling = serde_json::from_str(json).unwrap();
183 assert_eq!(scaling.factor, 8.0);
184 assert_eq!(scaling.high_freq_factor, 4.0);
185 assert_eq!(scaling.low_freq_factor, 1.0);
186 assert_eq!(scaling.original_max_position_embeddings, 8192);
187 assert_eq!(scaling.rope_type.as_deref(), Some("llama3"));
188 }
189}