use std::fmt;
pub mod fixtures;
pub mod generators;
#[cfg(test)]
pub mod combinatorial_tests;
#[cfg(test)]
pub mod popperian_tests;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModelFormat {
PyTorch,
GGUF,
APR,
Safetensors,
}
impl fmt::Display for ModelFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ModelFormat::PyTorch => write!(f, "PyTorch"),
ModelFormat::GGUF => write!(f, "GGUF"),
ModelFormat::APR => write!(f, "APR"),
ModelFormat::Safetensors => write!(f, "Safetensors"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Device {
Cpu,
Cuda(u32),
}
impl fmt::Display for Device {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Device::Cpu => write!(f, "CPU"),
Device::Cuda(id) => write!(f, "CUDA:{}", id),
}
}
}
impl Device {
pub fn is_cuda(&self) -> bool {
matches!(self, Device::Cuda(_))
}
pub fn cuda_id(&self) -> Option<u32> {
match self {
Device::Cuda(id) => Some(*id),
Device::Cpu => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
pub enum QuantType {
F32,
F16,
BF16,
Q8_0,
Q4_0,
Q4_K,
Q5_K,
Q6_K,
}
impl QuantType {
pub fn bits_per_weight(&self) -> f32 {
match self {
QuantType::F32 => 32.0,
QuantType::F16 => 16.0,
QuantType::BF16 => 16.0,
QuantType::Q8_0 => 8.5, QuantType::Q4_0 => 4.5,
QuantType::Q4_K => 4.5,
QuantType::Q5_K => 5.5,
QuantType::Q6_K => 6.5,
}
}
pub fn supported_by(&self, format: ModelFormat) -> bool {
match format {
ModelFormat::PyTorch => {
matches!(self, QuantType::F32 | QuantType::F16 | QuantType::BF16)
},
ModelFormat::GGUF => true, ModelFormat::APR => true, ModelFormat::Safetensors => {
matches!(self, QuantType::F32 | QuantType::F16 | QuantType::BF16)
},
}
}
}
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub hidden_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub vocab_size: usize,
pub intermediate_dim: usize,
pub rope_theta: f32,
pub max_seq_len: usize,
pub rms_norm_eps: f32,
}
impl Default for ModelConfig {
fn default() -> Self {
Self::tiny()
}
}
impl ModelConfig {
pub fn tiny() -> Self {
Self {
hidden_dim: 64,
num_layers: 2,
num_heads: 4,
num_kv_heads: 2, vocab_size: 256,
intermediate_dim: 128,
rope_theta: 10000.0,
max_seq_len: 32,
rms_norm_eps: 1e-5,
}
}
pub fn small() -> Self {
Self {
hidden_dim: 256,
num_layers: 4,
num_heads: 8,
num_kv_heads: 2, vocab_size: 1024,
intermediate_dim: 512,
rope_theta: 10000.0,
max_seq_len: 128,
rms_norm_eps: 1e-5,
}
}
pub fn tinyllama() -> Self {
Self {
hidden_dim: 2048,
num_layers: 22,
num_heads: 32,
num_kv_heads: 4, vocab_size: 32000,
intermediate_dim: 5632,
rope_theta: 10000.0,
max_seq_len: 2048,
rms_norm_eps: 1e-5,
}
}
pub fn qwen_1_5b() -> Self {
Self {
hidden_dim: 1536,
num_layers: 28,
num_heads: 12,
num_kv_heads: 2, vocab_size: 151936,
intermediate_dim: 8960,
rope_theta: 1000000.0,
max_seq_len: 32768,
rms_norm_eps: 1e-6,
}
}
pub fn head_dim(&self) -> usize {
self.hidden_dim / self.num_heads
}
pub fn q_dim(&self) -> usize {
self.num_heads * self.head_dim()
}
pub fn k_dim(&self) -> usize {
self.num_kv_heads * self.head_dim()
}
pub fn v_dim(&self) -> usize {
self.k_dim()
}
pub fn gqa_group_size(&self) -> usize {
self.num_heads / self.num_kv_heads
}
pub fn is_gqa(&self) -> bool {
self.num_kv_heads < self.num_heads
}
pub fn is_mqa(&self) -> bool {
self.num_kv_heads == 1
}
pub fn param_count(&self) -> usize {
let embed = self.vocab_size * self.hidden_dim;
let per_layer = {
let q = self.hidden_dim * self.q_dim();
let k = self.hidden_dim * self.k_dim();
let v = self.hidden_dim * self.v_dim();
let o = self.q_dim() * self.hidden_dim;
let gate = self.hidden_dim * self.intermediate_dim;
let up = self.hidden_dim * self.intermediate_dim;
let down = self.intermediate_dim * self.hidden_dim;
let norms = 2 * self.hidden_dim;
q + k + v + o + gate + up + down + norms
};
let output_norm = self.hidden_dim;
let lm_head = self.hidden_dim * self.vocab_size;
embed + (self.num_layers * per_layer) + output_norm + lm_head
}
}
#[derive(Debug, Clone)]
pub struct ConstructorInput {
pub config: ModelConfig,
pub quantization: Option<QuantType>,
pub weights_seed: u64,
}
impl ConstructorInput {
pub fn new(config: ModelConfig) -> Self {
Self {
config,
quantization: None,
weights_seed: 42,
}
}
pub fn with_quant(config: ModelConfig, quant: QuantType, seed: u64) -> Self {
Self {
config,
quantization: Some(quant),
weights_seed: seed,
}
}
}
#[derive(Debug, Clone)]
pub struct ForwardInput {
pub tokens: Vec<u32>,
pub position: usize,
}
impl ForwardInput {
pub fn new(tokens: Vec<u32>) -> Self {
Self {
tokens,
position: 0,
}
}
pub fn at_position(tokens: Vec<u32>, position: usize) -> Self {
Self { tokens, position }
}
pub fn seq_len(&self) -> usize {
self.tokens.len()
}
}
#[derive(Clone)]
pub struct ModelTestCase {
pub desc: String,
pub constructor: ConstructorInput,
pub forward: ForwardInput,
pub expected_output_norm: Option<f32>,
pub source_format: ModelFormat,
pub target_format: Option<ModelFormat>,
pub device: Device,
}
impl fmt::Debug for ModelTestCase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ModelTestCase")
.field("desc", &self.desc)
.field("source_format", &self.source_format)
.field("target_format", &self.target_format)
.field("device", &self.device)
.field(
"config",
&format!(
"{}L/{}H",
self.constructor.config.num_layers, self.constructor.config.num_heads
),
)
.finish_non_exhaustive()
}
}
include!("mod_model.rs");