use crate::context::GpuContext;
use crate::error::{GpuError, GpuResult};
use crate::kernels::GpuKernel;
#[derive(Debug, Clone)]
pub struct BatchedGemvConfig {
pub rows: usize,
pub cols: usize,
pub batch_size: usize,
}
pub fn batched_gemv_f32(
ctx: &GpuContext,
matrix_f32: &[f32],
vectors_f32: &[f32],
config: &BatchedGemvConfig,
) -> GpuResult<Vec<f32>> {
#[cfg(feature = "gpu")]
{
gpu_batched_gemv_f32(ctx, matrix_f32, vectors_f32, config)
}
#[cfg(not(feature = "gpu"))]
{
let _ = (ctx, matrix_f32, vectors_f32, config);
Err(GpuError::NoAdapter)
}
}
#[cfg(feature = "gpu")]
fn gpu_batched_gemv_f32(
ctx: &GpuContext,
matrix_f32: &[f32],
vectors_f32: &[f32],
config: &BatchedGemvConfig,
) -> GpuResult<Vec<f32>> {
use crate::buffer::{create_output_f32, download_f32, upload_f32, upload_uniform};
use bytemuck::{Pod, Zeroable};
use wgpu::{
BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, ComputePassDescriptor,
ComputePipelineDescriptor, PipelineLayoutDescriptor, ShaderModuleDescriptor, ShaderSource,
};
let BatchedGemvConfig {
rows,
cols,
batch_size,
} = *config;
let expected_matrix = rows * cols;
if matrix_f32.len() != expected_matrix {
return Err(GpuError::BufferSize {
expected: expected_matrix,
got: matrix_f32.len(),
});
}
let expected_vectors = batch_size * cols;
if vectors_f32.len() != expected_vectors {
return Err(GpuError::BufferSize {
expected: expected_vectors,
got: vectors_f32.len(),
});
}
if rows == 0 || cols == 0 || batch_size == 0 {
return Ok(vec![0.0f32; batch_size * rows]);
}
let matrix_buf = upload_f32(&ctx.device, "batched-gemv-matrix", matrix_f32);
let vectors_buf = upload_f32(&ctx.device, "batched-gemv-vectors", vectors_f32);
let output_len = batch_size * rows;
let output_buf = create_output_f32(&ctx.device, "batched-gemv-output", output_len);
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct Params {
rows: u32,
cols: u32,
batch_size: u32,
_pad: u32,
}
let params = Params {
rows: rows as u32,
cols: cols as u32,
batch_size: batch_size as u32,
_pad: 0,
};
let params_buf = upload_uniform(&ctx.device, "batched-gemv-params", ¶ms);
const WGSL: &str = include_str!("../shaders/batched_gemv_f32.wgsl");
let shader = ctx.device.create_shader_module(ShaderModuleDescriptor {
label: Some("batched_gemv_f32"),
source: ShaderSource::Wgsl(std::borrow::Cow::Borrowed(WGSL)),
});
let bgl = ctx
.device
.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("batched-gemv-bgl"),
entries: &[
bgl_storage_ro(0),
bgl_storage_ro(1),
bgl_storage_rw(2),
bgl_uniform(3),
],
});
let pipeline_layout = ctx
.device
.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("batched-gemv-layout"),
bind_group_layouts: &[Some(&bgl)],
immediate_size: 0,
});
let pipeline = ctx
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: Some("batched-gemv-pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let bind_group = ctx.device.create_bind_group(&BindGroupDescriptor {
label: Some("batched-gemv-bg"),
layout: &bgl,
entries: &[
BindGroupEntry {
binding: 0,
resource: matrix_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 1,
resource: vectors_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 2,
resource: output_buf.as_entire_binding(),
},
BindGroupEntry {
binding: 3,
resource: params_buf.as_entire_binding(),
},
],
});
let dispatch_x = rows.div_ceil(64) as u32;
let dispatch_y = batch_size as u32;
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("batched-gemv-encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("batched-gemv-pass"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
}
ctx.queue.submit([encoder.finish()]);
download_f32(&ctx.device, &ctx.queue, &output_buf, output_len)
}
#[cfg(feature = "gpu")]
fn bgl_storage_ro(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
#[cfg(feature = "gpu")]
fn bgl_storage_rw(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
#[cfg(feature = "gpu")]
fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
pub trait BatchedGpuKernel: GpuKernel {
fn batched_gemv(
&self,
ctx: &GpuContext,
quant_data: &[u8],
vectors: &[f32],
rows: usize,
cols: usize,
batch_size: usize,
) -> GpuResult<Vec<f32>>;
}
use crate::kernels::q4_0::Q4_0GpuKernel;
impl BatchedGpuKernel for Q4_0GpuKernel {
fn batched_gemv(
&self,
ctx: &GpuContext,
quant_data: &[u8],
vectors: &[f32],
rows: usize,
cols: usize,
batch_size: usize,
) -> GpuResult<Vec<f32>> {
#[cfg(feature = "gpu")]
{
let f32_weights = crate::kernels::q4_0::dequant_q4_0_to_f32(quant_data, rows, cols)?;
let config = BatchedGemvConfig {
rows,
cols,
batch_size,
};
batched_gemv_f32(ctx, &f32_weights, vectors, &config)
}
#[cfg(not(feature = "gpu"))]
{
let _ = (ctx, quant_data, vectors, rows, cols, batch_size);
Err(GpuError::NoAdapter)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "gpu")]
fn cpu_batched_gemv(
matrix: &[f32],
vectors: &[f32],
rows: usize,
cols: usize,
batch: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; batch * rows];
for b in 0..batch {
for r in 0..rows {
let mut acc = 0.0f32;
for c in 0..cols {
acc += matrix[r * cols + c] * vectors[b * cols + c];
}
out[b * rows + r] = acc;
}
}
out
}
#[test]
fn test_batched_gemv_no_gpu_graceful() {
let _ctx = GpuContext::try_init();
}
#[cfg(feature = "gpu")]
fn try_gpu_ctx() -> Option<GpuContext> {
GpuContext::try_init()
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_identity_batch1() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let rows = 4;
let cols = 4;
let batch = 1;
#[rustfmt::skip]
let matrix = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let vectors = vec![1.0, 2.0, 3.0, 4.0];
let config = BatchedGemvConfig {
rows,
cols,
batch_size: batch,
};
let result = batched_gemv_f32(&ctx, &matrix, &vectors, &config)
.expect("batched GEMV should succeed");
assert_eq!(result.len(), batch * rows);
for (i, (&got, &want)) in result.iter().zip(vectors.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-5,
"element {i}: got {got}, expected {want}"
);
}
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_identity_batch4() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let rows = 4;
let cols = 4;
let batch = 4;
#[rustfmt::skip]
let matrix = vec![
1.0, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let vectors = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 0.5, 0.5, 0.5, 0.5, -1.0, -2.0, -3.0, -4.0,
];
let config = BatchedGemvConfig {
rows,
cols,
batch_size: batch,
};
let result = batched_gemv_f32(&ctx, &matrix, &vectors, &config)
.expect("batched GEMV should succeed");
assert_eq!(result.len(), batch * rows);
for (i, (&got, &want)) in result.iter().zip(vectors.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-5,
"element {i}: got {got}, expected {want}"
);
}
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_known_values() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let rows = 2;
let cols = 3;
let batch = 2;
#[rustfmt::skip]
let matrix = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
];
let vectors = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
let expected = cpu_batched_gemv(&matrix, &vectors, rows, cols, batch);
let config = BatchedGemvConfig {
rows,
cols,
batch_size: batch,
};
let result = batched_gemv_f32(&ctx, &matrix, &vectors, &config)
.expect("batched GEMV should succeed");
assert_eq!(result.len(), expected.len());
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-4,
"element {i}: got {got}, expected {want}"
);
}
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_batch1_matches_single() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let rows = 3;
let cols = 4;
#[rustfmt::skip]
let matrix = vec![
0.5, -1.0, 2.0, 0.3,
1.0, 0.0, -0.5, 1.2,
-0.3, 0.7, 0.1, -0.9,
];
let vector = vec![1.0, 2.0, 3.0, 4.0];
let kernel = Q4_0GpuKernel;
let expected = cpu_batched_gemv(&matrix, &vector, rows, cols, 1);
let config = BatchedGemvConfig {
rows,
cols,
batch_size: 1,
};
let result =
batched_gemv_f32(&ctx, &matrix, &vector, &config).expect("batched GEMV should succeed");
let _ = kernel;
assert_eq!(result.len(), expected.len());
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-4,
"element {i}: got {got}, expected {want}"
);
}
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_various_batch_sizes() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let rows = 8;
let cols = 16;
let matrix: Vec<f32> = (0..rows * cols).map(|i| i as f32 * 0.01).collect();
for batch_size in [1, 2, 4, 8] {
let vectors: Vec<f32> = (0..batch_size * cols)
.map(|i| ((i % 7) as f32 - 3.0) * 0.1)
.collect();
let expected = cpu_batched_gemv(&matrix, &vectors, rows, cols, batch_size);
let config = BatchedGemvConfig {
rows,
cols,
batch_size,
};
let result = batched_gemv_f32(&ctx, &matrix, &vectors, &config)
.unwrap_or_else(|e| panic!("batch_size={batch_size}: {e}"));
assert_eq!(result.len(), expected.len(), "batch_size={batch_size}");
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-3,
"batch_size={batch_size} element {i}: got {got}, expected {want}"
);
}
}
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_dimension_validation_matrix() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let config = BatchedGemvConfig {
rows: 4,
cols: 4,
batch_size: 1,
};
let result = batched_gemv_f32(&ctx, &[1.0; 10], &[1.0; 4], &config);
assert!(result.is_err(), "should reject matrix with wrong size");
}
#[cfg(feature = "gpu")]
#[test]
fn test_batched_gemv_dimension_validation_vectors() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
let config = BatchedGemvConfig {
rows: 4,
cols: 4,
batch_size: 2,
};
let result = batched_gemv_f32(&ctx, &[1.0; 16], &[1.0; 4], &config);
assert!(result.is_err(), "should reject vectors with wrong size");
}
#[cfg(feature = "gpu")]
#[test]
fn test_q4_0_batched_kernel() {
let ctx = match try_gpu_ctx() {
Some(c) => c,
None => return,
};
const Q4_0_BLOCK_SIZE: usize = 32;
const Q4_0_BLOCK_BYTES: usize = 18;
let make_block = |scale: f32, nibbles: &[u8; 16]| -> Vec<u8> {
let mut block = Vec::with_capacity(Q4_0_BLOCK_BYTES);
let d_bits = half::f16::from_f32(scale).to_bits();
block.extend_from_slice(&d_bits.to_le_bytes());
block.extend_from_slice(nibbles);
block
};
let mut weight_bytes = Vec::new();
weight_bytes.extend_from_slice(&make_block(1.0, &[0x99u8; 16]));
weight_bytes.extend_from_slice(&make_block(0.5, &[0xAAu8; 16]));
let rows = 2;
let cols = Q4_0_BLOCK_SIZE;
let batch = 2;
let vectors = [vec![1.0f32; cols], vec![0.5f32; cols]].concat();
let kernel = Q4_0GpuKernel;
let result = kernel
.batched_gemv(&ctx, &weight_bytes, &vectors, rows, cols, batch)
.expect("Q4_0 batched GEMV should succeed");
assert_eq!(result.len(), batch * rows);
let expected = [32.0f32, 32.0, 16.0, 16.0];
for (i, (&got, &want)) in result.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-2,
"element {i}: got {got}, expected {want}"
);
}
}
}