use baracuda_driver::{init, Context, Device, DeviceBuffer, Stream};
use baracuda_kernels::{
contiguous_stride, ElementKind, FlashSdpaArgs, FlashSdpaDescriptor, FlashSdpaPlan,
PlanPreference, TensorMut, TensorRef, Workspace, WriteSliceArgs, WriteSliceDescriptor,
WriteSlicePlan,
};
fn draft_tree_parents() -> [Option<usize>; 5] {
[None, Some(0), Some(0), Some(1), Some(1)]
}
fn build_tree_mask(prefix_len: usize, num_heads: usize) -> Vec<f32> {
const Q: usize = 5;
let k_len = prefix_len + Q;
let mut mask = vec![f32::NEG_INFINITY; num_heads * Q * k_len];
let parents = draft_tree_parents();
let mut ancestors: [Vec<usize>; 5] = Default::default();
for i in 0..Q {
let mut chain = vec![i];
let mut cur = parents[i];
while let Some(p) = cur {
chain.push(p);
cur = parents[p];
}
ancestors[i] = chain;
}
for h in 0..num_heads {
for qi in 0..Q {
for kj in 0..prefix_len {
let idx = (h * Q + qi) * k_len + kj;
mask[idx] = 0.0;
}
for &anc in &ancestors[qi] {
let kj = prefix_len + anc;
let idx = (h * Q + qi) * k_len + kj;
mask[idx] = 0.0;
}
}
}
mask
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
init()?;
let device = Device::get(0)?;
let ctx = Context::new(&device)?;
let stream = Stream::new(&ctx)?;
const BATCH: i32 = 1;
const HEADS: i32 = 8;
const HEAD_DIM: i32 = 32;
const PREFIX_LEN: i32 = 64;
const Q_LEN: i32 = 5; const K_CAP: i32 = 256; let k_cur_len = PREFIX_LEN + Q_LEN;
let mut k_cache_h = vec![0f32; (BATCH * HEADS * K_CAP * HEAD_DIM) as usize];
let mut v_cache_h = vec![0f32; (BATCH * HEADS * K_CAP * HEAD_DIM) as usize];
for i in 0..(BATCH * HEADS * PREFIX_LEN * HEAD_DIM) as usize {
k_cache_h[i] = ((i as f32) * 0.011).sin() * 0.5;
v_cache_h[i] = ((i as f32) * 0.013 + 0.3).cos() * 0.5;
}
let mut k_cache = DeviceBuffer::from_slice(&ctx, &k_cache_h)?;
let mut v_cache = DeviceBuffer::from_slice(&ctx, &v_cache_h)?;
let n_qkv = (BATCH * HEADS * Q_LEN * HEAD_DIM) as usize;
let q_draft: Vec<f32> = (0..n_qkv).map(|i| ((i as f32) * 0.017 + 1.1).sin() * 0.5).collect();
let k_draft: Vec<f32> = (0..n_qkv).map(|i| ((i as f32) * 0.019 + 0.7).cos() * 0.5).collect();
let v_draft: Vec<f32> = (0..n_qkv).map(|i| ((i as f32) * 0.023 + 0.2).sin() * 0.5).collect();
let d_q = DeviceBuffer::from_slice(&ctx, &q_draft)?;
let d_k = DeviceBuffer::from_slice(&ctx, &k_draft)?;
let d_v = DeviceBuffer::from_slice(&ctx, &v_draft)?;
let kv_dest_shape = [BATCH, HEADS, K_CAP, HEAD_DIM];
let kv_source_shape = [BATCH, HEADS, Q_LEN, HEAD_DIM];
let kv_desc = WriteSliceDescriptor {
dest_shape: kv_dest_shape,
source_shape: kv_source_shape,
ranges: [
(0, BATCH),
(0, HEADS),
(PREFIX_LEN, PREFIX_LEN + Q_LEN),
(0, HEAD_DIM),
],
element: ElementKind::F32,
};
let kv_plan: WriteSlicePlan<f32, 4> =
WriteSlicePlan::select(&stream, &kv_desc, PlanPreference::default())?;
kv_plan.run(
&stream,
Workspace::None,
WriteSliceArgs {
dest: TensorMut {
data: k_cache.as_slice_mut(),
shape: kv_dest_shape,
stride: contiguous_stride(kv_dest_shape),
},
source: TensorRef {
data: d_k.as_slice(),
shape: kv_source_shape,
stride: contiguous_stride(kv_source_shape),
},
},
)?;
kv_plan.run(
&stream,
Workspace::None,
WriteSliceArgs {
dest: TensorMut {
data: v_cache.as_slice_mut(),
shape: kv_dest_shape,
stride: contiguous_stride(kv_dest_shape),
},
source: TensorRef {
data: d_v.as_slice(),
shape: kv_source_shape,
stride: contiguous_stride(kv_source_shape),
},
},
)?;
let mask_h = build_tree_mask(PREFIX_LEN as usize, HEADS as usize);
let d_mask = DeviceBuffer::from_slice(&ctx, &mask_h)?;
let mut y_out: DeviceBuffer<f32> =
DeviceBuffer::zeros(&ctx, (BATCH * HEADS * Q_LEN * HEAD_DIM) as usize)?;
let mut lse_out: DeviceBuffer<f32> =
DeviceBuffer::zeros(&ctx, (BATCH * HEADS * Q_LEN) as usize)?;
let q_shape = [BATCH, HEADS, Q_LEN, HEAD_DIM];
let k_shape = [BATCH, HEADS, k_cur_len, HEAD_DIM];
let v_shape = [BATCH, HEADS, k_cur_len, HEAD_DIM];
let y_shape = [BATCH, HEADS, Q_LEN, HEAD_DIM];
let lse_shape = [BATCH, HEADS, Q_LEN];
let mask_shape = [BATCH, HEADS, Q_LEN, k_cur_len];
let mut k_packed = vec![0f32; (BATCH * HEADS * k_cur_len * HEAD_DIM) as usize];
let mut v_packed = vec![0f32; (BATCH * HEADS * k_cur_len * HEAD_DIM) as usize];
k_cache.copy_to_host(&mut k_cache_h)?;
v_cache.copy_to_host(&mut v_cache_h)?;
for b in 0..(BATCH as usize) {
for h in 0..(HEADS as usize) {
for kk in 0..(k_cur_len as usize) {
for d in 0..(HEAD_DIM as usize) {
let src = ((b * HEADS as usize + h) * K_CAP as usize + kk)
* HEAD_DIM as usize
+ d;
let dst = ((b * HEADS as usize + h) * k_cur_len as usize + kk)
* HEAD_DIM as usize
+ d;
k_packed[dst] = k_cache_h[src];
v_packed[dst] = v_cache_h[src];
}
}
}
}
let d_k_packed = DeviceBuffer::from_slice(&ctx, &k_packed)?;
let d_v_packed = DeviceBuffer::from_slice(&ctx, &v_packed)?;
let scale = 1.0 / (HEAD_DIM as f32).sqrt();
let desc = FlashSdpaDescriptor::new(
BATCH,
HEADS,
Q_LEN,
k_cur_len,
HEAD_DIM,
HEAD_DIM,
scale,
false, ElementKind::F32,
);
let plan = FlashSdpaPlan::<f32>::select(&stream, &desc, PlanPreference::default())?;
plan.run(
&stream,
Workspace::None,
FlashSdpaArgs {
q: TensorRef {
data: d_q.as_slice(),
shape: q_shape,
stride: contiguous_stride(q_shape),
},
k: TensorRef {
data: d_k_packed.as_slice(),
shape: k_shape,
stride: contiguous_stride(k_shape),
},
v: TensorRef {
data: d_v_packed.as_slice(),
shape: v_shape,
stride: contiguous_stride(v_shape),
},
y: TensorMut {
data: y_out.as_slice_mut(),
shape: y_shape,
stride: contiguous_stride(y_shape),
},
lse: TensorMut {
data: lse_out.as_slice_mut(),
shape: lse_shape,
stride: contiguous_stride(lse_shape),
},
mask: Some(TensorRef {
data: d_mask.as_slice(),
shape: mask_shape,
stride: contiguous_stride(mask_shape),
}),
alibi_slopes: None,
},
)?;
stream.synchronize()?;
let mut y_host = vec![0f32; (BATCH * HEADS * Q_LEN * HEAD_DIM) as usize];
y_out.copy_to_host(&mut y_host)?;
println!("Spec-decode tree attention output (first 8 cells of t0):");
for i in 0..8 {
println!(" y[t0, h0, {i}] = {:.6}", y_host[i]);
}
println!("\nDone — tree-attention FW ran successfully.");
println!("Acceptance / sampling / commit logic is caller-owned");
println!("(see docs/SPEC_DECODE.md for division of responsibilities).");
Ok(())
}