Skip to main content

ferrum_models/
definition.rs

1//! Model definition and configuration parsing
2
3use crate::{registry::Architecture, source::ResolvedModelSource};
4use ferrum_types::{
5    Activation, AttentionConfig, FerrumError, ModelInfo, ModelType, NormType, Result, RopeScaling,
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use tracing::{debug, warn};
11
12/// Model definition from config.json
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelDefinition {
15    /// Architecture type
16    pub architecture: Architecture,
17    /// Hidden size (embedding dimension)
18    pub hidden_size: usize,
19    /// Intermediate size (FFN dimension)
20    pub intermediate_size: usize,
21    /// Vocabulary size
22    pub vocab_size: usize,
23    /// Number of hidden layers
24    pub num_hidden_layers: usize,
25    /// Number of attention heads
26    pub num_attention_heads: usize,
27    /// Number of key-value heads (for GQA)
28    pub num_key_value_heads: Option<usize>,
29    /// Maximum position embeddings
30    pub max_position_embeddings: usize,
31    /// RoPE theta (frequency base)
32    pub rope_theta: Option<f64>,
33    /// RoPE scaling config
34    pub rope_scaling: Option<RopeScaling>,
35    /// Normalization type
36    pub norm_type: NormType,
37    /// Normalization epsilon
38    pub norm_eps: f64,
39    /// Attention configuration
40    pub attention_config: AttentionConfig,
41    /// Activation function
42    pub activation: Activation,
43    /// Extra parameters
44    #[serde(flatten)]
45    pub extra_params: serde_json::Value,
46}
47
48impl Default for ModelDefinition {
49    fn default() -> Self {
50        Self {
51            architecture: Architecture::Llama,
52            hidden_size: 4096,
53            intermediate_size: 11008,
54            vocab_size: 32000,
55            num_hidden_layers: 32,
56            num_attention_heads: 32,
57            num_key_value_heads: None,
58            max_position_embeddings: 2048,
59            rope_theta: Some(10000.0),
60            rope_scaling: None,
61            norm_type: NormType::RMSNorm,
62            norm_eps: 1e-6,
63            attention_config: AttentionConfig {
64                attention_bias: false,
65                sliding_window: None,
66            },
67            activation: Activation::SiLU,
68            extra_params: serde_json::Value::Object(serde_json::Map::new()),
69        }
70    }
71}
72
73impl ModelDefinition {
74    /// Convert to ModelInfo
75    pub fn to_model_info(&self, model_id: impl Into<String>) -> ModelInfo {
76        use ferrum_types::{DataType, Device};
77
78        let model_id_str = model_id.into();
79
80        // Calculate approximate parameter count
81        let params = self.estimate_parameters();
82
83        ModelInfo {
84            model_id: ferrum_types::ModelId::new(model_id_str.clone()),
85            model_type: ModelType::Custom(format!("{:?}", self.architecture)),
86            num_parameters: params as u64,
87            hidden_size: self.hidden_size,
88            num_layers: self.num_hidden_layers,
89            num_heads: self.num_attention_heads,
90            num_kv_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
91            vocab_size: self.vocab_size,
92            max_sequence_length: self.max_position_embeddings,
93            dtype: DataType::FP16, // Default, can be overridden
94            device: Device::CPU,   // Default, will be set by backend
95            version: None,
96            license: None,
97            metadata: HashMap::new(),
98        }
99    }
100
101    /// Estimate parameter count
102    fn estimate_parameters(&self) -> usize {
103        // Rough estimation based on typical transformer architecture
104        let embedding_params = self.vocab_size * self.hidden_size;
105        let layer_params = self.num_hidden_layers
106            * (
107                // Attention: Q, K, V, O projections
108                4 * self.hidden_size * self.hidden_size +
109            // FFN: up, down, gate (if applicable)
110            3 * self.hidden_size * self.intermediate_size +
111            // Layer norms
112            2 * self.hidden_size
113            );
114        let lm_head_params = self.vocab_size * self.hidden_size;
115
116        embedding_params + layer_params + lm_head_params
117    }
118}
119
120/// Configuration manager for loading and parsing model configs
121#[derive(Debug, Default)]
122pub struct ConfigManager {
123    _cache: HashMap<String, ModelDefinition>,
124}
125
126impl ConfigManager {
127    pub fn new() -> Self {
128        Self {
129            _cache: HashMap::new(),
130        }
131    }
132
133    /// Load model definition from a resolved source
134    pub async fn load_from_source(
135        &mut self,
136        source: &ResolvedModelSource,
137    ) -> Result<ModelDefinition> {
138        self.load_from_path(&source.local_path).await
139    }
140
141    /// Load model definition from a directory path
142    pub async fn load_from_path(&mut self, path: &Path) -> Result<ModelDefinition> {
143        let config_path = path.join("config.json");
144
145        if !config_path.exists() {
146            return Err(FerrumError::model(format!(
147                "config.json not found in model directory: {:?}",
148                path
149            )));
150        }
151
152        debug!("Loading model config from: {:?}", config_path);
153
154        let content = tokio::fs::read_to_string(&config_path)
155            .await
156            .map_err(|e| FerrumError::io(format!("Failed to read config.json: {}", e)))?;
157
158        let raw_config: serde_json::Value = serde_json::from_str(&content)
159            .map_err(|e| FerrumError::model(format!("Failed to parse config.json: {}", e)))?;
160
161        self.parse_config(&raw_config)
162    }
163
164    /// Parse config from JSON value
165    fn parse_config(&mut self, raw: &serde_json::Value) -> Result<ModelDefinition> {
166        let obj = raw
167            .as_object()
168            .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
169
170        // Detect architecture
171        let architecture = self.detect_architecture(raw)?;
172
173        // Parse common fields
174        let hidden_size = obj
175            .get("hidden_size")
176            .and_then(|v| v.as_u64())
177            .unwrap_or(4096) as usize;
178
179        let intermediate_size = obj
180            .get("intermediate_size")
181            .and_then(|v| v.as_u64())
182            .or_else(|| obj.get("ffn_dim").and_then(|v| v.as_u64()))
183            .unwrap_or(11008) as usize;
184
185        let vocab_size = obj
186            .get("vocab_size")
187            .and_then(|v| v.as_u64())
188            .ok_or_else(|| FerrumError::model("vocab_size not found in config"))?
189            as usize;
190
191        let num_hidden_layers = obj
192            .get("num_hidden_layers")
193            .and_then(|v| v.as_u64())
194            .or_else(|| obj.get("n_layer").and_then(|v| v.as_u64()))
195            .unwrap_or(32) as usize;
196
197        let num_attention_heads = obj
198            .get("num_attention_heads")
199            .and_then(|v| v.as_u64())
200            .or_else(|| obj.get("n_head").and_then(|v| v.as_u64()))
201            .unwrap_or(32) as usize;
202
203        let num_key_value_heads = obj
204            .get("num_key_value_heads")
205            .and_then(|v| v.as_u64())
206            .map(|v| v as usize);
207
208        let max_position_embeddings = obj
209            .get("max_position_embeddings")
210            .and_then(|v| v.as_u64())
211            .or_else(|| obj.get("n_positions").and_then(|v| v.as_u64()))
212            .unwrap_or(2048) as usize;
213
214        let rope_theta = obj
215            .get("rope_theta")
216            .and_then(|v| v.as_f64())
217            .or_else(|| obj.get("rotary_emb_base").and_then(|v| v.as_f64()));
218
219        // Parse RoPE scaling
220        let rope_scaling = obj
221            .get("rope_scaling")
222            .and_then(|v| serde_json::from_value(v.clone()).ok());
223
224        // Detect norm type
225        let norm_type = if obj.get("rms_norm_eps").is_some() {
226            NormType::RMSNorm
227        } else {
228            NormType::LayerNorm
229        };
230
231        let norm_eps = obj
232            .get("rms_norm_eps")
233            .or_else(|| obj.get("layer_norm_eps"))
234            .or_else(|| obj.get("layer_norm_epsilon"))
235            .and_then(|v| v.as_f64())
236            .unwrap_or(1e-6);
237
238        // Parse attention config
239        let attention_bias = obj
240            .get("attention_bias")
241            .and_then(|v| v.as_bool())
242            .unwrap_or(false);
243
244        let sliding_window = obj
245            .get("sliding_window")
246            .and_then(|v| v.as_u64())
247            .map(|v| v as usize);
248
249        // Detect activation
250        let activation = obj
251            .get("hidden_act")
252            .and_then(|v| v.as_str())
253            .map(|s| match s {
254                "gelu" | "gelu_new" => Activation::GELU,
255                "silu" => Activation::SiLU,
256                "relu" => Activation::ReLU,
257                "swish" => Activation::Swish,
258                _ => {
259                    warn!("Unknown activation function: {}, defaulting to SiLU", s);
260                    Activation::SiLU
261                }
262            })
263            .unwrap_or(Activation::SiLU);
264
265        Ok(ModelDefinition {
266            architecture,
267            hidden_size,
268            intermediate_size,
269            vocab_size,
270            num_hidden_layers,
271            num_attention_heads,
272            num_key_value_heads,
273            max_position_embeddings,
274            rope_theta,
275            rope_scaling,
276            norm_type,
277            norm_eps,
278            attention_config: AttentionConfig {
279                attention_bias,
280                sliding_window,
281            },
282            activation,
283            extra_params: raw.clone(),
284        })
285    }
286
287    /// Detect architecture from config
288    fn detect_architecture(&self, config: &serde_json::Value) -> Result<Architecture> {
289        let obj = config
290            .as_object()
291            .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
292
293        // Try model_type field
294        if let Some(model_type) = obj.get("model_type").and_then(|v| v.as_str()) {
295            return Ok(Architecture::from_str(model_type));
296        }
297
298        // Try architectures array
299        if let Some(architectures) = obj.get("architectures").and_then(|v| v.as_array()) {
300            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
301                return Ok(Architecture::from_str(arch));
302            }
303        }
304
305        warn!("Could not detect architecture, using default (Llama)");
306        Ok(Architecture::Llama)
307    }
308
309    /// Infer model type from definition
310    pub fn infer_model_type(&self, definition: &ModelDefinition) -> ModelType {
311        ModelType::Custom(format!("{:?}", definition.architecture))
312    }
313}