use crate::buffer::Arena;
use crate::device::vulkan_device;
use crate::kernels::kernels;
use ash::vk;
use rlx_compile::memory::{BufferSlot, MemoryPlan};
use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp, RopeStyle};
use rlx_ir::{DType, Graph, NodeId, Op, RngOptions};
use std::collections::{HashMap, HashSet};
pub const SUPPORTED_OPS: &[rlx_ir::OpKind] = {
use rlx_ir::OpKind::*;
&[
Input,
Param,
Constant,
Cast,
StopGradient,
Reshape, Binary,
Compare,
Where,
Activation, MatMul,
Reduce,
Softmax, LayerNorm,
RmsNorm,
LayerNorm2d, Rope,
Attention, FusedAttentionBlock,
Transpose,
Narrow,
Concat,
Expand,
Gather,
Cumsum,
Reverse, ArgMax,
ArgMin,
Pool,
ResizeNearest2x,
Conv, GroupedMatMul, SelectiveScan, Im2Col,
ScatterAdd,
TopK, Lstm,
Gru,
Rnn,
Mamba2,
GatedDeltaNet,
ConvTranspose2d,
Fft,
DequantMatMul,
DequantGroupedMatMul,
DequantMoEWeights, RngNormal,
RngUniform,
Sample, ]
};
fn is_host_fallback(op: &Op) -> bool {
matches!(
op,
Op::Lstm { .. }
| Op::Gru { .. }
| Op::Rnn { .. }
| Op::Mamba2 { .. }
| Op::GatedDeltaNet { .. }
| Op::ConvTranspose2d { .. }
| Op::Fft { .. }
| Op::DequantGroupedMatMul { .. }
| Op::DequantMoEWeights { .. }
| Op::RngNormal { .. }
| Op::RngUniform { .. }
| Op::Sample { .. }
)
}
#[derive(Clone)]
enum Step {
Gpu {
kernel: &'static str,
push: Vec<u8>,
groups: (u32, u32, u32),
},
Host {
op: Op,
out: NodeId,
out_shape: rlx_ir::Shape,
inputs: Vec<NodeId>,
},
}
enum Segment {
Gpu(vk::CommandBuffer),
Host {
op: Op,
out: NodeId,
out_shape: rlx_ir::Shape,
inputs: Vec<NodeId>,
},
}
pub struct VulkanExecutable {
graph: Graph,
arena: Arena,
schedule: Vec<Step>,
segments: Vec<Segment>,
fence: vk::Fence,
cached: bool,
input_ids: HashMap<String, NodeId>,
param_ids: HashMap<String, NodeId>,
output_ids: Vec<NodeId>,
output_dtypes: Vec<DType>,
desc_pool: vk::DescriptorPool,
desc_set: vk::DescriptorSet,
rng: RngOptions,
active_extent: Option<(usize, usize)>,
gpu_handles: HashMap<String, Vec<f32>>,
gpu_handle_feeds: HashMap<String, usize>,
gpu_handle_resident: HashSet<String>,
kv_row_feeds: HashMap<String, usize>,
}
unsafe impl Send for VulkanExecutable {}
fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
let mut schedule = Vec::with_capacity(graph.nodes().len());
let mut cursor = 0usize;
for node in graph.nodes() {
if matches!(
node.op,
Op::Reshape { .. } | Op::Cast { .. } | Op::StopGradient
) {
if let Some(in_id) = node.inputs.first()
&& let Some(slot) = assignments.get(in_id)
{
let aliased = slot.clone();
assignments.insert(node.id, aliased);
schedule.push(node.id);
continue;
}
}
let elems = node.shape.num_elements().unwrap_or(0);
let elem_size = match node.shape.dtype() {
DType::U8 | DType::I8 => 1,
_ => 4,
};
let bytes = (elems * elem_size).max(4);
let aligned = bytes.div_ceil(align) * align;
assignments.insert(
node.id,
BufferSlot {
offset: cursor,
size: aligned,
},
);
schedule.push(node.id);
cursor += aligned;
}
MemoryPlan {
arena_size: cursor.max(align),
assignments,
schedule,
}
}
fn dims(graph: &Graph, id: NodeId) -> Vec<usize> {
graph
.node(id)
.shape
.dims()
.iter()
.map(|d| match d {
rlx_ir::Dim::Static(s) => *s,
_ => 0,
})
.collect()
}
fn numel(d: &[usize]) -> usize {
d.iter()
.product::<usize>()
.max(if d.is_empty() { 1 } else { 0 })
}
fn contig_strides(d: &[usize]) -> Vec<usize> {
let mut s = vec![1usize; d.len()];
for i in (0..d.len().saturating_sub(1)).rev() {
s[i] = s[i + 1] * d[i + 1];
}
s
}
fn norm_axis(axis: i32, rank: usize) -> usize {
if axis < 0 {
(rank as i32 + axis).max(0) as usize
} else {
(axis as usize).min(rank.saturating_sub(1))
}
}
#[derive(Default)]
struct Push {
words: Vec<u32>,
}
impl Push {
fn u(mut self, v: u32) -> Self {
self.words.push(v);
self
}
fn f(mut self, v: f32) -> Self {
self.words.push(v.to_bits());
self
}
fn us(mut self, vs: &[u32]) -> Self {
self.words.extend_from_slice(vs);
self
}
fn bytes(self) -> Vec<u8> {
let mut b = Vec::with_capacity(self.words.len() * 4);
for w in self.words {
b.extend_from_slice(&w.to_le_bytes());
}
b
}
}
fn ceil_div(n: usize, d: u32) -> u32 {
(n as u64).div_ceil(d as u64) as u32
}
fn coop_eligible(m: usize, _k: usize, n: usize) -> bool {
m.is_multiple_of(16) && n.is_multiple_of(16)
}
fn matmul_kernel(m: usize, k: usize, n: usize) -> &'static str {
let dev = vulkan_device();
let portability = dev.map(|d| d.portability).unwrap_or(false);
let coop = dev.map(|d| d.coop_matmul).unwrap_or(false);
match std::env::var("RLX_VULKAN_MATMUL").ok().as_deref() {
Some("scalar") => "matmul",
Some("tiled") => "matmul_tiled",
Some("coop") if coop && coop_eligible(m, k, n) => "matmul_coop",
Some("coop") => "matmul_tiled",
_ if portability => "matmul",
_ => "matmul_tiled",
}
}
fn groups1d(n: usize, local: u32) -> (u32, u32, u32) {
(ceil_div(n, local).max(1), 1, 1)
}
fn act_id(a: Activation) -> u32 {
match a {
Activation::Gelu => 0,
Activation::GeluApprox => 1,
Activation::Silu => 2,
Activation::Relu => 3,
Activation::Sigmoid => 4,
Activation::Tanh => 5,
Activation::Exp => 6,
Activation::Log => 7,
Activation::Sqrt => 8,
Activation::Rsqrt => 9,
Activation::Neg => 10,
Activation::Abs => 11,
Activation::Sin => 12,
Activation::Cos => 13,
Activation::Tan => 14,
Activation::Atan => 15,
Activation::Round => 16,
}
}
fn binop_id(op: BinaryOp) -> u32 {
match op {
BinaryOp::Add => 0,
BinaryOp::Sub => 1,
BinaryOp::Mul => 2,
BinaryOp::Div => 3,
BinaryOp::Max => 4,
BinaryOp::Min => 5,
BinaryOp::Pow => 6,
}
}
fn cmp_id(op: CmpOp) -> u32 {
match op {
CmpOp::Eq => 0,
CmpOp::Ne => 1,
CmpOp::Lt => 2,
CmpOp::Le => 3,
CmpOp::Gt => 4,
CmpOp::Ge => 5,
}
}
fn reduce_id(op: ReduceOp) -> u32 {
match op {
ReduceOp::Sum => 0,
ReduceOp::Mean => 1,
ReduceOp::Max => 2,
ReduceOp::Min => 3,
ReduceOp::Prod => 4,
}
}
impl VulkanExecutable {
pub fn compile(graph: Graph) -> Self {
Self::compile_rng(graph, RngOptions::default())
}
pub fn compile_rng(graph: Graph, rng: RngOptions) -> Self {
use rlx_opt::pass::Pass as _;
let graph = rlx_opt::LowerControlFlow.run(graph);
let graph = rlx_opt::unfuse::unfuse_attention_block(graph);
let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, SUPPORTED_OPS)
.unwrap_or_else(|errs| panic!("{}", rlx_opt::format_legalize_error("vulkan", &errs)));
let graph = rlx_opt::LegalizeBroadcast.run(graph);
Self::build(graph, rng)
}
fn build(graph: Graph, rng: RngOptions) -> Self {
let dev = vulkan_device().expect("rlx-vulkan: no device");
let kern = kernels().expect("rlx-vulkan: no kernels");
let plan = plan_f32_uniform(&graph, 16);
let arena = Arena::from_plan(&plan);
for node in graph.nodes() {
if let Op::Constant { data } = &node.op
&& arena.has(node.id)
&& !data.is_empty()
{
let f = widen_const_to_f32(data, node.shape.dtype());
arena.write_f32(node.id, &f);
}
}
let mut input_ids = HashMap::new();
let mut param_ids = HashMap::new();
for node in graph.nodes() {
match &node.op {
Op::Input { name } => {
input_ids.insert(name.clone(), node.id);
}
Op::Param { name } => {
param_ids.insert(name.clone(), node.id);
}
_ => {}
}
}
let output_ids = graph.outputs.clone();
let output_dtypes = output_ids
.iter()
.map(|&id| graph.node(id).shape.dtype())
.collect();
let (schedule, deps) = build_schedule(&graph, &arena);
let pool_sizes = [vk::DescriptorPoolSize::default()
.ty(vk::DescriptorType::STORAGE_BUFFER)
.descriptor_count(1)];
let desc_pool = unsafe {
dev.device.create_descriptor_pool(
&vk::DescriptorPoolCreateInfo::default()
.max_sets(1)
.pool_sizes(&pool_sizes),
None,
)
}
.expect("vk descriptor_pool");
let set_layouts = [kern.dsl];
let desc_set = unsafe {
dev.device.allocate_descriptor_sets(
&vk::DescriptorSetAllocateInfo::default()
.descriptor_pool(desc_pool)
.set_layouts(&set_layouts),
)
}
.expect("vk descriptor_set")[0];
let buf_info = [vk::DescriptorBufferInfo::default()
.buffer(arena.buffer)
.offset(0)
.range(vk::WHOLE_SIZE)];
let write = vk::WriteDescriptorSet::default()
.dst_set(desc_set)
.dst_binding(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&buf_info);
unsafe { dev.device.update_descriptor_sets(&[write], &[]) };
let cached = std::env::var("RLX_VULKAN_NOCACHE").as_deref() != Ok("1");
let (segments, fence) = if cached {
let segs = record_segments(dev, kern, desc_set, &schedule, &deps);
(segs, dev.create_reusable_fence())
} else {
(Vec::new(), vk::Fence::null())
};
if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
let gpu = schedule
.iter()
.filter(|s| matches!(s, Step::Gpu { .. }))
.count();
let host = schedule.len() - gpu;
let gpu_segs = segments
.iter()
.filter(|s| matches!(s, Segment::Gpu(_)))
.count();
let mut hist: HashMap<&'static str, usize> = HashMap::new();
for s in &schedule {
if let Step::Gpu { kernel, .. } = s {
*hist.entry(kernel).or_default() += 1;
}
}
let mut by_count: Vec<_> = hist.into_iter().collect();
by_count.sort_by_key(|&(_, c)| std::cmp::Reverse(c));
eprintln!(
"[rlx-vulkan] schedule: {gpu} gpu dispatches, {host} host ops; \
cached={cached} ({gpu_segs} gpu submit(s)/run)"
);
eprintln!("[rlx-vulkan] dispatch histogram: {by_count:?}");
}
Self {
graph,
arena,
schedule,
segments,
fence,
cached,
input_ids,
param_ids,
output_ids,
output_dtypes,
desc_pool,
desc_set,
rng,
active_extent: None,
gpu_handles: HashMap::new(),
gpu_handle_feeds: HashMap::new(),
gpu_handle_resident: HashSet::new(),
kv_row_feeds: HashMap::new(),
}
}
pub fn set_param(&mut self, name: &str, data: &[f32]) {
if let Some(&id) = self.param_ids.get(name) {
self.arena.write_f32(id, data);
}
}
pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
if let Some(&id) = self.param_ids.get(name) {
self.arena.write_bytes(id, data);
}
}
pub fn output_dtypes(&self) -> Vec<DType> {
self.output_dtypes.clone()
}
pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
self.active_extent = extent;
}
pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
let Some(&id) = self.input_ids.get(name) else {
return false;
};
self.gpu_handle_resident.remove(name);
self.arena.write_f32(id, data);
self.gpu_handles.insert(name.to_string(), data.to_vec());
true
}
pub fn has_gpu_handle(&self, name: &str) -> bool {
self.gpu_handles.contains_key(name)
}
pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
self.gpu_handle_feeds
.insert(handle_name.to_string(), output_index);
}
pub fn register_kv_row_feed(&mut self, handle_name: &str, output_index: usize) {
self.kv_row_feeds
.insert(handle_name.to_string(), output_index);
}
pub fn feed_kv_row(&mut self, src_row: usize, dst_row: usize, row_elems: usize) {
let feeds: Vec<(String, usize)> = self
.kv_row_feeds
.iter()
.map(|(k, &v)| (k.clone(), v))
.collect();
for (name, out_idx) in feeds {
let Some(&out_id) = self.output_ids.get(out_idx) else {
continue;
};
let Some(&in_id) = self.input_ids.get(name.as_str()) else {
continue;
};
if in_id != out_id {
self.arena.copy_node_f32_range(
in_id,
dst_row * row_elems,
out_id,
src_row * row_elems,
row_elems,
);
}
self.gpu_handle_resident.insert(name.clone());
self.gpu_handles.insert(name.clone(), Vec::new());
}
}
pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
if let Some(&out_idx) = self.gpu_handle_feeds.get(name)
&& let Some(&out_id) = self.output_ids.get(out_idx)
{
let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
return Some(self.arena.read_f32(out_id, n));
}
if self.gpu_handle_resident.contains(name)
&& let Some(&id) = self.input_ids.get(name)
{
let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
return Some(self.arena.read_f32(id, n));
}
self.gpu_handles.get(name).cloned()
}
pub fn read_output_row(
&self,
out_idx: usize,
row: usize,
row_inner: usize,
) -> Option<Vec<f32>> {
let id = *self.output_ids.get(out_idx)?;
let base = self.arena.elem_offset(id) as usize + row * row_inner;
Some(self.arena.read_f32_at_elem(base, row_inner))
}
fn propagate_gpu_handle_feeds_in_arena(&mut self) {
let extent = self.active_extent;
let feeds: Vec<(String, usize)> = self
.gpu_handle_feeds
.iter()
.map(|(k, &v)| (k.clone(), v))
.collect();
for (name, out_idx) in feeds {
let Some(&out_id) = self.output_ids.get(out_idx) else {
continue;
};
let Some(&in_id) = self.input_ids.get(name.as_str()) else {
continue;
};
if in_id != out_id {
let out_elems = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
let copy_elems = match extent {
Some((actual, upper)) if upper > 0 => actual * (out_elems / (upper + 1)).max(1),
_ => out_elems,
};
self.arena
.copy_node_f32_prefix(in_id, out_id, copy_elems.min(out_elems));
}
self.gpu_handle_resident.insert(name.clone());
self.gpu_handles.insert(name.clone(), Vec::new());
}
}
fn refresh_gpu_handles_from_outputs(&mut self) {
let feeds: Vec<(String, usize)> = self
.gpu_handle_feeds
.iter()
.map(|(k, &v)| (k.clone(), v))
.collect();
for (name, out_idx) in feeds {
let Some(&out_id) = self.output_ids.get(out_idx) else {
continue;
};
let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
let src = self.arena.read_f32(out_id, n);
self.gpu_handles.insert(name, src);
}
}
pub fn set_rng(&mut self, rng: RngOptions) {
self.rng = rng;
}
pub fn rng(&self) -> RngOptions {
self.rng
}
pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
self.run_read_outputs(inputs, None)
}
pub fn run_read_outputs(
&mut self,
inputs: &[(&str, &[f32])],
read_indices: Option<&[usize]>,
) -> Vec<Vec<f32>> {
for (name, data) in &self.gpu_handles {
if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
continue;
}
if let Some(&id) = self.input_ids.get(name) {
self.arena.write_f32(id, data);
}
}
for &(name, data) in inputs {
if let Some(&id) = self.input_ids.get(name) {
self.arena.write_f32(id, data);
}
}
let dev = vulkan_device().expect("rlx-vulkan: no device");
let kern = kernels().expect("rlx-vulkan: no kernels");
let desc_set = self.desc_set;
let layout = kern.pipeline_layout;
if self.cached {
let nseg = self.segments.len();
for si in 0..nseg {
match &self.segments[si] {
Segment::Gpu(cmd) => {
let cmd = *cmd;
dev.submit_recorded_wait(cmd, self.fence);
}
Segment::Host {
op,
out,
out_shape,
inputs: in_ids,
} => {
let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
.iter()
.map(|&id| {
let sh = self.graph.node(id).shape.clone();
let nn = sh.num_elements().unwrap_or(0);
let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
} else {
crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
};
(sh, buf)
})
.collect();
let result = crate::host::eval(op, out_shape, &in_specs);
self.arena.write_f32(*out, &result);
}
}
}
return self.finish_run(read_indices);
}
let n = self.schedule.len();
let mut i = 0;
while i < n {
let start = i;
while i < n && matches!(self.schedule[i], Step::Gpu { .. }) {
i += 1;
}
if i > start {
let gpu = self.schedule[start..i].to_vec();
dev.submit_and_wait(|cmd| unsafe {
dev.device.cmd_bind_descriptor_sets(
cmd,
vk::PipelineBindPoint::COMPUTE,
layout,
0,
&[desc_set],
&[],
);
let barrier = vk::MemoryBarrier::default()
.src_access_mask(vk::AccessFlags::SHADER_WRITE)
.dst_access_mask(
vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE,
);
for (j, step) in gpu.iter().enumerate() {
if let Step::Gpu {
kernel,
push,
groups,
} = step
{
let pipeline = kern.pipeline(kernel);
dev.device.cmd_bind_pipeline(
cmd,
vk::PipelineBindPoint::COMPUTE,
pipeline,
);
dev.device.cmd_push_constants(
cmd,
layout,
vk::ShaderStageFlags::COMPUTE,
0,
push,
);
dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
if j + 1 < gpu.len() {
dev.device.cmd_pipeline_barrier(
cmd,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::DependencyFlags::empty(),
&[barrier],
&[],
&[],
);
}
}
}
});
}
if i < n {
if let Step::Host {
op,
out,
out_shape,
inputs: in_ids,
} = self.schedule[i].clone()
{
let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
.iter()
.map(|&id| {
let sh = self.graph.node(id).shape.clone();
let nn = sh.num_elements().unwrap_or(0);
let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
} else {
crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
};
(sh, buf)
})
.collect();
let result = crate::host::eval(&op, &out_shape, &in_specs);
self.arena.write_f32(out, &result);
}
i += 1;
}
}
self.finish_run(read_indices)
}
fn finish_run(&mut self, read_indices: Option<&[usize]>) -> Vec<Vec<f32>> {
if !self.gpu_handle_feeds.is_empty() {
self.propagate_gpu_handle_feeds_in_arena();
if read_indices.is_none() {
self.refresh_gpu_handles_from_outputs();
}
}
let want: Vec<usize> = match read_indices {
Some(ix) => ix.to_vec(),
None => (0..self.output_ids.len()).collect(),
};
want.into_iter()
.filter_map(|i| {
let id = *self.output_ids.get(i)?;
let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
Some(self.arena.read_f32(id, n))
})
.collect()
}
pub fn clone_for_cache(&self) -> Self {
let mut twin = Self::build(self.graph.clone(), self.rng);
twin.active_extent = self.active_extent;
self.arena.copy_into(&twin.arena);
twin.gpu_handles = self.gpu_handles.clone();
twin.gpu_handle_feeds = self.gpu_handle_feeds.clone();
twin.gpu_handle_resident = self.gpu_handle_resident.clone();
twin.kv_row_feeds = self.kv_row_feeds.clone();
twin
}
}
impl Drop for VulkanExecutable {
fn drop(&mut self) {
if let Some(dev) = vulkan_device() {
let cmds: Vec<vk::CommandBuffer> = self
.segments
.iter()
.filter_map(|s| match s {
Segment::Gpu(cmd) => Some(*cmd),
Segment::Host { .. } => None,
})
.collect();
if !cmds.is_empty() {
dev.free_cmds(&cmds);
}
if self.fence != vk::Fence::null() {
dev.destroy_fence(self.fence);
}
unsafe {
dev.device.destroy_descriptor_pool(self.desc_pool, None);
}
}
}
}
fn record_segments(
dev: &crate::device::VulkanDevice,
kern: &crate::kernels::Kernels,
desc_set: vk::DescriptorSet,
schedule: &[Step],
deps: &[StepDep],
) -> Vec<Segment> {
let layout = kern.pipeline_layout;
let no_barrier = std::env::var("RLX_VULKAN_NOBARRIER").as_deref() == Ok("1");
let full_barrier = std::env::var("RLX_VULKAN_FULLBARRIER").as_deref() == Ok("1");
let mut segments = Vec::new();
let n = schedule.len();
let mut i = 0;
while i < n {
let start = i;
while i < n && matches!(schedule[i], Step::Gpu { .. }) {
i += 1;
}
if i > start {
let run = &schedule[start..i];
let run_deps = &deps[start..i];
let cmd = dev.alloc_primary_cmd();
unsafe {
dev.device
.begin_command_buffer(cmd, &vk::CommandBufferBeginInfo::default())
.expect("vk begin cmd");
dev.device.cmd_bind_descriptor_sets(
cmd,
vk::PipelineBindPoint::COMPUTE,
layout,
0,
&[desc_set],
&[],
);
let barrier = vk::MemoryBarrier::default()
.src_access_mask(vk::AccessFlags::SHADER_WRITE)
.dst_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE);
let mut wrote: HashSet<u32> = HashSet::new();
let mut read: HashSet<u32> = HashSet::new();
for (j, step) in run.iter().enumerate() {
if let Step::Gpu {
kernel,
push,
groups,
} = step
{
let dep = &run_deps[j];
let hazard = !wrote.is_empty()
&& (dep.reads.iter().any(|r| wrote.contains(r))
|| wrote.contains(&dep.write)
|| read.contains(&dep.write));
let emit_barrier = j > 0 && !no_barrier && (full_barrier || hazard);
if emit_barrier {
dev.device.cmd_pipeline_barrier(
cmd,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::PipelineStageFlags::COMPUTE_SHADER,
vk::DependencyFlags::empty(),
&[barrier],
&[],
&[],
);
wrote.clear();
read.clear();
}
let pipeline = kern.pipeline(kernel);
dev.device
.cmd_bind_pipeline(cmd, vk::PipelineBindPoint::COMPUTE, pipeline);
dev.device.cmd_push_constants(
cmd,
layout,
vk::ShaderStageFlags::COMPUTE,
0,
push,
);
dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
wrote.insert(dep.write);
for &r in &dep.reads {
read.insert(r);
}
}
}
dev.device.end_command_buffer(cmd).expect("vk end cmd");
}
segments.push(Segment::Gpu(cmd));
}
if i < n {
if let Step::Host {
op,
out,
out_shape,
inputs,
} = &schedule[i]
{
segments.push(Segment::Host {
op: op.clone(),
out: *out,
out_shape: out_shape.clone(),
inputs: inputs.clone(),
});
}
i += 1;
}
}
segments
}
fn widen_const_to_f32(data: &[u8], dt: DType) -> Vec<f32> {
match dt {
DType::F32 => data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect(),
DType::F16 => data
.chunks_exact(2)
.map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
.collect(),
DType::BF16 => data
.chunks_exact(2)
.map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
.collect(),
DType::F64 => data
.chunks_exact(8)
.map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
.collect(),
DType::I64 => data
.chunks_exact(8)
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
.collect(),
DType::I32 | DType::U32 => data
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
.collect(),
DType::I16 => data
.chunks_exact(2)
.map(|c| i16::from_le_bytes([c[0], c[1]]) as f32)
.collect(),
DType::I8 => data.iter().map(|&b| b as i8 as f32).collect(),
DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
DType::C64 => data
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect(),
}
}
#[derive(Clone, Default)]
struct StepDep {
reads: Vec<u32>,
write: u32,
}
fn build_schedule(graph: &Graph, arena: &Arena) -> (Vec<Step>, Vec<StepDep>) {
let mut steps = Vec::new();
let mut deps: Vec<StepDep> = Vec::new();
for node in graph.nodes() {
let off = |id: NodeId| arena.elem_offset(id);
let out = node.id;
let before = steps.len();
match &node.op {
Op::Input { .. }
| Op::Param { .. }
| Op::Constant { .. }
| Op::Reshape { .. }
| Op::Cast { .. }
| Op::StopGradient => {}
Op::Binary(op) => {
let a = node.inputs[0];
let b = node.inputs[1];
let n = numel(&dims(graph, out));
let an = numel(&dims(graph, a));
let bn = numel(&dims(graph, b));
let push = Push::default()
.u(n as u32)
.u(off(a))
.u(off(b))
.u(off(out))
.u(if an == n { 0 } else { an as u32 })
.u(if bn == n { 0 } else { bn as u32 })
.u(binop_id(*op))
.bytes();
steps.push(Step::Gpu {
kernel: "binary",
push,
groups: groups1d(n, 256),
});
}
Op::Compare(op) => {
let a = node.inputs[0];
let b = node.inputs[1];
let n = numel(&dims(graph, out));
let an = numel(&dims(graph, a));
let bn = numel(&dims(graph, b));
let push = Push::default()
.u(n as u32)
.u(off(a))
.u(off(b))
.u(off(out))
.u(if an == n { 0 } else { an as u32 })
.u(if bn == n { 0 } else { bn as u32 })
.u(cmp_id(*op))
.bytes();
steps.push(Step::Gpu {
kernel: "compare",
push,
groups: groups1d(n, 256),
});
}
Op::Where => {
let c = node.inputs[0];
let a = node.inputs[1];
let b = node.inputs[2];
let n = numel(&dims(graph, out));
let cn = numel(&dims(graph, c));
let an = numel(&dims(graph, a));
let bn = numel(&dims(graph, b));
let push = Push::default()
.u(n as u32)
.u(off(c))
.u(off(a))
.u(off(b))
.u(off(out))
.u(if cn == n { 0 } else { cn as u32 })
.u(if an == n { 0 } else { an as u32 })
.u(if bn == n { 0 } else { bn as u32 })
.bytes();
steps.push(Step::Gpu {
kernel: "where",
push,
groups: groups1d(n, 256),
});
}
Op::Activation(act) => {
let x = node.inputs[0];
let n = numel(&dims(graph, out));
let push = Push::default()
.u(n as u32)
.u(off(x))
.u(off(out))
.u(act_id(*act))
.bytes();
steps.push(Step::Gpu {
kernel: "unary",
push,
groups: groups1d(n, 256),
});
}
Op::MatMul => {
let a = node.inputs[0];
let b = node.inputs[1];
let ad = dims(graph, a);
let bd = dims(graph, b);
let od = dims(graph, out);
let (m, k) = (ad[ad.len() - 2], ad[ad.len() - 1]);
let n = bd[bd.len() - 1];
let batch = if od.len() > 2 {
numel(&od[..od.len() - 2])
} else {
1
};
let a_batch = if ad.len() > 2 {
numel(&ad[..ad.len() - 2])
} else {
1
};
let b_batch = if bd.len() > 2 {
numel(&bd[..bd.len() - 2])
} else {
1
};
let a_bs = if a_batch <= 1 { 0 } else { m * k };
let b_bs = if b_batch <= 1 { 0 } else { k * n };
let push = Push::default()
.u(m as u32)
.u(k as u32)
.u(n as u32)
.u(off(a))
.u(off(b))
.u(off(out))
.u(batch as u32)
.u(a_bs as u32)
.u(b_bs as u32)
.u((m * n) as u32)
.bytes();
steps.push(Step::Gpu {
kernel: matmul_kernel(m, k, n),
push,
groups: (ceil_div(n, 16), ceil_div(m, 16), batch.max(1) as u32),
});
}
Op::Reduce { op, axes, .. } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let rank = xd.len();
let last = rank.saturating_sub(1);
debug_assert!(
axes.as_slice() == [last] || (rank <= 1),
"rlx-vulkan: non-last-axis reduce should have been lowered"
);
let r = *xd.get(last).unwrap_or(&1);
let outer = numel(&xd) / r.max(1);
let push = Push::default()
.u(outer as u32)
.u(r as u32)
.u(off(x))
.u(off(out))
.u(reduce_id(*op))
.bytes();
steps.push(Step::Gpu {
kernel: "reduce",
push,
groups: groups1d(outer, 256),
});
}
Op::Softmax { axis } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let ax = norm_axis(*axis, xd.len());
let axis_len = xd[ax];
let outer = numel(&xd[..ax]);
let inner = numel(&xd[ax + 1..]);
let push = Push::default()
.u(outer as u32)
.u(axis_len as u32)
.u(inner as u32)
.u(off(x))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "softmax",
push,
groups: groups1d(outer * inner, 256),
});
}
Op::RmsNorm { axis, eps } => {
let x = node.inputs[0];
let gamma = node.inputs[1];
let beta = node.inputs[2];
let xd = dims(graph, x);
let ax = norm_axis(*axis, xd.len());
debug_assert_eq!(ax, xd.len().saturating_sub(1), "rmsnorm expects last axis");
let n = xd[ax];
let rows = numel(&xd) / n.max(1);
let push = Push::default()
.u(rows as u32)
.u(n as u32)
.u(off(x))
.u(off(gamma))
.u(off(beta))
.u(off(out))
.f(*eps)
.bytes();
steps.push(Step::Gpu {
kernel: "rmsnorm",
push,
groups: groups1d(rows, 64),
});
}
Op::LayerNorm { axis, eps } => {
let x = node.inputs[0];
let gamma = node.inputs[1];
let has_beta = node.inputs.len() >= 3;
let beta = if has_beta { node.inputs[2] } else { gamma };
let xd = dims(graph, x);
let ax = norm_axis(*axis, xd.len());
let n = xd[ax];
let rows = numel(&xd) / n.max(1);
let push = Push::default()
.u(rows as u32)
.u(n as u32)
.u(off(x))
.u(off(gamma))
.u(off(beta))
.u(off(out))
.u(if has_beta { 1 } else { 0 })
.f(*eps)
.bytes();
steps.push(Step::Gpu {
kernel: "layernorm",
push,
groups: groups1d(rows, 64),
});
}
Op::Rope {
head_dim,
n_rot,
style,
} => {
let x = node.inputs[0];
let cos = node.inputs[1];
let sin = node.inputs[2];
let xd = dims(graph, x);
let (batch, seq, hidden) = if xd.len() >= 3 {
(xd[0], xd[1], xd[2])
} else {
let total = numel(&xd);
(1, xd[0], total / xd[0].max(1))
};
let hd = *head_dim;
let nh = hidden / hd.max(1);
let tab_half = hd / 2;
let cos_len = numel(&dims(graph, cos));
let cos_rows = cos_len / tab_half.max(1);
let per_token = (cos_rows == batch * seq && cos_rows != seq) as u32;
let style_id = match style {
RopeStyle::NeoX => 0u32,
RopeStyle::GptJ => 1u32,
};
let push = Push::default()
.u(batch as u32)
.u(seq as u32)
.u(hidden as u32)
.u(hd as u32)
.u(*n_rot as u32)
.u(nh as u32)
.u(tab_half as u32)
.u(hidden as u32) .u(per_token)
.u(style_id)
.u(off(x))
.u(off(cos))
.u(off(sin))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "rope",
push,
groups: groups1d(batch * seq * nh, 64),
});
}
Op::Attention {
num_heads,
head_dim,
mask_kind,
score_scale,
..
} => {
let q = node.inputs[0];
let k = node.inputs[1];
let v = node.inputs[2];
let qd = dims(graph, q);
let kd = dims(graph, k);
let nh = *num_heads;
let dh = *head_dim;
let (batch, q_s, k_s, bhsd) = if qd.len() == 4 {
if qd[1] == nh {
(qd[0], qd[2], kd[2], 1u32) } else {
(qd[0], qd[1], kd[1], 0u32) }
} else if qd.len() >= 3 {
(qd[0], qd[1], kd[1], 0u32)
} else {
(1, qd[0], kd[0], 0u32)
};
let hs = (nh * dh) as u32;
let (mask_kind_id, mask_off, window) = match mask_kind {
MaskKind::None => (0u32, 0u32, 0u32),
MaskKind::Causal => (1, 0, 0),
MaskKind::SlidingWindow(w) => (2, 0, *w as u32),
MaskKind::Custom => (3, off(node.inputs[3]), 0),
MaskKind::Bias => (4, off(node.inputs[3]), 0),
};
let scale = score_scale.unwrap_or((dh as f32).powf(-0.5));
let push = Push::default()
.u(batch as u32)
.u(nh as u32)
.u(q_s as u32)
.u(k_s as u32)
.u(dh as u32)
.u(off(q))
.u(off(k))
.u(off(v))
.u(off(out))
.u(hs)
.u(hs)
.u(hs)
.u(bhsd)
.u(mask_kind_id)
.u(mask_off)
.u(window)
.f(scale)
.f(-1.0e30)
.f(0.5)
.bytes();
steps.push(Step::Gpu {
kernel: "attention",
push,
groups: groups1d(batch * nh * q_s, 64),
});
}
Op::Transpose { perm } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let od = dims(graph, out);
let in_str = contig_strides(&xd);
let out_str = contig_strides(&od);
let rank = od.len();
let mut shape = [1u32; 6];
let mut istr = [0u32; 6];
let mut ostr = [0u32; 6];
for ax in 0..rank {
shape[ax] = od[ax] as u32;
istr[ax] = in_str[perm[ax]] as u32;
ostr[ax] = out_str[ax] as u32;
}
let n = numel(&od);
let push = Push::default()
.u(n as u32)
.u(rank as u32)
.u(off(x))
.u(off(out))
.us(&shape)
.us(&istr)
.us(&ostr)
.bytes();
steps.push(Step::Gpu {
kernel: "reindex",
push,
groups: groups1d(n, 256),
});
}
Op::Narrow { axis, start, .. } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let od = dims(graph, out);
let in_str = contig_strides(&xd);
let out_str = contig_strides(&od);
let rank = od.len();
let mut shape = [1u32; 6];
let mut istr = [0u32; 6];
let mut ostr = [0u32; 6];
for ax in 0..rank {
shape[ax] = od[ax] as u32;
istr[ax] = in_str[ax] as u32;
ostr[ax] = out_str[ax] as u32;
}
let in_off = off(x) + (*start * in_str[*axis]) as u32;
let n = numel(&od);
let push = Push::default()
.u(n as u32)
.u(rank as u32)
.u(in_off)
.u(off(out))
.us(&shape)
.us(&istr)
.us(&ostr)
.bytes();
steps.push(Step::Gpu {
kernel: "reindex",
push,
groups: groups1d(n, 256),
});
}
Op::Expand { .. } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let od = dims(graph, out);
let rank = od.len();
let pad = rank - xd.len();
let in_str_full = contig_strides(&xd);
let out_str = contig_strides(&od);
let mut shape = [1u32; 6];
let mut istr = [0u32; 6];
let mut ostr = [0u32; 6];
for ax in 0..rank {
shape[ax] = od[ax] as u32;
ostr[ax] = out_str[ax] as u32;
if ax < pad {
istr[ax] = 0;
} else {
let xi = ax - pad;
istr[ax] = if xd[xi] == 1 && od[ax] != 1 {
0
} else {
in_str_full[xi] as u32
};
}
}
let n = numel(&od);
let push = Push::default()
.u(n as u32)
.u(rank as u32)
.u(off(x))
.u(off(out))
.us(&shape)
.us(&istr)
.us(&ostr)
.bytes();
steps.push(Step::Gpu {
kernel: "reindex",
push,
groups: groups1d(n, 256),
});
}
Op::Concat { axis } => {
let od = dims(graph, out);
let out_str = contig_strides(&od);
let rank = od.len();
let mut axis_cursor = 0usize;
for &inp in &node.inputs {
let id_dims = dims(graph, inp);
let in_str = contig_strides(&id_dims);
let mut shape = [1u32; 6];
let mut istr = [0u32; 6];
let mut ostr = [0u32; 6];
for ax in 0..rank {
shape[ax] = *id_dims.get(ax).unwrap_or(&1) as u32;
istr[ax] = *in_str.get(ax).unwrap_or(&0) as u32;
ostr[ax] = out_str[ax] as u32;
}
let out_off = off(out) + (axis_cursor * out_str[*axis]) as u32;
let n = numel(&id_dims);
let push = Push::default()
.u(n as u32)
.u(rank as u32)
.u(off(inp))
.u(out_off)
.us(&shape)
.us(&istr)
.us(&ostr)
.bytes();
steps.push(Step::Gpu {
kernel: "reindex",
push,
groups: groups1d(n, 256),
});
axis_cursor += *id_dims.get(*axis).unwrap_or(&1);
}
}
Op::Gather { axis } => {
let data = node.inputs[0];
let idx = node.inputs[1];
let dd = dims(graph, data);
let ax = *axis;
let out_outer = numel(&dd[..ax]);
let axis_dim = dd[ax];
let out_inner = numel(&dd[ax + 1..]);
let n_idx = numel(&dims(graph, idx));
let total = out_outer * n_idx * out_inner;
let push = Push::default()
.u(out_outer as u32)
.u(n_idx as u32)
.u(out_inner as u32)
.u(axis_dim as u32)
.u(off(data))
.u(off(idx))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "gather",
push,
groups: groups1d(total, 256),
});
}
Op::Cumsum { axis, exclusive } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let ax = norm_axis(*axis, xd.len());
debug_assert_eq!(ax, xd.len().saturating_sub(1), "cumsum expects last axis");
let cols = *xd.get(ax).unwrap_or(&1);
let rows = numel(&xd) / cols.max(1);
let push = Push::default()
.u(rows as u32)
.u(cols as u32)
.u(off(x))
.u(off(out))
.u(if *exclusive { 1 } else { 0 })
.bytes();
steps.push(Step::Gpu {
kernel: "cumsum",
push,
groups: groups1d(rows, 64),
});
}
Op::Reverse { axes } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let rank = xd.len();
let mut shape = [1u32; 6];
let mut flip = [0u32; 6];
for ax in 0..rank {
shape[ax] = xd[ax] as u32;
flip[ax] = if axes.contains(&ax) { 1 } else { 0 };
}
let n = numel(&xd);
let push = Push::default()
.u(n as u32)
.u(rank as u32)
.u(off(x))
.u(off(out))
.us(&shape)
.us(&flip)
.bytes();
steps.push(Step::Gpu {
kernel: "reverse",
push,
groups: groups1d(n, 256),
});
}
Op::ArgMax { axis, .. } | Op::ArgMin { axis, .. } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let ax = (*axis).min(xd.len().saturating_sub(1));
let axis_len = xd[ax];
let outer = numel(&xd[..ax]);
let inner = numel(&xd[ax + 1..]);
let op_id = if matches!(node.op, Op::ArgMax { .. }) {
0
} else {
1
};
let push = Push::default()
.u(outer as u32)
.u(axis_len as u32)
.u(inner as u32)
.u(off(x))
.u(off(out))
.u(op_id)
.bytes();
steps.push(Step::Gpu {
kernel: "argreduce",
push,
groups: groups1d(outer * inner, 256),
});
}
Op::LayerNorm2d { eps } => {
let x = node.inputs[0];
let gamma = node.inputs[1];
let beta = node.inputs[2];
let xd = dims(graph, x);
let (nn, cc, hw) = (xd[0], xd[1], xd[2] * xd[3]);
let positions = nn * hw;
let push = Push::default()
.u(positions as u32)
.u(cc as u32)
.u(hw as u32)
.u(off(x))
.u(off(gamma))
.u(off(beta))
.u(off(out))
.f(*eps)
.bytes();
steps.push(Step::Gpu {
kernel: "layernorm2d",
push,
groups: groups1d(positions, 64),
});
}
Op::Pool {
kind,
kernel_size,
stride,
padding,
} => {
let x = node.inputs[0];
let xd = dims(graph, x);
let od = dims(graph, out);
let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
let (oh, ow) = (od[2], od[3]);
let (kh, kw) = (kernel_size[0], kernel_size[1]);
let (sh, sw) = (stride[0], stride[1]);
let (ph, pw) = (padding[0], padding[1]);
let kind_id = reduce_id(*kind); let push = Push::default()
.us(&[nn as u32, cc as u32, hh as u32, ww as u32])
.us(&[oh as u32, ow as u32])
.us(&[
kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
])
.u(off(x))
.u(off(out))
.u(kind_id)
.bytes();
steps.push(Step::Gpu {
kernel: "pool2d",
push,
groups: groups1d(nn * cc * oh * ow, 64),
});
}
Op::ResizeNearest2x => {
let x = node.inputs[0];
let xd = dims(graph, x);
let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
let push = Push::default()
.us(&[nn as u32, cc as u32, hh as u32, ww as u32])
.u(off(x))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "resize2x",
push,
groups: groups1d(nn * cc * hh * 4 * ww, 256),
});
}
Op::GroupedMatMul => {
let input = node.inputs[0];
let weight = node.inputs[1];
let idx = node.inputs[2];
let id = dims(graph, input);
let wd = dims(graph, weight);
let (m, k) = (id[id.len() - 2], id[id.len() - 1]);
let n = wd[wd.len() - 1];
let push = Push::default()
.u(m as u32)
.u(k as u32)
.u(n as u32)
.u(off(input))
.u(off(weight))
.u(off(idx))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "grouped_matmul",
push,
groups: (ceil_div(n, 16), ceil_div(m, 16), 1),
});
}
Op::Conv {
kernel_size,
stride,
padding,
dilation,
groups,
} => {
let x = node.inputs[0];
let weight = node.inputs[1];
let has_bias = node.inputs.len() > 2;
let bias = if has_bias { node.inputs[2] } else { weight };
let xd = dims(graph, x);
let od = dims(graph, out);
let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
let (cout, oh, ow) = (od[1], od[2], od[3]);
let (kh, kw) = (kernel_size[0], kernel_size[1]);
let (sh, sw) = (stride[0], stride[1]);
let (ph, pw) = (padding[0], padding[1]);
let (dh, dw) = (dilation[0], dilation[1]);
let push = Push::default()
.us(&[nn as u32, cin as u32, hh as u32, ww as u32])
.us(&[cout as u32, kh as u32, kw as u32])
.us(&[oh as u32, ow as u32])
.us(&[
sh as u32, sw as u32, ph as u32, pw as u32, dh as u32, dw as u32,
])
.u(*groups as u32)
.u(if has_bias { 1 } else { 0 })
.u(off(x))
.u(off(weight))
.u(off(bias))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "conv2d",
push,
groups: groups1d(nn * cout * oh * ow, 64),
});
}
Op::SelectiveScan { state_size } => {
let x = node.inputs[0];
let delta = node.inputs[1];
let a = node.inputs[2];
let bmat = node.inputs[3];
let cmat = node.inputs[4];
let xd = dims(graph, x);
let (bb, ss, hh) = (xd[0], xd[1], xd[2]);
let nn = *state_size;
let push = Push::default()
.u(bb as u32)
.u(ss as u32)
.u(hh as u32)
.u(nn as u32)
.u(off(x))
.u(off(delta))
.u(off(a))
.u(off(bmat))
.u(off(cmat))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "selective_scan",
push,
groups: groups1d(bb * hh, 64),
});
}
Op::Im2Col {
kernel_size,
stride,
padding,
dilation,
} => {
let x = node.inputs[0];
let xd = dims(graph, x);
let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
let (kh, kw) = (kernel_size[0], kernel_size[1]);
let (sh, sw) = (stride[0], stride[1]);
let (ph, pw) = (padding[0], padding[1]);
let (dh, dw) = (dilation[0], dilation[1]);
let eff_h = dh * (kh - 1) + 1;
let eff_w = dw * (kw - 1) + 1;
let ho = (hh + 2 * ph - eff_h) / sh + 1;
let wo = (ww + 2 * pw - eff_w) / sw + 1;
let push = Push::default()
.us(&[nn as u32, cin as u32, hh as u32, ww as u32])
.us(&[ho as u32, wo as u32])
.us(&[
kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
dh as u32, dw as u32,
])
.u(off(x))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "im2col",
push,
groups: groups1d(nn * ho * wo * cin * kh * kw, 256),
});
}
Op::ScatterAdd => {
let updates = node.inputs[0];
let indices = node.inputs[1];
let ud = dims(graph, updates);
let od = dims(graph, out);
let num_updates = ud[0];
let trailing = numel(&ud[1..]);
let out_dim = od[0];
let push = Push::default()
.u(out_dim as u32)
.u(trailing as u32)
.u(num_updates as u32)
.u(off(updates))
.u(off(indices))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "scatter_add",
push,
groups: groups1d(out_dim * trailing, 256),
});
}
Op::TopK { k } => {
let x = node.inputs[0];
let xd = dims(graph, x);
let n = *xd.last().unwrap_or(&1);
let rows = numel(&xd) / n.max(1);
let push = Push::default()
.u(rows as u32)
.u(n as u32)
.u(*k as u32)
.u(off(x))
.u(off(out))
.bytes();
steps.push(Step::Gpu {
kernel: "topk",
push,
groups: groups1d(rows, 64),
});
}
Op::DequantMatMul { scheme } => {
use rlx_ir::quant::QuantScheme;
let x = node.inputs[0];
let xd = dims(graph, x);
let od = dims(graph, out);
let n = *od.last().unwrap_or(&1);
let m = numel(&od) / n.max(1);
let k = numel(&xd) / m.max(1);
let gpu_scheme = match scheme {
QuantScheme::GgufQ4K => Some(0u32),
QuantScheme::GgufQ6K => Some(1u32),
_ => None,
};
match gpu_scheme {
Some(sc) if m == 1 && k.is_multiple_of(256) && n >= 1 => {
let w = node.inputs[1];
let push = Push::default()
.u(n as u32)
.u(k as u32)
.u(off(x))
.u(off(w))
.u(off(out))
.u(sc)
.bytes();
steps.push(Step::Gpu {
kernel: "dequant_matmul",
push,
groups: groups1d(n, 64),
});
}
_ => {
steps.push(Step::Host {
op: node.op.clone(),
out: node.id,
out_shape: node.shape.clone(),
inputs: node.inputs.clone(),
});
}
}
}
op if is_host_fallback(op) => {
steps.push(Step::Host {
op: node.op.clone(),
out: node.id,
out_shape: node.shape.clone(),
inputs: node.inputs.clone(),
});
}
other => panic!(
"rlx-vulkan: op {:?} reached the scheduler but has no kernel \
(should have been rejected at legalize). Pin this graph to Device::Cpu.",
other.kind()
),
}
let added = steps.len() - before;
if added > 0 {
let reads: Vec<u32> = node
.inputs
.iter()
.filter(|&&id| arena.has(id))
.map(|&id| arena.elem_offset(id))
.collect();
let write = if arena.has(out) {
arena.elem_offset(out)
} else {
0
};
for _ in 0..added {
deps.push(StepDep {
reads: reads.clone(),
write,
});
}
}
}
(steps, deps)
}