pub struct BlockWeights {
pub attn_norm_weight: Vec<f32>,
pub attn_norm_bias: Vec<f32>,
pub qkv_weight: Vec<f32>,
pub qkv_bias: Vec<f32>,
pub out_weight: Vec<f32>,
pub out_bias: Vec<f32>,
pub ffn_norm_weight: Vec<f32>,
pub ffn_norm_bias: Vec<f32>,
pub ffn_fc1_weight: Vec<f32>,
pub ffn_fc1_bias: Vec<f32>,
pub ffn_fc2_weight: Vec<f32>,
pub ffn_fc2_bias: Vec<f32>,
pub ffn_gate_weight: Option<Vec<f32>>,
pub linear_attn: Option<LinearAttnWeights>,
pub moe_experts: Option<MoeExpertWeights>,
}
#[derive(Debug, Clone)]
pub struct LinearAttnWeights {
pub z_weight: Vec<f32>,
pub b_weight: Vec<f32>,
pub a_weight: Vec<f32>,
pub conv1d_weight: Vec<f32>,
pub a_log: Vec<f32>,
pub dt_bias: Vec<f32>,
pub norm_weight: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct MoeExpertWeights {
pub gate_weight: Vec<f32>,
pub expert_gate_up: Vec<f32>,
pub expert_down: Vec<f32>,
pub shared_gate: Vec<f32>,
pub shared_up: Vec<f32>,
pub shared_down: Vec<f32>,
pub shared_expert_gate_weight: Vec<f32>,
pub num_experts: usize,
pub num_experts_per_tok: usize,
pub expert_intermediate: usize,
}
#[derive(Debug, Clone)]
pub struct LmHeadWeight(pub Vec<f32>);
#[derive(Debug, Clone)]
pub struct LmHeadWeightTransposed(pub Vec<f32>);
impl LmHeadWeight {
pub fn into_inner(self) -> Vec<f32> {
self.0
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
impl LmHeadWeightTransposed {
pub fn into_inner(self) -> Vec<f32> {
self.0
}
pub fn as_slice(&self) -> &[f32] {
&self.0
}
}
#[derive(Debug, Clone, Copy)]
pub enum WeightType {
Qkv,
Output,
FfnFc1,
FfnFc2,
LmHead,
}
#[derive(Debug, Clone)]
pub struct GpuModelConfig {
pub vocab_size: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub num_layers: usize,
pub intermediate_dim: usize,
pub eps: f32,
pub rope_theta: f32,
pub explicit_head_dim: Option<usize>,
pub layer_types: Option<Vec<String>>,
pub linear_key_head_dim: Option<usize>,
pub linear_value_head_dim: Option<usize>,
pub linear_num_key_heads: Option<usize>,
pub linear_num_value_heads: Option<usize>,
pub linear_conv_kernel_dim: Option<usize>,
pub constraints: Option<crate::gguf::ArchConstraints>,
pub num_experts: Option<usize>,
pub num_experts_per_tok: Option<usize>,
pub expert_intermediate_size: Option<usize>,
}
impl GpuModelConfig {
#[inline]
pub fn head_dim(&self) -> usize {
self.explicit_head_dim
.unwrap_or_else(|| self.hidden_dim / self.num_heads)
}
#[inline]
pub fn kv_dim(&self) -> usize {
self.num_kv_heads * self.head_dim()
}
#[inline]
pub fn qkv_dim(&self) -> usize {
let hd = self.head_dim();
(self.num_heads + 2 * self.num_kv_heads) * hd
}
#[inline]
pub fn q_dim(&self) -> usize {
self.num_heads * self.head_dim()
}
#[inline]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads < self.num_heads
}
#[inline]
pub fn is_linear_layer(&self, block_idx: usize) -> bool {
self.layer_types
.as_ref()
.and_then(|lt| lt.get(block_idx))
.is_some_and(|t| t == "linear" || t == "linear_attention")
}
#[inline]
pub fn linear_key_dim(&self) -> usize {
self.linear_num_key_heads.unwrap_or(0) * self.linear_key_head_dim.unwrap_or(0)
}
#[inline]
pub fn linear_value_dim(&self) -> usize {
self.linear_num_value_heads.unwrap_or(0) * self.linear_value_head_dim.unwrap_or(0)
}
#[inline]
pub fn linear_conv_dim(&self) -> usize {
2 * self.linear_key_dim() + self.linear_value_dim()
}
#[inline]
pub fn is_moe(&self) -> bool {
self.num_experts.is_some_and(|n| n > 1)
}
}
#[derive(Debug, Clone)]
pub struct GpuGenerateConfig {
pub max_tokens: usize,
pub temperature: f32,
pub top_k: usize,
pub stop_tokens: Vec<usize>,
pub trace: bool,
}
impl Default for GpuGenerateConfig {
fn default() -> Self {
Self {
max_tokens: 64,
temperature: 0.0,
top_k: 1,
stop_tokens: Vec::new(),
trace: false,
}
}
}
impl GpuGenerateConfig {
#[must_use]
pub fn deterministic(max_tokens: usize) -> Self {
Self {
max_tokens,
temperature: 0.0,
top_k: 1,
stop_tokens: Vec::new(),
trace: false,
}
}
#[must_use]
pub fn with_sampling(max_tokens: usize, temperature: f32, top_k: usize) -> Self {
Self {
max_tokens,
temperature,
top_k,
stop_tokens: Vec::new(),
trace: false,
}
}
#[must_use]
pub fn with_stop_tokens(mut self, stop_tokens: Vec<usize>) -> Self {
self.stop_tokens = stop_tokens;
self
}
}
#[derive(Debug)]
pub struct AttentionBuffers {
pub q_buffer: Vec<f32>,
pub scores_buffer: Vec<f32>,
pub output_buffer: Vec<f32>,
pub kv_proj_buffer: Vec<f32>,
pub ffn_buffer: Vec<f32>,
pub max_seq_len: usize,
}
impl AttentionBuffers {
#[must_use]
pub fn new(config: &GpuModelConfig, max_seq_len: usize) -> Self {
Self {
q_buffer: vec![0.0; config.hidden_dim],
scores_buffer: vec![0.0; config.num_heads * max_seq_len],
output_buffer: vec![0.0; config.hidden_dim],
kv_proj_buffer: vec![0.0; config.hidden_dim],
ffn_buffer: vec![0.0; config.intermediate_dim],
max_seq_len,
}
}
pub fn reset(&mut self) {
self.q_buffer.fill(0.0);
self.scores_buffer.fill(0.0);
self.output_buffer.fill(0.0);
self.kv_proj_buffer.fill(0.0);
self.ffn_buffer.fill(0.0);
}
}
pub struct ValidatedGpuWeights {
pub(crate) config: GpuModelConfig,
pub(crate) embedding_weights: Vec<f32>,
pub(crate) block_weights: Vec<BlockWeights>,
pub(crate) final_norm_weight: Vec<f32>,
pub(crate) final_norm_bias: Vec<f32>,
pub(crate) lm_head_weight: LmHeadWeight,
pub(crate) lm_head_weight_t: LmHeadWeightTransposed,
pub(crate) lm_head_bias: Vec<f32>,
}
#[derive(Debug)]
pub struct GpuWeightError {
pub field: &'static str,
pub reason: String,
}
impl std::fmt::Display for GpuWeightError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PMAT-284 weight validation '{}': {}",
self.field, self.reason
)
}
}
impl std::error::Error for GpuWeightError {}
impl ValidatedGpuWeights {
pub fn new(
config: GpuModelConfig,
embedding_weights: Vec<f32>,
block_weights: Vec<BlockWeights>,
final_norm_weight: Vec<f32>,
final_norm_bias: Vec<f32>,
lm_head_weight: LmHeadWeight,
lm_head_weight_t: LmHeadWeightTransposed,
lm_head_bias: Vec<f32>,
) -> std::result::Result<Self, GpuWeightError> {
let expected_embed = config.vocab_size * config.hidden_dim;
if embedding_weights.len() != expected_embed {
return Err(GpuWeightError {
field: "embedding_weights",
reason: format!(
"expected {} (vocab_size={} * hidden_dim={}), got {}",
expected_embed,
config.vocab_size,
config.hidden_dim,
embedding_weights.len()
),
});
}
if block_weights.len() != config.num_layers {
return Err(GpuWeightError {
field: "block_weights",
reason: format!(
"expected {} layers, got {}",
config.num_layers,
block_weights.len()
),
});
}
if final_norm_weight.len() != config.hidden_dim {
return Err(GpuWeightError {
field: "final_norm_weight",
reason: format!(
"expected {} (hidden_dim), got {}",
config.hidden_dim,
final_norm_weight.len()
),
});
}
let expected_lm = config.vocab_size * config.hidden_dim;
if lm_head_weight.0.len() != expected_lm {
return Err(GpuWeightError {
field: "lm_head_weight",
reason: format!(
"expected {} (vocab_size={} * hidden_dim={}), got {}",
expected_lm,
config.vocab_size,
config.hidden_dim,
lm_head_weight.0.len()
),
});
}
if lm_head_weight_t.0.len() != expected_lm {
return Err(GpuWeightError {
field: "lm_head_weight_t",
reason: format!(
"expected {} (hidden_dim={} * vocab_size={}), got {}",
expected_lm,
config.hidden_dim,
config.vocab_size,
lm_head_weight_t.0.len()
),
});
}
Ok(Self {
config,
embedding_weights,
block_weights,
final_norm_weight,
final_norm_bias,
lm_head_weight,
lm_head_weight_t,
lm_head_bias,
})
}
}