use std::borrow::Cow;
use bytemuck::{Pod, Zeroable};
use crate::backend::WgpuCtx;
use crate::error::{Result, RullamaError};
use crate::kernels;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable, Debug)]
struct MatmulParams {
k: u32,
n: u32,
_pad0: u32,
_pad1: u32,
}
pub async fn matmul_bf16(
ctx: &WgpuCtx,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if w_bytes.len() != k * n * 2 {
return Err(RullamaError::Inference(format!(
"bf16 W bytes {} != k*n*2 = {}",
w_bytes.len(),
k * n * 2
)));
}
if x.len() != k {
return Err(RullamaError::Inference(format!(
"x.len() {} != k {}",
x.len(),
k
)));
}
if !k.is_multiple_of(2) {
return Err(RullamaError::Inference(format!(
"k {k} must be even for bf16 matmul"
)));
}
dispatch_matmul(ctx, kernels::BF16_MATMUL, w_bytes, x, k, n).await
}
pub async fn matmul_f16(
ctx: &WgpuCtx,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if w_bytes.len() != k * n * 2 {
return Err(RullamaError::Inference(format!(
"f16 W bytes {} != k*n*2 = {}",
w_bytes.len(),
k * n * 2
)));
}
if x.len() != k {
return Err(RullamaError::Inference(format!(
"x.len() {} != k {}",
x.len(),
k
)));
}
if !k.is_multiple_of(2) {
return Err(RullamaError::Inference(format!(
"k {k} must be even for f16 matmul"
)));
}
dispatch_matmul(ctx, kernels::F16_MATMUL, w_bytes, x, k, n).await
}
pub async fn matmul_q4_k(
ctx: &WgpuCtx,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if !k.is_multiple_of(256) {
return Err(RullamaError::Inference(format!(
"k {k} must be a multiple of 256 for Q4_K matmul"
)));
}
let row_bytes = (k / 256) * 144;
let expected = row_bytes * n;
if w_bytes.len() != expected {
return Err(RullamaError::Inference(format!(
"Q4_K W bytes {} != (k/256)*144*n = {}",
w_bytes.len(),
expected
)));
}
if x.len() != k {
return Err(RullamaError::Inference(format!(
"x.len() {} != k {}",
x.len(),
k
)));
}
if !row_bytes.is_multiple_of(4) {
return Err(RullamaError::Inference(format!(
"Q4_K row_bytes {row_bytes} not multiple of 4 (k={k})"
)));
}
dispatch_matmul(ctx, kernels::Q4_K_DEQUANT_MATMUL, w_bytes, x, k, n).await
}
pub async fn matmul_q4_0(
ctx: &WgpuCtx,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if !k.is_multiple_of(32) {
return Err(RullamaError::Inference(format!(
"k {k} must be a multiple of 32 for Q4_0 matmul"
)));
}
let row_bytes = (k / 32) * 18;
let expected = row_bytes * n;
if w_bytes.len() != expected {
return Err(RullamaError::Inference(format!(
"Q4_0 W bytes {} != (k/32)*18*n = {}",
w_bytes.len(),
expected
)));
}
if x.len() != k {
return Err(RullamaError::Inference(format!(
"x.len() {} != k {}",
x.len(),
k
)));
}
if !row_bytes.is_multiple_of(4) {
return Err(RullamaError::Inference(format!(
"Q4_0 row_bytes {row_bytes} not multiple of 4 (k={k})"
)));
}
dispatch_matmul(ctx, kernels::Q4_0_DEQUANT_MATMUL, w_bytes, x, k, n).await
}
pub async fn matmul_q6_k(
ctx: &WgpuCtx,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
if !k.is_multiple_of(256) {
return Err(RullamaError::Inference(format!(
"k {k} must be a multiple of 256 for Q6_K matmul"
)));
}
let row_bytes = (k / 256) * 210;
let expected = row_bytes * n;
if w_bytes.len() != expected {
return Err(RullamaError::Inference(format!(
"Q6_K W bytes {} != (k/256)*210*n = {}",
w_bytes.len(),
expected
)));
}
if x.len() != k {
return Err(RullamaError::Inference(format!(
"x.len() {} != k {}",
x.len(),
k
)));
}
if !row_bytes.is_multiple_of(4) {
return Err(RullamaError::Inference(format!(
"Q6_K row_bytes {row_bytes} not multiple of 4 (k={k})"
)));
}
dispatch_matmul(ctx, kernels::Q6_K_DEQUANT_MATMUL, w_bytes, x, k, n).await
}
async fn dispatch_matmul(
ctx: &WgpuCtx,
wgsl: &str,
w_bytes: &[u8],
x: &[f32],
k: usize,
n: usize,
) -> Result<Vec<f32>> {
let device = &ctx.device;
let queue = &ctx.queue;
let params = MatmulParams {
k: k as u32,
n: n as u32,
_pad0: 0,
_pad1: 0,
};
let params_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("matmul.params"),
size: std::mem::size_of::<MatmulParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let w_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("matmul.W"),
size: w_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&w_buf, 0, w_bytes);
let x_bytes_len = (x.len() * 4) as u64;
let x_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("matmul.x"),
size: x_bytes_len,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&x_buf, 0, bytemuck::cast_slice(x));
let y_bytes_len = (n * 4) as u64;
let y_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("matmul.y"),
size: y_bytes_len,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let read_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("matmul.read"),
size: y_bytes_len,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("matmul.wgsl"),
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(wgsl)),
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("matmul.pipeline"),
layout: None,
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let bg_layout = pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("matmul.bg"),
layout: &bg_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: params_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("matmul.encoder"),
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("matmul.pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
let workgroups = (n as u32).div_ceil(64);
cpass.dispatch_workgroups(workgroups, 1, 1);
}
encoder.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, y_bytes_len);
queue.submit(Some(encoder.finish()));
let slice = read_buf.slice(..);
let (sender, receiver) = futures_channel::oneshot::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
let _ = sender.send(r);
});
device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.map_err(|e| RullamaError::Inference(format!("{e:?}")))?;
receiver
.await
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?
.map_err(|e| RullamaError::BufferMap(format!("{e}")))?;
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
read_buf.unmap();
Ok(out)
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use half::f16;
fn f32_to_f16_bytes(values: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(values.len() * 2);
for v in values {
out.extend_from_slice(&f16::from_f32(*v).to_le_bytes());
}
out
}
fn cpu_matmul_f32(w: &[f32], x: &[f32], k: usize, n: usize) -> Vec<f32> {
let mut y = vec![0f32; n];
for j in 0..n {
let mut acc = 0f32;
for i in 0..k {
acc += w[j * k + i] * x[i];
}
y[j] = acc;
}
y
}
#[test]
fn f16_matmul_3x4_eye() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 4;
let n = 4;
let mut w = vec![0f32; n * k];
for j in 0..n {
w[j * k + j] = 1.0;
}
let x: Vec<f32> = (0..k).map(|i| (i as f32 + 1.0) * 0.25).collect();
let w_f16 = f32_to_f16_bytes(&w);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let y = pollster::block_on(matmul_f16(&ctx, &w_f16, &x, k, n)).expect("matmul");
for i in 0..n {
assert!(
(y[i] - x[i]).abs() < 1e-4,
"y[{i}]={} != x[{i}]={}",
y[i],
x[i]
);
}
}
#[test]
fn q4_0_matmul_gpu_vs_cpu_oracle() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 128usize; let n = 16usize;
let blocks_per_row = k / 32;
let mut state: u32 = 0x1234_5678;
let mut next_u = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
state
};
let mut w_bytes = Vec::with_capacity(n * blocks_per_row * 18);
for _ in 0..(n * blocks_per_row) {
let dv = ((next_u() >> 8) as f32 / 16777216.0 - 0.5) * 0.2;
w_bytes.extend_from_slice(&f16::from_f32(dv).to_le_bytes());
for _ in 0..16 {
w_bytes.push((next_u() & 0xFF) as u8);
}
}
let x: Vec<f32> = (0..k)
.map(|_| (next_u() >> 8) as f32 / 16777216.0 - 0.5)
.collect();
let mut w_f32 = vec![0f32; k * n];
crate::gguf::quant::dequant_q4_0(&w_bytes, &mut w_f32).expect("dequant");
let cpu = cpu_matmul_f32(&w_f32, &x, k, n);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let gpu = pollster::block_on(matmul_q4_0(&ctx, &w_bytes, &x, k, n)).expect("gpu");
let mut max_abs = 0f32;
for i in 0..n {
max_abs = max_abs.max((gpu[i] - cpu[i]).abs());
}
eprintln!("q4_0 matmul GPU vs CPU: max_abs={max_abs:.3e}");
assert!(
max_abs < 1e-4,
"Q4_0 GPU vs CPU max_abs {max_abs} exceeds 1e-4"
);
}
#[test]
fn layer0_qkv_o_proj_gpu_vs_cpu() {
let _ = env_logger::builder().is_test(true).try_init();
let path = "/Users/nightness/.ollama/models/blobs/sha256-4e30e2665218745ef463f722c0bf86be0cab6ee676320f1cfadf91e989107448";
if !std::path::Path::new(path).exists() {
eprintln!("skipping: gemma4 GGUF not available at {path}");
return;
}
let r = crate::gguf::tensor::reader_from_file_header(path).expect("parse header");
let d_model = 1536usize;
let mut state: u32 = 0xCAFE_BABE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let x_qkv: Vec<f32> = (0..d_model).map(|_| next()).collect();
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let q_desc = r.tensor("blk.0.attn_q.weight").expect("Q desc").clone();
let q_bytes =
crate::gguf::tensor::read_tensor_raw(path, &r, "blk.0.attn_q.weight").expect("Q bytes");
let n_q = q_desc.dims[1] as usize;
let mut q_w_f32 = vec![0f32; d_model * n_q];
crate::gguf::quant::dequant_q4_k(&q_bytes, &mut q_w_f32).expect("Q dequant");
let cpu_q = cpu_matmul_f32(&q_w_f32, &x_qkv, d_model, n_q);
let gpu_q =
pollster::block_on(matmul_q4_k(&ctx, &q_bytes, &x_qkv, d_model, n_q)).expect("Q gpu");
let k_desc = r.tensor("blk.0.attn_k.weight").expect("K desc").clone();
let k_bytes =
crate::gguf::tensor::read_tensor_raw(path, &r, "blk.0.attn_k.weight").expect("K bytes");
let n_k = k_desc.dims[1] as usize;
let mut k_w_f32 = vec![0f32; d_model * n_k];
crate::gguf::quant::dequant_q4_k(&k_bytes, &mut k_w_f32).expect("K dequant");
let cpu_k = cpu_matmul_f32(&k_w_f32, &x_qkv, d_model, n_k);
let gpu_k =
pollster::block_on(matmul_q4_k(&ctx, &k_bytes, &x_qkv, d_model, n_k)).expect("K gpu");
let v_desc = r.tensor("blk.0.attn_v.weight").expect("V desc").clone();
let v_bytes =
crate::gguf::tensor::read_tensor_raw(path, &r, "blk.0.attn_v.weight").expect("V bytes");
let n_v = v_desc.dims[1] as usize;
let mut v_w_f32 = vec![0f32; d_model * n_v];
crate::gguf::quant::dequant_q6_k(&v_bytes, &mut v_w_f32).expect("V dequant");
let cpu_v = cpu_matmul_f32(&v_w_f32, &x_qkv, d_model, n_v);
let gpu_v =
pollster::block_on(matmul_q6_k(&ctx, &v_bytes, &x_qkv, d_model, n_v)).expect("V gpu");
let attn_out: Vec<f32> = (0..n_q).map(|_| next()).collect();
let o_desc = r
.tensor("blk.0.attn_output.weight")
.expect("O desc")
.clone();
let o_bytes = crate::gguf::tensor::read_tensor_raw(path, &r, "blk.0.attn_output.weight")
.expect("O bytes");
assert_eq!(o_desc.dims, vec![n_q as u64, d_model as u64]);
let mut o_w_f32 = vec![0f32; n_q * d_model];
crate::gguf::quant::dequant_q4_k(&o_bytes, &mut o_w_f32).expect("O dequant");
let cpu_o = cpu_matmul_f32(&o_w_f32, &attn_out, n_q, d_model);
let gpu_o = pollster::block_on(matmul_q4_k(&ctx, &o_bytes, &attn_out, n_q, d_model))
.expect("O gpu");
for (name, c, g) in [
("Q [1536,2048]", &cpu_q, &gpu_q),
("K [1536, 256]", &cpu_k, &gpu_k),
("V [1536, 256]", &cpu_v, &gpu_v),
("O [2048,1536]", &cpu_o, &gpu_o),
] {
let mut max_abs = 0f32;
let mut max_rel = 0f32;
for i in 0..c.len() {
let abs = (g[i] - c[i]).abs();
let rel = if c[i].abs() > 1e-3 {
abs / c[i].abs()
} else {
0.0
};
if abs > max_abs {
max_abs = abs;
}
if rel > max_rel {
max_rel = rel;
}
}
eprintln!("layer0 {name}: max_abs={max_abs:.5e}, max_rel={max_rel:.5e}");
assert!(max_abs < 1e-2, "{name} max_abs {max_abs} exceeds 1e-2");
}
}
#[test]
fn q4_k_matmul_real_layer0_attn_q() {
let _ = env_logger::builder().is_test(true).try_init();
let path = "/Users/nightness/.ollama/models/blobs/sha256-4e30e2665218745ef463f722c0bf86be0cab6ee676320f1cfadf91e989107448";
if !std::path::Path::new(path).exists() {
eprintln!("skipping: gemma4 GGUF not available at {path}");
return;
}
let r = crate::gguf::tensor::reader_from_file_header(path).expect("parse header");
let name = "blk.0.attn_q.weight";
let desc = r.tensor(name).expect("tensor").clone();
assert!(matches!(desc.dtype, crate::gguf::GgmlDtype::Q4_K));
assert_eq!(desc.dims, vec![1536, 2048]);
let k = desc.dims[0] as usize;
let n = desc.dims[1] as usize;
let w_bytes = crate::gguf::tensor::read_tensor_raw(path, &r, name).expect("bytes");
let mut state: u32 = 0xC0FF_EE42;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let x: Vec<f32> = (0..k).map(|_| next()).collect();
let mut w_f32 = vec![0f32; k * n];
crate::gguf::quant::dequant_q4_k(&w_bytes, &mut w_f32).expect("dequant");
let cpu_y = cpu_matmul_f32(&w_f32, &x, k, n);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let gpu_y = pollster::block_on(matmul_q4_k(&ctx, &w_bytes, &x, k, n)).expect("matmul");
let mut max_abs = 0f32;
let mut max_rel = 0f32;
for i in 0..n {
let abs = (gpu_y[i] - cpu_y[i]).abs();
let rel = if cpu_y[i].abs() > 1e-3 {
abs / cpu_y[i].abs()
} else {
0.0
};
if abs > max_abs {
max_abs = abs;
}
if rel > max_rel {
max_rel = rel;
}
}
eprintln!(
"q4_k matmul real layer-0 attn_q: max_abs={max_abs:.5}, max_rel={max_rel:.5}, k={k}, n={n}"
);
assert!(
max_abs < 1e-2,
"Q4_K matmul GPU/CPU disagreement: max_abs={max_abs}"
);
}
#[test]
fn q6_k_matmul_real_layer0_attn_v() {
let _ = env_logger::builder().is_test(true).try_init();
let path = "/Users/nightness/.ollama/models/blobs/sha256-4e30e2665218745ef463f722c0bf86be0cab6ee676320f1cfadf91e989107448";
if !std::path::Path::new(path).exists() {
eprintln!("skipping: gemma4 GGUF not available at {path}");
return;
}
let r = crate::gguf::tensor::reader_from_file_header(path).expect("parse header");
let name = "blk.0.attn_v.weight";
let desc = r.tensor(name).expect("tensor").clone();
assert!(matches!(desc.dtype, crate::gguf::GgmlDtype::Q6_K));
assert_eq!(desc.dims, vec![1536, 256]);
let k = desc.dims[0] as usize;
let n = desc.dims[1] as usize;
let w_bytes = crate::gguf::tensor::read_tensor_raw(path, &r, name).expect("bytes");
let mut state: u32 = 0xDEAD_BEEF;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let x: Vec<f32> = (0..k).map(|_| next()).collect();
let mut w_f32 = vec![0f32; k * n];
crate::gguf::quant::dequant_q6_k(&w_bytes, &mut w_f32).expect("dequant");
let cpu_y = cpu_matmul_f32(&w_f32, &x, k, n);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let gpu_y = pollster::block_on(matmul_q6_k(&ctx, &w_bytes, &x, k, n)).expect("matmul");
let mut max_abs = 0f32;
let mut max_rel = 0f32;
for i in 0..n {
let abs = (gpu_y[i] - cpu_y[i]).abs();
let rel = if cpu_y[i].abs() > 1e-3 {
abs / cpu_y[i].abs()
} else {
0.0
};
if abs > max_abs {
max_abs = abs;
}
if rel > max_rel {
max_rel = rel;
}
}
eprintln!("q6_k matmul real layer-0 attn_v: max_abs={max_abs:.5}, max_rel={max_rel:.5}");
assert!(
max_abs < 1e-3,
"Q6_K matmul GPU/CPU disagreement: max_abs={max_abs}"
);
}
fn f32_to_bf16_bytes(values: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(values.len() * 2);
for &v in values {
let bits = v.to_bits();
let bf = (bits >> 16) as u16;
out.extend_from_slice(&bf.to_le_bytes());
}
out
}
#[test]
fn bf16_matmul_random_64x128() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 64;
let n = 128;
let mut state: u32 = 0xDEAD_F00D;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x: Vec<f32> = (0..k).map(|_| next()).collect();
let w_bf16_bytes = f32_to_bf16_bytes(&w_f32);
let w_round_tripped: Vec<f32> = w_f32
.iter()
.map(|&v| {
let bits = v.to_bits();
f32::from_bits(bits & 0xFFFF0000)
})
.collect();
let cpu_y = cpu_matmul_f32(&w_round_tripped, &x, k, n);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let gpu_y = pollster::block_on(matmul_bf16(&ctx, &w_bf16_bytes, &x, k, n)).expect("matmul");
let mut max_abs = 0f32;
for i in 0..n {
let diff = (gpu_y[i] - cpu_y[i]).abs();
if diff > max_abs {
max_abs = diff;
}
assert!(
diff < 1e-4,
"bf16 y[{i}] cpu={} gpu={} diff={}",
cpu_y[i],
gpu_y[i],
diff
);
}
eprintln!("bf16_matmul max_abs_diff over {n} outputs = {max_abs:e}");
}
#[test]
fn bf16_matmul_batched_matches_cpu() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 64;
let n = 32;
let batch = 6;
let mut state: u32 = 0xCAFEFACE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_bf16_bytes = f32_to_bf16_bytes(&w_f32);
let w_round_tripped: Vec<f32> = w_f32
.iter()
.map(|&v| f32::from_bits(v.to_bits() & 0xFFFF0000))
.collect();
let mut cpu_y = vec![0f32; batch * n];
for b in 0..batch {
let row = cpu_matmul_f32(&w_round_tripped, &x_batch[b * k..(b + 1) * k], k, n);
cpu_y[b * n..(b + 1) * n].copy_from_slice(&row);
}
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("test.w"),
size: w_bf16_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_bf16_bytes);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("test.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let y_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("test.y"),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let read_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("test.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut enc = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("test.enc"),
});
crate::backend::dispatch::matmul_bf16_batched_chained(
&ctx, &pipes, &mut enc, &w_buf, &x_buf, &y_buf, k, n, batch,
);
enc.copy_buffer_to_buffer(&y_buf, 0, &read_buf, 0, y_size);
ctx.queue.submit(Some(enc.finish()));
let slice = read_buf.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let gpu_y: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
read_buf.unmap();
let mut max_abs = 0f32;
for i in 0..gpu_y.len() {
let d = (gpu_y[i] - cpu_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"bf16_matmul_batched max_abs over {} outputs = {max_abs:e}",
gpu_y.len()
);
assert!(max_abs < 1e-4, "bf16 batched diff: {max_abs}");
}
#[test]
fn f16_matmul_batched_tiled_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80; let n = 24; let batch = 17; let mut state: u32 = 0xA5A5A5A5;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_f16 = f32_to_f16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft.w"),
size: w_f16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_f16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let mk_read = || {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = mk_read();
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("ft.y_naive");
let y_tiled = mk_y("ft.y_tiled");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"ft.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("ft.naive.bg"),
layout: &pipes.f16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.f16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_f16_batched_tiled_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_tiled, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let tiled_y = read(&y_tiled);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - tiled_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"f16 tiled vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "tiled vs naive diff: {max_abs}");
}
#[test]
fn bf16_matmul_batched_tiled_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80;
let n = 24;
let batch = 17;
let mut state: u32 = 0xBEEFCAFE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_bf16 = f32_to_bf16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft.w"),
size: w_bf16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_bf16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let mk_read = || {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = mk_read();
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("bft.y_naive");
let y_tiled = mk_y("bft.y_tiled");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"bft.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bft.naive.bg"),
layout: &pipes.bf16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.bf16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_bf16_batched_tiled_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_tiled, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let tiled_y = read(&y_tiled);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - tiled_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"bf16 tiled vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "tiled vs naive diff: {max_abs}");
}
#[test]
fn f16_matmul_batched_tiled_v3_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80;
let n = 72; let batch = 35; let mut state: u32 = 0x33333333;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_f16 = f32_to_f16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft3.w"),
size: w_f16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_f16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft3.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft3.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("ft3.y_naive");
let y_v3 = mk_y("ft3.y_v3");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"ft3.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("ft3.naive.bg"),
layout: &pipes.f16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.f16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_f16_batched_tiled_v3_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_v3, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let v3_y = read(&y_v3);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - v3_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"f16 v3 vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "v3 vs naive diff: {max_abs}");
}
#[test]
fn bf16_matmul_batched_tiled_v3_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80;
let n = 72;
let batch = 35;
let mut state: u32 = 0x44444444;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_bf16 = f32_to_bf16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft3.w"),
size: w_bf16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_bf16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft3.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft3.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("bft3.y_naive");
let y_v3 = mk_y("bft3.y_v3");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"bft3.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bft3.naive.bg"),
layout: &pipes.bf16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.bf16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_bf16_batched_tiled_v3_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_v3, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let v3_y = read(&y_v3);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - v3_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"bf16 v3 vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "v3 vs naive diff: {max_abs}");
}
#[test]
fn f16_matmul_batched_tiled_v2_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80; let n = 40; let batch = 19; let mut state: u32 = 0xA5A5A5A5;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_f16 = f32_to_f16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft2.w"),
size: w_f16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_f16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft2.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let mk_read = || {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("ft2.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = mk_read();
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("ft2.y_naive");
let y_v2 = mk_y("ft2.y_v2");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"ft2.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("ft2.naive.bg"),
layout: &pipes.f16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.f16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_f16_batched_tiled_v2_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_v2, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let v2_y = read(&y_v2);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - v2_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"f16 v2 vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "v2 vs naive diff: {max_abs}");
}
#[test]
fn bf16_matmul_batched_tiled_v2_matches_naive() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 80;
let n = 40;
let batch = 19;
let mut state: u32 = 0xBEEFCAFE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w_f32: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x_batch: Vec<f32> = (0..batch * k).map(|_| next()).collect();
let w_bf16 = f32_to_bf16_bytes(&w_f32);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let w_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft2.w"),
size: w_bf16.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&w_buf, 0, &w_bf16);
let x_buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft2.x"),
size: (x_batch.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue
.write_buffer(&x_buf, 0, bytemuck::cast_slice(&x_batch));
let y_size = (batch * n * 4) as u64;
let mk_y = |label| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: y_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let mk_read = || {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bft2.read"),
size: y_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
})
};
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = mk_read();
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, y_size);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let y_naive = mk_y("bft2.y_naive");
let y_v2 = mk_y("bft2.y_v2");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let params = crate::backend::dispatch::BatchedMatmulParams {
k: k as u32,
n: n as u32,
batch: batch as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"bft2.naive.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bft2.naive.bg"),
layout: &pipes.bf16_matmul_batched.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: w_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: x_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: y_naive.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.bf16_matmul_batched);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups((n as u32).div_ceil(64), batch as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::matmul_bf16_batched_tiled_v2_chained(
&ctx, &pipes, &mut e2, &w_buf, &x_buf, &y_v2, k, n, batch,
);
ctx.queue.submit(Some(e2.finish()));
let naive_y = read(&y_naive);
let v2_y = read(&y_v2);
let mut max_abs = 0f32;
for i in 0..naive_y.len() {
let d = (naive_y[i] - v2_y[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!(
"bf16 v2 vs naive max_abs over {} outputs = {max_abs:e}",
naive_y.len()
);
assert!(max_abs < 1e-4, "v2 vs naive diff: {max_abs}");
}
#[test]
fn f16_matmul_random_64x128() {
let _ = env_logger::builder().is_test(true).try_init();
let k = 64;
let n = 128;
let mut state: u32 = 0x1234_5678;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let w: Vec<f32> = (0..n * k).map(|_| next() * 0.1).collect();
let x: Vec<f32> = (0..k).map(|_| next()).collect();
let w_f16 = f32_to_f16_bytes(&w);
let cpu_y = cpu_matmul_f32(&w, &x, k, n);
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let gpu_y = pollster::block_on(matmul_f16(&ctx, &w_f16, &x, k, n)).expect("matmul");
let mut max_abs = 0f32;
for i in 0..n {
let diff = (gpu_y[i] - cpu_y[i]).abs();
if diff > max_abs {
max_abs = diff;
}
assert!(
diff < 1e-2,
"y[{i}] cpu={} gpu={} diff={}",
cpu_y[i],
gpu_y[i],
diff
);
}
eprintln!("f16_matmul max_abs_diff over {n} outputs = {max_abs:e}");
}
#[test]
fn vision_attention_flash_matches_original() {
let _ = env_logger::builder().is_test(true).try_init();
let n_patches = 100;
let n_heads = 3;
let head_dim = 64;
let total = n_patches * n_heads * head_dim;
let mut state: u32 = 0xFEEDFACE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let q: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let k: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let v: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let mkbuf = |label: &'static str, data: &[f32]| {
let buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (data.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
buf
};
let q_buf = mkbuf("vat.q", &q);
let k_buf = mkbuf("vat.k", &k);
let v_buf = mkbuf("vat.v", &v);
let mk_out = |label: &'static str| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let out_orig = mk_out("vat.out_orig");
let out_flash = mk_out("vat.out_flash");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Params {
head_dim: u32,
n_heads: u32,
n_patches: u32,
_pad: u32,
}
let params = Params {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"vat.orig.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vat.orig.bg"),
layout: &pipes.vision_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out_orig.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.vision_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::vision_attention_flash_chained(
&ctx, &pipes, &mut e2, &q_buf, &k_buf, &v_buf, &out_flash, head_dim, n_heads, n_patches,
);
ctx.queue.submit(Some(e2.finish()));
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("vat.read"),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, (total * 4) as u64);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let o_orig = read(&out_orig);
let o_flash = read(&out_flash);
let mut max_abs = 0f32;
let mut max_rel = 0f32;
for i in 0..total {
let d = (o_orig[i] - o_flash[i]).abs();
if d > max_abs {
max_abs = d;
}
let denom = o_orig[i].abs().max(1e-6);
let r = d / denom;
if r > max_rel {
max_rel = r;
}
}
eprintln!("vision_attention flash vs original: max_abs={max_abs:e} max_rel={max_rel:e}");
assert!(max_abs < 1e-4, "flash diverges: max_abs={max_abs}");
}
#[test]
fn vision_attention_flash_q4_matches_original() {
let _ = env_logger::builder().is_test(true).try_init();
let n_patches = 102; let n_heads = 3;
let head_dim = 64;
let total = n_patches * n_heads * head_dim;
let mut state: u32 = 0xC0FFEE13;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let q: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let k: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let v: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let mkbuf = |label: &'static str, data: &[f32]| {
let buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (data.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
buf
};
let q_buf = mkbuf("vatq.q", &q);
let k_buf = mkbuf("vatq.k", &k);
let v_buf = mkbuf("vatq.v", &v);
let mk_out = |label: &'static str| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let out_orig = mk_out("vatq.out_orig");
let out_q4 = mk_out("vatq.out_q4");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
head_dim: u32,
n_heads: u32,
n_patches: u32,
_pad: u32,
}
let params = P {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"vatq.orig.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vatq.orig.bg"),
layout: &pipes.vision_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out_orig.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.vision_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::vision_attention_flash_q4_chained(
&ctx, &pipes, &mut e2, &q_buf, &k_buf, &v_buf, &out_q4, head_dim, n_heads, n_patches,
);
ctx.queue.submit(Some(e2.finish()));
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("vatq.read"),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, (total * 4) as u64);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let o_orig = read(&out_orig);
let o_q4 = read(&out_q4);
let mut max_abs = 0f32;
for i in 0..total {
let d = (o_orig[i] - o_q4[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!("vision_attention Q4 vs original: max_abs={max_abs:e}");
assert!(max_abs < 1e-4, "Q4 diverges: max_abs={max_abs}");
}
#[test]
fn vision_attention_flash_q8_matches_original() {
let _ = env_logger::builder().is_test(true).try_init();
let n_patches = 103; let n_heads = 3;
let head_dim = 64;
let total = n_patches * n_heads * head_dim;
let mut state: u32 = 0xC0DEFACE;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
((state >> 8) as f32 / 16777216.0) - 0.5
};
let q: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let k: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let v: Vec<f32> = (0..total).map(|_| next() * 0.1).collect();
let ctx = pollster::block_on(WgpuCtx::new()).expect("wgpu");
let pipes = crate::backend::Pipelines::new(&ctx.device);
let mkbuf = |label: &'static str, data: &[f32]| {
let buf = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (data.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
ctx.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
buf
};
let q_buf = mkbuf("vat8.q", &q);
let k_buf = mkbuf("vat8.k", &k);
let v_buf = mkbuf("vat8.v", &v);
let mk_out = |label: &'static str| {
ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
})
};
let out_orig = mk_out("vat8.out_orig");
let out_q8 = mk_out("vat8.out_q8");
let mut e1 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct P {
head_dim: u32,
n_heads: u32,
n_patches: u32,
_pad: u32,
}
let params = P {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = crate::backend::dispatch::write_uniform(
&ctx.device,
&ctx.queue,
"vat8.orig.params",
¶ms,
);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vat8.orig.bg"),
layout: &pipes.vision_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out_orig.as_entire_binding(),
},
],
});
let mut cp = e1.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
cp.set_pipeline(&pipes.vision_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
ctx.queue.submit(Some(e1.finish()));
let mut e2 = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
crate::backend::dispatch::vision_attention_flash_q8_chained(
&ctx, &pipes, &mut e2, &q_buf, &k_buf, &v_buf, &out_q8, head_dim, n_heads, n_patches,
);
ctx.queue.submit(Some(e2.finish()));
let read = |buf: &wgpu::Buffer| -> Vec<f32> {
let r = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("vat8.read"),
size: (total * 4) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let mut e = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
e.copy_buffer_to_buffer(buf, 0, &r, 0, (total * 4) as u64);
ctx.queue.submit(Some(e.finish()));
let slice = r.slice(..);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |x| {
tx.send(x).unwrap();
});
ctx.device
.poll(wgpu::PollType::Wait {
submission_index: None,
timeout: None,
})
.unwrap();
rx.recv().unwrap().unwrap();
let data = slice.get_mapped_range();
let out: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
drop(data);
r.unmap();
out
};
let o_orig = read(&out_orig);
let o_q8 = read(&out_q8);
let mut max_abs = 0f32;
for i in 0..total {
let d = (o_orig[i] - o_q8[i]).abs();
if d > max_abs {
max_abs = d;
}
}
eprintln!("vision_attention Q8 vs original: max_abs={max_abs:e}");
assert!(max_abs < 1e-4, "Q8 diverges: max_abs={max_abs}");
}
}