Skip to main content

ferrum_models/
definition.rs

1//! Model definition and configuration parsing
2
3use crate::{registry::Architecture, source::ResolvedModelSource};
4use ferrum_types::{
5    Activation, AttentionConfig, FerrumError, ModelInfo, ModelType, NormType, Result, RopeScaling,
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use tracing::{debug, warn};
11
12/// Model definition from config.json
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ModelDefinition {
15    /// Architecture type
16    pub architecture: Architecture,
17    /// Hidden size (embedding dimension)
18    pub hidden_size: usize,
19    /// Intermediate size (FFN dimension)
20    pub intermediate_size: usize,
21    /// Vocabulary size
22    pub vocab_size: usize,
23    /// Number of hidden layers
24    pub num_hidden_layers: usize,
25    /// Number of attention heads
26    pub num_attention_heads: usize,
27    /// Number of key-value heads (for GQA)
28    pub num_key_value_heads: Option<usize>,
29    /// Maximum position embeddings
30    pub max_position_embeddings: usize,
31    /// RoPE theta (frequency base)
32    pub rope_theta: Option<f64>,
33    /// RoPE scaling config
34    pub rope_scaling: Option<RopeScaling>,
35    /// Normalization type
36    pub norm_type: NormType,
37    /// Normalization epsilon
38    pub norm_eps: f64,
39    /// Attention configuration
40    pub attention_config: AttentionConfig,
41    /// Activation function
42    pub activation: Activation,
43    /// Extra parameters
44    #[serde(flatten)]
45    pub extra_params: serde_json::Value,
46}
47
48impl Default for ModelDefinition {
49    fn default() -> Self {
50        Self {
51            architecture: Architecture::Llama,
52            hidden_size: 4096,
53            intermediate_size: 11008,
54            vocab_size: 32000,
55            num_hidden_layers: 32,
56            num_attention_heads: 32,
57            num_key_value_heads: None,
58            max_position_embeddings: 2048,
59            rope_theta: Some(10000.0),
60            rope_scaling: None,
61            norm_type: NormType::RMSNorm,
62            norm_eps: 1e-6,
63            attention_config: AttentionConfig {
64                attention_bias: false,
65                sliding_window: None,
66            },
67            activation: Activation::SiLU,
68            extra_params: serde_json::Value::Object(serde_json::Map::new()),
69        }
70    }
71}
72
73impl ModelDefinition {
74    /// Convert to ModelInfo
75    pub fn to_model_info(&self, model_id: impl Into<String>) -> ModelInfo {
76        use ferrum_types::{DataType, Device};
77
78        let model_id_str = model_id.into();
79
80        // Calculate approximate parameter count
81        let params = self.estimate_parameters();
82
83        ModelInfo {
84            model_id: ferrum_types::ModelId::new(model_id_str.clone()),
85            model_type: ModelType::Custom(format!("{:?}", self.architecture)),
86            num_parameters: params as u64,
87            hidden_size: self.hidden_size,
88            num_layers: self.num_hidden_layers,
89            num_heads: self.num_attention_heads,
90            num_kv_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
91            vocab_size: self.vocab_size,
92            max_sequence_length: self.max_position_embeddings,
93            dtype: DataType::FP16, // Default, can be overridden
94            device: Device::CPU,   // Default, will be set by backend
95            version: None,
96            license: None,
97            metadata: HashMap::new(),
98        }
99    }
100
101    /// Estimate parameter count
102    fn estimate_parameters(&self) -> usize {
103        // Rough estimation based on typical transformer architecture
104        let embedding_params = self.vocab_size * self.hidden_size;
105        let layer_params = self.num_hidden_layers
106            * (
107                // Attention: Q, K, V, O projections
108                4 * self.hidden_size * self.hidden_size +
109            // FFN: up, down, gate (if applicable)
110            3 * self.hidden_size * self.intermediate_size +
111            // Layer norms
112            2 * self.hidden_size
113            );
114        let lm_head_params = self.vocab_size * self.hidden_size;
115
116        embedding_params + layer_params + lm_head_params
117    }
118}
119
120/// Configuration manager for loading and parsing model configs
121#[derive(Debug, Default)]
122pub struct ConfigManager {
123    _cache: HashMap<String, ModelDefinition>,
124}
125
126impl ConfigManager {
127    pub fn new() -> Self {
128        Self {
129            _cache: HashMap::new(),
130        }
131    }
132
133    /// Load model definition from a resolved source
134    pub async fn load_from_source(
135        &mut self,
136        source: &ResolvedModelSource,
137    ) -> Result<ModelDefinition> {
138        self.load_from_path(&source.local_path).await
139    }
140
141    /// Load model definition from a directory path
142    pub async fn load_from_path(&mut self, path: &Path) -> Result<ModelDefinition> {
143        let config_path = path.join("config.json");
144
145        if !config_path.exists() {
146            return Err(FerrumError::model(format!(
147                "config.json not found in model directory: {:?}",
148                path
149            )));
150        }
151
152        debug!("Loading model config from: {:?}", config_path);
153
154        let content = tokio::fs::read_to_string(&config_path)
155            .await
156            .map_err(|e| FerrumError::io(format!("Failed to read config.json: {}", e)))?;
157
158        let raw_config: serde_json::Value = serde_json::from_str(&content)
159            .map_err(|e| FerrumError::model(format!("Failed to parse config.json: {}", e)))?;
160
161        self.parse_config(&raw_config)
162    }
163
164    /// Parse config from JSON value
165    fn parse_config(&mut self, raw: &serde_json::Value) -> Result<ModelDefinition> {
166        let obj = raw
167            .as_object()
168            .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
169
170        // Detect architecture
171        let architecture = self.detect_architecture(raw)?;
172
173        // Parse common fields (CLIP stores these in text_config/vision_config)
174        let text_cfg = obj.get("text_config");
175        let hidden_size = obj
176            .get("hidden_size")
177            .and_then(|v| v.as_u64())
178            .or_else(|| {
179                text_cfg
180                    .and_then(|tc| tc.get("hidden_size"))
181                    .and_then(|v| v.as_u64())
182            })
183            .unwrap_or(4096) as usize;
184
185        let intermediate_size = obj
186            .get("intermediate_size")
187            .and_then(|v| v.as_u64())
188            .or_else(|| obj.get("ffn_dim").and_then(|v| v.as_u64()))
189            .unwrap_or(11008) as usize;
190
191        // CLIP models store vocab_size in text_config, not at top level
192        let vocab_size = obj
193            .get("vocab_size")
194            .and_then(|v| v.as_u64())
195            .or_else(|| {
196                text_cfg
197                    .and_then(|tc| tc.get("vocab_size"))
198                    .and_then(|v| v.as_u64())
199            })
200            .unwrap_or(0) as usize;
201
202        let num_hidden_layers = obj
203            .get("num_hidden_layers")
204            .and_then(|v| v.as_u64())
205            .or_else(|| obj.get("n_layer").and_then(|v| v.as_u64()))
206            .unwrap_or(32) as usize;
207
208        let num_attention_heads = obj
209            .get("num_attention_heads")
210            .and_then(|v| v.as_u64())
211            .or_else(|| obj.get("n_head").and_then(|v| v.as_u64()))
212            .unwrap_or(32) as usize;
213
214        let num_key_value_heads = obj
215            .get("num_key_value_heads")
216            .and_then(|v| v.as_u64())
217            .map(|v| v as usize);
218
219        let max_position_embeddings = obj
220            .get("max_position_embeddings")
221            .and_then(|v| v.as_u64())
222            .or_else(|| obj.get("n_positions").and_then(|v| v.as_u64()))
223            .unwrap_or(2048) as usize;
224
225        let rope_theta = obj
226            .get("rope_theta")
227            .and_then(|v| v.as_f64())
228            .or_else(|| obj.get("rotary_emb_base").and_then(|v| v.as_f64()));
229
230        // Parse RoPE scaling
231        let rope_scaling = obj
232            .get("rope_scaling")
233            .and_then(|v| serde_json::from_value(v.clone()).ok());
234
235        // Detect norm type
236        let norm_type = if obj.get("rms_norm_eps").is_some() {
237            NormType::RMSNorm
238        } else {
239            NormType::LayerNorm
240        };
241
242        let norm_eps = obj
243            .get("rms_norm_eps")
244            .or_else(|| obj.get("layer_norm_eps"))
245            .or_else(|| obj.get("layer_norm_epsilon"))
246            .and_then(|v| v.as_f64())
247            .unwrap_or(1e-6);
248
249        // Parse attention config
250        let attention_bias = obj
251            .get("attention_bias")
252            .and_then(|v| v.as_bool())
253            .unwrap_or(false);
254
255        let sliding_window = obj
256            .get("sliding_window")
257            .and_then(|v| v.as_u64())
258            .map(|v| v as usize);
259
260        // Detect activation
261        let activation = obj
262            .get("hidden_act")
263            .and_then(|v| v.as_str())
264            .map(|s| match s {
265                "gelu" | "gelu_new" => Activation::GELU,
266                "silu" => Activation::SiLU,
267                "relu" => Activation::ReLU,
268                "swish" => Activation::Swish,
269                _ => {
270                    warn!("Unknown activation function: {}, defaulting to SiLU", s);
271                    Activation::SiLU
272                }
273            })
274            .unwrap_or(Activation::SiLU);
275
276        Ok(ModelDefinition {
277            architecture,
278            hidden_size,
279            intermediate_size,
280            vocab_size,
281            num_hidden_layers,
282            num_attention_heads,
283            num_key_value_heads,
284            max_position_embeddings,
285            rope_theta,
286            rope_scaling,
287            norm_type,
288            norm_eps,
289            attention_config: AttentionConfig {
290                attention_bias,
291                sliding_window,
292            },
293            activation,
294            extra_params: raw.clone(),
295        })
296    }
297
298    /// Detect architecture from config
299    fn detect_architecture(&self, config: &serde_json::Value) -> Result<Architecture> {
300        let obj = config
301            .as_object()
302            .ok_or_else(|| FerrumError::model("config.json root is not an object"))?;
303
304        // Try model_type field
305        if let Some(model_type) = obj.get("model_type").and_then(|v| v.as_str()) {
306            return Ok(Architecture::from_str(model_type));
307        }
308
309        // Try architectures array
310        if let Some(architectures) = obj.get("architectures").and_then(|v| v.as_array()) {
311            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
312                return Ok(Architecture::from_str(arch));
313            }
314        }
315
316        warn!("Could not detect architecture, using default (Llama)");
317        Ok(Architecture::Llama)
318    }
319
320    /// Infer model type from definition
321    pub fn infer_model_type(&self, definition: &ModelDefinition) -> ModelType {
322        ModelType::Custom(format!("{:?}", definition.architecture))
323    }
324}