Skip to main content

ferrum_types/
models.rs

1//! Model-related types and configurations
2
3use crate::{devices::*, ids::ModelId, FerrumError, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Model type enumeration
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub enum ModelType {
10    /// LLaMA family models
11    Llama,
12    /// Mistral family models
13    Mistral,
14    /// Qwen family models
15    Qwen,
16    /// Phi family models
17    Phi,
18    /// Gemma family models
19    Gemma,
20    /// Code-specific models
21    Code(String),
22    /// Embedding models (BERT, etc.)
23    Embedding,
24    /// CLIP vision-language models
25    Clip,
26    /// Custom model implementation
27    Custom(String),
28}
29
30impl std::fmt::Display for ModelType {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            ModelType::Llama => write!(f, "llama"),
34            ModelType::Mistral => write!(f, "mistral"),
35            ModelType::Qwen => write!(f, "qwen"),
36            ModelType::Phi => write!(f, "phi"),
37            ModelType::Gemma => write!(f, "gemma"),
38            ModelType::Embedding => write!(f, "embedding"),
39            ModelType::Clip => write!(f, "clip"),
40            ModelType::Code(name) => write!(f, "code-{}", name),
41            ModelType::Custom(name) => write!(f, "custom-{}", name),
42        }
43    }
44}
45
46/// Model information and metadata
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelInfo {
49    /// Model identifier
50    pub model_id: ModelId,
51    /// Model type/architecture
52    pub model_type: ModelType,
53    /// Number of parameters
54    pub num_parameters: u64,
55    /// Hidden dimension size
56    pub hidden_size: usize,
57    /// Number of transformer layers
58    pub num_layers: usize,
59    /// Number of attention heads
60    pub num_heads: usize,
61    /// Number of key-value heads (for GQA)
62    pub num_kv_heads: usize,
63    /// Vocabulary size
64    pub vocab_size: usize,
65    /// Maximum sequence length
66    pub max_sequence_length: usize,
67    /// Data type used by the model
68    pub dtype: DataType,
69    /// Device where model is loaded
70    pub device: Device,
71    /// Model version or revision
72    pub version: Option<String>,
73    /// Model license
74    pub license: Option<String>,
75    /// Additional model metadata
76    pub metadata: HashMap<String, serde_json::Value>,
77}
78
79impl ModelInfo {
80    /// Calculate approximate model size in bytes
81    pub fn estimated_size_bytes(&self) -> u64 {
82        // Rough estimation: parameters * dtype size + some overhead
83        let param_size = self.num_parameters * self.dtype.size_bytes() as u64;
84        // Add ~20% overhead for embeddings, activations, etc.
85        (param_size as f64 * 1.2) as u64
86    }
87
88    /// Check if model supports a specific sequence length
89    pub fn supports_sequence_length(&self, length: usize) -> bool {
90        length <= self.max_sequence_length
91    }
92
93    /// Get memory requirements for inference
94    pub fn memory_requirements(
95        &self,
96        batch_size: usize,
97        sequence_length: usize,
98    ) -> ModelMemoryRequirements {
99        let param_memory = self.estimated_size_bytes();
100
101        // Estimate KV cache size: layers * heads * seq_len * head_dim * 2 (key + value) * dtype * batch_size
102        let head_dim = self.hidden_size / self.num_heads;
103        let kv_cache_per_token =
104            self.num_layers * self.num_kv_heads * head_dim * 2 * self.dtype.size_bytes();
105        let kv_cache_memory = (kv_cache_per_token * sequence_length * batch_size) as u64;
106
107        // Estimate activation memory (rough approximation)
108        let activation_memory =
109            (self.hidden_size * sequence_length * batch_size * self.dtype.size_bytes()) as u64 * 4;
110
111        ModelMemoryRequirements {
112            parameter_memory: param_memory,
113            kv_cache_memory,
114            activation_memory,
115            total_estimated: param_memory + kv_cache_memory + activation_memory,
116        }
117    }
118}
119
120/// Memory requirements for model inference
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ModelMemoryRequirements {
123    /// Memory required for model parameters
124    pub parameter_memory: u64,
125    /// Memory required for KV cache
126    pub kv_cache_memory: u64,
127    /// Memory required for activations
128    pub activation_memory: u64,
129    /// Total estimated memory requirement
130    pub total_estimated: u64,
131}
132
133/// Model configuration for runtime
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ModelConfig {
136    /// Model identifier
137    pub model_id: ModelId,
138    /// Path to model files
139    pub model_path: String,
140    /// Model type/architecture
141    pub model_type: ModelType,
142    /// Data type to use for inference
143    pub dtype: DataType,
144    /// Target device
145    pub device: Device,
146    /// Maximum batch size
147    pub max_batch_size: usize,
148    /// Maximum sequence length
149    pub max_sequence_length: usize,
150    /// Tensor parallelism size
151    pub tensor_parallel_size: Option<usize>,
152    /// Pipeline parallelism size  
153    pub pipeline_parallel_size: Option<usize>,
154    /// Quantization configuration
155    pub quantization: Option<QuantizationConfig>,
156    /// Use flash attention if available
157    pub use_flash_attention: bool,
158    /// Use paged attention for KV cache
159    pub use_paged_attention: bool,
160    /// Enable CUDA graphs for optimization
161    pub enable_cuda_graphs: bool,
162    /// Additional configuration parameters
163    pub extra_config: HashMap<String, serde_json::Value>,
164}
165
166impl ModelConfig {
167    /// Create a new model configuration
168    pub fn new(model_id: impl Into<ModelId>, model_path: impl Into<String>) -> Self {
169        Self {
170            model_id: model_id.into(),
171            model_path: model_path.into(),
172            model_type: ModelType::Custom("unknown".to_string()),
173            dtype: DataType::FP16,
174            device: Device::CPU,
175            max_batch_size: 1,
176            max_sequence_length: 2048,
177            tensor_parallel_size: None,
178            pipeline_parallel_size: None,
179            quantization: None,
180            use_flash_attention: false,
181            use_paged_attention: false,
182            enable_cuda_graphs: false,
183            extra_config: HashMap::new(),
184        }
185    }
186
187    /// Validate the configuration
188    pub fn validate(&self) -> Result<()> {
189        if self.model_path.is_empty() {
190            return Err(FerrumError::config("Model path cannot be empty"));
191        }
192
193        if self.max_batch_size == 0 {
194            return Err(FerrumError::config("Max batch size must be positive"));
195        }
196
197        if self.max_sequence_length == 0 {
198            return Err(FerrumError::config("Max sequence length must be positive"));
199        }
200
201        if let Some(tp_size) = self.tensor_parallel_size {
202            if tp_size == 0 {
203                return Err(FerrumError::config("Tensor parallel size must be positive"));
204            }
205        }
206
207        if let Some(pp_size) = self.pipeline_parallel_size {
208            if pp_size == 0 {
209                return Err(FerrumError::config(
210                    "Pipeline parallel size must be positive",
211                ));
212            }
213        }
214
215        Ok(())
216    }
217}
218
219/// Quantization configuration
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub enum QuantizationConfig {
222    /// GPTQ quantization
223    GPTQ {
224        bits: u8,
225        group_size: usize,
226        desc_act: bool,
227    },
228    /// AWQ quantization
229    AWQ {
230        bits: u8,
231        zero_point: bool,
232        version: String,
233    },
234    /// FP8 quantization
235    FP8 { e4m3: bool, kv_cache: bool },
236    /// INT8 quantization
237    INT8 { symmetric: bool, per_channel: bool },
238    /// INT4 quantization
239    INT4 { symmetric: bool, group_size: usize },
240    /// SmoothQuant
241    SmoothQuant { alpha: f32, calibration_size: usize },
242}
243
244impl QuantizationConfig {
245    /// Get the number of bits used by this quantization method
246    pub fn bits(&self) -> u8 {
247        match self {
248            QuantizationConfig::GPTQ { bits, .. } => *bits,
249            QuantizationConfig::AWQ { bits, .. } => *bits,
250            QuantizationConfig::FP8 { .. } => 8,
251            QuantizationConfig::INT8 { .. } => 8,
252            QuantizationConfig::INT4 { .. } => 4,
253            QuantizationConfig::SmoothQuant { .. } => 8,
254        }
255    }
256
257    /// Check if this quantization preserves accuracy well
258    pub fn is_high_accuracy(&self) -> bool {
259        match self {
260            QuantizationConfig::FP8 { .. } => true,
261            QuantizationConfig::INT8 { .. } => true,
262            QuantizationConfig::SmoothQuant { .. } => true,
263            _ => false,
264        }
265    }
266}
267
268/// Token usage statistics
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct TokenUsage {
271    /// Number of tokens in the prompt
272    pub prompt_tokens: usize,
273    /// Number of tokens generated
274    pub completion_tokens: usize,
275    /// Total tokens processed
276    pub total_tokens: usize,
277}
278
279impl TokenUsage {
280    /// Create new token usage
281    pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
282        Self {
283            prompt_tokens,
284            completion_tokens,
285            total_tokens: prompt_tokens + completion_tokens,
286        }
287    }
288
289    /// Add completion tokens
290    pub fn add_completion_tokens(&mut self, tokens: usize) {
291        self.completion_tokens += tokens;
292        self.total_tokens = self.prompt_tokens + self.completion_tokens;
293    }
294}
295
296/// RoPE (Rotary Position Embedding) scaling configuration
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct RopeScaling {
299    /// Type of scaling: "linear", "dynamic", etc.
300    pub scaling_type: String,
301    /// Scaling factor
302    pub factor: f32,
303}
304
305/// Normalization type used in the model
306#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
307pub enum NormType {
308    /// Layer Normalization
309    LayerNorm,
310    /// Root Mean Square Normalization
311    RMSNorm,
312}
313
314/// Activation function type
315#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
316pub enum Activation {
317    /// Gaussian Error Linear Unit
318    GELU,
319    /// Sigmoid Linear Unit  
320    SiLU,
321    /// Rectified Linear Unit
322    ReLU,
323    /// Swish activation
324    Swish,
325}
326
327/// Attention configuration for model architecture
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct AttentionConfig {
330    /// Whether attention uses bias
331    pub attention_bias: bool,
332    /// Sliding window size (None for full attention)
333    pub sliding_window: Option<usize>,
334}
335
336impl Default for AttentionConfig {
337    fn default() -> Self {
338        Self {
339            attention_bias: false,
340            sliding_window: None,
341        }
342    }
343}
344
345/// Model loading source specification
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub enum ModelSource {
348    /// Local file path
349    Local(String),
350    /// Hugging Face Hub model
351    HuggingFace {
352        repo_id: String,
353        revision: Option<String>,
354        cache_dir: Option<String>,
355    },
356    /// URL download
357    Url {
358        url: String,
359        headers: HashMap<String, String>,
360    },
361    /// S3-compatible storage
362    S3 {
363        bucket: String,
364        key: String,
365        region: Option<String>,
366        endpoint: Option<String>,
367    },
368}