use crate::error::BonsaiResult;
use crate::gguf::metadata::MetadataStore;
use crate::gguf::tensor_info::keys;
#[derive(Debug, Clone)]
pub struct Qwen3Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_layers: usize,
pub num_attention_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub max_context_length: usize,
pub rms_norm_eps: f32,
pub rope_freq_base: f32,
pub architecture: String,
pub model_name: String,
}
impl Qwen3Config {
pub fn from_metadata(metadata: &MetadataStore) -> BonsaiResult<Self> {
let architecture = metadata
.get_string(keys::GENERAL_ARCHITECTURE)
.unwrap_or("qwen3")
.to_string();
let model_name = metadata
.get_string(keys::GENERAL_NAME)
.unwrap_or("Bonsai-8B")
.to_string();
let arch_prefix = &architecture;
let hidden_size = metadata
.get_u32(&format!("{arch_prefix}.embedding_length"))
.or_else(|_| metadata.get_u32(keys::LLM_EMBEDDING_LENGTH))
.unwrap_or(4096) as usize;
let num_layers = metadata
.get_u32(&format!("{arch_prefix}.block_count"))
.or_else(|_| metadata.get_u32(keys::LLM_BLOCK_COUNT))
.unwrap_or(36) as usize;
let num_attention_heads = metadata
.get_u32(&format!("{arch_prefix}.attention.head_count"))
.or_else(|_| metadata.get_u32(keys::LLM_ATTENTION_HEAD_COUNT))
.unwrap_or(32) as usize;
let num_kv_heads = metadata
.get_u32(&format!("{arch_prefix}.attention.head_count_kv"))
.or_else(|_| metadata.get_u32(keys::LLM_ATTENTION_HEAD_COUNT_KV))
.unwrap_or(8) as usize;
let intermediate_size = metadata
.get_u32(&format!("{arch_prefix}.feed_forward_length"))
.or_else(|_| metadata.get_u32(keys::LLM_FEED_FORWARD_LENGTH))
.unwrap_or(14336) as usize;
let vocab_size = metadata
.get_u32(&format!("{arch_prefix}.vocab_size"))
.or_else(|_| metadata.get_u32(keys::LLM_VOCAB_SIZE))
.unwrap_or(151936) as usize;
let max_context_length = metadata
.get_u32(&format!("{arch_prefix}.context_length"))
.or_else(|_| metadata.get_u32(keys::LLM_CONTEXT_LENGTH))
.unwrap_or(65536) as usize;
let rms_norm_eps = metadata
.get_f32(&format!("{arch_prefix}.attention.layer_norm_rms_epsilon"))
.or_else(|_| metadata.get_f32(keys::LLM_ATTENTION_LAYER_NORM_RMS_EPSILON))
.unwrap_or(1e-6);
let rope_freq_base = metadata
.get_f32(&format!("{arch_prefix}.rope.freq_base"))
.or_else(|_| metadata.get_f32(keys::LLM_ROPE_FREQ_BASE))
.unwrap_or(1_000_000.0);
let head_dim = hidden_size / num_attention_heads;
Ok(Qwen3Config {
hidden_size,
intermediate_size,
num_layers,
num_attention_heads,
num_kv_heads,
head_dim,
vocab_size,
max_context_length,
rms_norm_eps,
rope_freq_base,
architecture,
model_name,
})
}
pub fn tiny_test() -> Self {
Qwen3Config {
hidden_size: 64,
intermediate_size: 128,
num_layers: 2,
num_attention_heads: 4,
num_kv_heads: 2,
head_dim: 16,
vocab_size: 151936, max_context_length: 512,
rms_norm_eps: 1e-6,
rope_freq_base: 10_000.0,
architecture: "qwen3".to_string(),
model_name: "Bonsai-Tiny-Test".to_string(),
}
}
pub fn bonsai_4b() -> Self {
Qwen3Config {
hidden_size: 2560,
intermediate_size: 6912,
num_layers: 24,
num_attention_heads: 20,
num_kv_heads: 4,
head_dim: 128,
vocab_size: 151936,
max_context_length: 65536,
rms_norm_eps: 1e-6,
rope_freq_base: 1_000_000.0,
architecture: "qwen3".to_string(),
model_name: "Bonsai-4B".to_string(),
}
}
pub fn bonsai_1_7b() -> Self {
Qwen3Config {
hidden_size: 1536,
intermediate_size: 4096,
num_layers: 16,
num_attention_heads: 12,
num_kv_heads: 2,
head_dim: 128,
vocab_size: 151936,
max_context_length: 65536,
rms_norm_eps: 1e-6,
rope_freq_base: 1_000_000.0,
architecture: "qwen3".to_string(),
model_name: "Bonsai-1.7B".to_string(),
}
}
pub fn bonsai_8b() -> Self {
Qwen3Config {
hidden_size: 4096,
intermediate_size: 14336,
num_layers: 36,
num_attention_heads: 32,
num_kv_heads: 8,
head_dim: 128,
vocab_size: 151936,
max_context_length: 65536,
rms_norm_eps: 1e-6,
rope_freq_base: 1_000_000.0,
architecture: "qwen3".to_string(),
model_name: "Bonsai-8B".to_string(),
}
}
pub fn ternary_bonsai_8b() -> Self {
let mut cfg = Self::bonsai_8b();
cfg.model_name = "Ternary-Bonsai-8B".to_string();
cfg
}
pub fn ternary_bonsai_4b() -> Self {
let mut cfg = Self::bonsai_4b();
cfg.model_name = "Ternary-Bonsai-4B".to_string();
cfg
}
pub fn ternary_bonsai_1_7b() -> Self {
let mut cfg = Self::bonsai_1_7b();
cfg.model_name = "Ternary-Bonsai-1.7B".to_string();
cfg
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_bonsai_8b_config() {
let config = Qwen3Config::bonsai_8b();
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.intermediate_size, 14336);
assert_eq!(config.num_layers, 36);
assert_eq!(config.num_attention_heads, 32);
assert_eq!(config.num_kv_heads, 8);
assert_eq!(config.head_dim, 128);
assert_eq!(config.vocab_size, 151936);
assert_eq!(config.max_context_length, 65536);
}
#[test]
fn bonsai_4b_config() {
let config = Qwen3Config::bonsai_4b();
assert_eq!(config.hidden_size, 2560);
assert_eq!(config.intermediate_size, 6912);
assert_eq!(config.num_layers, 24);
assert_eq!(config.num_attention_heads, 20);
assert_eq!(config.num_kv_heads, 4);
assert_eq!(config.head_dim, 128);
assert_eq!(config.vocab_size, 151936);
}
#[test]
fn bonsai_1_7b_config() {
let config = Qwen3Config::bonsai_1_7b();
assert_eq!(config.hidden_size, 1536);
assert_eq!(config.intermediate_size, 4096);
assert_eq!(config.num_layers, 16);
assert_eq!(config.num_attention_heads, 12);
assert_eq!(config.num_kv_heads, 2);
assert_eq!(config.head_dim, 128);
assert_eq!(config.vocab_size, 151936);
}
#[test]
fn from_empty_metadata_uses_defaults() {
let metadata = MetadataStore::new();
let config = Qwen3Config::from_metadata(&metadata)
.expect("config from empty metadata should use defaults");
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.num_layers, 36);
}
#[test]
fn ternary_bonsai_8b_matches_spec() {
let cfg = Qwen3Config::ternary_bonsai_8b();
assert_eq!(cfg.hidden_size, 4096);
assert_eq!(cfg.intermediate_size, 14336);
assert_eq!(cfg.num_layers, 36);
assert_eq!(cfg.num_attention_heads, 32);
assert_eq!(cfg.num_kv_heads, 8);
assert_eq!(cfg.head_dim, 128);
assert_eq!(cfg.vocab_size, 151936);
assert_eq!(cfg.max_context_length, 65536);
assert_eq!(cfg.model_name, "Ternary-Bonsai-8B");
assert_eq!(cfg.architecture, "qwen3");
}
#[test]
fn ternary_bonsai_name_distinct() {
assert_ne!(
Qwen3Config::bonsai_8b().model_name,
Qwen3Config::ternary_bonsai_8b().model_name
);
assert_ne!(
Qwen3Config::bonsai_4b().model_name,
Qwen3Config::ternary_bonsai_4b().model_name
);
assert_ne!(
Qwen3Config::bonsai_1_7b().model_name,
Qwen3Config::ternary_bonsai_1_7b().model_name
);
}
#[test]
fn ternary_bonsai_4b_matches_spec() {
let cfg = Qwen3Config::ternary_bonsai_4b();
assert_eq!(cfg.hidden_size, 2560);
assert_eq!(cfg.num_layers, 24);
assert_eq!(cfg.model_name, "Ternary-Bonsai-4B");
}
#[test]
fn ternary_bonsai_1_7b_matches_spec() {
let cfg = Qwen3Config::ternary_bonsai_1_7b();
assert_eq!(cfg.hidden_size, 1536);
assert_eq!(cfg.num_layers, 16);
assert_eq!(cfg.model_name, "Ternary-Bonsai-1.7B");
}
}