Skip to main content

cake_core/models/common/
config.rs

1use anyhow::Result;
2use serde::de::{self, Deserializer};
3
4/// EOS token ID(s) — deserializes from either a single u32 or an array of u32.
5#[derive(Debug, Clone)]
6pub enum EosTokenId {
7    Single(u32),
8    Multiple(Vec<u32>),
9}
10
11impl EosTokenId {
12    /// Check if the given token ID is an EOS token.
13    pub fn is_eos(&self, token_id: u32) -> bool {
14        match self {
15            EosTokenId::Single(id) => *id == token_id,
16            EosTokenId::Multiple(ids) => ids.contains(&token_id),
17        }
18    }
19}
20
21impl<'de> serde::Deserialize<'de> for EosTokenId {
22    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
23    where
24        D: Deserializer<'de>,
25    {
26        let value = serde_json::Value::deserialize(deserializer)?;
27        match value {
28            serde_json::Value::Number(n) => n
29                .as_u64()
30                .map(|v| EosTokenId::Single(v as u32))
31                .ok_or_else(|| de::Error::custom("expected u32 for eos_token_id")),
32            serde_json::Value::Array(arr) => {
33                let ids: std::result::Result<Vec<u32>, _> = arr
34                    .iter()
35                    .map(|v| {
36                        v.as_u64()
37                            .map(|n| n as u32)
38                            .ok_or_else(|| de::Error::custom("expected u32 in eos_token_id array"))
39                    })
40                    .collect();
41                ids.map(EosTokenId::Multiple)
42            }
43            _ => Err(de::Error::custom("expected u32 or array for eos_token_id")),
44        }
45    }
46}
47
48/// RoPE scaling configuration for models with extended context (e.g. LLaMA 3.1+).
49#[derive(Debug, Clone, serde::Deserialize)]
50pub struct RopeScaling {
51    /// Factor for scaling (typically 8.0 for LLaMA 3.1).
52    #[serde(default)]
53    pub factor: f32,
54    /// High frequency factor (typically 4.0).
55    #[serde(default)]
56    pub high_freq_factor: f32,
57    /// Low frequency factor (typically 1.0).
58    #[serde(default)]
59    pub low_freq_factor: f32,
60    /// Original max position embeddings before scaling (typically 8192).
61    #[serde(default)]
62    pub original_max_position_embeddings: usize,
63    /// RoPE type (e.g. "llama3").
64    #[serde(default)]
65    pub rope_type: Option<String>,
66}
67
68/// Configuration for linear (recurrent) attention layers (e.g. Gated DeltaNet in Qwen3.5).
69#[derive(Debug, Clone)]
70pub struct LinearAttnConfig {
71    /// Per-layer type: "linear_attention" or "full_attention".
72    pub layer_types: Vec<String>,
73    /// Conv1d kernel size for short convolution preprocessing.
74    pub conv_kernel_dim: usize,
75    /// Number of key heads in linear attention.
76    pub num_key_heads: usize,
77    /// Per-head key dimension in linear attention.
78    pub key_head_dim: usize,
79    /// Number of value heads in linear attention.
80    pub num_value_heads: usize,
81    /// Per-head value dimension in linear attention.
82    pub value_head_dim: usize,
83}
84
85/// Generalized LLM configuration shared by all decoder-only text models.
86#[derive(Debug, Clone)]
87pub struct Config {
88    pub hidden_size: usize,
89    pub intermediate_size: usize,
90    pub vocab_size: usize,
91    pub num_hidden_layers: usize,
92    pub num_attention_heads: usize,
93    pub num_key_value_heads: usize,
94    pub rms_norm_eps: f64,
95    pub rope_theta: f32,
96    pub bos_token_id: Option<u32>,
97    pub eos_token_id: Option<EosTokenId>,
98    pub rope_scaling: Option<RopeScaling>,
99    pub tie_word_embeddings: bool,
100    pub max_seq_len: usize,
101    /// Whether Q/K/V projections use bias (true for Qwen2, false for LLaMA).
102    pub use_qkv_bias: bool,
103    /// Weight tensor prefix for the transformer stack (e.g. "model" or "model.language_model").
104    pub model_prefix: String,
105    /// Explicit head dimension when it differs from hidden_size / num_attention_heads.
106    pub head_dim: Option<usize>,
107    /// Fraction of head dims to apply rotary embeddings to (1.0 = all, 0.25 = first quarter).
108    pub partial_rotary_factor: f32,
109    /// Linear attention configuration (None for pure softmax-attention models).
110    pub linear_attn: Option<LinearAttnConfig>,
111    /// Whether RMS norm uses residual weight: `(1 + weight) * norm(x)` instead of `weight * norm(x)`.
112    /// True for Qwen3.5 whose norm weights are initialized to zero with +1 applied at runtime.
113    pub residual_rms_norm: bool,
114}
115
116/// Load an RMS norm, optionally applying the residual weight pattern `(1 + weight)`.
117/// When `residual` is true (Qwen3.5), the stored weight is treated as a residual and 1.0 is added.
118pub fn load_rms_norm(
119    size: usize,
120    eps: f64,
121    residual: bool,
122    vb: candle_nn::VarBuilder,
123) -> candle_core::Result<candle_nn::RmsNorm> {
124    let weight = vb.get(size, "weight")?;
125    let weight = if residual {
126        (weight + 1.0)?
127    } else {
128        weight
129    };
130    Ok(candle_nn::RmsNorm::new(weight, eps))
131}
132
133/// Auto-detect text model architecture from config.json's "architectures" field.
134pub fn detect_text_model_arch(config_path: &std::path::Path) -> Result<String> {
135    let data = std::fs::read(config_path)
136        .map_err(|e| anyhow!("can't read {}: {:?}", config_path.display(), e))?;
137    let json: serde_json::Value = serde_json::from_slice(&data)
138        .map_err(|e| anyhow!("can't parse {}: {:?}", config_path.display(), e))?;
139
140    if let Some(archs) = json.get("architectures").and_then(|v| v.as_array()) {
141        for arch in archs {
142            if let Some(s) = arch.as_str() {
143                return Ok(s.to_string());
144            }
145        }
146    }
147
148    Ok(String::new())
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test_eos_token_id_single() {
157        let json = r#"128001"#;
158        let eos: EosTokenId = serde_json::from_str(json).unwrap();
159        assert!(eos.is_eos(128001));
160        assert!(!eos.is_eos(128008));
161    }
162
163    #[test]
164    fn test_eos_token_id_array() {
165        let json = r#"[128001, 128008, 128009]"#;
166        let eos: EosTokenId = serde_json::from_str(json).unwrap();
167        assert!(eos.is_eos(128001));
168        assert!(eos.is_eos(128008));
169        assert!(eos.is_eos(128009));
170        assert!(!eos.is_eos(0));
171    }
172
173    #[test]
174    fn test_rope_scaling_deserialization() {
175        let json = r#"{
176            "factor": 8.0,
177            "high_freq_factor": 4.0,
178            "low_freq_factor": 1.0,
179            "original_max_position_embeddings": 8192,
180            "rope_type": "llama3"
181        }"#;
182        let scaling: RopeScaling = serde_json::from_str(json).unwrap();
183        assert_eq!(scaling.factor, 8.0);
184        assert_eq!(scaling.high_freq_factor, 4.0);
185        assert_eq!(scaling.low_freq_factor, 1.0);
186        assert_eq!(scaling.original_max_position_embeddings, 8192);
187        assert_eq!(scaling.rope_type.as_deref(), Some("llama3"));
188    }
189}