use std::fs;
use std::path::{Path, PathBuf};
use tempfile::TempDir;
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub architecture: String,
pub hidden_dim: usize,
pub intermediate_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub num_layers: usize,
pub vocab_size: usize,
pub context_length: usize,
pub rope_theta: f32,
pub eps: f32,
}
impl Default for ModelConfig {
fn default() -> Self {
Self::tiny()
}
}
impl ModelConfig {
#[must_use]
pub fn tiny() -> Self {
Self {
architecture: "llama".to_string(),
hidden_dim: 64,
intermediate_dim: 128,
num_heads: 4,
num_kv_heads: 4,
num_layers: 1,
vocab_size: 100,
context_length: 512,
rope_theta: 10000.0,
eps: 1e-5,
}
}
#[must_use]
pub fn small() -> Self {
Self {
architecture: "llama".to_string(),
hidden_dim: 256,
intermediate_dim: 512,
num_heads: 8,
num_kv_heads: 8,
num_layers: 2,
vocab_size: 1000,
context_length: 1024,
rope_theta: 10000.0,
eps: 1e-5,
}
}
#[must_use]
pub fn gqa() -> Self {
Self {
architecture: "llama".to_string(),
hidden_dim: 128,
intermediate_dim: 256,
num_heads: 8,
num_kv_heads: 2, num_layers: 1,
vocab_size: 100,
context_length: 512,
rope_theta: 10000.0,
eps: 1e-5,
}
}
#[must_use]
pub fn phi() -> Self {
Self {
architecture: "phi2".to_string(),
hidden_dim: 64,
intermediate_dim: 128,
num_heads: 4,
num_kv_heads: 4,
num_layers: 1,
vocab_size: 100,
context_length: 512,
rope_theta: 10000.0,
eps: 1e-5,
}
}
#[must_use]
pub fn qwen() -> Self {
Self {
architecture: "qwen2".to_string(),
hidden_dim: 128,
intermediate_dim: 256,
num_heads: 8,
num_kv_heads: 4,
num_layers: 1,
vocab_size: 100,
context_length: 512,
rope_theta: 1000000.0,
eps: 1e-6,
}
}
#[must_use]
pub fn with_architecture(mut self, arch: &str) -> Self {
self.architecture = arch.to_string();
self
}
#[must_use]
pub fn with_hidden_dim(mut self, dim: usize) -> Self {
self.hidden_dim = dim;
self
}
#[must_use]
pub fn with_layers(mut self, n: usize) -> Self {
self.num_layers = n;
self
}
#[must_use]
pub fn with_vocab_size(mut self, n: usize) -> Self {
self.vocab_size = n;
self
}
#[must_use]
pub fn with_gqa(mut self, num_heads: usize, num_kv_heads: usize) -> Self {
self.num_heads = num_heads;
self.num_kv_heads = num_kv_heads;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
Gguf,
Apr,
SafeTensors,
}
impl ModelFormat {
#[must_use]
pub fn extension(&self) -> &'static str {
match self {
Self::Gguf => "gguf",
Self::Apr => "apr",
Self::SafeTensors => "safetensors",
}
}
}
impl std::fmt::Display for ModelFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Gguf => write!(f, "GGUF"),
Self::Apr => write!(f, "APR"),
Self::SafeTensors => write!(f, "SafeTensors"),
}
}
}
pub struct ModelFixture {
path: PathBuf,
_temp_dir: TempDir,
format: ModelFormat,
config: ModelConfig,
}
include!("mod_generate_apr_model.rs");
include!("mod_fixture_gguf.rs");