use super::{
BlockLocalAttnParams, PosEmbedAddParams, TransposeParams, VisionAttnParams, write_uniform,
};
use crate::backend::WgpuCtx;
use crate::backend::pipelines::Pipelines;
pub fn pos_embed_add_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
hidden: &wgpu::Buffer,
pos_embd: &wgpu::Buffer,
pos_x: &wgpu::Buffer,
pos_y: &wgpu::Buffer,
n_patches: usize,
hidden_size: usize,
pos_size: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = PosEmbedAddParams {
n_patches: n_patches as u32,
hidden_size: hidden_size as u32,
pos_size: pos_size as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "posembed.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("posembed.bg"),
layout: &p.pos_embed_add.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: hidden.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: pos_embd.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: pos_x.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: pos_y.as_entire_binding(),
},
],
});
let total = (n_patches * hidden_size) as u32;
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("posembed.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.pos_embed_add);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn vision_attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
if head_dim <= 64 && n_patches >= 8 {
if let Some(sub) = p.vision_attention_flash_subgroup.as_ref() {
return vision_attention_flash_subgroup_chained(
ctx, p, sub, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
vision_attention_flash_q8_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
if head_dim <= 64 && n_patches >= 4 {
vision_attention_flash_q4_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
if head_dim <= 64 {
vision_attention_flash_chained(ctx, p, enc, q, k, v, out, head_dim, n_heads, n_patches);
return;
}
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattn.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattn.bg"),
layout: &p.vision_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattn.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
pub fn vision_attention_flash_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnf.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnf.bg"),
layout: &p.vision_attention_flash.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnf.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(n_patches as u32, n_heads as u32, 1);
}
pub fn vision_attention_flash_q16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq16.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq16.bg"),
layout: &p.vision_attention_flash_q16.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq16.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q16);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(16);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_q8_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq8.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq8.bg"),
layout: &p.vision_attention_flash_q8.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq8.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q8);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_subgroup_chained(
ctx: &WgpuCtx,
p: &Pipelines,
sub: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p; let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSub.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSub.bg"),
layout: &sub.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSub.pass"),
timestamp_writes: None,
});
cp.set_pipeline(sub);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn transpose_phd_to_hpd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
transpose_chained(
ctx,
&p.transpose_phd_to_hpd,
"tposePHDtoHPD",
enc,
src,
dst,
n_patches,
n_heads,
head_dim,
);
}
pub fn transpose_hpd_to_phd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
transpose_chained(
ctx,
&p.transpose_hpd_to_phd,
"tposeHPDtoPHD",
enc,
src,
dst,
n_patches,
n_heads,
head_dim,
);
}
fn transpose_chained(
ctx: &WgpuCtx,
pipe: &wgpu::ComputePipeline,
label: &str,
enc: &mut wgpu::CommandEncoder,
src: &wgpu::Buffer,
dst: &wgpu::Buffer,
n_patches: usize,
n_heads: usize,
head_dim: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = TransposeParams {
n_patches: n_patches as u32,
n_heads: n_heads as u32,
head_dim: head_dim as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, &format!("{label}.params"), ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(&format!("{label}.bg")),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: src.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: dst.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some(&format!("{label}.pass")),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let total = (n_patches * n_heads * head_dim) as u32;
cp.dispatch_workgroups(total.div_ceil(64), 1, 1);
}
pub fn vision_attention_flash_sub_hpd_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
vision_attention_flash_sub_hpd_chained(
ctx, p, pipe, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
pub fn vision_attention_flash_hpd_f16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
vision_attention_flash_sub_hpd_chained(
ctx, p, pipe, enc, q, k, v, out, head_dim, n_heads, n_patches,
);
}
pub fn vision_attention_flash_sub_hpd_f16_q16_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubHPDQ16.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubHPDQ16.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubHPDQ16.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(16);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_sub_hpd_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubHPD.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubHPD.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubHPD.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_sub_t64_chained(
ctx: &WgpuCtx,
p: &Pipelines,
pipe: &wgpu::ComputePipeline,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let _ = p;
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnSubT64.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnSubT64.bg"),
layout: &pipe.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnSubT64.pass"),
timestamp_writes: None,
});
cp.set_pipeline(pipe);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(8);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn vision_attention_flash_q4_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
out: &wgpu::Buffer,
head_dim: usize,
n_heads: usize,
n_patches: usize,
) {
let device = &ctx.device;
let queue = &ctx.queue;
let params = VisionAttnParams {
head_dim: head_dim as u32,
n_heads: n_heads as u32,
n_patches: n_patches as u32,
_pad: 0,
};
let p_buf = write_uniform(device, queue, "vattnq4.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vattnq4.bg"),
layout: &p.vision_attention_flash_q4.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vattnq4.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.vision_attention_flash_q4);
cp.set_bind_group(0, &bg, &[]);
let n_query_groups = (n_patches as u32).div_ceil(4);
cp.dispatch_workgroups(n_query_groups, n_heads as u32, 1);
}
pub fn block_local_attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q_pad: &wgpu::Buffer,
k_padded: &wgpu::Buffer,
v_padded: &wgpu::Buffer,
pos_proj: &wgpu::Buffer,
attn_out: &wgpu::Buffer,
seq: usize,
padded_len: usize,
hidden: usize,
n_heads: usize,
head_dim: usize,
chunk_size: usize,
context_size: usize,
max_span: usize,
max_past: usize,
max_future: usize,
pad_left: usize,
logit_cap: f32,
) {
debug_assert_eq!(
head_dim, 128,
"block_local_attention.wgsl is hard-coded to head_dim=128"
);
let device = &ctx.device;
let queue = &ctx.queue;
let params = BlockLocalAttnParams {
seq: seq as u32,
padded_len: padded_len as u32,
hidden: hidden as u32,
n_heads: n_heads as u32,
head_dim: head_dim as u32,
chunk_size: chunk_size as u32,
context_size: context_size as u32,
max_span: max_span as u32,
max_past: max_past as u32,
max_future: max_future as u32,
pad_left: pad_left as u32,
logit_cap,
};
let p_buf = write_uniform(device, queue, "blattn.params", ¶ms);
let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("blattn.bg"),
layout: &p.block_local_attention.get_bind_group_layout(0),
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: p_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: q_pad.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: k_padded.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: v_padded.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: pos_proj.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: attn_out.as_entire_binding(),
},
],
});
let mut cp = enc.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("blattn.pass"),
timestamp_writes: None,
});
cp.set_pipeline(&p.block_local_attention);
cp.set_bind_group(0, &bg, &[]);
cp.dispatch_workgroups(padded_len as u32, n_heads as u32, 1);
}