1use crate::buffer::Arena;
27use crate::device::vulkan_device;
28use crate::kernels::kernels;
29use ash::vk;
30use rlx_compile::memory::{BufferSlot, MemoryPlan};
31use rlx_ir::op::{Activation, BinaryOp, CmpOp, MaskKind, ReduceOp, RopeStyle};
32use rlx_ir::{DType, Graph, NodeId, Op, RngOptions};
33use std::collections::{HashMap, HashSet};
34
35pub const SUPPORTED_OPS: &[rlx_ir::OpKind] = {
38 use rlx_ir::OpKind::*;
39 &[
40 Input,
41 Param,
42 Constant,
43 Cast,
44 StopGradient,
45 Reshape, Binary,
47 Compare,
48 Where,
49 Activation, MatMul,
51 Reduce,
52 Softmax, LayerNorm,
54 RmsNorm,
55 LayerNorm2d, Rope,
57 Attention, FusedAttentionBlock,
62 Transpose,
63 Narrow,
64 Concat,
65 Expand,
66 Gather,
67 Cumsum,
68 Reverse, ArgMax,
70 ArgMin,
71 Pool,
72 ResizeNearest2x,
73 Conv, GroupedMatMul, SelectiveScan, Im2Col,
77 ScatterAdd,
78 TopK, Lstm,
82 Gru,
83 Rnn,
84 Mamba2,
85 GatedDeltaNet,
86 ConvTranspose2d,
87 Fft,
88 DequantMatMul,
89 DequantGroupedMatMul,
90 DequantMoEWeights, RngNormal,
92 RngUniform,
93 Sample, ]
95};
96
97fn is_host_fallback(op: &Op) -> bool {
103 matches!(
104 op,
105 Op::Lstm { .. }
106 | Op::Gru { .. }
107 | Op::Rnn { .. }
108 | Op::Mamba2 { .. }
109 | Op::GatedDeltaNet { .. }
110 | Op::ConvTranspose2d { .. }
111 | Op::Fft { .. }
112 | Op::DequantGroupedMatMul { .. }
113 | Op::DequantMoEWeights { .. }
114 | Op::RngNormal { .. }
115 | Op::RngUniform { .. }
116 | Op::Sample { .. }
117 )
118}
119
120#[derive(Clone)]
123enum Step {
124 Gpu {
125 kernel: &'static str,
126 push: Vec<u8>,
127 groups: (u32, u32, u32),
128 },
129 Host {
130 op: Op,
131 out: NodeId,
132 out_shape: rlx_ir::Shape,
133 inputs: Vec<NodeId>,
134 },
135}
136
137enum Segment {
143 Gpu(vk::CommandBuffer),
145 Host {
148 op: Op,
149 out: NodeId,
150 out_shape: rlx_ir::Shape,
151 inputs: Vec<NodeId>,
152 },
153}
154
155pub struct VulkanExecutable {
156 graph: Graph,
158 arena: Arena,
159 schedule: Vec<Step>,
160 segments: Vec<Segment>,
165 fence: vk::Fence,
167 cached: bool,
169 input_ids: HashMap<String, NodeId>,
170 param_ids: HashMap<String, NodeId>,
171 output_ids: Vec<NodeId>,
172 output_dtypes: Vec<DType>,
173 desc_pool: vk::DescriptorPool,
174 desc_set: vk::DescriptorSet,
175 rng: RngOptions,
176 active_extent: Option<(usize, usize)>,
177 gpu_handles: HashMap<String, Vec<f32>>,
181 gpu_handle_feeds: HashMap<String, usize>,
184 gpu_handle_resident: HashSet<String>,
186 kv_row_feeds: HashMap<String, usize>,
192}
193
194unsafe impl Send for VulkanExecutable {}
195
196fn plan_f32_uniform(graph: &Graph, align: usize) -> MemoryPlan {
199 let mut assignments: HashMap<NodeId, BufferSlot> = HashMap::new();
200 let mut schedule = Vec::with_capacity(graph.nodes().len());
201 let mut cursor = 0usize;
202 for node in graph.nodes() {
203 if matches!(
204 node.op,
205 Op::Reshape { .. } | Op::Cast { .. } | Op::StopGradient
206 ) {
207 if let Some(in_id) = node.inputs.first()
208 && let Some(slot) = assignments.get(in_id)
209 {
210 let aliased = slot.clone();
211 assignments.insert(node.id, aliased);
212 schedule.push(node.id);
213 continue;
214 }
215 }
216 let elems = node.shape.num_elements().unwrap_or(0);
217 let elem_size = match node.shape.dtype() {
225 DType::U8 | DType::I8 => 1,
226 _ => 4,
227 };
228 let bytes = (elems * elem_size).max(4);
229 let aligned = bytes.div_ceil(align) * align;
230 assignments.insert(
231 node.id,
232 BufferSlot {
233 offset: cursor,
234 size: aligned,
235 },
236 );
237 schedule.push(node.id);
238 cursor += aligned;
239 }
240 MemoryPlan {
241 arena_size: cursor.max(align),
242 assignments,
243 schedule,
244 }
245}
246
247fn dims(graph: &Graph, id: NodeId) -> Vec<usize> {
250 graph
251 .node(id)
252 .shape
253 .dims()
254 .iter()
255 .map(|d| match d {
256 rlx_ir::Dim::Static(s) => *s,
257 _ => 0,
258 })
259 .collect()
260}
261
262fn numel(d: &[usize]) -> usize {
263 d.iter()
264 .product::<usize>()
265 .max(if d.is_empty() { 1 } else { 0 })
266}
267
268fn contig_strides(d: &[usize]) -> Vec<usize> {
270 let mut s = vec![1usize; d.len()];
271 for i in (0..d.len().saturating_sub(1)).rev() {
272 s[i] = s[i + 1] * d[i + 1];
273 }
274 s
275}
276
277fn norm_axis(axis: i32, rank: usize) -> usize {
278 if axis < 0 {
279 (rank as i32 + axis).max(0) as usize
280 } else {
281 (axis as usize).min(rank.saturating_sub(1))
282 }
283}
284
285#[derive(Default)]
288struct Push {
289 words: Vec<u32>,
290}
291impl Push {
292 fn u(mut self, v: u32) -> Self {
293 self.words.push(v);
294 self
295 }
296 fn f(mut self, v: f32) -> Self {
297 self.words.push(v.to_bits());
298 self
299 }
300 fn us(mut self, vs: &[u32]) -> Self {
301 self.words.extend_from_slice(vs);
302 self
303 }
304 fn bytes(self) -> Vec<u8> {
305 let mut b = Vec::with_capacity(self.words.len() * 4);
306 for w in self.words {
307 b.extend_from_slice(&w.to_le_bytes());
308 }
309 b
310 }
311}
312
313fn ceil_div(n: usize, d: u32) -> u32 {
314 (n as u64).div_ceil(d as u64) as u32
315}
316
317fn coop_eligible(m: usize, _k: usize, n: usize) -> bool {
323 m.is_multiple_of(16) && n.is_multiple_of(16)
324}
325
326fn matmul_kernel(m: usize, k: usize, n: usize) -> &'static str {
338 let dev = vulkan_device();
339 let portability = dev.map(|d| d.portability).unwrap_or(false);
340 let coop = dev.map(|d| d.coop_matmul).unwrap_or(false);
341 match std::env::var("RLX_VULKAN_MATMUL").ok().as_deref() {
342 Some("scalar") => "matmul",
343 Some("tiled") => "matmul_tiled",
344 Some("coop") if coop && coop_eligible(m, k, n) => "matmul_coop",
345 Some("coop") => "matmul_tiled",
346 _ if portability => "matmul",
347 _ => "matmul_tiled",
348 }
349}
350
351fn groups1d(n: usize, local: u32) -> (u32, u32, u32) {
356 (ceil_div(n, local).max(1), 1, 1)
357}
358
359fn act_id(a: Activation) -> u32 {
360 match a {
361 Activation::Gelu => 0,
362 Activation::GeluApprox => 1,
363 Activation::Silu => 2,
364 Activation::Relu => 3,
365 Activation::Sigmoid => 4,
366 Activation::Tanh => 5,
367 Activation::Exp => 6,
368 Activation::Log => 7,
369 Activation::Sqrt => 8,
370 Activation::Rsqrt => 9,
371 Activation::Neg => 10,
372 Activation::Abs => 11,
373 Activation::Sin => 12,
374 Activation::Cos => 13,
375 Activation::Tan => 14,
376 Activation::Atan => 15,
377 Activation::Round => 16,
378 }
379}
380
381fn binop_id(op: BinaryOp) -> u32 {
382 match op {
383 BinaryOp::Add => 0,
384 BinaryOp::Sub => 1,
385 BinaryOp::Mul => 2,
386 BinaryOp::Div => 3,
387 BinaryOp::Max => 4,
388 BinaryOp::Min => 5,
389 BinaryOp::Pow => 6,
390 }
391}
392
393fn cmp_id(op: CmpOp) -> u32 {
394 match op {
395 CmpOp::Eq => 0,
396 CmpOp::Ne => 1,
397 CmpOp::Lt => 2,
398 CmpOp::Le => 3,
399 CmpOp::Gt => 4,
400 CmpOp::Ge => 5,
401 }
402}
403
404fn reduce_id(op: ReduceOp) -> u32 {
405 match op {
406 ReduceOp::Sum => 0,
407 ReduceOp::Mean => 1,
408 ReduceOp::Max => 2,
409 ReduceOp::Min => 3,
410 ReduceOp::Prod => 4,
411 }
412}
413
414impl VulkanExecutable {
415 pub fn compile(graph: Graph) -> Self {
416 Self::compile_rng(graph, RngOptions::default())
417 }
418
419 pub fn compile_rng(graph: Graph, rng: RngOptions) -> Self {
423 use rlx_opt::pass::Pass as _;
424
425 let graph = rlx_opt::LowerControlFlow.run(graph);
426 let graph = rlx_opt::unfuse::unfuse_attention_block(graph);
431 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, SUPPORTED_OPS)
432 .unwrap_or_else(|errs| panic!("{}", rlx_opt::format_legalize_error("vulkan", &errs)));
433 let graph = rlx_opt::LegalizeBroadcast.run(graph);
436
437 Self::build(graph, rng)
438 }
439
440 fn build(graph: Graph, rng: RngOptions) -> Self {
441 let dev = vulkan_device().expect("rlx-vulkan: no device");
442 let kern = kernels().expect("rlx-vulkan: no kernels");
443
444 let plan = plan_f32_uniform(&graph, 16);
445 let arena = Arena::from_plan(&plan);
446
447 for node in graph.nodes() {
449 if let Op::Constant { data } = &node.op
450 && arena.has(node.id)
451 && !data.is_empty()
452 {
453 let f = widen_const_to_f32(data, node.shape.dtype());
454 arena.write_f32(node.id, &f);
455 }
456 }
457
458 let mut input_ids = HashMap::new();
459 let mut param_ids = HashMap::new();
460 for node in graph.nodes() {
461 match &node.op {
462 Op::Input { name } => {
463 input_ids.insert(name.clone(), node.id);
464 }
465 Op::Param { name } => {
466 param_ids.insert(name.clone(), node.id);
467 }
468 _ => {}
469 }
470 }
471
472 let output_ids = graph.outputs.clone();
473 let output_dtypes = output_ids
474 .iter()
475 .map(|&id| graph.node(id).shape.dtype())
476 .collect();
477
478 let (schedule, deps) = build_schedule(&graph, &arena);
479
480 let pool_sizes = [vk::DescriptorPoolSize::default()
482 .ty(vk::DescriptorType::STORAGE_BUFFER)
483 .descriptor_count(1)];
484 let desc_pool = unsafe {
485 dev.device.create_descriptor_pool(
486 &vk::DescriptorPoolCreateInfo::default()
487 .max_sets(1)
488 .pool_sizes(&pool_sizes),
489 None,
490 )
491 }
492 .expect("vk descriptor_pool");
493 let set_layouts = [kern.dsl];
494 let desc_set = unsafe {
495 dev.device.allocate_descriptor_sets(
496 &vk::DescriptorSetAllocateInfo::default()
497 .descriptor_pool(desc_pool)
498 .set_layouts(&set_layouts),
499 )
500 }
501 .expect("vk descriptor_set")[0];
502 let buf_info = [vk::DescriptorBufferInfo::default()
503 .buffer(arena.buffer)
504 .offset(0)
505 .range(vk::WHOLE_SIZE)];
506 let write = vk::WriteDescriptorSet::default()
507 .dst_set(desc_set)
508 .dst_binding(0)
509 .descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
510 .buffer_info(&buf_info);
511 unsafe { dev.device.update_descriptor_sets(&[write], &[]) };
512
513 let cached = std::env::var("RLX_VULKAN_NOCACHE").as_deref() != Ok("1");
520 let (segments, fence) = if cached {
521 let segs = record_segments(dev, kern, desc_set, &schedule, &deps);
522 (segs, dev.create_reusable_fence())
523 } else {
524 (Vec::new(), vk::Fence::null())
525 };
526
527 if std::env::var_os("RLX_VULKAN_DEBUG").is_some() {
528 let gpu = schedule
529 .iter()
530 .filter(|s| matches!(s, Step::Gpu { .. }))
531 .count();
532 let host = schedule.len() - gpu;
533 let gpu_segs = segments
534 .iter()
535 .filter(|s| matches!(s, Segment::Gpu(_)))
536 .count();
537 let mut hist: HashMap<&'static str, usize> = HashMap::new();
538 for s in &schedule {
539 if let Step::Gpu { kernel, .. } = s {
540 *hist.entry(kernel).or_default() += 1;
541 }
542 }
543 let mut by_count: Vec<_> = hist.into_iter().collect();
544 by_count.sort_by_key(|&(_, c)| std::cmp::Reverse(c));
545 eprintln!(
546 "[rlx-vulkan] schedule: {gpu} gpu dispatches, {host} host ops; \
547 cached={cached} ({gpu_segs} gpu submit(s)/run)"
548 );
549 eprintln!("[rlx-vulkan] dispatch histogram: {by_count:?}");
550 }
551
552 Self {
553 graph,
554 arena,
555 schedule,
556 segments,
557 fence,
558 cached,
559 input_ids,
560 param_ids,
561 output_ids,
562 output_dtypes,
563 desc_pool,
564 desc_set,
565 rng,
566 active_extent: None,
567 gpu_handles: HashMap::new(),
568 gpu_handle_feeds: HashMap::new(),
569 gpu_handle_resident: HashSet::new(),
570 kv_row_feeds: HashMap::new(),
571 }
572 }
573
574 pub fn set_param(&mut self, name: &str, data: &[f32]) {
575 if let Some(&id) = self.param_ids.get(name) {
576 self.arena.write_f32(id, data);
577 }
578 }
579
580 pub fn set_param_bytes(&mut self, name: &str, data: &[u8]) {
583 if let Some(&id) = self.param_ids.get(name) {
584 self.arena.write_bytes(id, data);
585 }
586 }
587
588 pub fn output_dtypes(&self) -> Vec<DType> {
589 self.output_dtypes.clone()
590 }
591
592 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
593 self.active_extent = extent;
594 }
595
596 pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
601 let Some(&id) = self.input_ids.get(name) else {
602 return false;
603 };
604 self.gpu_handle_resident.remove(name);
606 self.arena.write_f32(id, data);
607 self.gpu_handles.insert(name.to_string(), data.to_vec());
609 true
610 }
611
612 pub fn has_gpu_handle(&self, name: &str) -> bool {
613 self.gpu_handles.contains_key(name)
614 }
615
616 pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) {
617 self.gpu_handle_feeds
618 .insert(handle_name.to_string(), output_index);
619 }
620
621 pub fn register_kv_row_feed(&mut self, handle_name: &str, output_index: usize) {
628 self.kv_row_feeds
629 .insert(handle_name.to_string(), output_index);
630 }
631
632 pub fn feed_kv_row(&mut self, src_row: usize, dst_row: usize, row_elems: usize) {
636 let feeds: Vec<(String, usize)> = self
637 .kv_row_feeds
638 .iter()
639 .map(|(k, &v)| (k.clone(), v))
640 .collect();
641 for (name, out_idx) in feeds {
642 let Some(&out_id) = self.output_ids.get(out_idx) else {
643 continue;
644 };
645 let Some(&in_id) = self.input_ids.get(name.as_str()) else {
646 continue;
647 };
648 if in_id != out_id {
649 self.arena.copy_node_f32_range(
650 in_id,
651 dst_row * row_elems,
652 out_id,
653 src_row * row_elems,
654 row_elems,
655 );
656 }
657 self.gpu_handle_resident.insert(name.clone());
658 self.gpu_handles.insert(name.clone(), Vec::new());
659 }
660 }
661
662 pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
665 if let Some(&out_idx) = self.gpu_handle_feeds.get(name)
666 && let Some(&out_id) = self.output_ids.get(out_idx)
667 {
668 let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
669 return Some(self.arena.read_f32(out_id, n));
670 }
671 if self.gpu_handle_resident.contains(name)
672 && let Some(&id) = self.input_ids.get(name)
673 {
674 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
675 return Some(self.arena.read_f32(id, n));
676 }
677 self.gpu_handles.get(name).cloned()
678 }
679
680 pub fn read_output_row(
685 &self,
686 out_idx: usize,
687 row: usize,
688 row_inner: usize,
689 ) -> Option<Vec<f32>> {
690 let id = *self.output_ids.get(out_idx)?;
691 let base = self.arena.elem_offset(id) as usize + row * row_inner;
692 Some(self.arena.read_f32_at_elem(base, row_inner))
693 }
694
695 fn propagate_gpu_handle_feeds_in_arena(&mut self) {
700 let extent = self.active_extent;
701 let feeds: Vec<(String, usize)> = self
702 .gpu_handle_feeds
703 .iter()
704 .map(|(k, &v)| (k.clone(), v))
705 .collect();
706 for (name, out_idx) in feeds {
707 let Some(&out_id) = self.output_ids.get(out_idx) else {
708 continue;
709 };
710 let Some(&in_id) = self.input_ids.get(name.as_str()) else {
711 continue;
712 };
713 if in_id != out_id {
714 let out_elems = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
715 let copy_elems = match extent {
716 Some((actual, upper)) if upper > 0 => actual * (out_elems / (upper + 1)).max(1),
717 _ => out_elems,
718 };
719 self.arena
720 .copy_node_f32_prefix(in_id, out_id, copy_elems.min(out_elems));
721 }
722 self.gpu_handle_resident.insert(name.clone());
723 self.gpu_handles.insert(name.clone(), Vec::new());
725 }
726 }
727
728 fn refresh_gpu_handles_from_outputs(&mut self) {
730 let feeds: Vec<(String, usize)> = self
731 .gpu_handle_feeds
732 .iter()
733 .map(|(k, &v)| (k.clone(), v))
734 .collect();
735 for (name, out_idx) in feeds {
736 let Some(&out_id) = self.output_ids.get(out_idx) else {
737 continue;
738 };
739 let n = self.graph.node(out_id).shape.num_elements().unwrap_or(0);
740 let src = self.arena.read_f32(out_id, n);
741 self.gpu_handles.insert(name, src);
742 }
743 }
744
745 pub fn set_rng(&mut self, rng: RngOptions) {
746 self.rng = rng;
747 }
748
749 pub fn rng(&self) -> RngOptions {
750 self.rng
751 }
752
753 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
754 self.run_read_outputs(inputs, None)
755 }
756
757 pub fn run_read_outputs(
758 &mut self,
759 inputs: &[(&str, &[f32])],
760 read_indices: Option<&[usize]>,
761 ) -> Vec<Vec<f32>> {
762 for (name, data) in &self.gpu_handles {
767 if self.gpu_handle_resident.contains(name) || inputs.iter().any(|(n, _)| n == name) {
768 continue;
769 }
770 if let Some(&id) = self.input_ids.get(name) {
771 self.arena.write_f32(id, data);
772 }
773 }
774 for &(name, data) in inputs {
776 if let Some(&id) = self.input_ids.get(name) {
777 self.arena.write_f32(id, data);
778 }
779 }
780
781 let dev = vulkan_device().expect("rlx-vulkan: no device");
786 let kern = kernels().expect("rlx-vulkan: no kernels");
787 let desc_set = self.desc_set;
788 let layout = kern.pipeline_layout;
789
790 if self.cached {
791 let nseg = self.segments.len();
796 for si in 0..nseg {
797 match &self.segments[si] {
798 Segment::Gpu(cmd) => {
799 let cmd = *cmd;
800 dev.submit_recorded_wait(cmd, self.fence);
801 }
802 Segment::Host {
803 op,
804 out,
805 out_shape,
806 inputs: in_ids,
807 } => {
808 let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
809 .iter()
810 .map(|&id| {
811 let sh = self.graph.node(id).shape.clone();
812 let nn = sh.num_elements().unwrap_or(0);
813 let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
814 crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
815 } else {
816 crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
817 };
818 (sh, buf)
819 })
820 .collect();
821 let result = crate::host::eval(op, out_shape, &in_specs);
822 self.arena.write_f32(*out, &result);
823 }
824 }
825 }
826 return self.finish_run(read_indices);
828 }
829
830 let n = self.schedule.len();
831 let mut i = 0;
832 while i < n {
833 let start = i;
834 while i < n && matches!(self.schedule[i], Step::Gpu { .. }) {
835 i += 1;
836 }
837 if i > start {
838 let gpu = self.schedule[start..i].to_vec();
839 dev.submit_and_wait(|cmd| unsafe {
840 dev.device.cmd_bind_descriptor_sets(
841 cmd,
842 vk::PipelineBindPoint::COMPUTE,
843 layout,
844 0,
845 &[desc_set],
846 &[],
847 );
848 let barrier = vk::MemoryBarrier::default()
849 .src_access_mask(vk::AccessFlags::SHADER_WRITE)
850 .dst_access_mask(
851 vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE,
852 );
853 for (j, step) in gpu.iter().enumerate() {
854 if let Step::Gpu {
855 kernel,
856 push,
857 groups,
858 } = step
859 {
860 let pipeline = kern.pipeline(kernel);
861 dev.device.cmd_bind_pipeline(
862 cmd,
863 vk::PipelineBindPoint::COMPUTE,
864 pipeline,
865 );
866 dev.device.cmd_push_constants(
867 cmd,
868 layout,
869 vk::ShaderStageFlags::COMPUTE,
870 0,
871 push,
872 );
873 dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
874 if j + 1 < gpu.len() {
875 dev.device.cmd_pipeline_barrier(
876 cmd,
877 vk::PipelineStageFlags::COMPUTE_SHADER,
878 vk::PipelineStageFlags::COMPUTE_SHADER,
879 vk::DependencyFlags::empty(),
880 &[barrier],
881 &[],
882 &[],
883 );
884 }
885 }
886 }
887 });
888 }
889 if i < n {
890 if let Step::Host {
891 op,
892 out,
893 out_shape,
894 inputs: in_ids,
895 } = self.schedule[i].clone()
896 {
897 let in_specs: Vec<(rlx_ir::Shape, crate::host::HostBuf)> = in_ids
898 .iter()
899 .map(|&id| {
900 let sh = self.graph.node(id).shape.clone();
901 let nn = sh.num_elements().unwrap_or(0);
902 let buf = if matches!(sh.dtype(), DType::U8 | DType::I8) {
905 crate::host::HostBuf::Bytes(self.arena.read_bytes(id, nn))
906 } else {
907 crate::host::HostBuf::F32(self.arena.read_f32(id, nn))
908 };
909 (sh, buf)
910 })
911 .collect();
912 let result = crate::host::eval(&op, &out_shape, &in_specs);
913 self.arena.write_f32(out, &result);
914 }
915 i += 1;
916 }
917 }
918
919 self.finish_run(read_indices)
920 }
921
922 fn finish_run(&mut self, read_indices: Option<&[usize]>) -> Vec<Vec<f32>> {
929 if !self.gpu_handle_feeds.is_empty() {
930 self.propagate_gpu_handle_feeds_in_arena();
931 if read_indices.is_none() {
932 self.refresh_gpu_handles_from_outputs();
933 }
934 }
935
936 let want: Vec<usize> = match read_indices {
937 Some(ix) => ix.to_vec(),
938 None => (0..self.output_ids.len()).collect(),
939 };
940 want.into_iter()
941 .filter_map(|i| {
942 let id = *self.output_ids.get(i)?;
943 let n = self.graph.node(id).shape.num_elements().unwrap_or(0);
944 Some(self.arena.read_f32(id, n))
945 })
946 .collect()
947 }
948
949 pub fn clone_for_cache(&self) -> Self {
952 let mut twin = Self::build(self.graph.clone(), self.rng);
953 twin.active_extent = self.active_extent;
954 self.arena.copy_into(&twin.arena);
958 twin.gpu_handles = self.gpu_handles.clone();
959 twin.gpu_handle_feeds = self.gpu_handle_feeds.clone();
960 twin.gpu_handle_resident = self.gpu_handle_resident.clone();
961 twin.kv_row_feeds = self.kv_row_feeds.clone();
962 twin
963 }
964}
965
966impl Drop for VulkanExecutable {
967 fn drop(&mut self) {
968 if let Some(dev) = vulkan_device() {
969 let cmds: Vec<vk::CommandBuffer> = self
972 .segments
973 .iter()
974 .filter_map(|s| match s {
975 Segment::Gpu(cmd) => Some(*cmd),
976 Segment::Host { .. } => None,
977 })
978 .collect();
979 if !cmds.is_empty() {
980 dev.free_cmds(&cmds);
981 }
982 if self.fence != vk::Fence::null() {
983 dev.destroy_fence(self.fence);
984 }
985 unsafe {
986 dev.device.destroy_descriptor_pool(self.desc_pool, None);
987 }
988 }
989 }
990}
991
992fn record_segments(
1009 dev: &crate::device::VulkanDevice,
1010 kern: &crate::kernels::Kernels,
1011 desc_set: vk::DescriptorSet,
1012 schedule: &[Step],
1013 deps: &[StepDep],
1014) -> Vec<Segment> {
1015 let layout = kern.pipeline_layout;
1016 let no_barrier = std::env::var("RLX_VULKAN_NOBARRIER").as_deref() == Ok("1");
1017 let full_barrier = std::env::var("RLX_VULKAN_FULLBARRIER").as_deref() == Ok("1");
1018 let mut segments = Vec::new();
1019 let n = schedule.len();
1020 let mut i = 0;
1021 while i < n {
1022 let start = i;
1023 while i < n && matches!(schedule[i], Step::Gpu { .. }) {
1024 i += 1;
1025 }
1026 if i > start {
1027 let run = &schedule[start..i];
1028 let run_deps = &deps[start..i];
1029 let cmd = dev.alloc_primary_cmd();
1030 unsafe {
1031 dev.device
1032 .begin_command_buffer(cmd, &vk::CommandBufferBeginInfo::default())
1033 .expect("vk begin cmd");
1034 dev.device.cmd_bind_descriptor_sets(
1035 cmd,
1036 vk::PipelineBindPoint::COMPUTE,
1037 layout,
1038 0,
1039 &[desc_set],
1040 &[],
1041 );
1042 let barrier = vk::MemoryBarrier::default()
1043 .src_access_mask(vk::AccessFlags::SHADER_WRITE)
1044 .dst_access_mask(vk::AccessFlags::SHADER_READ | vk::AccessFlags::SHADER_WRITE);
1045 let mut wrote: HashSet<u32> = HashSet::new();
1050 let mut read: HashSet<u32> = HashSet::new();
1051 for (j, step) in run.iter().enumerate() {
1052 if let Step::Gpu {
1053 kernel,
1054 push,
1055 groups,
1056 } = step
1057 {
1058 let dep = &run_deps[j];
1059 let hazard = !wrote.is_empty()
1060 && (dep.reads.iter().any(|r| wrote.contains(r))
1061 || wrote.contains(&dep.write)
1062 || read.contains(&dep.write));
1063 let emit_barrier = j > 0 && !no_barrier && (full_barrier || hazard);
1064 if emit_barrier {
1065 dev.device.cmd_pipeline_barrier(
1066 cmd,
1067 vk::PipelineStageFlags::COMPUTE_SHADER,
1068 vk::PipelineStageFlags::COMPUTE_SHADER,
1069 vk::DependencyFlags::empty(),
1070 &[barrier],
1071 &[],
1072 &[],
1073 );
1074 wrote.clear();
1075 read.clear();
1076 }
1077 let pipeline = kern.pipeline(kernel);
1078 dev.device
1079 .cmd_bind_pipeline(cmd, vk::PipelineBindPoint::COMPUTE, pipeline);
1080 dev.device.cmd_push_constants(
1081 cmd,
1082 layout,
1083 vk::ShaderStageFlags::COMPUTE,
1084 0,
1085 push,
1086 );
1087 dev.device.cmd_dispatch(cmd, groups.0, groups.1, groups.2);
1088 wrote.insert(dep.write);
1089 for &r in &dep.reads {
1090 read.insert(r);
1091 }
1092 }
1093 }
1094 dev.device.end_command_buffer(cmd).expect("vk end cmd");
1095 }
1096 segments.push(Segment::Gpu(cmd));
1097 }
1098 if i < n {
1099 if let Step::Host {
1100 op,
1101 out,
1102 out_shape,
1103 inputs,
1104 } = &schedule[i]
1105 {
1106 segments.push(Segment::Host {
1107 op: op.clone(),
1108 out: *out,
1109 out_shape: out_shape.clone(),
1110 inputs: inputs.clone(),
1111 });
1112 }
1113 i += 1;
1114 }
1115 }
1116 segments
1117}
1118
1119fn widen_const_to_f32(data: &[u8], dt: DType) -> Vec<f32> {
1121 match dt {
1122 DType::F32 => data
1123 .chunks_exact(4)
1124 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1125 .collect(),
1126 DType::F16 => data
1127 .chunks_exact(2)
1128 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
1129 .collect(),
1130 DType::BF16 => data
1131 .chunks_exact(2)
1132 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
1133 .collect(),
1134 DType::F64 => data
1135 .chunks_exact(8)
1136 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
1137 .collect(),
1138 DType::I64 => data
1139 .chunks_exact(8)
1140 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]) as f32)
1141 .collect(),
1142 DType::I32 | DType::U32 => data
1143 .chunks_exact(4)
1144 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
1145 .collect(),
1146 DType::I16 => data
1147 .chunks_exact(2)
1148 .map(|c| i16::from_le_bytes([c[0], c[1]]) as f32)
1149 .collect(),
1150 DType::I8 => data.iter().map(|&b| b as i8 as f32).collect(),
1151 DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
1152 DType::C64 => data
1153 .chunks_exact(4)
1154 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
1155 .collect(),
1156 }
1157}
1158
1159#[derive(Clone, Default)]
1167struct StepDep {
1168 reads: Vec<u32>,
1169 write: u32,
1170}
1171
1172fn build_schedule(graph: &Graph, arena: &Arena) -> (Vec<Step>, Vec<StepDep>) {
1178 let mut steps = Vec::new();
1179 let mut deps: Vec<StepDep> = Vec::new();
1180 for node in graph.nodes() {
1181 let off = |id: NodeId| arena.elem_offset(id);
1182 let out = node.id;
1183 let before = steps.len();
1184 match &node.op {
1185 Op::Input { .. }
1187 | Op::Param { .. }
1188 | Op::Constant { .. }
1189 | Op::Reshape { .. }
1190 | Op::Cast { .. }
1191 | Op::StopGradient => {}
1192
1193 Op::Binary(op) => {
1194 let a = node.inputs[0];
1195 let b = node.inputs[1];
1196 let n = numel(&dims(graph, out));
1197 let an = numel(&dims(graph, a));
1198 let bn = numel(&dims(graph, b));
1199 let push = Push::default()
1200 .u(n as u32)
1201 .u(off(a))
1202 .u(off(b))
1203 .u(off(out))
1204 .u(if an == n { 0 } else { an as u32 })
1205 .u(if bn == n { 0 } else { bn as u32 })
1206 .u(binop_id(*op))
1207 .bytes();
1208 steps.push(Step::Gpu {
1209 kernel: "binary",
1210 push,
1211 groups: groups1d(n, 256),
1212 });
1213 }
1214
1215 Op::Compare(op) => {
1216 let a = node.inputs[0];
1217 let b = node.inputs[1];
1218 let n = numel(&dims(graph, out));
1219 let an = numel(&dims(graph, a));
1220 let bn = numel(&dims(graph, b));
1221 let push = Push::default()
1222 .u(n as u32)
1223 .u(off(a))
1224 .u(off(b))
1225 .u(off(out))
1226 .u(if an == n { 0 } else { an as u32 })
1227 .u(if bn == n { 0 } else { bn as u32 })
1228 .u(cmp_id(*op))
1229 .bytes();
1230 steps.push(Step::Gpu {
1231 kernel: "compare",
1232 push,
1233 groups: groups1d(n, 256),
1234 });
1235 }
1236
1237 Op::Where => {
1238 let c = node.inputs[0];
1239 let a = node.inputs[1];
1240 let b = node.inputs[2];
1241 let n = numel(&dims(graph, out));
1242 let cn = numel(&dims(graph, c));
1243 let an = numel(&dims(graph, a));
1244 let bn = numel(&dims(graph, b));
1245 let push = Push::default()
1246 .u(n as u32)
1247 .u(off(c))
1248 .u(off(a))
1249 .u(off(b))
1250 .u(off(out))
1251 .u(if cn == n { 0 } else { cn as u32 })
1252 .u(if an == n { 0 } else { an as u32 })
1253 .u(if bn == n { 0 } else { bn as u32 })
1254 .bytes();
1255 steps.push(Step::Gpu {
1256 kernel: "where",
1257 push,
1258 groups: groups1d(n, 256),
1259 });
1260 }
1261
1262 Op::Activation(act) => {
1263 let x = node.inputs[0];
1264 let n = numel(&dims(graph, out));
1265 let push = Push::default()
1266 .u(n as u32)
1267 .u(off(x))
1268 .u(off(out))
1269 .u(act_id(*act))
1270 .bytes();
1271 steps.push(Step::Gpu {
1272 kernel: "unary",
1273 push,
1274 groups: groups1d(n, 256),
1275 });
1276 }
1277
1278 Op::MatMul => {
1279 let a = node.inputs[0];
1280 let b = node.inputs[1];
1281 let ad = dims(graph, a);
1282 let bd = dims(graph, b);
1283 let od = dims(graph, out);
1284 let (m, k) = (ad[ad.len() - 2], ad[ad.len() - 1]);
1285 let n = bd[bd.len() - 1];
1286 let batch = if od.len() > 2 {
1287 numel(&od[..od.len() - 2])
1288 } else {
1289 1
1290 };
1291 let a_batch = if ad.len() > 2 {
1292 numel(&ad[..ad.len() - 2])
1293 } else {
1294 1
1295 };
1296 let b_batch = if bd.len() > 2 {
1297 numel(&bd[..bd.len() - 2])
1298 } else {
1299 1
1300 };
1301 let a_bs = if a_batch <= 1 { 0 } else { m * k };
1302 let b_bs = if b_batch <= 1 { 0 } else { k * n };
1303 let push = Push::default()
1304 .u(m as u32)
1305 .u(k as u32)
1306 .u(n as u32)
1307 .u(off(a))
1308 .u(off(b))
1309 .u(off(out))
1310 .u(batch as u32)
1311 .u(a_bs as u32)
1312 .u(b_bs as u32)
1313 .u((m * n) as u32)
1314 .bytes();
1315 steps.push(Step::Gpu {
1316 kernel: matmul_kernel(m, k, n),
1317 push,
1318 groups: (ceil_div(n, 16), ceil_div(m, 16), batch.max(1) as u32),
1319 });
1320 }
1321
1322 Op::Reduce { op, axes, .. } => {
1323 let x = node.inputs[0];
1324 let xd = dims(graph, x);
1325 let rank = xd.len();
1326 let last = rank.saturating_sub(1);
1328 debug_assert!(
1329 axes.as_slice() == [last] || (rank <= 1),
1330 "rlx-vulkan: non-last-axis reduce should have been lowered"
1331 );
1332 let r = *xd.get(last).unwrap_or(&1);
1333 let outer = numel(&xd) / r.max(1);
1334 let push = Push::default()
1335 .u(outer as u32)
1336 .u(r as u32)
1337 .u(off(x))
1338 .u(off(out))
1339 .u(reduce_id(*op))
1340 .bytes();
1341 steps.push(Step::Gpu {
1342 kernel: "reduce",
1343 push,
1344 groups: groups1d(outer, 256),
1345 });
1346 }
1347
1348 Op::Softmax { axis } => {
1349 let x = node.inputs[0];
1350 let xd = dims(graph, x);
1351 let ax = norm_axis(*axis, xd.len());
1352 let axis_len = xd[ax];
1353 let outer = numel(&xd[..ax]);
1354 let inner = numel(&xd[ax + 1..]);
1355 let push = Push::default()
1356 .u(outer as u32)
1357 .u(axis_len as u32)
1358 .u(inner as u32)
1359 .u(off(x))
1360 .u(off(out))
1361 .bytes();
1362 steps.push(Step::Gpu {
1363 kernel: "softmax",
1364 push,
1365 groups: groups1d(outer * inner, 256),
1366 });
1367 }
1368
1369 Op::RmsNorm { axis, eps } => {
1370 let x = node.inputs[0];
1372 let gamma = node.inputs[1];
1373 let beta = node.inputs[2];
1374 let xd = dims(graph, x);
1375 let ax = norm_axis(*axis, xd.len());
1376 debug_assert_eq!(ax, xd.len().saturating_sub(1), "rmsnorm expects last axis");
1377 let n = xd[ax];
1378 let rows = numel(&xd) / n.max(1);
1379 let push = Push::default()
1380 .u(rows as u32)
1381 .u(n as u32)
1382 .u(off(x))
1383 .u(off(gamma))
1384 .u(off(beta))
1385 .u(off(out))
1386 .f(*eps)
1387 .bytes();
1388 steps.push(Step::Gpu {
1389 kernel: "rmsnorm",
1390 push,
1391 groups: groups1d(rows, 64),
1392 });
1393 }
1394
1395 Op::LayerNorm { axis, eps } => {
1396 let x = node.inputs[0];
1397 let gamma = node.inputs[1];
1398 let has_beta = node.inputs.len() >= 3;
1399 let beta = if has_beta { node.inputs[2] } else { gamma };
1400 let xd = dims(graph, x);
1401 let ax = norm_axis(*axis, xd.len());
1402 let n = xd[ax];
1403 let rows = numel(&xd) / n.max(1);
1404 let push = Push::default()
1405 .u(rows as u32)
1406 .u(n as u32)
1407 .u(off(x))
1408 .u(off(gamma))
1409 .u(off(beta))
1410 .u(off(out))
1411 .u(if has_beta { 1 } else { 0 })
1412 .f(*eps)
1413 .bytes();
1414 steps.push(Step::Gpu {
1415 kernel: "layernorm",
1416 push,
1417 groups: groups1d(rows, 64),
1418 });
1419 }
1420
1421 Op::Rope {
1422 head_dim,
1423 n_rot,
1424 style,
1425 } => {
1426 let x = node.inputs[0];
1427 let cos = node.inputs[1];
1428 let sin = node.inputs[2];
1429 let xd = dims(graph, x);
1430 let (batch, seq, hidden) = if xd.len() >= 3 {
1431 (xd[0], xd[1], xd[2])
1432 } else {
1433 let total = numel(&xd);
1434 (1, xd[0], total / xd[0].max(1))
1435 };
1436 let hd = *head_dim;
1437 let nh = hidden / hd.max(1);
1438 let tab_half = hd / 2;
1439 let cos_len = numel(&dims(graph, cos));
1440 let cos_rows = cos_len / tab_half.max(1);
1441 let per_token = (cos_rows == batch * seq && cos_rows != seq) as u32;
1442 let style_id = match style {
1443 RopeStyle::NeoX => 0u32,
1444 RopeStyle::GptJ => 1u32,
1445 };
1446 let push = Push::default()
1447 .u(batch as u32)
1448 .u(seq as u32)
1449 .u(hidden as u32)
1450 .u(hd as u32)
1451 .u(*n_rot as u32)
1452 .u(nh as u32)
1453 .u(tab_half as u32)
1454 .u(hidden as u32) .u(per_token)
1456 .u(style_id)
1457 .u(off(x))
1458 .u(off(cos))
1459 .u(off(sin))
1460 .u(off(out))
1461 .bytes();
1462 steps.push(Step::Gpu {
1463 kernel: "rope",
1464 push,
1465 groups: groups1d(batch * seq * nh, 64),
1466 });
1467 }
1468
1469 Op::Attention {
1470 num_heads,
1471 head_dim,
1472 mask_kind,
1473 score_scale,
1474 ..
1475 } => {
1476 let q = node.inputs[0];
1477 let k = node.inputs[1];
1478 let v = node.inputs[2];
1479 let qd = dims(graph, q);
1480 let kd = dims(graph, k);
1481 let nh = *num_heads;
1482 let dh = *head_dim;
1483 let (batch, q_s, k_s, bhsd) = if qd.len() == 4 {
1484 if qd[1] == nh {
1485 (qd[0], qd[2], kd[2], 1u32) } else {
1487 (qd[0], qd[1], kd[1], 0u32) }
1489 } else if qd.len() >= 3 {
1490 (qd[0], qd[1], kd[1], 0u32)
1491 } else {
1492 (1, qd[0], kd[0], 0u32)
1493 };
1494 let hs = (nh * dh) as u32;
1495 let (mask_kind_id, mask_off, window) = match mask_kind {
1496 MaskKind::None => (0u32, 0u32, 0u32),
1497 MaskKind::Causal => (1, 0, 0),
1498 MaskKind::SlidingWindow(w) => (2, 0, *w as u32),
1499 MaskKind::Custom => (3, off(node.inputs[3]), 0),
1500 MaskKind::Bias => (4, off(node.inputs[3]), 0),
1501 };
1502 let scale = score_scale.unwrap_or((dh as f32).powf(-0.5));
1503 let push = Push::default()
1504 .u(batch as u32)
1505 .u(nh as u32)
1506 .u(q_s as u32)
1507 .u(k_s as u32)
1508 .u(dh as u32)
1509 .u(off(q))
1510 .u(off(k))
1511 .u(off(v))
1512 .u(off(out))
1513 .u(hs)
1514 .u(hs)
1515 .u(hs)
1516 .u(bhsd)
1517 .u(mask_kind_id)
1518 .u(mask_off)
1519 .u(window)
1520 .f(scale)
1521 .f(-1.0e30)
1522 .f(0.5)
1523 .bytes();
1524 steps.push(Step::Gpu {
1525 kernel: "attention",
1526 push,
1527 groups: groups1d(batch * nh * q_s, 64),
1528 });
1529 }
1530
1531 Op::Transpose { perm } => {
1532 let x = node.inputs[0];
1533 let xd = dims(graph, x);
1534 let od = dims(graph, out);
1535 let in_str = contig_strides(&xd);
1536 let out_str = contig_strides(&od);
1537 let rank = od.len();
1538 let mut shape = [1u32; 6];
1539 let mut istr = [0u32; 6];
1540 let mut ostr = [0u32; 6];
1541 for ax in 0..rank {
1542 shape[ax] = od[ax] as u32;
1543 istr[ax] = in_str[perm[ax]] as u32;
1544 ostr[ax] = out_str[ax] as u32;
1545 }
1546 let n = numel(&od);
1547 let push = Push::default()
1548 .u(n as u32)
1549 .u(rank as u32)
1550 .u(off(x))
1551 .u(off(out))
1552 .us(&shape)
1553 .us(&istr)
1554 .us(&ostr)
1555 .bytes();
1556 steps.push(Step::Gpu {
1557 kernel: "reindex",
1558 push,
1559 groups: groups1d(n, 256),
1560 });
1561 }
1562
1563 Op::Narrow { axis, start, .. } => {
1564 let x = node.inputs[0];
1565 let xd = dims(graph, x);
1566 let od = dims(graph, out);
1567 let in_str = contig_strides(&xd);
1568 let out_str = contig_strides(&od);
1569 let rank = od.len();
1570 let mut shape = [1u32; 6];
1571 let mut istr = [0u32; 6];
1572 let mut ostr = [0u32; 6];
1573 for ax in 0..rank {
1574 shape[ax] = od[ax] as u32;
1575 istr[ax] = in_str[ax] as u32;
1576 ostr[ax] = out_str[ax] as u32;
1577 }
1578 let in_off = off(x) + (*start * in_str[*axis]) as u32;
1579 let n = numel(&od);
1580 let push = Push::default()
1581 .u(n as u32)
1582 .u(rank as u32)
1583 .u(in_off)
1584 .u(off(out))
1585 .us(&shape)
1586 .us(&istr)
1587 .us(&ostr)
1588 .bytes();
1589 steps.push(Step::Gpu {
1590 kernel: "reindex",
1591 push,
1592 groups: groups1d(n, 256),
1593 });
1594 }
1595
1596 Op::Expand { .. } => {
1597 let x = node.inputs[0];
1598 let xd = dims(graph, x);
1599 let od = dims(graph, out);
1600 let rank = od.len();
1601 let pad = rank - xd.len();
1603 let in_str_full = contig_strides(&xd);
1604 let out_str = contig_strides(&od);
1605 let mut shape = [1u32; 6];
1606 let mut istr = [0u32; 6];
1607 let mut ostr = [0u32; 6];
1608 for ax in 0..rank {
1609 shape[ax] = od[ax] as u32;
1610 ostr[ax] = out_str[ax] as u32;
1611 if ax < pad {
1612 istr[ax] = 0;
1613 } else {
1614 let xi = ax - pad;
1615 istr[ax] = if xd[xi] == 1 && od[ax] != 1 {
1616 0
1617 } else {
1618 in_str_full[xi] as u32
1619 };
1620 }
1621 }
1622 let n = numel(&od);
1623 let push = Push::default()
1624 .u(n as u32)
1625 .u(rank as u32)
1626 .u(off(x))
1627 .u(off(out))
1628 .us(&shape)
1629 .us(&istr)
1630 .us(&ostr)
1631 .bytes();
1632 steps.push(Step::Gpu {
1633 kernel: "reindex",
1634 push,
1635 groups: groups1d(n, 256),
1636 });
1637 }
1638
1639 Op::Concat { axis } => {
1640 let od = dims(graph, out);
1641 let out_str = contig_strides(&od);
1642 let rank = od.len();
1643 let mut axis_cursor = 0usize;
1644 for &inp in &node.inputs {
1645 let id_dims = dims(graph, inp);
1646 let in_str = contig_strides(&id_dims);
1647 let mut shape = [1u32; 6];
1648 let mut istr = [0u32; 6];
1649 let mut ostr = [0u32; 6];
1650 for ax in 0..rank {
1651 shape[ax] = *id_dims.get(ax).unwrap_or(&1) as u32;
1652 istr[ax] = *in_str.get(ax).unwrap_or(&0) as u32;
1653 ostr[ax] = out_str[ax] as u32;
1654 }
1655 let out_off = off(out) + (axis_cursor * out_str[*axis]) as u32;
1656 let n = numel(&id_dims);
1657 let push = Push::default()
1658 .u(n as u32)
1659 .u(rank as u32)
1660 .u(off(inp))
1661 .u(out_off)
1662 .us(&shape)
1663 .us(&istr)
1664 .us(&ostr)
1665 .bytes();
1666 steps.push(Step::Gpu {
1667 kernel: "reindex",
1668 push,
1669 groups: groups1d(n, 256),
1670 });
1671 axis_cursor += *id_dims.get(*axis).unwrap_or(&1);
1672 }
1673 }
1674
1675 Op::Gather { axis } => {
1676 let data = node.inputs[0];
1677 let idx = node.inputs[1];
1678 let dd = dims(graph, data);
1679 let ax = *axis;
1680 let out_outer = numel(&dd[..ax]);
1681 let axis_dim = dd[ax];
1682 let out_inner = numel(&dd[ax + 1..]);
1683 let n_idx = numel(&dims(graph, idx));
1684 let total = out_outer * n_idx * out_inner;
1685 let push = Push::default()
1686 .u(out_outer as u32)
1687 .u(n_idx as u32)
1688 .u(out_inner as u32)
1689 .u(axis_dim as u32)
1690 .u(off(data))
1691 .u(off(idx))
1692 .u(off(out))
1693 .bytes();
1694 steps.push(Step::Gpu {
1695 kernel: "gather",
1696 push,
1697 groups: groups1d(total, 256),
1698 });
1699 }
1700
1701 Op::Cumsum { axis, exclusive } => {
1702 let x = node.inputs[0];
1703 let xd = dims(graph, x);
1704 let ax = norm_axis(*axis, xd.len());
1705 debug_assert_eq!(ax, xd.len().saturating_sub(1), "cumsum expects last axis");
1706 let cols = *xd.get(ax).unwrap_or(&1);
1707 let rows = numel(&xd) / cols.max(1);
1708 let push = Push::default()
1709 .u(rows as u32)
1710 .u(cols as u32)
1711 .u(off(x))
1712 .u(off(out))
1713 .u(if *exclusive { 1 } else { 0 })
1714 .bytes();
1715 steps.push(Step::Gpu {
1716 kernel: "cumsum",
1717 push,
1718 groups: groups1d(rows, 64),
1719 });
1720 }
1721
1722 Op::Reverse { axes } => {
1723 let x = node.inputs[0];
1724 let xd = dims(graph, x);
1725 let rank = xd.len();
1726 let mut shape = [1u32; 6];
1727 let mut flip = [0u32; 6];
1728 for ax in 0..rank {
1729 shape[ax] = xd[ax] as u32;
1730 flip[ax] = if axes.contains(&ax) { 1 } else { 0 };
1731 }
1732 let n = numel(&xd);
1733 let push = Push::default()
1734 .u(n as u32)
1735 .u(rank as u32)
1736 .u(off(x))
1737 .u(off(out))
1738 .us(&shape)
1739 .us(&flip)
1740 .bytes();
1741 steps.push(Step::Gpu {
1742 kernel: "reverse",
1743 push,
1744 groups: groups1d(n, 256),
1745 });
1746 }
1747
1748 Op::ArgMax { axis, .. } | Op::ArgMin { axis, .. } => {
1749 let x = node.inputs[0];
1750 let xd = dims(graph, x);
1751 let ax = (*axis).min(xd.len().saturating_sub(1));
1752 let axis_len = xd[ax];
1753 let outer = numel(&xd[..ax]);
1754 let inner = numel(&xd[ax + 1..]);
1755 let op_id = if matches!(node.op, Op::ArgMax { .. }) {
1756 0
1757 } else {
1758 1
1759 };
1760 let push = Push::default()
1761 .u(outer as u32)
1762 .u(axis_len as u32)
1763 .u(inner as u32)
1764 .u(off(x))
1765 .u(off(out))
1766 .u(op_id)
1767 .bytes();
1768 steps.push(Step::Gpu {
1769 kernel: "argreduce",
1770 push,
1771 groups: groups1d(outer * inner, 256),
1772 });
1773 }
1774
1775 Op::LayerNorm2d { eps } => {
1776 let x = node.inputs[0];
1778 let gamma = node.inputs[1];
1779 let beta = node.inputs[2];
1780 let xd = dims(graph, x);
1781 let (nn, cc, hw) = (xd[0], xd[1], xd[2] * xd[3]);
1782 let positions = nn * hw;
1783 let push = Push::default()
1784 .u(positions as u32)
1785 .u(cc as u32)
1786 .u(hw as u32)
1787 .u(off(x))
1788 .u(off(gamma))
1789 .u(off(beta))
1790 .u(off(out))
1791 .f(*eps)
1792 .bytes();
1793 steps.push(Step::Gpu {
1794 kernel: "layernorm2d",
1795 push,
1796 groups: groups1d(positions, 64),
1797 });
1798 }
1799
1800 Op::Pool {
1801 kind,
1802 kernel_size,
1803 stride,
1804 padding,
1805 } => {
1806 let x = node.inputs[0];
1808 let xd = dims(graph, x);
1809 let od = dims(graph, out);
1810 let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1811 let (oh, ow) = (od[2], od[3]);
1812 let (kh, kw) = (kernel_size[0], kernel_size[1]);
1813 let (sh, sw) = (stride[0], stride[1]);
1814 let (ph, pw) = (padding[0], padding[1]);
1815 let kind_id = reduce_id(*kind); let push = Push::default()
1817 .us(&[nn as u32, cc as u32, hh as u32, ww as u32])
1818 .us(&[oh as u32, ow as u32])
1819 .us(&[
1820 kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
1821 ])
1822 .u(off(x))
1823 .u(off(out))
1824 .u(kind_id)
1825 .bytes();
1826 steps.push(Step::Gpu {
1827 kernel: "pool2d",
1828 push,
1829 groups: groups1d(nn * cc * oh * ow, 64),
1830 });
1831 }
1832
1833 Op::ResizeNearest2x => {
1834 let x = node.inputs[0];
1835 let xd = dims(graph, x);
1836 let (nn, cc, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1837 let push = Push::default()
1838 .us(&[nn as u32, cc as u32, hh as u32, ww as u32])
1839 .u(off(x))
1840 .u(off(out))
1841 .bytes();
1842 steps.push(Step::Gpu {
1843 kernel: "resize2x",
1844 push,
1845 groups: groups1d(nn * cc * hh * 4 * ww, 256),
1846 });
1847 }
1848
1849 Op::GroupedMatMul => {
1850 let input = node.inputs[0];
1852 let weight = node.inputs[1];
1853 let idx = node.inputs[2];
1854 let id = dims(graph, input);
1855 let wd = dims(graph, weight);
1856 let (m, k) = (id[id.len() - 2], id[id.len() - 1]);
1857 let n = wd[wd.len() - 1];
1858 let push = Push::default()
1859 .u(m as u32)
1860 .u(k as u32)
1861 .u(n as u32)
1862 .u(off(input))
1863 .u(off(weight))
1864 .u(off(idx))
1865 .u(off(out))
1866 .bytes();
1867 steps.push(Step::Gpu {
1868 kernel: "grouped_matmul",
1869 push,
1870 groups: (ceil_div(n, 16), ceil_div(m, 16), 1),
1871 });
1872 }
1873
1874 Op::Conv {
1875 kernel_size,
1876 stride,
1877 padding,
1878 dilation,
1879 groups,
1880 } => {
1881 let x = node.inputs[0];
1883 let weight = node.inputs[1];
1884 let has_bias = node.inputs.len() > 2;
1885 let bias = if has_bias { node.inputs[2] } else { weight };
1886 let xd = dims(graph, x);
1887 let od = dims(graph, out);
1888 let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1889 let (cout, oh, ow) = (od[1], od[2], od[3]);
1890 let (kh, kw) = (kernel_size[0], kernel_size[1]);
1891 let (sh, sw) = (stride[0], stride[1]);
1892 let (ph, pw) = (padding[0], padding[1]);
1893 let (dh, dw) = (dilation[0], dilation[1]);
1894 let push = Push::default()
1895 .us(&[nn as u32, cin as u32, hh as u32, ww as u32])
1896 .us(&[cout as u32, kh as u32, kw as u32])
1897 .us(&[oh as u32, ow as u32])
1898 .us(&[
1899 sh as u32, sw as u32, ph as u32, pw as u32, dh as u32, dw as u32,
1900 ])
1901 .u(*groups as u32)
1902 .u(if has_bias { 1 } else { 0 })
1903 .u(off(x))
1904 .u(off(weight))
1905 .u(off(bias))
1906 .u(off(out))
1907 .bytes();
1908 steps.push(Step::Gpu {
1909 kernel: "conv2d",
1910 push,
1911 groups: groups1d(nn * cout * oh * ow, 64),
1912 });
1913 }
1914
1915 Op::SelectiveScan { state_size } => {
1916 let x = node.inputs[0];
1918 let delta = node.inputs[1];
1919 let a = node.inputs[2];
1920 let bmat = node.inputs[3];
1921 let cmat = node.inputs[4];
1922 let xd = dims(graph, x);
1923 let (bb, ss, hh) = (xd[0], xd[1], xd[2]);
1924 let nn = *state_size;
1925 let push = Push::default()
1926 .u(bb as u32)
1927 .u(ss as u32)
1928 .u(hh as u32)
1929 .u(nn as u32)
1930 .u(off(x))
1931 .u(off(delta))
1932 .u(off(a))
1933 .u(off(bmat))
1934 .u(off(cmat))
1935 .u(off(out))
1936 .bytes();
1937 steps.push(Step::Gpu {
1938 kernel: "selective_scan",
1939 push,
1940 groups: groups1d(bb * hh, 64),
1941 });
1942 }
1943
1944 Op::Im2Col {
1945 kernel_size,
1946 stride,
1947 padding,
1948 dilation,
1949 } => {
1950 let x = node.inputs[0];
1952 let xd = dims(graph, x);
1953 let (nn, cin, hh, ww) = (xd[0], xd[1], xd[2], xd[3]);
1954 let (kh, kw) = (kernel_size[0], kernel_size[1]);
1955 let (sh, sw) = (stride[0], stride[1]);
1956 let (ph, pw) = (padding[0], padding[1]);
1957 let (dh, dw) = (dilation[0], dilation[1]);
1958 let eff_h = dh * (kh - 1) + 1;
1959 let eff_w = dw * (kw - 1) + 1;
1960 let ho = (hh + 2 * ph - eff_h) / sh + 1;
1961 let wo = (ww + 2 * pw - eff_w) / sw + 1;
1962 let push = Push::default()
1963 .us(&[nn as u32, cin as u32, hh as u32, ww as u32])
1964 .us(&[ho as u32, wo as u32])
1965 .us(&[
1966 kh as u32, kw as u32, sh as u32, sw as u32, ph as u32, pw as u32,
1967 dh as u32, dw as u32,
1968 ])
1969 .u(off(x))
1970 .u(off(out))
1971 .bytes();
1972 steps.push(Step::Gpu {
1973 kernel: "im2col",
1974 push,
1975 groups: groups1d(nn * ho * wo * cin * kh * kw, 256),
1976 });
1977 }
1978
1979 Op::ScatterAdd => {
1980 let updates = node.inputs[0];
1982 let indices = node.inputs[1];
1983 let ud = dims(graph, updates);
1984 let od = dims(graph, out);
1985 let num_updates = ud[0];
1986 let trailing = numel(&ud[1..]);
1987 let out_dim = od[0];
1988 let push = Push::default()
1989 .u(out_dim as u32)
1990 .u(trailing as u32)
1991 .u(num_updates as u32)
1992 .u(off(updates))
1993 .u(off(indices))
1994 .u(off(out))
1995 .bytes();
1996 steps.push(Step::Gpu {
1997 kernel: "scatter_add",
1998 push,
1999 groups: groups1d(out_dim * trailing, 256),
2000 });
2001 }
2002
2003 Op::TopK { k } => {
2004 let x = node.inputs[0];
2005 let xd = dims(graph, x);
2006 let n = *xd.last().unwrap_or(&1);
2007 let rows = numel(&xd) / n.max(1);
2008 let push = Push::default()
2009 .u(rows as u32)
2010 .u(n as u32)
2011 .u(*k as u32)
2012 .u(off(x))
2013 .u(off(out))
2014 .bytes();
2015 steps.push(Step::Gpu {
2016 kernel: "topk",
2017 push,
2018 groups: groups1d(rows, 64),
2019 });
2020 }
2021
2022 Op::DequantMatMul { scheme } => {
2026 use rlx_ir::quant::QuantScheme;
2027 let x = node.inputs[0];
2028 let xd = dims(graph, x);
2029 let od = dims(graph, out);
2030 let n = *od.last().unwrap_or(&1);
2031 let m = numel(&od) / n.max(1);
2032 let k = numel(&xd) / m.max(1);
2033 let gpu_scheme = match scheme {
2034 QuantScheme::GgufQ4K => Some(0u32),
2035 QuantScheme::GgufQ6K => Some(1u32),
2036 _ => None,
2037 };
2038 match gpu_scheme {
2039 Some(sc) if m == 1 && k.is_multiple_of(256) && n >= 1 => {
2040 let w = node.inputs[1];
2041 let push = Push::default()
2042 .u(n as u32)
2043 .u(k as u32)
2044 .u(off(x))
2045 .u(off(w))
2046 .u(off(out))
2047 .u(sc)
2048 .bytes();
2049 steps.push(Step::Gpu {
2050 kernel: "dequant_matmul",
2051 push,
2052 groups: groups1d(n, 64),
2053 });
2054 }
2055 _ => {
2056 steps.push(Step::Host {
2057 op: node.op.clone(),
2058 out: node.id,
2059 out_shape: node.shape.clone(),
2060 inputs: node.inputs.clone(),
2061 });
2062 }
2063 }
2064 }
2065
2066 op if is_host_fallback(op) => {
2067 steps.push(Step::Host {
2068 op: node.op.clone(),
2069 out: node.id,
2070 out_shape: node.shape.clone(),
2071 inputs: node.inputs.clone(),
2072 });
2073 }
2074
2075 other => panic!(
2076 "rlx-vulkan: op {:?} reached the scheduler but has no kernel \
2077 (should have been rejected at legalize). Pin this graph to Device::Cpu.",
2078 other.kind()
2079 ),
2080 }
2081
2082 let added = steps.len() - before;
2087 if added > 0 {
2088 let reads: Vec<u32> = node
2089 .inputs
2090 .iter()
2091 .filter(|&&id| arena.has(id))
2092 .map(|&id| arena.elem_offset(id))
2093 .collect();
2094 let write = if arena.has(out) {
2095 arena.elem_offset(out)
2096 } else {
2097 0
2098 };
2099 for _ in 0..added {
2100 deps.push(StepDep {
2101 reads: reads.clone(),
2102 write,
2103 });
2104 }
2105 }
2106 }
2107 (steps, deps)
2108}