use super::CudaExecutor;
use crate::cuda::types::{BoundWeight, GemvKernel};
use trueno_gpu::driver::GpuBuffer;
use trueno_gpu::error::GpuError;
impl CudaExecutor {
#[inline]
pub fn bound_gemv(
&mut self,
weight: &BoundWeight,
input: &GpuBuffer<f32>,
output: &GpuBuffer<f32>,
) -> Result<(), GpuError> {
match weight.kernel() {
GemvKernel::Q4K => {
self.q4k_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q5K => {
self.q5k_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q6K => {
self.q6k_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q8_0 => {
self.q8_0_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q4_0 => {
self.q4_0_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q5_0 => {
self.q5_0_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::Q4_1 => {
self.q4_1_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
GemvKernel::F32 => {
self.f32_gemv_into(weight.ptr, input, output, weight.out_dim, weight.in_dim)
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cuda::types::WeightQuantType;
#[test]
fn test_bound_weight_kernel_mapping_exhaustive() {
let types = [
(WeightQuantType::Q4K, GemvKernel::Q4K),
(WeightQuantType::Q5K, GemvKernel::Q5K),
(WeightQuantType::Q6K, GemvKernel::Q6K),
(WeightQuantType::Q8_0, GemvKernel::Q8_0),
(WeightQuantType::Q4_0, GemvKernel::Q4_0),
(WeightQuantType::Q5_0, GemvKernel::Q5_0),
(WeightQuantType::Q4_1, GemvKernel::Q4_1),
];
for (qtype, expected_kernel) in types {
let bound = BoundWeight::bind(0x1000, 1024, qtype, 128, 128);
assert_eq!(
bound.kernel(),
expected_kernel,
"WeightQuantType::{:?} should map to GemvKernel::{:?}",
qtype,
expected_kernel
);
}
}
#[test]
fn test_bound_weight_preserves_dimensions() {
let bound = BoundWeight::bind(0xDEAD, 4096, WeightQuantType::Q6K, 512, 256);
assert_eq!(bound.ptr, 0xDEAD);
assert_eq!(bound.len, 4096);
assert_eq!(bound.out_dim, 512);
assert_eq!(bound.in_dim, 256);
assert_eq!(bound.kernel(), GemvKernel::Q6K);
}
#[test]
fn test_bound_layer_weights_construction() {
use crate::cuda::types::{BoundLayerWeights, IndexedLayerWeights};
let indexed = IndexedLayerWeights {
attn_q_ptr: 0x1000,
attn_q_len: 1024,
attn_q_qtype: WeightQuantType::Q4K,
attn_k_ptr: 0x2000,
attn_k_len: 512,
attn_k_qtype: WeightQuantType::Q4K,
attn_v_ptr: 0x3000,
attn_v_len: 768,
attn_v_qtype: WeightQuantType::Q6K,
attn_output_ptr: 0x4000,
attn_output_len: 1024,
attn_output_qtype: WeightQuantType::Q4K,
ffn_gate_ptr: 0x5000,
ffn_gate_len: 2048,
ffn_gate_qtype: WeightQuantType::Q4K,
ffn_up_ptr: 0x6000,
ffn_up_len: 2048,
ffn_up_qtype: WeightQuantType::Q4K,
ffn_down_ptr: 0x7000,
ffn_down_len: 1536,
ffn_down_qtype: WeightQuantType::Q6K,
attn_norm_ptr: 0x8000,
attn_norm_len: 256,
ffn_norm_ptr: 0x9000,
ffn_norm_len: 256,
attn_q_bias_ptr: 0xA000,
attn_q_bias_len: 128,
attn_k_bias_ptr: 0xB000,
attn_k_bias_len: 64,
attn_v_bias_ptr: 0xC000,
attn_v_bias_len: 64,
attn_q_norm_ptr: 0,
attn_q_norm_len: 0,
attn_k_norm_ptr: 0,
attn_k_norm_len: 0,
};
let arch = crate::gguf::ArchConstraints::from_architecture("qwen2");
let validated = crate::cuda::types::ValidatedLayerWeights::validate(indexed, &arch, 0)
.expect("test weights should validate for qwen2 arch");
let bound = BoundLayerWeights::bind(&validated, 256, 256, 64, 512);
assert_eq!(bound.q_proj.kernel(), GemvKernel::Q4K);
assert_eq!(bound.k_proj.kernel(), GemvKernel::Q4K);
assert_eq!(bound.v_proj.kernel(), GemvKernel::Q6K); assert_eq!(bound.o_proj.kernel(), GemvKernel::Q4K);
assert_eq!(bound.ffn_gate.kernel(), GemvKernel::Q4K);
assert_eq!(bound.ffn_up.kernel(), GemvKernel::Q4K);
assert_eq!(bound.ffn_down.kernel(), GemvKernel::Q6K);
assert_eq!(bound.q_proj.out_dim, 256); assert_eq!(bound.q_proj.in_dim, 256); assert_eq!(bound.v_proj.out_dim, 64); assert_eq!(bound.ffn_gate.out_dim, 512); assert_eq!(bound.ffn_down.out_dim, 256); assert_eq!(bound.ffn_down.in_dim, 512);
assert_eq!(bound.attn_norm_ptr, 0x8000);
assert_eq!(bound.attn_q_bias_ptr, 0xA000);
assert_eq!(bound.attn_v_bias_len, 64);
}
}