use crate::{devices::*, ids::ModelId, FerrumError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
Llama,
Mistral,
Qwen,
Phi,
Gemma,
Code(String),
Embedding,
Clip,
Custom(String),
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::Llama => write!(f, "llama"),
ModelType::Mistral => write!(f, "mistral"),
ModelType::Qwen => write!(f, "qwen"),
ModelType::Phi => write!(f, "phi"),
ModelType::Gemma => write!(f, "gemma"),
ModelType::Embedding => write!(f, "embedding"),
ModelType::Clip => write!(f, "clip"),
ModelType::Code(name) => write!(f, "code-{}", name),
ModelType::Custom(name) => write!(f, "custom-{}", name),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub model_id: ModelId,
pub model_type: ModelType,
pub num_parameters: u64,
pub hidden_size: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub vocab_size: usize,
pub max_sequence_length: usize,
pub dtype: DataType,
pub device: Device,
pub version: Option<String>,
pub license: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ModelInfo {
pub fn estimated_size_bytes(&self) -> u64 {
let param_size = self.num_parameters * self.dtype.size_bytes() as u64;
(param_size as f64 * 1.2) as u64
}
pub fn supports_sequence_length(&self, length: usize) -> bool {
length <= self.max_sequence_length
}
pub fn memory_requirements(
&self,
batch_size: usize,
sequence_length: usize,
) -> ModelMemoryRequirements {
let param_memory = self.estimated_size_bytes();
let head_dim = self.hidden_size / self.num_heads;
let kv_cache_per_token =
self.num_layers * self.num_kv_heads * head_dim * 2 * self.dtype.size_bytes();
let kv_cache_memory = (kv_cache_per_token * sequence_length * batch_size) as u64;
let activation_memory =
(self.hidden_size * sequence_length * batch_size * self.dtype.size_bytes()) as u64 * 4;
ModelMemoryRequirements {
parameter_memory: param_memory,
kv_cache_memory,
activation_memory,
total_estimated: param_memory + kv_cache_memory + activation_memory,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMemoryRequirements {
pub parameter_memory: u64,
pub kv_cache_memory: u64,
pub activation_memory: u64,
pub total_estimated: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub model_id: ModelId,
pub model_path: String,
pub model_type: ModelType,
pub dtype: DataType,
pub device: Device,
pub max_batch_size: usize,
pub max_sequence_length: usize,
pub tensor_parallel_size: Option<usize>,
pub pipeline_parallel_size: Option<usize>,
pub quantization: Option<QuantizationConfig>,
pub use_flash_attention: bool,
pub use_paged_attention: bool,
pub enable_cuda_graphs: bool,
pub extra_config: HashMap<String, serde_json::Value>,
}
impl ModelConfig {
pub fn new(model_id: impl Into<ModelId>, model_path: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
model_path: model_path.into(),
model_type: ModelType::Custom("unknown".to_string()),
dtype: DataType::FP16,
device: Device::CPU,
max_batch_size: 1,
max_sequence_length: 2048,
tensor_parallel_size: None,
pipeline_parallel_size: None,
quantization: None,
use_flash_attention: false,
use_paged_attention: false,
enable_cuda_graphs: false,
extra_config: HashMap::new(),
}
}
pub fn validate(&self) -> Result<()> {
if self.model_path.is_empty() {
return Err(FerrumError::config("Model path cannot be empty"));
}
if self.max_batch_size == 0 {
return Err(FerrumError::config("Max batch size must be positive"));
}
if self.max_sequence_length == 0 {
return Err(FerrumError::config("Max sequence length must be positive"));
}
if let Some(tp_size) = self.tensor_parallel_size {
if tp_size == 0 {
return Err(FerrumError::config("Tensor parallel size must be positive"));
}
}
if let Some(pp_size) = self.pipeline_parallel_size {
if pp_size == 0 {
return Err(FerrumError::config(
"Pipeline parallel size must be positive",
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationConfig {
GPTQ {
bits: u8,
group_size: usize,
desc_act: bool,
},
AWQ {
bits: u8,
zero_point: bool,
version: String,
},
FP8 { e4m3: bool, kv_cache: bool },
INT8 { symmetric: bool, per_channel: bool },
INT4 { symmetric: bool, group_size: usize },
SmoothQuant { alpha: f32, calibration_size: usize },
}
impl QuantizationConfig {
pub fn bits(&self) -> u8 {
match self {
QuantizationConfig::GPTQ { bits, .. } => *bits,
QuantizationConfig::AWQ { bits, .. } => *bits,
QuantizationConfig::FP8 { .. } => 8,
QuantizationConfig::INT8 { .. } => 8,
QuantizationConfig::INT4 { .. } => 4,
QuantizationConfig::SmoothQuant { .. } => 8,
}
}
pub fn is_high_accuracy(&self) -> bool {
match self {
QuantizationConfig::FP8 { .. } => true,
QuantizationConfig::INT8 { .. } => true,
QuantizationConfig::SmoothQuant { .. } => true,
_ => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
impl TokenUsage {
pub fn new(prompt_tokens: usize, completion_tokens: usize) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
pub fn add_completion_tokens(&mut self, tokens: usize) {
self.completion_tokens += tokens;
self.total_tokens = self.prompt_tokens + self.completion_tokens;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RopeScaling {
pub scaling_type: String,
pub factor: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NormType {
LayerNorm,
RMSNorm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Activation {
GELU,
SiLU,
ReLU,
Swish,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AttentionConfig {
pub attention_bias: bool,
pub sliding_window: Option<usize>,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
attention_bias: false,
sliding_window: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelSource {
Local(String),
HuggingFace {
repo_id: String,
revision: Option<String>,
cache_dir: Option<String>,
},
Url {
url: String,
headers: HashMap<String, String>,
},
S3 {
bucket: String,
key: String,
region: Option<String>,
endpoint: Option<String>,
},
}