Skip to main content

ferrum_types/
models.rs

1//! Model-related types and configurations
2
3use crate::{devices::*, ids::ModelId, FerrumError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Model type enumeration
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ModelType {
10    /// LLaMA family models
11    Llama,
12    /// Mistral family models
13    Mistral,
14    /// Qwen family models
15    Qwen,
16    /// Phi family models
17    Phi,
18    /// Gemma family models
19    Gemma,
20    /// Code-specific models
21    Code(String),
22    /// Custom model implementation
23    Custom(String),
24}
25
26impl std::fmt::Display for ModelType {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            ModelType::Llama => write!(f, "llama"),
30            ModelType::Mistral => write!(f, "mistral"),
31            ModelType::Qwen => write!(f, "qwen"),
32            ModelType::Phi => write!(f, "phi"),
33            ModelType::Gemma => write!(f, "gemma"),
34            ModelType::Code(name) => write!(f, "code-{}", name),
35            ModelType::Custom(name) => write!(f, "custom-{}", name),
36        }
37    }
38}
39
40/// Model information and metadata
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ModelInfo {
43    /// Model identifier
44    pub model_id: ModelId,
45    /// Model type/architecture
46    pub model_type: ModelType,
47    /// Number of parameters
48    pub num_parameters: u64,
49    /// Hidden dimension size
50    pub hidden_size: usize,
51    /// Number of transformer layers
52    pub num_layers: usize,
53    /// Number of attention heads
54    pub num_heads: usize,
55    /// Number of key-value heads (for GQA)
56    pub num_kv_heads: usize,
57    /// Vocabulary size
58    pub vocab_size: usize,
59    /// Maximum sequence length
60    pub max_sequence_length: usize,
61    /// Data type used by the model
62    pub dtype: DataType,
63    /// Device where model is loaded
64    pub device: Device,
65    /// Model version or revision
66    pub version: Option<String>,
67    /// Model license
68    pub license: Option<String>,
69    /// Additional model metadata
70    pub metadata: HashMap<String, serde_json::Value>,
71}
72
73impl ModelInfo {
74    /// Calculate approximate model size in bytes
75    pub fn estimated_size_bytes(&self) -> u64 {
76        // Rough estimation: parameters * dtype size + some overhead
77        let param_size = self.num_parameters * self.dtype.size_bytes() as u64;
78        // Add ~20% overhead for embeddings, activations, etc.
79        (param_size as f64 * 1.2) as u64
80    }
81
82    /// Check if model supports a specific sequence length
83    pub fn supports_sequence_length(&self, length: usize) -> bool {
84        length <= self.max_sequence_length
85    }
86
87    /// Get memory requirements for inference
88    pub fn memory_requirements(
89        &self,
90        batch_size: usize,
91        sequence_length: usize,
92    ) -> ModelMemoryRequirements {
93        let param_memory = self.estimated_size_bytes();
94
95        // Estimate KV cache size: layers * heads * seq_len * head_dim * 2 (key + value) * dtype * batch_size
96        let head_dim = self.hidden_size / self.num_heads;
97        let kv_cache_per_token =
98            self.num_layers * self.num_kv_heads * head_dim * 2 * self.dtype.size_bytes();
99        let kv_cache_memory = (kv_cache_per_token * sequence_length * batch_size) as u64;
100
101        // Estimate activation memory (rough approximation)
102        let activation_memory =
103            (self.hidden_size * sequence_length * batch_size * self.dtype.size_bytes()) as u64 * 4;
104
105        ModelMemoryRequirements {
106            parameter_memory: param_memory,
107            kv_cache_memory,
108            activation_memory,
109            total_estimated: param_memory + kv_cache_memory + activation_memory,
110        }
111    }
112}
113
114/// Memory requirements for model inference
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelMemoryRequirements {
117    /// Memory required for model parameters
118    pub parameter_memory: u64,
119    /// Memory required for KV cache
120    pub kv_cache_memory: u64,
121    /// Memory required for activations
122    pub activation_memory: u64,
123    /// Total estimated memory requirement
124    pub total_estimated: u64,
125}
126
127/// Model configuration for runtime
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ModelConfig {
130    /// Model identifier
131    pub model_id: ModelId,
132    /// Path to model files
133    pub model_path: String,
134    /// Model type/architecture
135    pub model_type: ModelType,
136    /// Data type to use for inference
137    pub dtype: DataType,
138    /// Target device
139    pub device: Device,
140    /// Maximum batch size
141    pub max_batch_size: usize,
142    /// Maximum sequence length
143    pub max_sequence_length: usize,
144    /// Tensor parallelism size
145    pub tensor_parallel_size: Option<usize>,
146    /// Pipeline parallelism size  
147    pub pipeline_parallel_size: Option<usize>,
148    /// Quantization configuration
149    pub quantization: Option<QuantizationConfig>,
150    /// Use flash attention if available
151    pub use_flash_attention: bool,
152    /// Use paged attention for KV cache
153    pub use_paged_attention: bool,
154    /// Enable CUDA graphs for optimization
155    pub enable_cuda_graphs: bool,
156    /// Additional configuration parameters
157    pub extra_config: HashMap<String, serde_json::Value>,
158}
159
160impl ModelConfig {
161    /// Create a new model configuration
162    pub fn new(model_id: impl Into<ModelId>, model_path: impl Into<String>) -> Self {
163        Self {
164            model_id: model_id.into(),
165            model_path: model_path.into(),
166            model_type: ModelType::Custom("unknown".to_string()),
167            dtype: DataType::FP16,
168            device: Device::CPU,
169            max_batch_size: 1,
170            max_sequence_length: 2048,
171            tensor_parallel_size: None,
172            pipeline_parallel_size: None,
173            quantization: None,
174            use_flash_attention: false,
175            use_paged_attention: false,
176            enable_cuda_graphs: false,
177            extra_config: HashMap::new(),
178        }
179    }
180
181    /// Validate the configuration
182    pub fn validate(&self) -> Result<()> {
183        if self.model_path.is_empty() {
184            return Err(FerrumError::config("Model path cannot be empty"));
185        }
186
187        if self.max_batch_size == 0 {
188            return Err(FerrumError::config("Max batch size must be positive"));
189        }
190
191        if self.max_sequence_length == 0 {
192            return Err(FerrumError::config("Max sequence length must be positive"));
193        }
194
195        if let Some(tp_size) = self.tensor_parallel_size {
196            if tp_size == 0 {
197                return Err(FerrumError::config("Tensor parallel size must be positive"));
198            }
199        }
200
201        if let Some(pp_size) = self.pipeline_parallel_size {
202            if pp_size == 0 {
203                return Err(FerrumError::config(
204                    "Pipeline parallel size must be positive",
205                ));
206            }
207        }
208
209        Ok(())
210    }
211}
212
213/// Quantization configuration
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub enum QuantizationConfig {
216    /// GPTQ quantization
217    GPTQ {
218        bits: u8,
219        group_size: usize,
220        desc_act: bool,
221    },
222    /// AWQ quantization
223    AWQ {
224        bits: u8,
225        zero_point: bool,
226        version: String,
227    },
228    /// FP8 quantization
229    FP8 { e4m3: bool, kv_cache: bool },
230    /// INT8 quantization
231    INT8 { symmetric: bool, per_channel: bool },
232    /// INT4 quantization
233    INT4 { symmetric: bool, group_size: usize },
234    /// SmoothQuant
235    SmoothQuant { alpha: f32, calibration_size: usize },
236}
237
238impl QuantizationConfig {
239    /// Get the number of bits used by this quantization method
240    pub fn bits(&self) -> u8 {
241        match self {
242            QuantizationConfig::GPTQ { bits, .. } => *bits,
243            QuantizationConfig::AWQ { bits, .. } => *bits,
244            QuantizationConfig::FP8 { .. } => 8,
245            QuantizationConfig::INT8 { .. } => 8,
246            QuantizationConfig::INT4 { .. } => 4,
247            QuantizationConfig::SmoothQuant { .. } => 8,
248        }
249    }
250
251    /// Check if this quantization preserves accuracy well
252    pub fn is_high_accuracy(&self) -> bool {
253        match self {
254            QuantizationConfig::FP8 { .. } => true,
255            QuantizationConfig::INT8 { .. } => true,
256            QuantizationConfig::SmoothQuant { .. } => true,
257            _ => false,
258        }
259    }
260}
261
262/// Token usage statistics
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct TokenUsage {
265    /// Number of tokens in the prompt
266    pub prompt_tokens: usize,
267    /// Number of tokens generated
268    pub completion_tokens: usize,
269    /// Total tokens processed
270    pub total_tokens: usize,
271}
272
273impl TokenUsage {
274    /// Create new token usage
275    pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
276        Self {
277            prompt_tokens,
278            completion_tokens,
279            total_tokens: prompt_tokens + completion_tokens,
280        }
281    }
282
283    /// Add completion tokens
284    pub fn add_completion_tokens(&mut self, tokens: usize) {
285        self.completion_tokens += tokens;
286        self.total_tokens = self.prompt_tokens + self.completion_tokens;
287    }
288}
289
290/// RoPE (Rotary Position Embedding) scaling configuration
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct RopeScaling {
293    /// Type of scaling: "linear", "dynamic", etc.
294    pub scaling_type: String,
295    /// Scaling factor
296    pub factor: f32,
297}
298
299/// Normalization type used in the model
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
301pub enum NormType {
302    /// Layer Normalization
303    LayerNorm,
304    /// Root Mean Square Normalization
305    RMSNorm,
306}
307
308/// Activation function type
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
310pub enum Activation {
311    /// Gaussian Error Linear Unit
312    GELU,
313    /// Sigmoid Linear Unit  
314    SiLU,
315    /// Rectified Linear Unit
316    ReLU,
317    /// Swish activation
318    Swish,
319}
320
321/// Attention configuration for model architecture
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct AttentionConfig {
324    /// Whether attention uses bias
325    pub attention_bias: bool,
326    /// Sliding window size (None for full attention)
327    pub sliding_window: Option<usize>,
328}
329
330impl Default for AttentionConfig {
331    fn default() -> Self {
332        Self {
333            attention_bias: false,
334            sliding_window: None,
335        }
336    }
337}
338
339/// Model loading source specification
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub enum ModelSource {
342    /// Local file path
343    Local(String),
344    /// Hugging Face Hub model
345    HuggingFace {
346        repo_id: String,
347        revision: Option<String>,
348        cache_dir: Option<String>,
349    },
350    /// URL download
351    Url {
352        url: String,
353        headers: HashMap<String, String>,
354    },
355    /// S3-compatible storage
356    S3 {
357        bucket: String,
358        key: String,
359        region: Option<String>,
360        endpoint: Option<String>,
361    },
362}