use crate::error::{RealizarError, Result};
use crate::gguf::OwnedQuantizedModel;
use crate::quantize::{dequantize_q4_k, dequantize_q5_k, dequantize_q6_k};
#[provable_contracts_macros::contract("wgpu-forward-pass-v1", equation = "dequant_correctness")]
pub fn dequant_model_weights(
model: &OwnedQuantizedModel,
) -> Result<Vec<(String, Vec<f32>, usize, usize)>> {
let config = &model.config;
let hidden = config.hidden_dim;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim();
let intermediate = config.intermediate_dim;
let num_layers = model.layers().len();
let mut weights = Vec::new();
eprintln!(
"[PMAT-333] Dequantizing {} layers (hidden={}, heads={}/{}, intermediate={})",
num_layers, hidden, num_heads, num_kv_heads, intermediate,
);
for (i, layer) in model.layers().iter().enumerate() {
let prefix = format!("layer.{i}");
weights.push((
format!("{prefix}.attn_norm"),
layer.attn_norm_weight.clone(),
1,
hidden,
));
if let Some(ref ffn_norm) = layer.ffn_norm_weight {
weights.push((format!("{prefix}.ffn_norm"), ffn_norm.clone(), 1, hidden));
}
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
match &layer.qkv_weight {
crate::gguf::OwnedQKVWeights::Fused(tensor) => {
let f32_data = dequant_tensor_public(tensor)?;
let total_out = q_dim + 2 * kv_dim;
let q_data = f32_data[..q_dim * hidden].to_vec();
let k_data = f32_data[q_dim * hidden..(q_dim + kv_dim) * hidden].to_vec();
let v_data = f32_data[(q_dim + kv_dim) * hidden..total_out * hidden].to_vec();
weights.push((format!("{prefix}.q_proj"), q_data, q_dim, hidden));
weights.push((format!("{prefix}.k_proj"), k_data, kv_dim, hidden));
weights.push((format!("{prefix}.v_proj"), v_data, kv_dim, hidden));
},
crate::gguf::OwnedQKVWeights::Separate { q, k, v } => {
weights.push((
format!("{prefix}.q_proj"),
dequant_tensor_public(q)?,
q_dim,
hidden,
));
weights.push((
format!("{prefix}.k_proj"),
dequant_tensor_public(k)?,
kv_dim,
hidden,
));
weights.push((
format!("{prefix}.v_proj"),
dequant_tensor_public(v)?,
kv_dim,
hidden,
));
},
}
if let Some(ref bias) = layer.qkv_bias {
if bias.len() >= q_dim + 2 * kv_dim {
weights.push((format!("{prefix}.q_bias"), bias[..q_dim].to_vec(), 1, q_dim));
weights.push((
format!("{prefix}.k_bias"),
bias[q_dim..q_dim + kv_dim].to_vec(),
1,
kv_dim,
));
weights.push((
format!("{prefix}.v_bias"),
bias[q_dim + kv_dim..q_dim + 2 * kv_dim].to_vec(),
1,
kv_dim,
));
}
}
weights.push((
format!("{prefix}.o_proj"),
dequant_tensor_public(&layer.attn_output_weight)?,
hidden,
q_dim,
));
if let Some(ref gate) = layer.ffn_gate_weight {
weights.push((
format!("{prefix}.gate_proj"),
dequant_tensor_public(gate)?,
intermediate,
hidden,
));
}
weights.push((
format!("{prefix}.up_proj"),
dequant_tensor_public(&layer.ffn_up_weight)?,
intermediate,
hidden,
));
weights.push((
format!("{prefix}.down_proj"),
dequant_tensor_public(&layer.ffn_down_weight)?,
hidden,
intermediate,
));
if (i + 1) % 7 == 0 || i == num_layers - 1 {
eprintln!(" Dequantized layer {}/{}", i + 1, num_layers);
}
}
weights.push((
"lm_head".to_string(),
dequant_tensor_public(model.lm_head_weight())?,
config.vocab_size,
hidden,
));
let total_bytes: usize = weights.iter().map(|(_, d, _, _)| d.len() * 4).sum();
eprintln!(
"[PMAT-333] Dequantized {} weights, {:.1} MB F32",
weights.len(),
total_bytes as f64 / 1e6,
);
Ok(weights)
}
pub fn raw_q4k_weights(model: &OwnedQuantizedModel) -> Vec<(String, Vec<u8>, usize, usize)> {
const GGUF_TYPE_Q4_K: u32 = 12;
let config = &model.config;
let hidden = config.hidden_dim;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim();
let intermediate = config.intermediate_dim;
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let mut raw = Vec::new();
for (i, layer) in model.layers().iter().enumerate() {
let prefix = format!("layer.{i}");
let projections: Vec<(&str, &crate::gguf::OwnedQuantizedTensor, usize, usize)> = vec![
("o_proj", &layer.attn_output_weight, hidden, q_dim),
("up_proj", &layer.ffn_up_weight, intermediate, hidden),
("down_proj", &layer.ffn_down_weight, hidden, intermediate),
];
if let Some(ref gate) = layer.ffn_gate_weight {
raw.push((
format!("{prefix}.gate_proj"),
gate.data.clone(),
intermediate,
hidden,
));
}
for (name, tensor, rows, cols) in projections {
if tensor.qtype == GGUF_TYPE_Q4_K {
raw.push((format!("{prefix}.{name}"), tensor.data.clone(), rows, cols));
}
}
if let crate::gguf::OwnedQKVWeights::Separate { q, k, v } = &layer.qkv_weight {
if q.qtype == GGUF_TYPE_Q4_K {
raw.push((format!("{prefix}.q_proj"), q.data.clone(), q_dim, hidden));
}
if k.qtype == GGUF_TYPE_Q4_K {
raw.push((format!("{prefix}.k_proj"), k.data.clone(), kv_dim, hidden));
}
if v.qtype == GGUF_TYPE_Q4_K {
raw.push((format!("{prefix}.v_proj"), v.data.clone(), kv_dim, hidden));
}
}
}
raw
}
pub fn dequant_tensor_public(tensor: &crate::gguf::OwnedQuantizedTensor) -> Result<Vec<f32>> {
const GGUF_TYPE_Q4_K: u32 = 12;
const GGUF_TYPE_Q6_K: u32 = 14;
const GGUF_TYPE_Q5_K: u32 = 13;
const GGUF_TYPE_F32: u32 = 0;
const GGUF_TYPE_F16: u32 = 1;
match tensor.qtype {
GGUF_TYPE_Q4_K => dequantize_q4_k(&tensor.data),
GGUF_TYPE_Q6_K => dequantize_q6_k(&tensor.data),
GGUF_TYPE_Q5_K => dequantize_q5_k(&tensor.data),
GGUF_TYPE_F32 => Ok(tensor
.data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()),
GGUF_TYPE_F16 => Ok(tensor
.data
.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect()),
other => Err(RealizarError::FormatError {
reason: format!("Unsupported quantization type {} for WGPU dequant", other),
}),
}
}