use crate::compile::{BufferRef, Dispatch, ExecutionPlan, ShaderEntry};
use std::collections::{HashMap, HashSet};
type Gpu = blade_graphics::Context;
#[derive(blade_macros::ShaderData)]
struct ScatterAddData {
indices: blade_graphics::BufferPiece,
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: ScatterAddParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct ScatterAddParams {
total: u32,
seq_len: u32,
embed_dim: u32,
_pad: u32,
}
#[derive(Clone, Debug)]
pub struct MemorySummary {
pub total_buffer_bytes: usize,
pub adam_state_bytes: usize,
pub num_buffers: usize,
pub largest_buffer_bytes: usize,
}
impl std::fmt::Display for MemorySummary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} buffers, {:.1} MB total ({:.1} MB adam state), largest {:.1} MB",
self.num_buffers,
self.total_buffer_bytes as f64 / 1e6,
self.adam_state_bytes as f64 / 1e6,
self.largest_buffer_bytes as f64 / 1e6,
)
}
}
const MIN_COOP_WORKGROUPS: u32 = 32;
const MIN_COOP_WORKGROUPS_HIGH_K: u32 = 32;
#[derive(blade_macros::ShaderData)]
struct MatMulData {
matrix_a: blade_graphics::BufferPiece,
matrix_b: blade_graphics::BufferPiece,
matrix_c: blade_graphics::BufferPiece,
params: MatMulParams,
}
#[derive(blade_macros::ShaderData)]
#[allow(unused)] struct FusedRmsNormMatMulCoopData {
matrix_a: blade_graphics::BufferPiece,
matrix_b: blade_graphics::BufferPiece,
rsqrt_buf: blade_graphics::BufferPiece,
w_norm: blade_graphics::BufferPiece,
matrix_c: blade_graphics::BufferPiece,
params: MatMulParams,
}
#[derive(blade_macros::ShaderData)]
struct FusedMatMulAddData {
matrix_a: blade_graphics::BufferPiece,
matrix_b: blade_graphics::BufferPiece,
matrix_c: blade_graphics::BufferPiece,
src: blade_graphics::BufferPiece, params: MatMulParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct MatMulParams {
m: u32,
n: u32,
k: u32,
_pad: u32,
}
#[derive(blade_macros::ShaderData)]
struct UnaryData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: UnaryParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct UnaryParams {
len: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[derive(blade_macros::ShaderData)]
struct BinaryData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: UnaryParams, }
#[derive(blade_macros::ShaderData)]
struct TernaryData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
src_c: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: UnaryParams,
}
#[derive(blade_macros::ShaderData)]
struct BiasAddData {
src: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: BiasAddParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct BiasAddParams {
len: u32,
bias_len: u32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct SgdData {
param: blade_graphics::BufferPiece,
grad: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: SgdParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct SgdParams {
len: u32,
lr: f32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct AdamData {
param: blade_graphics::BufferPiece,
grad: blade_graphics::BufferPiece,
m: blade_graphics::BufferPiece,
v: blade_graphics::BufferPiece,
params: AdamParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct AdamParams {
len: u32,
lr: f32,
beta1: f32,
beta2: f32,
eps: f32,
step: f32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct RmsNormData {
src: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece, dst: blade_graphics::BufferPiece,
params: BiasAddParams, }
#[derive(blade_macros::ShaderData)]
struct EmbeddingData {
indices: blade_graphics::BufferPiece,
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: UnaryParams, }
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct RoPEParams {
seq: u32,
dim: u32,
theta_bits: u32,
pos_offset: u32,
head_dim: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[derive(blade_macros::ShaderData)]
struct RoPEData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: RoPEParams,
}
#[derive(blade_macros::ShaderData)]
struct FourBufData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: MatMulParams,
}
#[derive(blade_macros::ShaderData)]
struct AttentionData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece, dst: blade_graphics::BufferPiece,
lse: blade_graphics::BufferPiece, params: MatMulParams,
}
#[derive(blade_macros::ShaderData)]
struct SlidingWindowAttentionData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
lse: blade_graphics::BufferPiece,
params: SlidingWindowAttentionParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct SlidingWindowAttentionParams {
seq: u32,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
window_size: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[derive(blade_macros::ShaderData)]
struct RoPEDynamicData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
pos_offset_buf: blade_graphics::BufferPiece,
params: RoPEParams,
}
#[derive(blade_macros::ShaderData)]
struct CacheWriteData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
kv_pos_buf: blade_graphics::BufferPiece,
params: UnaryParams, }
#[derive(blade_macros::ShaderData)]
struct GroupNormData {
src: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: GroupNormParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct GroupNormParams {
batch: u32,
channels: u32,
spatial: u32,
num_groups: u32,
eps_bits: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[derive(blade_macros::ShaderData)]
struct GroupNormGradInputData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: GroupNormParams,
}
#[derive(blade_macros::ShaderData)]
struct GroupNormGradWeightBiasData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: GroupNormParams,
}
#[derive(blade_macros::ShaderData)]
struct Conv2dData {
src: blade_graphics::BufferPiece,
weight: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: Conv2dParams,
}
#[derive(blade_macros::ShaderData)]
struct Conv2dGradInputData {
grad_out: blade_graphics::BufferPiece,
weight: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: Conv2dParams,
}
#[derive(blade_macros::ShaderData)]
struct Conv2dGradWeightData {
grad_out: blade_graphics::BufferPiece,
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: Conv2dParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct Conv2dParams {
batch: u32,
in_channels: u32,
in_h: u32,
in_w: u32,
out_channels: u32,
kernel_h: u32,
kernel_w: u32,
stride: u32,
padding_h: u32,
out_h: u32,
out_w: u32,
padding_w: u32,
}
#[derive(blade_macros::ShaderData)]
struct MaxPool2dData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: MaxPool2dParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct MaxPool2dParams {
batch: u32,
channels: u32,
in_h: u32,
in_w: u32,
kernel_h: u32,
kernel_w: u32,
stride: u32,
padding: u32,
out_h: u32,
out_w: u32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct GlobalAvgPoolData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: GlobalAvgPoolParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct GlobalAvgPoolParams {
channels: u32,
spatial: u32,
total_out: u32,
_pad: u32,
}
#[derive(blade_macros::ShaderData)]
struct CachedAttentionData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
kv_pos_buf: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: MatMulParams, }
#[derive(blade_macros::ShaderData)]
struct LayerNormData {
src: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece, bias: blade_graphics::BufferPiece, dst: blade_graphics::BufferPiece,
params: MatMulParams, }
#[derive(blade_macros::ShaderData)]
struct SoftmaxData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: SoftmaxParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct SoftmaxParams {
batch: u32,
features: u32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct CrossEntropyData {
logits: blade_graphics::BufferPiece,
labels: blade_graphics::BufferPiece,
grad_out: blade_graphics::BufferPiece,
loss_out: blade_graphics::BufferPiece,
params: SoftmaxParams,
}
#[derive(blade_macros::ShaderData)]
struct BceData {
pred: blade_graphics::BufferPiece,
labels: blade_graphics::BufferPiece,
grad_out: blade_graphics::BufferPiece,
loss_out: blade_graphics::BufferPiece,
params: UnaryParams, }
#[derive(blade_macros::ShaderData)]
struct TransposeData {
src: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: TransposeParams,
}
#[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct TransposeParams {
m: u32,
n: u32,
_pad0: u32,
_pad1: u32,
}
#[derive(blade_macros::ShaderData)]
struct MultiHeadAttnData {
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
lse: blade_graphics::BufferPiece,
params: MatMulParams,
}
#[derive(blade_macros::ShaderData, Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
#[repr(C)]
struct AttentionGradParams {
q_seq: u32,
kv_seq: u32,
packed_heads: u32,
head_dim: u32,
window_size: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[derive(blade_macros::ShaderData)]
struct MultiHeadAttnGradData {
d_out: blade_graphics::BufferPiece,
src_a: blade_graphics::BufferPiece,
src_b: blade_graphics::BufferPiece,
bias: blade_graphics::BufferPiece,
lse: blade_graphics::BufferPiece,
fwd_dst: blade_graphics::BufferPiece,
dst: blade_graphics::BufferPiece,
params: AttentionGradParams,
}
struct Pipelines {
map: HashMap<ShaderEntry, blade_graphics::ComputePipeline>,
coop_map: HashMap<ShaderEntry, blade_graphics::ComputePipeline>,
small_map: HashMap<ShaderEntry, blade_graphics::ComputePipeline>,
epilogue_map:
HashMap<(ShaderEntry, Vec<crate::compile::EpilogueOp>), blade_graphics::ComputePipeline>,
}
impl Pipelines {
fn new(
gpu: &Gpu,
plan: &ExecutionPlan,
coop_config: Option<&crate::codegen::CoopConfig>,
) -> Self {
use crate::codegen::ShaderGroup;
use blade_graphics as bg;
let mut needed: HashSet<ShaderGroup> = HashSet::new();
let mut needed_coop: HashSet<ShaderGroup> = HashSet::new();
let mut entries_for_group: HashMap<ShaderGroup, HashSet<ShaderEntry>> = HashMap::new();
for dispatch in &plan.dispatches {
let group = dispatch.shader.shader_group();
needed.insert(group);
entries_for_group
.entry(group)
.or_default()
.insert(dispatch.shader.clone());
if dispatch.use_small_tiles {
let small_group = match group {
ShaderGroup::MatMul => ShaderGroup::MatMulSmall,
ShaderGroup::MatMulAdd => ShaderGroup::MatMulSmallAdd,
ShaderGroup::MatMulAT => ShaderGroup::MatMulSmallAT,
ShaderGroup::MatMulBT => ShaderGroup::MatMulSmallBT,
_ => continue,
};
needed.insert(small_group);
entries_for_group
.entry(small_group)
.or_default()
.insert(dispatch.shader.clone());
}
if dispatch.use_coop {
let coop_group = match group {
ShaderGroup::MatMul => ShaderGroup::MatMulCoop,
ShaderGroup::MatMulAdd => ShaderGroup::MatMulCoopAdd,
ShaderGroup::MatMulAT => ShaderGroup::MatMulCoopAT,
ShaderGroup::MatMulBT => ShaderGroup::MatMulCoopBT,
ShaderGroup::Conv2dGradInputGemm => ShaderGroup::Conv2dGradInputGemmCoop,
ShaderGroup::FusedRmsNormMatMul => ShaderGroup::FusedRmsNormMatMulCoop,
_ => continue,
};
needed_coop.insert(coop_group);
entries_for_group
.entry(coop_group)
.or_default()
.insert(dispatch.shader.clone());
}
}
if !plan.param_grad_pairs.is_empty() {
needed.insert(ShaderGroup::Sgd);
entries_for_group
.entry(ShaderGroup::Sgd)
.or_default()
.insert(ShaderEntry::SgdUpdate);
needed.insert(ShaderGroup::Adam);
entries_for_group
.entry(ShaderGroup::Adam)
.or_default()
.insert(ShaderEntry::AdamUpdate);
}
let mut map = HashMap::new();
let mut coop_map = HashMap::new();
let mut small_map = HashMap::new();
let compile_group =
|group: ShaderGroup,
target: &mut HashMap<ShaderEntry, blade_graphics::ComputePipeline>| {
let sm = crate::codegen::generate_module(group);
let shader = gpu.create_shader(bg::ShaderDesc {
source: &sm.source,
naga_module: Some(sm.module),
});
if let Some(entries) = entries_for_group.get(&group) {
for entry in entries {
let layout = shader_data_layout(entry);
let pipeline = gpu.create_compute_pipeline(bg::ComputePipelineDesc {
name: entry.entry_point(),
data_layouts: &[&layout],
compute: shader.at(entry.entry_point()),
});
target.insert(entry.clone(), pipeline);
}
}
};
let small_tile_groups: HashSet<ShaderGroup> = [
ShaderGroup::MatMulSmall,
ShaderGroup::MatMulSmallAdd,
ShaderGroup::MatMulSmallAT,
ShaderGroup::MatMulSmallBT,
]
.into_iter()
.collect();
for &group in &needed {
if small_tile_groups.contains(&group) {
compile_group(group, &mut small_map);
} else {
compile_group(group, &mut map);
}
}
for &group in &needed_coop {
if let Some(config) = coop_config {
let sm = crate::codegen::generate_coop_module(group, config);
let shader = gpu.create_shader(bg::ShaderDesc {
source: &sm.source,
naga_module: Some(sm.module),
});
if let Some(entries) = entries_for_group.get(&group) {
for entry in entries {
let layout = shader_data_layout(entry);
let pipeline = gpu.create_compute_pipeline(bg::ComputePipelineDesc {
name: entry.entry_point(),
data_layouts: &[&layout],
compute: shader.at(entry.entry_point()),
});
coop_map.insert(entry.clone(), pipeline);
}
}
} else {
compile_group(group, &mut coop_map);
}
}
let mut epilogue_map = HashMap::new();
for dispatch in &plan.dispatches {
if dispatch.epilogue.is_empty() {
continue;
}
let key = (dispatch.shader.clone(), dispatch.epilogue.clone());
if epilogue_map.contains_key(&key) {
continue;
}
let group = dispatch.shader.shader_group();
let sm = crate::codegen::generate_matmul_with_epilogue(group, &dispatch.epilogue);
let shader = gpu.create_shader(bg::ShaderDesc {
source: &sm.source,
naga_module: Some(sm.module),
});
let layout = shader_data_layout(&dispatch.shader);
let pipeline = gpu.create_compute_pipeline(bg::ComputePipelineDesc {
name: dispatch.shader.entry_point(),
data_layouts: &[&layout],
compute: shader.at(dispatch.shader.entry_point()),
});
epilogue_map.insert(key, pipeline);
}
Self {
map,
coop_map,
small_map,
epilogue_map,
}
}
fn get(&self, dispatch: &Dispatch) -> &blade_graphics::ComputePipeline {
if !dispatch.epilogue.is_empty() {
let key = (dispatch.shader.clone(), dispatch.epilogue.clone());
if let Some(p) = self.epilogue_map.get(&key) {
return p;
}
}
if dispatch.use_coop {
if let Some(p) = self.coop_map.get(&dispatch.shader) {
return p;
}
}
if dispatch.use_small_tiles {
if let Some(p) = self.small_map.get(&dispatch.shader) {
return p;
}
}
&self.map[&dispatch.shader]
}
}
fn shader_data_layout(entry: &ShaderEntry) -> blade_graphics::ShaderDataLayout {
use blade_graphics::ShaderData;
match *entry {
ShaderEntry::MatMul | ShaderEntry::MatMulAT | ShaderEntry::MatMulBT => MatMulData::layout(),
ShaderEntry::FusedMatMulAdd
| ShaderEntry::FusedMatMulATAdd
| ShaderEntry::FusedMatMulBTAdd => FusedMatMulAddData::layout(),
ShaderEntry::Relu
| ShaderEntry::Sigmoid
| ShaderEntry::Tanh
| ShaderEntry::Neg
| ShaderEntry::Abs
| ShaderEntry::Log
| ShaderEntry::Recip
| ShaderEntry::Silu => UnaryData::layout(),
ShaderEntry::Add | ShaderEntry::Mul | ShaderEntry::Greater | ShaderEntry::SwiGLU => {
BinaryData::layout()
}
ShaderEntry::BiasAdd => BiasAddData::layout(),
ShaderEntry::SgdUpdate => SgdData::layout(),
ShaderEntry::AdamUpdate => AdamData::layout(),
ShaderEntry::ScatterAdd => ScatterAddData::layout(),
ShaderEntry::SwiGLUConcat | ShaderEntry::SwiGLUConcatGrad => BinaryData::layout(),
ShaderEntry::SumAll | ShaderEntry::MeanAll | ShaderEntry::SumRows => UnaryData::layout(),
ShaderEntry::Softmax => SoftmaxData::layout(),
ShaderEntry::CrossEntropyLoss => CrossEntropyData::layout(),
ShaderEntry::BceLoss => BceData::layout(),
ShaderEntry::Transpose => TransposeData::layout(),
ShaderEntry::RmsNorm => RmsNormData::layout(),
ShaderEntry::Embedding => EmbeddingData::layout(),
ShaderEntry::RoPE | ShaderEntry::RoPEGrad => RoPEData::layout(),
ShaderEntry::CausalAttention => AttentionData::layout(),
ShaderEntry::CausalAttentionRoPE => AttentionData::layout(),
ShaderEntry::SlidingWindowAttention => SlidingWindowAttentionData::layout(),
ShaderEntry::Gelu => UnaryData::layout(),
ShaderEntry::LayerNorm => LayerNormData::layout(),
ShaderEntry::FullAttention | ShaderEntry::CrossAttention => AttentionData::layout(),
ShaderEntry::MultiHeadAttn => MultiHeadAttnData::layout(),
ShaderEntry::MultiHeadAttnGradQ
| ShaderEntry::MultiHeadAttnGradK
| ShaderEntry::MultiHeadAttnGradV => MultiHeadAttnGradData::layout(),
ShaderEntry::SwiGLUGradGate => TernaryData::layout(),
ShaderEntry::SwiGLUGradUp | ShaderEntry::SiluGrad => BinaryData::layout(),
ShaderEntry::RmsNormGradW | ShaderEntry::RmsNormGradX => FourBufData::layout(),
ShaderEntry::LayerNormGradWB | ShaderEntry::LayerNormGradX => FourBufData::layout(),
ShaderEntry::FusedRmsNormMatMul => FourBufData::layout(),
ShaderEntry::RmsNormRsqrt => UnaryData::layout(),
ShaderEntry::GroupNorm | ShaderEntry::GroupNormSilu => GroupNormData::layout(),
ShaderEntry::GroupNormGradInput => GroupNormGradInputData::layout(),
ShaderEntry::GroupNormGradWeightBias => GroupNormGradWeightBiasData::layout(),
ShaderEntry::Concat => BinaryData::layout(),
ShaderEntry::SplitA | ShaderEntry::SplitB => UnaryData::layout(),
ShaderEntry::Upsample2x | ShaderEntry::Upsample2xGrad => UnaryData::layout(),
ShaderEntry::Conv2d => Conv2dData::layout(),
ShaderEntry::Conv2dGemm | ShaderEntry::Conv2dGemmSmall => Conv2dData::layout(),
ShaderEntry::Conv2dGradInput => Conv2dGradInputData::layout(),
ShaderEntry::Conv2dGradInputGemm
| ShaderEntry::Conv2dGradInputGemmSmall
| ShaderEntry::Conv2dGradInputGemmCoop => Conv2dGradInputData::layout(),
ShaderEntry::Conv2dGradWeight
| ShaderEntry::Conv2dGradWeightGemm
| ShaderEntry::Conv2dGradWeightGemmSmall => Conv2dGradWeightData::layout(),
ShaderEntry::RoPEDynamic => RoPEDynamicData::layout(),
ShaderEntry::CacheWrite => CacheWriteData::layout(),
ShaderEntry::CachedAttention => CachedAttentionData::layout(),
ShaderEntry::MaxPool2d => MaxPool2dData::layout(),
ShaderEntry::GlobalAvgPool => GlobalAvgPoolData::layout(),
ShaderEntry::GlobalAvgPoolGrad => UnaryData::layout(),
}
}
fn reorder_by_level(dispatches: &mut Vec<Dispatch>) {
let n = dispatches.len();
if n == 0 {
return;
}
let mut producer: HashMap<u32, usize> = HashMap::new();
let mut levels = vec![0u32; n];
for (i, dispatch) in dispatches.iter().enumerate() {
let level = dispatch
.input_buffers
.iter()
.filter_map(|b| producer.get(&b.0))
.map(|&pred| levels[pred] + 1)
.max()
.unwrap_or(0);
levels[i] = level;
producer.insert(dispatch.output_buffer.0, i);
for &extra in &dispatch.extra_outputs {
producer.insert(extra.0, i);
}
}
let mut order: Vec<usize> = (0..n).collect();
order.sort_by_key(|&i| levels[i]);
let old = std::mem::take(dispatches);
*dispatches = order.iter().map(|&i| old[i].clone()).collect();
}
fn compute_groups(dispatches: &[Dispatch]) -> Vec<std::ops::Range<usize>> {
let mut groups = Vec::new();
let mut dirty = HashSet::<u32>::new();
let mut start = 0;
for (i, dispatch) in dispatches.iter().enumerate() {
if dispatch.input_buffers.iter().any(|b| dirty.contains(&b.0)) {
groups.push(start..i);
start = i;
dirty.clear();
}
dirty.insert(dispatch.output_buffer.0);
for &extra in &dispatch.extra_outputs {
dirty.insert(extra.0);
}
}
if !dispatches.is_empty() {
groups.push(start..dispatches.len());
}
groups
}
pub struct Session {
gpu: Gpu,
buffers: Vec<blade_graphics::Buffer>,
pipelines: Pipelines,
plan: ExecutionPlan,
groups: Vec<std::ops::Range<usize>>,
encoder: blade_graphics::CommandEncoder,
sync_point: Option<blade_graphics::SyncPoint>,
last_submit_ns: u64,
profiling: bool,
pending_lr: Option<f32>,
adam_state: Vec<(blade_graphics::Buffer, blade_graphics::Buffer)>,
adam_step: u32,
pending_adam: Option<(f32, f32, f32, f32)>, }
impl Session {
fn select_coop_config(
caps: &blade_graphics::CooperativeMatrix,
) -> Option<crate::codegen::CoopConfig> {
use crate::codegen::CoopConfig;
if caps.f16_tile > 0 {
Some(CoopConfig {
tile_size: caps.f16_tile,
use_f16_input: true,
})
} else if caps.f32_tile > 0 {
Some(CoopConfig {
tile_size: caps.f32_tile,
use_f16_input: false,
})
} else {
None
}
}
fn test_coop_matmul(gpu: &Gpu, config: &crate::codegen::CoopConfig) -> bool {
use crate::codegen::ShaderGroup;
use blade_graphics as bg;
let sm = crate::codegen::generate_coop_module(ShaderGroup::MatMulCoop, config);
let shader = match gpu.try_create_shader(bg::ShaderDesc {
source: &sm.source,
naga_module: Some(sm.module),
}) {
Ok(s) => s,
Err(e) => {
log::warn!("cooperative matmul shader rejected: {}", e);
return false;
}
};
let layout = shader_data_layout(&ShaderEntry::MatMul);
let mut pipeline = gpu.create_compute_pipeline(bg::ComputePipelineDesc {
name: "main",
data_layouts: &[&layout],
compute: shader.at("main"),
});
let ot = config.output_tile() as usize;
let m: usize = ot; let inner: usize = ot; let n_out: usize = ot * 2; let a_size = (m * inner * 4) as u64;
let b_size = (inner * n_out * 4) as u64;
let c_size = (m * n_out * 4) as u64;
let a_buf = gpu.create_buffer(bg::BufferDesc {
name: "test_a",
size: a_size,
memory: bg::Memory::Shared,
});
let b_buf = gpu.create_buffer(bg::BufferDesc {
name: "test_b",
size: b_size,
memory: bg::Memory::Shared,
});
let c_buf = gpu.create_buffer(bg::BufferDesc {
name: "test_c",
size: c_size,
memory: bg::Memory::Shared,
});
unsafe {
let a = std::slice::from_raw_parts_mut(a_buf.data() as *mut f32, m * inner);
let b = std::slice::from_raw_parts_mut(b_buf.data() as *mut f32, inner * n_out);
let c = std::slice::from_raw_parts_mut(c_buf.data() as *mut f32, m * n_out);
for i in 0..m {
for j in 0..inner {
a[i * inner + j] = (i + 1) as f32;
}
}
for i in 0..inner {
for j in 0..n_out {
b[i * n_out + j] = (j + 1) as f32;
}
}
c.fill(0.0);
}
let mut encoder = gpu.create_command_encoder(bg::CommandEncoderDesc {
name: "coop_test",
buffer_count: 2,
});
encoder.start();
{
let mut pass = encoder.compute("coop_test");
let mut pc = pass.with(&pipeline);
let ot = config.output_tile();
pc.bind(
0,
&MatMulData {
matrix_a: a_buf.at(0),
matrix_b: b_buf.at(0),
matrix_c: c_buf.at(0),
params: MatMulParams {
m: m as u32,
n: n_out as u32,
k: inner as u32,
_pad: 0,
},
},
);
pc.dispatch([(m as u32).div_ceil(ot), (n_out as u32).div_ceil(ot), 1]);
}
let sp = gpu.submit(&mut encoder);
let _ = gpu.wait_for(&sp, !0);
let result =
unsafe { std::slice::from_raw_parts(c_buf.data() as *const f32, m * n_out).to_vec() };
gpu.destroy_command_encoder(&mut encoder);
gpu.destroy_compute_pipeline(&mut pipeline);
gpu.destroy_buffer(a_buf);
gpu.destroy_buffer(b_buf);
gpu.destroy_buffer(c_buf);
let mut ok = true;
let mut first_mismatch = true;
for i in 0..m {
for j in 0..n_out {
let expected = (inner as f32) * (i + 1) as f32 * (j + 1) as f32;
let got = result[i * n_out + j];
if (got - expected).abs() > 0.5 {
if first_mismatch {
log::warn!(
"coop self-test FAILED (m={m}, n={n_out}, k={inner}, tile={})",
config.tile_size
);
log::warn!(" row 0: {:?}", &result[0..n_out]);
if m > 1 {
log::warn!(" row 1: {:?}", &result[n_out..2 * n_out]);
}
log::warn!(
" expected row 0: {:?}",
(0..n_out)
.map(|j| inner as f32 * 1.0 * (j + 1) as f32)
.collect::<Vec<_>>()
);
first_mismatch = false;
}
ok = false;
}
}
}
ok
}
pub fn new(plan: ExecutionPlan) -> Self {
let gpu = unsafe {
blade_graphics::Context::init(blade_graphics::ContextDesc {
validation: cfg!(debug_assertions),
timing: true,
capture: false,
overlay: false,
device_id: std::env::var("MEGANEURA_DEVICE_ID")
.ok()
.and_then(|s| s.parse().ok()),
..Default::default()
})
}
.expect("failed to initialize blade GPU context");
let coop_caps = gpu.capabilities().cooperative_matrix;
let coop_config = Self::select_coop_config(&coop_caps)
.filter(|config| Self::test_coop_matmul(&gpu, config));
if let Some(ref config) = coop_config {
log::info!(
"cooperative matrix enabled (tile={}×{}, {}, f32_tile={}, f16_tile={})",
config.tile_size,
config.tile_size,
if config.use_f16_input {
"f16→f32"
} else {
"f32"
},
coop_caps.f32_tile,
coop_caps.f16_tile,
);
} else {
let info = gpu.device_information();
log::warn!(
"cooperative matrix not available on {} ({}) (f32_tile={}, f16_tile={}); using naive matmul",
info.device_name,
info.driver_name,
coop_caps.f32_tile,
coop_caps.f16_tile,
);
}
let mut plan = plan;
if let Some(ref config) = coop_config {
use crate::codegen::ShaderGroup;
let output_tile = config.output_tile();
let _half_tile = config.tile_size;
for dispatch in &mut plan.dispatches {
let group = dispatch.shader.shader_group();
let (m, n, k, batch) = match group {
ShaderGroup::MatMul | ShaderGroup::MatMulAdd => (
dispatch.params[0],
dispatch.params[2],
dispatch.params[1],
1u32,
),
ShaderGroup::Conv2dGradInputGemm => {
let in_ch = dispatch.params[1];
let in_h = dispatch.params[2];
let in_w = dispatch.params[3];
let out_ch = dispatch.params[4];
let kh = dispatch.params[5];
let kw = dispatch.params[6];
(in_ch, in_h * in_w, out_ch * kh * kw, dispatch.params[0])
}
ShaderGroup::FusedRmsNormMatMul => (
dispatch.params[0], dispatch.params[1], dispatch.params[2], 1u32,
),
ShaderGroup::MatMulAT | ShaderGroup::MatMulBT => (
dispatch.params[0],
dispatch.params[1],
dispatch.params[2],
1u32,
),
_ => continue,
};
let coop_wgs = m.div_ceil(output_tile) * n.div_ceil(output_tile) * batch;
let is_conv = matches!(group, ShaderGroup::Conv2dGradInputGemm);
let min_wgs = if is_conv {
512
} else if k >= 1024 {
MIN_COOP_WORKGROUPS_HIGH_K
} else {
MIN_COOP_WORKGROUPS
};
if coop_wgs >= min_wgs {
dispatch.use_coop = true;
dispatch.workgroups = [m.div_ceil(output_tile), n.div_ceil(output_tile), batch];
let padded_m = m.div_ceil(output_tile) * output_tile;
let padded_n = n.div_ceil(output_tile) * output_tile;
let padded_bytes = (padded_m * padded_n * batch * 4) as usize;
let buf_idx = dispatch.output_buffer.0 as usize;
if plan.buffers[buf_idx] < padded_bytes {
plan.buffers[buf_idx] = padded_bytes;
}
if dispatch.input_buffers.len() > 2 {
let src_idx = dispatch.input_buffers[2].0 as usize;
if plan.buffers[src_idx] < padded_bytes {
plan.buffers[src_idx] = padded_bytes;
}
}
}
}
}
reorder_by_level(&mut plan.dispatches);
let groups = if std::env::var("MEGANEURA_SERIAL_DISPATCH").is_ok() {
log::info!("MEGANEURA_SERIAL_DISPATCH: forcing one dispatch per pass");
(0..plan.dispatches.len()).map(|i| i..i + 1).collect()
} else {
compute_groups(&plan.dispatches)
};
log::info!(
"{} dispatches → {} barrier groups",
plan.dispatches.len(),
groups.len()
);
for group in &groups {
let mut written = HashSet::<u32>::new();
for i in group.clone() {
let d = &plan.dispatches[i];
for ib in &d.input_buffers {
if written.contains(&ib.0) {
log::warn!(
"RAW hazard in group: dispatch {} ({:?}) reads buf {} written earlier in same group",
i,
d.shader,
ib.0
);
}
}
written.insert(d.output_buffer.0);
for eo in &d.extra_outputs {
written.insert(eo.0);
}
}
}
let buffers: Vec<blade_graphics::Buffer> = plan
.buffers
.iter()
.enumerate()
.map(|(i, &size)| {
let size = size.max(4);
let buf = gpu.create_buffer(blade_graphics::BufferDesc {
name: &format!("buf_{}", i),
size: size as u64,
memory: blade_graphics::Memory::Shared,
});
unsafe {
std::ptr::write_bytes(buf.data(), 0, size);
}
buf
})
.collect();
for &(buf_ref, ref data) in &plan.constant_buffers {
let buffer = &buffers[buf_ref.0 as usize];
unsafe {
let ptr = buffer.data() as *mut f32;
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
}
}
let pipelines = Pipelines::new(&gpu, &plan, coop_config.as_ref());
let encoder = gpu.create_command_encoder(blade_graphics::CommandEncoderDesc {
name: "meganeura",
buffer_count: 2,
});
let adam_state = plan
.param_grad_pairs
.iter()
.enumerate()
.map(|(i, &(param_buf, _))| {
let size = (plan.buffers[param_buf.0 as usize] as u64).max(4);
let m_buf = gpu.create_buffer(blade_graphics::BufferDesc {
name: &format!("adam_m_{i}"),
size,
memory: blade_graphics::Memory::Shared,
});
let v_buf = gpu.create_buffer(blade_graphics::BufferDesc {
name: &format!("adam_v_{i}"),
size,
memory: blade_graphics::Memory::Shared,
});
unsafe {
std::ptr::write_bytes(m_buf.data(), 0, size as usize);
std::ptr::write_bytes(v_buf.data(), 0, size as usize);
}
(m_buf, v_buf)
})
.collect();
Self {
gpu,
buffers,
pipelines,
plan,
groups,
encoder,
sync_point: None,
last_submit_ns: 0,
profiling: false,
pending_lr: None,
adam_state,
adam_step: 0,
pending_adam: None,
}
}
pub fn set_profiling(&mut self, enabled: bool) {
self.profiling = enabled;
}
pub fn set_parameter(&mut self, name: &str, data: &[f32]) {
for &(ref param_name, buf_ref) in &self.plan.param_buffers {
if param_name == name {
self.upload_buffer(buf_ref, bytemuck::cast_slice(data));
for entry in &self.plan.derived_params {
let derived_buf = &entry.0;
let sources = &entry.1;
let total_cols: usize = sources.iter().map(|s| s.1).sum();
let buf_f32 = self.plan.buffers[derived_buf.0 as usize] / 4;
let rows = if total_cols > 0 {
buf_f32 / total_cols
} else {
0
};
let mut col_offset = 0usize;
for src in sources {
let src_name = &src.0;
let src_cols = src.1;
if src_name == name && rows > 0 {
let derived_ptr =
self.buffers[derived_buf.0 as usize].data() as *mut f32;
for r in 0..rows {
let src_start = r * src_cols;
let dst_start = r * total_cols + col_offset;
unsafe {
std::ptr::copy_nonoverlapping(
data[src_start..].as_ptr(),
derived_ptr.add(dst_start),
src_cols,
);
}
}
}
col_offset += src_cols;
}
}
return;
}
}
panic!("unknown parameter: {}", name);
}
pub fn set_input(&mut self, name: &str, data: &[f32]) {
for &(ref input_name, buf_ref) in &self.plan.input_buffers {
if input_name == name {
self.upload_buffer(buf_ref, bytemuck::cast_slice(data));
return;
}
}
panic!("unknown input: {}", name);
}
pub fn set_input_u32(&mut self, name: &str, data: &[u32]) {
for &(ref input_name, buf_ref) in &self.plan.input_buffers {
if input_name == name {
self.upload_buffer(buf_ref, bytemuck::cast_slice(data));
return;
}
}
panic!("unknown input: {}", name);
}
fn upload_buffer(&self, buf_ref: BufferRef, data: &[u8]) {
let buffer = &self.buffers[buf_ref.0 as usize];
unsafe {
let ptr = buffer.data();
std::ptr::copy_nonoverlapping(data.as_ptr(), ptr, data.len());
}
}
pub fn read_loss(&self) -> f32 {
if let Some(buf_ref) = self.plan.loss_buffer {
let buffer = &self.buffers[buf_ref.0 as usize];
unsafe {
let ptr = buffer.data() as *const f32;
*ptr
}
} else {
0.0
}
}
pub fn trace_dispatches(&self, threshold: f32) {
for (i, d) in self.plan.dispatches.iter().enumerate() {
let buf_size = self.plan.buffers[d.output_buffer.0 as usize];
let n = buf_size / 4;
if n == 0 {
continue;
}
let read_n = n.min(65536);
let mut data = vec![0.0f32; read_n];
self.read_buffer(d.output_buffer, &mut data);
let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
let has_nan = data.iter().any(|v| v.is_nan());
let has_inf = data.iter().any(|v| v.is_infinite());
if has_nan || has_inf || max_abs > threshold || (max_abs == 0.0 && n > 100) {
let label = if d.label.is_empty() {
format!("{:?}", d.shader)
} else {
d.label.clone()
};
log::warn!(
"dispatch {i}: {label} max_abs={max_abs:.3e} nan={has_nan} inf={has_inf} n={n}"
);
}
}
}
pub fn read_output(&self, len: usize) -> Vec<f32> {
if let Some(buf_ref) = self.plan.loss_buffer {
let mut out = vec![0.0_f32; len];
self.read_buffer(buf_ref, &mut out);
out
} else {
Vec::new()
}
}
pub fn read_buffer(&self, buf_ref: BufferRef, out: &mut [f32]) {
let buffer = &self.buffers[buf_ref.0 as usize];
unsafe {
let ptr = buffer.data() as *const f32;
std::ptr::copy_nonoverlapping(ptr, out.as_mut_ptr(), out.len());
}
}
pub fn read_output_by_index(&self, index: usize, out: &mut [f32]) {
let buf_ref = self.plan.output_buffers[index];
self.read_buffer(buf_ref, out);
}
pub fn num_outputs(&self) -> usize {
self.plan.output_buffers.len()
}
pub fn param_buffer(&self, name: &str) -> Option<BufferRef> {
self.plan
.param_buffers
.iter()
.find(|entry| entry.0 == name)
.map(|entry| entry.1)
}
pub fn read_param(&self, name: &str, out: &mut [f32]) {
let buf_ref = self
.param_buffer(name)
.unwrap_or_else(|| panic!("unknown param: {}", name));
self.read_buffer(buf_ref, out);
}
pub fn read_param_grad(&self, name: &str, out: &mut [f32]) {
let param_buf = self
.param_buffer(name)
.unwrap_or_else(|| panic!("unknown param: {}", name));
let grad_buf = self
.plan
.param_grad_pairs
.iter()
.find(|&&(p, _)| p == param_buf)
.map(|&(_, g)| g)
.unwrap_or_else(|| panic!("no gradient for param: {}", name));
self.read_buffer(grad_buf, out);
}
pub fn upload_param(&self, name: &str, data: &[f32]) {
let buf_ref = self
.param_buffer(name)
.unwrap_or_else(|| panic!("unknown param: {}", name));
self.upload_buffer(buf_ref, bytemuck::cast_slice(data));
}
pub fn dump_gpu_timings(&self) {
let timings = self.encoder.timings();
if timings.is_empty() {
eprintln!("(no GPU timings available)");
return;
}
let total: std::time::Duration = timings.iter().map(|&(_, d)| d).sum();
eprintln!(
"--- GPU pass timings ({} passes, {:.2}ms total) ---",
timings.len(),
total.as_secs_f64() * 1000.0
);
let mut by_type: std::collections::HashMap<&str, (u32, std::time::Duration)> =
std::collections::HashMap::new();
for &(ref name, dur) in timings {
let entry = by_type.entry(name.as_str()).or_default();
entry.0 += 1;
entry.1 += dur;
}
let mut sorted: Vec<_> = by_type.into_iter().collect();
sorted.sort_by(|a, b| b.1.1.cmp(&a.1.1));
for &(name, (count, dur)) in &sorted {
let pct = dur.as_secs_f64() / total.as_secs_f64() * 100.0;
eprintln!(
" {:>20}: {:>3}x {:>8.2}ms ({:>5.1}%)",
name,
count,
dur.as_secs_f64() * 1000.0,
pct
);
}
eprintln!("---");
}
pub fn wait(&mut self) {
if let Some(sp) = self.sync_point.take() {
let _span = tracing::info_span!("wait").entered();
let _ = self.gpu.wait_for(&sp, !0);
}
}
pub fn step(&mut self) {
let _span = tracing::info_span!("step").entered();
self.wait();
self.encoder.start();
self.drain_gpu_timings();
if self.profiling {
for i in 0..self.plan.dispatches.len() {
let dispatch = &self.plan.dispatches[i];
let pipeline = self.pipelines.get(dispatch);
let mut pass = self.encoder.compute(&dispatch.label);
let mut pc = pass.with(pipeline);
Self::bind_dispatch(&self.buffers, dispatch, &mut pc);
pc.dispatch(dispatch.workgroups);
}
} else {
let mut pass = self.encoder.compute("step");
for gi in 0..self.groups.len() {
if gi > 0 {
pass.barrier();
}
let group = self.groups[gi].clone();
for i in group {
let dispatch = &self.plan.dispatches[i];
let pipeline = self.pipelines.get(dispatch);
let mut pc = pass.with(pipeline);
Self::bind_dispatch(&self.buffers, dispatch, &mut pc);
pc.dispatch(dispatch.workgroups);
}
}
}
if !self.plan.param_grad_pairs.is_empty() {
let lr = self.pending_lr.take();
if let Some(learning_rate) = lr {
let pipeline = &self.pipelines.map[&ShaderEntry::SgdUpdate];
let mut pass = self.encoder.compute("sgd_update");
for &(param_buf, grad_buf) in &self.plan.param_grad_pairs {
let len = (self.plan.buffers[param_buf.0 as usize] / 4) as u32;
let mut pc = pass.with(pipeline);
pc.bind(
0,
&SgdData {
param: self.buffers[param_buf.0 as usize].at(0),
grad: self.buffers[grad_buf.0 as usize].at(0),
dst: self.buffers[param_buf.0 as usize].at(0),
params: SgdParams {
len,
lr: learning_rate,
_pad0: 0,
_pad1: 0,
},
},
);
pc.dispatch([len.div_ceil(256), 1, 1]);
}
} else if let Some((lr, beta1, beta2, eps)) = self.pending_adam.take() {
self.adam_step += 1;
let pipeline = &self.pipelines.map[&ShaderEntry::AdamUpdate];
let mut pass = self.encoder.compute("adam_update");
for (idx, &(param_buf, grad_buf)) in self.plan.param_grad_pairs.iter().enumerate() {
let len = (self.plan.buffers[param_buf.0 as usize] / 4) as u32;
let (ref m_buf, ref v_buf) = self.adam_state[idx];
let mut pc = pass.with(pipeline);
pc.bind(
0,
&AdamData {
param: self.buffers[param_buf.0 as usize].at(0),
grad: self.buffers[grad_buf.0 as usize].at(0),
m: m_buf.at(0),
v: v_buf.at(0),
params: AdamParams {
len,
lr,
beta1,
beta2,
eps,
step: self.adam_step as f32,
_pad0: 0,
_pad1: 0,
},
},
);
pc.dispatch([len.div_ceil(256), 1, 1]);
}
}
}
self.last_submit_ns = crate::profiler::now_ns();
self.sync_point = Some(self.gpu.submit(&mut self.encoder));
}
fn bind_dispatch(
buffers: &[blade_graphics::Buffer],
dispatch: &crate::compile::Dispatch,
pc: &mut impl blade_graphics::traits::PipelineEncoder,
) {
let buf = |r: BufferRef| buffers[r.0 as usize].at(0);
match dispatch.shader {
ShaderEntry::MatMul => {
pc.bind(
0,
&MatMulData {
matrix_a: buf(dispatch.input_buffers[0]),
matrix_b: buf(dispatch.input_buffers[1]),
matrix_c: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[2],
k: dispatch.params[1],
_pad: 0,
},
},
);
}
ShaderEntry::MatMulAT | ShaderEntry::MatMulBT => {
pc.bind(
0,
&MatMulData {
matrix_a: buf(dispatch.input_buffers[0]),
matrix_b: buf(dispatch.input_buffers[1]),
matrix_c: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: 0,
},
},
);
}
ShaderEntry::FusedMatMulAdd => {
pc.bind(
0,
&FusedMatMulAddData {
matrix_a: buf(dispatch.input_buffers[0]),
matrix_b: buf(dispatch.input_buffers[1]),
matrix_c: buf(dispatch.output_buffer),
src: buf(dispatch.input_buffers[2]), params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[2],
k: dispatch.params[1],
_pad: 0,
},
},
);
}
ShaderEntry::FusedMatMulATAdd | ShaderEntry::FusedMatMulBTAdd => {
pc.bind(
0,
&FusedMatMulAddData {
matrix_a: buf(dispatch.input_buffers[0]),
matrix_b: buf(dispatch.input_buffers[1]),
matrix_c: buf(dispatch.output_buffer),
src: buf(dispatch.input_buffers[2]), params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: 0,
},
},
);
}
ShaderEntry::Relu
| ShaderEntry::Sigmoid
| ShaderEntry::Tanh
| ShaderEntry::Neg
| ShaderEntry::Abs
| ShaderEntry::Log
| ShaderEntry::Recip
| ShaderEntry::Silu => {
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::SwiGLUConcat | ShaderEntry::SwiGLUConcatGrad => {
pc.bind(
0,
&BinaryData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: dispatch.params[1], _pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::Add | ShaderEntry::Mul | ShaderEntry::Greater | ShaderEntry::SwiGLU => {
pc.bind(
0,
&BinaryData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::BiasAdd => {
pc.bind(
0,
&BiasAddData {
src: buf(dispatch.input_buffers[0]),
bias: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: BiasAddParams {
len: dispatch.params[0],
bias_len: dispatch.params[1],
_pad0: 0,
_pad1: 0,
},
},
);
}
ShaderEntry::SgdUpdate => {
pc.bind(
0,
&SgdData {
param: buf(dispatch.input_buffers[0]),
grad: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: SgdParams {
len: dispatch.params[0],
lr: f32::from_bits(dispatch.params[1]),
_pad0: 0,
_pad1: 0,
},
},
);
}
ShaderEntry::SumAll | ShaderEntry::MeanAll => {
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::SumRows => {
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0], _pad0: dispatch.params[1], _pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::Softmax => {
pc.bind(
0,
&SoftmaxData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: SoftmaxParams {
batch: dispatch.params[0],
features: dispatch.params[1],
_pad0: 0,
_pad1: 0,
},
},
);
}
ShaderEntry::CrossEntropyLoss => {
let loss_buf = dispatch
.extra_outputs
.first()
.copied()
.unwrap_or(dispatch.output_buffer);
pc.bind(
0,
&CrossEntropyData {
logits: buf(dispatch.input_buffers[0]),
labels: buf(dispatch.input_buffers[1]),
grad_out: buf(dispatch.output_buffer),
loss_out: buf(loss_buf),
params: SoftmaxParams {
batch: dispatch.params[0],
features: dispatch.params[1],
_pad0: 0,
_pad1: 0,
},
},
);
}
ShaderEntry::BceLoss => {
pc.bind(
0,
&BceData {
pred: buf(dispatch.input_buffers[0]),
labels: buf(dispatch.input_buffers[1]),
grad_out: buf(dispatch.output_buffer),
loss_out: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::Transpose => {
pc.bind(
0,
&TransposeData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: TransposeParams {
m: dispatch.params[0],
n: dispatch.params[1],
_pad0: 0,
_pad1: 0,
},
},
);
}
ShaderEntry::RmsNorm => {
pc.bind(
0,
&RmsNormData {
src: buf(dispatch.input_buffers[0]),
bias: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: BiasAddParams {
len: dispatch.params[0],
bias_len: dispatch.params[1],
_pad0: dispatch.params[2], _pad1: 0,
},
},
);
}
ShaderEntry::Embedding => {
pc.bind(
0,
&EmbeddingData {
indices: buf(dispatch.input_buffers[0]),
src: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: dispatch.params[1],
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::RoPE | ShaderEntry::RoPEGrad => {
pc.bind(
0,
&RoPEData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: RoPEParams {
seq: dispatch.params[0],
dim: dispatch.params[1],
theta_bits: dispatch.params[2],
pos_offset: dispatch.params[3],
head_dim: dispatch.params[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::CausalAttention
| ShaderEntry::CausalAttentionRoPE
| ShaderEntry::FullAttention
| ShaderEntry::CrossAttention => {
let lse_buf = dispatch.extra_outputs[0];
pc.bind(
0,
&AttentionData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
lse: buf(lse_buf),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
ShaderEntry::SlidingWindowAttention => {
pc.bind(
0,
&SlidingWindowAttentionData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
lse: buf(dispatch.extra_outputs[0]),
params: SlidingWindowAttentionParams {
seq: dispatch.params[0],
num_heads: dispatch.params[1],
num_kv_heads: dispatch.params[2],
head_dim: dispatch.params[3],
window_size: dispatch.params[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::Gelu => {
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::LayerNorm => {
pc.bind(
0,
&LayerNormData {
src: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
ShaderEntry::MultiHeadAttn => {
pc.bind(
0,
&MultiHeadAttnData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
lse: buf(dispatch.extra_outputs[0]),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
ShaderEntry::MultiHeadAttnGradQ
| ShaderEntry::MultiHeadAttnGradK
| ShaderEntry::MultiHeadAttnGradV => {
pc.bind(
0,
&MultiHeadAttnGradData {
d_out: buf(dispatch.input_buffers[0]),
src_a: buf(dispatch.input_buffers[1]),
src_b: buf(dispatch.input_buffers[2]),
bias: buf(dispatch.input_buffers[3]),
lse: buf(dispatch.input_buffers[4]),
fwd_dst: buf(dispatch.input_buffers[5]),
dst: buf(dispatch.output_buffer),
params: AttentionGradParams {
q_seq: dispatch.params[0],
kv_seq: dispatch.params[1],
packed_heads: dispatch.params[2],
head_dim: dispatch.params[3],
window_size: dispatch.params[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::SwiGLUGradGate => {
pc.bind(
0,
&TernaryData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), src_c: buf(dispatch.input_buffers[2]), dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::SwiGLUGradUp => {
pc.bind(
0,
&BinaryData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::SiluGrad => {
pc.bind(
0,
&BinaryData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::RmsNormGradW | ShaderEntry::RmsNormGradX => {
pc.bind(
0,
&FourBufData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), bias: buf(dispatch.input_buffers[2]), dst: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0], n: dispatch.params[1], k: dispatch.params[2], _pad: dispatch.params[3],
},
},
);
}
ShaderEntry::LayerNormGradWB | ShaderEntry::LayerNormGradX => {
pc.bind(
0,
&FourBufData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), bias: buf(dispatch.input_buffers[2]), dst: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0], n: dispatch.params[1], k: dispatch.params[2], _pad: dispatch.params[3],
},
},
);
}
ShaderEntry::FusedRmsNormMatMul => {
{
pc.bind(
0,
&FourBufData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[2]), bias: buf(dispatch.input_buffers[1]), dst: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
}
ShaderEntry::RmsNormRsqrt => {
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: dispatch.params[0],
_pad0: dispatch.params[1],
_pad1: dispatch.params[2],
_pad2: dispatch.params[3],
},
},
);
}
ShaderEntry::AdamUpdate => {
unreachable!("AdamUpdate is dispatched via adam_step/set_adam, not bind_dispatch");
}
ShaderEntry::ScatterAdd => {
pc.bind(
0,
&ScatterAddData {
indices: buf(dispatch.input_buffers[0]),
src: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: ScatterAddParams {
total: dispatch.params[0],
seq_len: dispatch.params[1],
embed_dim: dispatch.params[2],
_pad: 0,
},
},
);
}
ShaderEntry::GroupNorm | ShaderEntry::GroupNormSilu => {
let p = &dispatch.params;
pc.bind(
0,
&GroupNormData {
src: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
params: GroupNormParams {
batch: p[0],
channels: p[1],
spatial: p[2],
num_groups: p[3],
eps_bits: p[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::GroupNormGradInput => {
let p = &dispatch.params;
pc.bind(
0,
&GroupNormGradInputData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[2]),
dst: buf(dispatch.output_buffer),
params: GroupNormParams {
batch: p[0],
channels: p[1],
spatial: p[2],
num_groups: p[3],
eps_bits: p[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::GroupNormGradWeightBias => {
let p = &dispatch.params;
pc.bind(
0,
&GroupNormGradWeightBiasData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
bias: buf(dispatch.input_buffers[1]), dst: buf(dispatch.output_buffer),
params: GroupNormParams {
batch: p[0],
channels: p[1],
spatial: p[2],
num_groups: p[3],
eps_bits: p[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::Concat => {
let p = &dispatch.params;
pc.bind(
0,
&BinaryData {
src_a: buf(dispatch.input_buffers[0]),
src_b: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: p[0],
_pad0: p[1],
_pad1: p[2],
_pad2: p[3],
},
},
);
}
ShaderEntry::SplitA | ShaderEntry::SplitB => {
let p = &dispatch.params;
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: p[0],
_pad0: p[1],
_pad1: p[2],
_pad2: p[3],
},
},
);
}
ShaderEntry::Upsample2x | ShaderEntry::Upsample2xGrad => {
let p = &dispatch.params;
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: p[0],
_pad0: p[1],
_pad1: p[2],
_pad2: p[3],
},
},
);
}
ShaderEntry::Conv2d | ShaderEntry::Conv2dGemm | ShaderEntry::Conv2dGemmSmall => {
let p = &dispatch.params;
pc.bind(
0,
&Conv2dData {
src: buf(dispatch.input_buffers[0]),
weight: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: Conv2dParams {
batch: p[0],
in_channels: p[1],
in_h: p[2],
in_w: p[3],
out_channels: p[4],
kernel_h: p[5],
kernel_w: p[6],
stride: p[7],
padding_h: p[8],
out_h: p[9],
out_w: p[10],
padding_w: p[11],
},
},
);
}
ShaderEntry::Conv2dGradInput
| ShaderEntry::Conv2dGradInputGemm
| ShaderEntry::Conv2dGradInputGemmSmall
| ShaderEntry::Conv2dGradInputGemmCoop => {
let p = &dispatch.params;
pc.bind(
0,
&Conv2dGradInputData {
grad_out: buf(dispatch.input_buffers[0]),
weight: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: Conv2dParams {
batch: p[0],
in_channels: p[1],
in_h: p[2],
in_w: p[3],
out_channels: p[4],
kernel_h: p[5],
kernel_w: p[6],
stride: p[7],
padding_h: p[8],
out_h: p[9],
out_w: p[10],
padding_w: p[11],
},
},
);
}
ShaderEntry::Conv2dGradWeight
| ShaderEntry::Conv2dGradWeightGemm
| ShaderEntry::Conv2dGradWeightGemmSmall => {
let p = &dispatch.params;
pc.bind(
0,
&Conv2dGradWeightData {
grad_out: buf(dispatch.input_buffers[0]),
src: buf(dispatch.input_buffers[1]),
dst: buf(dispatch.output_buffer),
params: Conv2dParams {
batch: p[0],
in_channels: p[1],
in_h: p[2],
in_w: p[3],
out_channels: p[4],
kernel_h: p[5],
kernel_w: p[6],
stride: p[7],
padding_h: p[8],
out_h: p[9],
out_w: p[10],
padding_w: p[11],
},
},
);
}
ShaderEntry::RoPEDynamic => {
pc.bind(
0,
&RoPEDynamicData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
pos_offset_buf: buf(dispatch.input_buffers[1]),
params: RoPEParams {
seq: dispatch.params[0],
dim: dispatch.params[1],
theta_bits: dispatch.params[2],
pos_offset: 0,
head_dim: dispatch.params[4],
_pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::CacheWrite => {
pc.bind(
0,
&CacheWriteData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
kv_pos_buf: buf(dispatch.input_buffers[2]),
params: UnaryParams {
len: dispatch.params[0], _pad0: 0,
_pad1: 0,
_pad2: 0,
},
},
);
}
ShaderEntry::CachedAttention => {
pc.bind(
0,
&CachedAttentionData {
src_a: buf(dispatch.input_buffers[0]), src_b: buf(dispatch.input_buffers[1]), bias: buf(dispatch.input_buffers[2]), kv_pos_buf: buf(dispatch.input_buffers[3]), dst: buf(dispatch.output_buffer),
params: MatMulParams {
m: dispatch.params[0],
n: dispatch.params[1],
k: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
ShaderEntry::MaxPool2d => {
pc.bind(
0,
&MaxPool2dData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: MaxPool2dParams {
batch: dispatch.params[0],
channels: dispatch.params[1],
in_h: dispatch.params[2],
in_w: dispatch.params[3],
kernel_h: dispatch.params[4],
kernel_w: dispatch.params[5],
stride: dispatch.params[6],
padding: dispatch.params[7],
out_h: dispatch.params[8],
out_w: dispatch.params[9],
_pad0: dispatch.params[10],
_pad1: dispatch.params[11],
},
},
);
}
ShaderEntry::GlobalAvgPool => {
pc.bind(
0,
&GlobalAvgPoolData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: GlobalAvgPoolParams {
channels: dispatch.params[0],
spatial: dispatch.params[1],
total_out: dispatch.params[2],
_pad: dispatch.params[3],
},
},
);
}
ShaderEntry::GlobalAvgPoolGrad => {
let p = &dispatch.params;
pc.bind(
0,
&UnaryData {
src: buf(dispatch.input_buffers[0]),
dst: buf(dispatch.output_buffer),
params: UnaryParams {
len: p[0],
_pad0: p[1],
_pad1: p[2],
_pad2: p[3],
},
},
);
}
}
}
fn drain_gpu_timings(&self) {
let timings = self.encoder.timings();
if !timings.is_empty() {
crate::profiler::record_gpu_passes(self.last_submit_ns, timings);
}
}
pub fn sgd_step(&mut self, learning_rate: f32) {
let _span = tracing::info_span!("sgd_step").entered();
self.wait();
self.encoder.start();
self.drain_gpu_timings();
let pipeline = &self.pipelines.map[&ShaderEntry::SgdUpdate];
let mut pass = self.encoder.compute("sgd_update");
for &(param_buf, grad_buf) in &self.plan.param_grad_pairs {
let len = (self.plan.buffers[param_buf.0 as usize] / 4) as u32;
let mut pc = pass.with(pipeline);
pc.bind(
0,
&SgdData {
param: self.buffers[param_buf.0 as usize].at(0),
grad: self.buffers[grad_buf.0 as usize].at(0),
dst: self.buffers[param_buf.0 as usize].at(0),
params: SgdParams {
len,
lr: learning_rate,
_pad0: 0,
_pad1: 0,
},
},
);
pc.dispatch([len.div_ceil(256), 1, 1]);
}
drop(pass);
self.last_submit_ns = crate::profiler::now_ns();
self.sync_point = Some(self.gpu.submit(&mut self.encoder));
}
pub fn set_learning_rate(&mut self, lr: f32) {
self.pending_lr = Some(lr);
}
pub fn sgd_step_cpu(&mut self, learning_rate: f32) {
let _span = tracing::info_span!("sgd_step_cpu").entered();
self.wait();
for &(param_buf, grad_buf) in &self.plan.param_grad_pairs {
let size = self.plan.buffers[param_buf.0 as usize] / 4;
let param = &self.buffers[param_buf.0 as usize];
let grad = &self.buffers[grad_buf.0 as usize];
unsafe {
let p = param.data() as *mut f32;
let g = grad.data() as *const f32;
for i in 0..size {
*p.add(i) -= learning_rate * *g.add(i);
}
}
}
}
pub fn adam_step(&mut self, lr: f32, beta1: f32, beta2: f32, eps: f32) {
let _span = tracing::info_span!("adam_step").entered();
self.adam_step += 1;
self.wait();
self.encoder.start();
self.drain_gpu_timings();
let pipeline = &self.pipelines.map[&ShaderEntry::AdamUpdate];
let mut pass = self.encoder.compute("adam_update");
for (idx, &(param_buf, grad_buf)) in self.plan.param_grad_pairs.iter().enumerate() {
let len = (self.plan.buffers[param_buf.0 as usize] / 4) as u32;
let (ref m_buf, ref v_buf) = self.adam_state[idx];
let mut pc = pass.with(pipeline);
pc.bind(
0,
&AdamData {
param: self.buffers[param_buf.0 as usize].at(0),
grad: self.buffers[grad_buf.0 as usize].at(0),
m: m_buf.at(0),
v: v_buf.at(0),
params: AdamParams {
len,
lr,
beta1,
beta2,
eps,
step: self.adam_step as f32,
_pad0: 0,
_pad1: 0,
},
},
);
pc.dispatch([len.div_ceil(256), 1, 1]);
}
drop(pass);
self.last_submit_ns = crate::profiler::now_ns();
self.sync_point = Some(self.gpu.submit(&mut self.encoder));
}
pub fn set_adam(&mut self, lr: f32, beta1: f32, beta2: f32, eps: f32) {
self.pending_adam = Some((lr, beta1, beta2, eps));
}
pub fn memory_summary(&self) -> MemorySummary {
let total: usize = self.plan.buffers.iter().sum();
let largest = self.plan.buffers.iter().copied().max().unwrap_or(0);
let adam_bytes: usize = self
.plan
.param_grad_pairs
.iter()
.map(|&(p, _)| self.plan.buffers[p.0 as usize] * 2)
.sum();
MemorySummary {
total_buffer_bytes: total,
adam_state_bytes: adam_bytes,
num_buffers: self.plan.buffers.len(),
largest_buffer_bytes: largest,
}
}
pub fn plan(&self) -> &ExecutionPlan {
&self.plan
}
pub fn num_groups(&self) -> usize {
self.groups.len()
}
pub fn device_information(&self) -> &blade_graphics::DeviceInformation {
self.gpu.device_information()
}
#[allow(clippy::pattern_type_mismatch)]
pub fn save_checkpoint(&mut self, path: &std::path::Path) -> std::io::Result<()> {
use safetensors::tensor::{Dtype, TensorView};
self.wait();
let mut owned_data: Vec<(String, Vec<u8>)> = Vec::new();
for (name, buf_ref) in &self.plan.param_buffers {
let byte_len = self.plan.buffers[buf_ref.0 as usize];
let mut data = vec![0u8; byte_len];
unsafe {
let ptr = self.buffers[buf_ref.0 as usize].data() as *const u8;
std::ptr::copy_nonoverlapping(ptr, data.as_mut_ptr(), byte_len);
}
owned_data.push((name.clone(), data));
}
for (idx, &(param_buf, _)) in self.plan.param_grad_pairs.iter().enumerate() {
if idx >= self.adam_state.len() {
break;
}
let name = self
.plan
.param_buffers
.iter()
.find(|(_, br)| *br == param_buf)
.map(|(n, _)| n.clone())
.unwrap_or_else(|| format!("param_{}", param_buf.0));
let byte_len = self.plan.buffers[param_buf.0 as usize];
for (suffix, buf) in [
("adam_m", &self.adam_state[idx].0),
("adam_v", &self.adam_state[idx].1),
] {
let key = format!("{suffix}.{name}");
let mut data = vec![0u8; byte_len];
unsafe {
let ptr = buf.data() as *const u8;
std::ptr::copy_nonoverlapping(ptr, data.as_mut_ptr(), byte_len);
}
owned_data.push((key, data));
}
}
let views: Vec<(String, TensorView<'_>)> = owned_data
.iter()
.map(|(name, data)| {
let float_len = data.len() / 4;
(
name.clone(),
TensorView::new(Dtype::F32, vec![float_len], data).expect("tensor view"),
)
})
.collect();
let mut metadata = HashMap::new();
metadata.insert("adam_step".to_string(), self.adam_step.to_string());
let buf = safetensors::tensor::serialize(views, &Some(metadata))
.map_err(|e| std::io::Error::other(e.to_string()))?;
std::fs::write(path, buf)
}
#[allow(clippy::pattern_type_mismatch)]
pub fn load_checkpoint(&mut self, path: &std::path::Path) -> std::io::Result<()> {
let file_data = std::fs::read(path)?;
let (header_size, metadata) = safetensors::SafeTensors::read_metadata(&file_data)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
let st = safetensors::SafeTensors::deserialize(&file_data)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
let _ = header_size;
for (name, buf_ref) in &self.plan.param_buffers {
if let Ok(tensor) = st.tensor(name) {
self.upload_buffer(*buf_ref, tensor.data());
} else {
log::warn!("checkpoint missing parameter: {name}");
}
}
for (idx, &(param_buf, _)) in self.plan.param_grad_pairs.iter().enumerate() {
if idx >= self.adam_state.len() {
break;
}
let name = self
.plan
.param_buffers
.iter()
.find(|(_, br)| *br == param_buf)
.map(|(n, _)| n.as_str())
.unwrap_or("");
for (suffix, buf) in [
("adam_m", &self.adam_state[idx].0),
("adam_v", &self.adam_state[idx].1),
] {
let key = format!("{suffix}.{name}");
if let Ok(tensor) = st.tensor(&key) {
unsafe {
let ptr = buf.data();
std::ptr::copy_nonoverlapping(
tensor.data().as_ptr(),
ptr,
tensor.data().len(),
);
}
}
}
}
if let Some(ref meta) = *metadata.metadata() {
if let Some(step_str) = meta.get("adam_step") {
self.adam_step = step_str.parse::<u32>().unwrap_or(0);
}
}
Ok(())
}
}
impl Drop for Session {
fn drop(&mut self) {
self.wait();
self.gpu.destroy_command_encoder(&mut self.encoder);
for (_, pipeline) in self.pipelines.map.iter_mut() {
self.gpu.destroy_compute_pipeline(pipeline);
}
for (_, pipeline) in self.pipelines.coop_map.iter_mut() {
self.gpu.destroy_compute_pipeline(pipeline);
}
for (_, pipeline) in self.pipelines.small_map.iter_mut() {
self.gpu.destroy_compute_pipeline(pipeline);
}
for buffer in &self.buffers {
self.gpu.destroy_buffer(*buffer);
}
for &(m_buf, v_buf) in &self.adam_state {
self.gpu.destroy_buffer(m_buf);
self.gpu.destroy_buffer(v_buf);
}
}
}