1use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39 use rlx_ir::DType;
40 match dtype {
41 DType::F32 => {
42 let n = data.len() / 4;
43 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44 s.to_vec()
45 }
46 DType::F16 => {
47 let n = data.len() / 2;
48 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49 s.iter().map(|h| h.to_f32()).collect()
50 }
51 DType::BF16 => {
52 let n = data.len() / 2;
53 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54 s.iter().map(|h| h.to_f32()).collect()
55 }
56 other => panic!(
57 "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58 (only F32/F16/BF16 are accepted on the host I/O surface)"
59 ),
60 }
61}
62
63#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69 use rlx_ir::DType;
70 match dt {
71 DType::F32 => {
72 let mut bytes = Vec::with_capacity(v.len() * 4);
73 for &x in v {
74 bytes.extend_from_slice(&x.to_le_bytes());
75 }
76 bytes
77 }
78 DType::F16 => {
79 let mut bytes = Vec::with_capacity(v.len() * 2);
80 for &x in v {
81 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82 }
83 bytes
84 }
85 DType::BF16 => {
86 let mut bytes = Vec::with_capacity(v.len() * 2);
87 for &x in v {
88 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89 }
90 bytes
91 }
92 DType::F64 => {
93 let mut bytes = Vec::with_capacity(v.len() * 8);
94 for &x in v {
95 bytes.extend_from_slice(&(x as f64).to_le_bytes());
96 }
97 bytes
98 }
99 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100 DType::U8 => v.iter().map(|&x| x as u8).collect(),
101 DType::I16 => {
102 let mut bytes = Vec::with_capacity(v.len() * 2);
103 for &x in v {
104 bytes.extend_from_slice(&(x as i16).to_le_bytes());
105 }
106 bytes
107 }
108 DType::I32 => {
109 let mut bytes = Vec::with_capacity(v.len() * 4);
110 for &x in v {
111 bytes.extend_from_slice(&(x as i32).to_le_bytes());
112 }
113 bytes
114 }
115 DType::U32 => {
116 let mut bytes = Vec::with_capacity(v.len() * 4);
117 for &x in v {
118 bytes.extend_from_slice(&(x as u32).to_le_bytes());
119 }
120 bytes
121 }
122 DType::I64 => {
123 let mut bytes = Vec::with_capacity(v.len() * 8);
124 for &x in v {
125 bytes.extend_from_slice(&(x as i64).to_le_bytes());
126 }
127 bytes
128 }
129 DType::Bool => v
130 .iter()
131 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132 .collect(),
133 DType::C64 => {
134 let mut bytes = Vec::with_capacity(v.len() * 8);
138 for &x in v {
139 bytes.extend_from_slice(&x.to_le_bytes());
140 bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141 }
142 bytes
143 }
144 }
145}
146
147pub trait ExecutableGraph: Send {
149 fn set_param(&mut self, name: &str, data: &[f32]);
151
152 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
159 panic!("clone_box not implemented for this backend");
160 }
161
162 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
164
165 fn run_read_outputs(
168 &mut self,
169 inputs: &[(&str, &[f32])],
170 read_indices: Option<&[usize]>,
171 ) -> Vec<Vec<f32>> {
172 match read_indices {
173 None => self.run(inputs),
174 Some(ix) => {
175 let all = self.run(inputs);
178 ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
179 }
180 }
181 }
182
183 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
185 let vecs = self.run(inputs);
186 vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
187 }
188
189 fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
192 &[] }
194
195 fn arena_ptr(&self) -> *const u8 {
197 std::ptr::null()
198 }
199
200 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
217 let _ = extent;
218 }
219
220 fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
222
223 fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
225
226 fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
228 false
229 }
230
231 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
233 None
234 }
235
236 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
238 None
239 }
240
241 fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
245 false
246 }
247
248 fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
250 None
251 }
252
253 fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
255 false
256 }
257
258 fn has_gpu_handle(&self, _name: &str) -> bool {
259 false
260 }
261
262 fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
263 false
264 }
265
266 fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
267 None
268 }
269
270 fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
273 None
274 }
275
276 fn run_feed_gpu_handle(
278 &mut self,
279 inputs: &[(&str, &[f32])],
280 _handle_name: &str,
281 _output_index: usize,
282 ) -> Option<Vec<f32>> {
283 let _ = inputs;
284 None
285 }
286
287 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
302 let _ = self.run(inputs);
303 }
304
305 fn sync_pending(&mut self) {}
308
309 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
318 input_sets.iter().map(|inputs| self.run(inputs)).collect()
319 }
320
321 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
334 if dtype != rlx_ir::DType::F32 {
335 panic!(
336 "backend's default set_param_typed only handles F32; \
337 got {dtype:?}. Override on the backend for typed support."
338 );
339 }
340 if !data.len().is_multiple_of(4) {
341 panic!(
342 "set_param_typed F32: data length {} not a multiple of 4",
343 data.len()
344 );
345 }
346 let n = data.len() / 4;
351 let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
352 self.set_param(name, f32_slice);
353 }
354
355 fn run_typed(
359 &mut self,
360 inputs: &[(&str, &[u8], rlx_ir::DType)],
361 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
362 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
365 for (name, data, dt) in inputs {
366 if *dt != rlx_ir::DType::F32 {
367 panic!(
368 "backend's default run_typed only handles F32 inputs; \
369 got {dt:?} for input '{name}'"
370 );
371 }
372 if data.len() % 4 != 0 {
373 panic!(
374 "run_typed F32 input '{name}': len {} not multiple of 4",
375 data.len()
376 );
377 }
378 let n = data.len() / 4;
379 let v: Vec<f32> =
380 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
381 owned.push((name.to_string(), v));
382 }
383 let refs: Vec<(&str, &[f32])> = owned
384 .iter()
385 .map(|(n, d)| (n.as_str(), d.as_slice()))
386 .collect();
387 let outs = self.run(&refs);
388 outs.into_iter()
389 .map(|v| {
390 let bytes =
391 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
392 .to_vec();
393 (bytes, rlx_ir::DType::F32)
394 })
395 .collect()
396 }
397}
398
399pub trait Backend: Send + Sync {
409 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
411
412 fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
416 self.compile(lir.into_graph(), options)
417 }
418
419 fn compile_hir(
421 &self,
422 hir: HirModule,
423 device: rlx_driver::Device,
424 options: &CompileOptions,
425 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
426 let result = crate::stages::compile_hir_stages(device, hir, options)?;
427 crate::stages::maybe_log_fusion(&result.fusion);
428 Ok(self.compile_lir(result.lir, options))
429 }
430
431 fn compile_module(
433 &self,
434 module: rlx_ir::GraphModule,
435 device: rlx_driver::Device,
436 options: &CompileOptions,
437 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
438 let result = crate::stages::compile_module_stages(device, module, options)?;
439 crate::stages::maybe_log_fusion(&result.fusion);
440 Ok(self.compile_lir(result.lir, options))
441 }
442
443 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
450 &[]
451 }
452}
453
454#[allow(dead_code)]
457fn prepare_fused_graph(
458 graph: Graph,
459 options: &CompileOptions,
460 supported_ops: &[rlx_ir::OpKind],
461 backend_name: &str,
462) -> Graph {
463 let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
464 graph,
465 backend_name,
466 supported_ops,
467 options.kernel_dispatch,
468 );
469 rlx_opt::maybe_log_dispatch_report(&report);
470 if !report.compile_ready {
471 panic!(
472 "{}\n{}",
473 rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
474 rlx_opt::format_dispatch_report(&report)
475 );
476 }
477 graph = crate::precompile::post_fusion_cleanup(graph, options);
478 if let Some(p) = options.policy.clone() {
479 use rlx_opt::pass::Pass as _;
480 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
481 }
482 graph
483}
484
485#[allow(dead_code)]
486fn declared_output_dtypes(
487 manifest: &cpu_low_precision::IoDtypeManifest,
488 exec_dtypes: Vec<rlx_ir::DType>,
489) -> Vec<rlx_ir::DType> {
490 exec_dtypes
491 .into_iter()
492 .enumerate()
493 .map(|(i, exec)| manifest.output_dtype(i, exec))
494 .collect()
495}
496
497pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
505 backend.compile(graph, &CompileOptions::default())
506}
507
508pub fn compile_hir(
510 backend: &dyn Backend,
511 hir: HirModule,
512 device: rlx_driver::Device,
513 options: &CompileOptions,
514) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
515 backend.compile_hir(hir, device, options)
516}
517
518pub fn compile_module(
520 backend: &dyn Backend,
521 module: rlx_ir::GraphModule,
522 device: rlx_driver::Device,
523 options: &CompileOptions,
524) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
525 backend.compile_module(module, device, options)
526}
527
528pub fn compile_with_precision(
530 backend: &dyn Backend,
531 graph: Graph,
532 precision: crate::Precision,
533) -> Box<dyn ExecutableGraph> {
534 backend.compile(graph, &CompileOptions::new().precision(precision))
535}
536
537fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
542 use rlx_opt::pass::Pass as _;
543 match policy {
544 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
545 None => graph,
546 }
547}
548
549#[cfg(feature = "cpu")]
552pub mod cpu_backend {
553 use super::*;
554 use rlx_cpu::{arena::Arena, thunk};
555 use rlx_ir::{DType, NodeId, Op};
556 use rlx_opt::memory::{self, MemoryPlan};
557 use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
560
561 pub struct CpuBackend;
562
563 const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
570 use rlx_ir::OpKind::*;
571 &[
572 Input,
573 Param,
574 Constant,
575 Activation,
576 Cast,
577 StopGradient,
578 Binary,
579 Compare,
580 Where,
581 ElementwiseRegion,
582 MatMul,
583 DotGeneral,
584 DenseSolve,
585 BatchedDenseSolve,
586 Scan,
587 ScanBackward,
588 ScanBackwardXs,
589 LayerNorm,
590 LayerNorm2d,
591 GroupNorm,
592 BatchNormInference,
593 RmsNorm,
594 ResizeNearest2x,
595 AxialRope2d,
596 Attention,
597 Rope,
598 Reshape,
599 Transpose,
600 Narrow,
601 Concat,
602 Expand,
603 Gather,
604 Reduce,
605 Softmax,
606 Cumsum,
607 TopK,
608 Sample,
609 Conv,
610 Im2Col,
611 ConvTranspose2d,
612 Pool,
613 GroupedMatMul,
614 DequantGroupedMatMul,
615 DequantMoEWeights,
616 ScatterAdd,
617 LoraMatMul,
618 DequantMatMul,
619 SelectiveScan,
620 GatedDeltaNet,
621 FusedSwiGLU,
622 FusedMatMulBiasAct,
623 FusedResidualLN,
624 FusedResidualRmsNorm,
625 FusedAttentionBlock,
626 ReluBackward,
631 ActivationBackward,
632 FakeQuantize,
633 FakeQuantizeBackward,
634 MaxPool2dBackward,
635 Conv2dBackwardInput,
636 Conv2dBackwardWeight,
637 SoftmaxCrossEntropyWithLogits,
638 SoftmaxCrossEntropyBackward,
639 AttentionBackward,
640 LayerNormBackwardInput,
641 LayerNormBackwardGamma,
642 BatchNormInferenceBackwardInput,
643 BatchNormInferenceBackwardGamma,
644 BatchNormInferenceBackwardBeta,
645 RmsNormBackwardInput,
646 RmsNormBackwardGamma,
647 RmsNormBackwardBeta,
648 RopeBackward,
649 CumsumBackward,
650 GatherBackward,
651 GaussianSplatRender,
653 GaussianSplatRenderBackward,
654 GaussianSplatPrepare,
655 GaussianSplatRasterize,
656 Custom,
660 CustomFn,
664 Fft,
668 FftButterflyStage,
669 LogMel,
670 LogMelBackward,
671 WelchPeaks,
672 ComplexNormSq,
677 ComplexNormSqBackward,
678 Conjugate,
679 ]
680 };
681
682 impl Backend for CpuBackend {
683 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
684 CPU_SUPPORTED_OPS
685 }
686
687 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
688 use rlx_opt::pass::Pass as _;
689 let graph = rlx_opt::LowerControlFlow.run(graph);
695 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
699 panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
700 }
701 let policy = options.policy.clone();
702 let _precision = options.precision;
703 let cfg = rlx_cpu::config::RuntimeConfig::global();
704
705 let graph = crate::precompile::precompile_cleanup(graph, options);
706
707 let mut compile_opts = options.clone();
709 compile_opts.arena_alignment = cfg.arena_alignment;
710 let compile_result = crate::stages::compile_graph_stages_for_backend(
711 rlx_driver::Device::Cpu,
712 graph,
713 &compile_opts,
714 CPU_SUPPORTED_OPS,
715 );
716 crate::stages::maybe_log_fusion(&compile_result.fusion);
717 let fused = compile_result.lir.into_graph();
718
719 let fused = match policy {
722 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
723 None => fused,
724 };
725
726 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
727 let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
728 cpu_low_precision::promote_to_f32(fused)
729 } else {
730 fused
731 };
732
733 let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
735 if cfg.verbose >= 1 {
736 eprintln!(
737 "[rlx] arena: {} bytes, {} buffers, alignment: {}",
738 plan.arena_size,
739 plan.assignments.len(),
740 cfg.arena_alignment
741 );
742 }
743 Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
744 }
745
746 fn compile_lir(
747 &self,
748 lir: LirModule,
749 options: &CompileOptions,
750 ) -> Box<dyn ExecutableGraph> {
751 let alignment = lir.buffers.alignment.max(options.arena_alignment);
752 let mut graph = lir.into_graph();
753 {
754 use rlx_opt::pass::Pass as _;
755 graph = rlx_opt::LegalizeBroadcast.run(graph);
756 }
757 if let Some(p) = options.policy.clone() {
758 use rlx_opt::pass::Pass;
759 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
760 }
761 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
762 let promote = cpu_low_precision::needs_f32_exec(&graph);
763 let exec_graph = if promote {
764 cpu_low_precision::promote_to_f32(graph)
765 } else {
766 graph
767 };
768 let plan = memory::plan_memory_aligned(&exec_graph, alignment);
771 let cfg = rlx_cpu::config::RuntimeConfig::global();
772 if cfg.verbose >= 1 {
773 eprintln!(
774 "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
775 plan.arena_size,
776 plan.assignments.len(),
777 alignment,
778 );
779 }
780 Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
781 }
782 }
783
784 fn build_cpu_executable(
785 graph: Graph,
786 plan: MemoryPlan,
787 io_manifest: cpu_low_precision::IoDtypeManifest,
788 ) -> CpuExecutable {
789 let mut arena = Arena::from_plan(plan);
790 let mut input_ids = HashMap::new();
791 let mut param_ids = HashMap::new();
792 let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
793 for node in graph.nodes() {
794 node_dtypes.insert(node.id, node.shape.dtype());
795 match &node.op {
796 Op::Input { name } => {
797 input_ids.insert(name.clone(), node.id);
798 }
799 Op::Param { name } => {
800 param_ids.insert(name.clone(), node.id);
801 }
802 _ => {}
803 }
804 }
805
806 let schedule = thunk::compile_thunks(&graph, &arena);
807
808 let mut input_slots = Vec::new();
809 for node in graph.nodes() {
810 if let Op::Input { name } = &node.op {
811 let off = arena.byte_offset(node.id);
812 let len = node.shape.num_elements().unwrap_or(0);
813 input_slots.push((name.clone(), off, len, node.shape.dtype()));
814 }
815 }
816
817 let output_slots: Vec<(usize, usize)> = graph
818 .outputs
819 .iter()
820 .map(|&id| {
821 let off = arena.byte_offset(id);
822 let len = graph.node(id).shape.num_elements().unwrap_or(0);
823 (off, len)
824 })
825 .collect();
826
827 for node in graph.nodes() {
828 if let Op::Constant { data } = &node.op
829 && arena.has_buffer(node.id)
830 && !data.is_empty()
831 {
832 match node.shape.dtype() {
833 DType::F64 | DType::F16 | DType::BF16 => {
834 let off = arena.byte_offset(node.id);
835 let buf = arena.raw_buf_mut();
836 let n = buf.len().saturating_sub(off).min(data.len());
837 buf[off..off + n].copy_from_slice(&data[..n]);
838 }
839 _ => {
840 let buf = arena.slice_mut(node.id);
841 let n_floats = data.len() / 4;
842 let n = buf.len().min(n_floats);
843 for i in 0..n {
844 let bytes = [
845 data[i * 4],
846 data[i * 4 + 1],
847 data[i * 4 + 2],
848 data[i * 4 + 3],
849 ];
850 buf[i] = f32::from_le_bytes(bytes);
851 }
852 }
853 }
854 }
855 }
856
857 CpuExecutable {
858 graph,
859 arena,
860 params: HashMap::new(),
861 typed_params: HashMap::new(),
862 input_ids,
863 param_ids,
864 node_dtypes,
865 io_manifest,
866 schedule,
867 input_slots,
868 output_slots,
869 handles: HashMap::new(),
870 active_extent: None,
871 moe_resident: None,
872 moe_resident_layers: None,
873 moe_topk_capture: None,
874 }
875 }
876
877 #[derive(Clone)]
878 struct CpuExecutable {
879 graph: Graph,
880 arena: Arena,
881 params: HashMap<String, Vec<f32>>,
882 typed_params: HashMap<String, (Vec<u8>, DType)>,
884 input_ids: HashMap<String, NodeId>,
885 param_ids: HashMap<String, NodeId>,
886 node_dtypes: HashMap<NodeId, DType>,
889 io_manifest: cpu_low_precision::IoDtypeManifest,
891 schedule: thunk::ThunkSchedule,
892 input_slots: Vec<(String, usize, usize, DType)>,
894 output_slots: Vec<(usize, usize)>,
896 handles: HashMap<String, Vec<f32>>,
901 active_extent: Option<(usize, usize)>,
907 moe_resident: Option<std::sync::Arc<[bool]>>,
908 moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
909 moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
910 }
911
912 unsafe impl Send for CpuExecutable {}
913
914 impl CpuExecutable {
915 fn write_input(&mut self, id: NodeId, data: &[f32]) {
917 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
918 let off = self.arena.byte_offset(id);
919 let buf = self.arena.raw_buf_mut();
920 let elem_size = dtype.size_bytes();
921 let max_elems = (buf.len() - off) / elem_size;
922 unsafe {
923 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
924 }
925 }
926
927 fn read_output(&self, id: NodeId) -> Vec<f32> {
929 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
930 let off = self.arena.byte_offset(id);
931 let buf = self.arena.raw_buf();
932 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
933 unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
934 }
935 }
936
937 impl ExecutableGraph for CpuExecutable {
938 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
939 Box::new(self.clone())
940 }
941 fn set_param(&mut self, name: &str, data: &[f32]) {
942 self.params.insert(name.to_string(), data.to_vec());
943 self.typed_params.remove(name);
944 if let Some(&id) = self.param_ids.get(name)
947 && self.arena.has_buffer(id)
948 {
949 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
950 let off = self.arena.byte_offset(id);
951 let buf = self.arena.raw_buf_mut();
952 let elem_size = dtype.size_bytes();
953 let max_elems = (buf.len() - off) / elem_size;
954 unsafe {
955 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
956 }
957 }
958 }
959
960 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
961 self.restore_arena_baseline();
962 let handle_names: Vec<String> = self.handles.keys().cloned().collect();
965 for name in &handle_names {
966 if let Some(&id) = self.input_ids.get(name)
967 && self.arena.has_buffer(id)
968 {
969 let data = self.handles.get(name).cloned().unwrap_or_default();
970 self.write_input(id, &data);
971 }
972 }
973 for &(name, data) in inputs {
975 if let Some(&id) = self.input_ids.get(name)
976 && self.arena.has_buffer(id)
977 {
978 self.write_input(id, data);
979 }
980 }
981
982 let active_used = if let Some((actual, upper)) = self.active_extent {
987 thunk::execute_thunks_active(
988 &self.schedule,
989 self.arena.raw_buf_mut(),
990 actual,
991 upper,
992 )
993 } else {
994 false
995 };
996 if !active_used {
997 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
999 }
1000
1001 for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1005 let name = format!("out{idx}");
1006 if self.handles.contains_key(&name) {
1007 let v = self.read_output(out_id);
1008 self.handles.insert(name, v);
1009 }
1010 }
1011
1012 self.graph
1013 .outputs
1014 .iter()
1015 .map(|&out_id| self.read_output(out_id))
1016 .collect()
1017 }
1018
1019 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1020 self.restore_arena_baseline();
1021 for &(name, data) in inputs {
1023 if let Some(&id) = self.input_ids.get(name)
1024 && self.arena.has_buffer(id)
1025 {
1026 self.write_input(id, data);
1027 }
1028 }
1029 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1030 self.graph
1034 .outputs
1035 .iter()
1036 .map(|&out_id| {
1037 let (ptr, len) = self.arena.raw_ptr(out_id);
1038 (ptr as *const f32, len)
1039 })
1040 .collect()
1041 }
1042
1043 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1047 self.restore_arena_baseline();
1048 let buf = self.arena.raw_buf_mut();
1049 for (i, &data) in inputs.iter().enumerate() {
1050 if i < self.input_slots.len() {
1051 let (_, off, max_len, dtype) = &self.input_slots[i];
1052 unsafe {
1053 write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1054 }
1055 }
1056 }
1057 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1058 &self.output_slots
1059 }
1060
1061 fn arena_ptr(&self) -> *const u8 {
1062 self.arena.raw_buf_mut_ptr()
1063 }
1064
1065 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1066 self.handles.insert(name.to_string(), data.to_vec());
1071 true
1072 }
1073
1074 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1075 self.handles.get(name).cloned()
1076 }
1077
1078 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1079 self.active_extent = extent;
1080 }
1081
1082 fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1083 self.moe_resident_layers = None;
1084 self.schedule.moe_resident_layers = None;
1085 self.moe_resident = Some(Arc::from(mask));
1086 self.schedule.moe_resident = self.moe_resident.clone();
1087 }
1088
1089 fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1090 self.moe_resident = None;
1091 self.schedule.moe_resident = None;
1092 let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1093 let arc = Arc::new(layers);
1094 self.moe_resident_layers = Some(arc.clone());
1095 self.schedule.moe_resident_layers = Some(arc);
1096 }
1097
1098 fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1099 let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1100 self.moe_topk_capture = Some(cap.clone());
1101 self.schedule.moe_topk_capture = Some(cap);
1102 true
1103 }
1104
1105 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1106 let cap = self.moe_topk_capture.as_ref()?;
1107 let layers = cap.take_layers();
1108 if layers.is_empty() {
1109 None
1110 } else {
1111 Some(layers)
1112 }
1113 }
1114
1115 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1116 rlx_cpu::moe_residency::take_last_forward_stats()
1117 }
1118
1119 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1125 if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1126 self.set_param_bytes(name, data, dtype);
1127 return;
1128 }
1129 if matches!(dtype, DType::U8 | DType::I8) {
1133 self.set_param_bytes(name, data, dtype);
1134 return;
1135 }
1136 if dtype == DType::F32 {
1137 let n = data.len() / 4;
1138 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1139 self.set_param(name, s);
1140 } else {
1141 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1142 self.set_param(name, &f32_buf);
1143 }
1144 }
1145
1146 fn run_typed(
1158 &mut self,
1159 inputs: &[(&str, &[u8], rlx_ir::DType)],
1160 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1161 let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1166
1167 if all_f64 {
1168 for (name, data, _) in inputs {
1169 if let Some(&id) = self.input_ids.get(*name) {
1170 if !self.arena.has_buffer(id) {
1171 continue;
1172 }
1173 let off = self.arena.byte_offset(id);
1174 let buf = self.arena.raw_buf_mut();
1175 let n = data.len();
1176 debug_assert!(
1177 off + n <= buf.len(),
1178 "run_typed: input '{name}' overflows arena slot"
1179 );
1180 buf[off..off + n].copy_from_slice(data);
1181 }
1182 }
1183 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1184 } else {
1185 let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1190 for (name, data, dt) in inputs {
1191 let direct = matches!(
1192 *dt,
1193 DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1194 );
1195 if direct {
1196 if let Some(&id) = self.input_ids.get(*name) {
1197 if !self.arena.has_buffer(id) {
1198 continue;
1199 }
1200 let off = self.arena.byte_offset(id);
1201 let buf = self.arena.raw_buf_mut();
1202 buf[off..off + data.len()].copy_from_slice(data);
1203 }
1204 } else {
1205 let v = super::widen_bytes_to_f32(data, *dt);
1206 f32_owned.push((name.to_string(), v));
1207 }
1208 }
1209 for (name, data) in &f32_owned {
1210 if let Some(&id) = self.input_ids.get(name.as_str()) {
1211 if self.arena.has_buffer(id) {
1212 self.write_input(id, data);
1213 }
1214 }
1215 }
1216 let active_used = if let Some((actual, upper)) = self.active_extent {
1217 thunk::execute_thunks_active(
1218 &self.schedule,
1219 self.arena.raw_buf_mut(),
1220 actual,
1221 upper,
1222 )
1223 } else {
1224 false
1225 };
1226 if !active_used {
1227 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1228 }
1229 }
1230
1231 self.graph
1233 .outputs
1234 .iter()
1235 .enumerate()
1236 .map(|(idx, &id)| {
1237 let exec_dtype = self.graph.node(id).shape.dtype();
1238 let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1239 if matches!(
1240 exec_dtype,
1241 DType::F64
1242 | DType::F16
1243 | DType::BF16
1244 | DType::I32
1245 | DType::I64
1246 | DType::U32
1247 | DType::C64
1248 ) {
1249 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1250 let n_bytes = n_elems * exec_dtype.size_bytes();
1251 let off = self.arena.byte_offset(id);
1252 let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1253 return (bytes, declared);
1254 }
1255 let f32_vals = self.read_output(id);
1256 if declared != exec_dtype {
1257 return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1258 }
1259 let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1260 (bytes, declared)
1261 })
1262 .collect()
1263 }
1264 }
1265
1266 impl CpuExecutable {
1267 fn restore_arena_baseline(&mut self) {
1272 self.arena.raw_buf_mut().fill(0);
1273 let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1274 .graph
1275 .nodes()
1276 .iter()
1277 .filter_map(|node| {
1278 if let Op::Constant { data } = &node.op
1279 && self.arena.has_buffer(node.id)
1280 && !data.is_empty()
1281 {
1282 Some((node.id, node.shape.dtype(), data.clone()))
1283 } else {
1284 None
1285 }
1286 })
1287 .collect();
1288 for (id, dtype, data) in constants {
1289 self.write_constant_to_arena(id, dtype, &data);
1290 }
1291 let params = self.params.clone();
1292 for (name, data) in params {
1293 if let Some(&id) = self.param_ids.get(&name)
1294 && self.arena.has_buffer(id)
1295 {
1296 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1297 let off = self.arena.byte_offset(id);
1298 let buf = self.arena.raw_buf_mut();
1299 let elem_size = dtype.size_bytes();
1300 let max_elems = (buf.len() - off) / elem_size;
1301 unsafe {
1302 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1303 }
1304 }
1305 }
1306 let typed = self.typed_params.clone();
1307 for (name, (data, dtype)) in typed {
1308 self.write_param_bytes_to_arena(&name, &data);
1309 let _ = dtype;
1310 }
1311 }
1312
1313 fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1314 match dtype {
1315 DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1316 let off = self.arena.byte_offset(id);
1317 let buf = self.arena.raw_buf_mut();
1318 let n = buf.len().saturating_sub(off).min(data.len());
1319 buf[off..off + n].copy_from_slice(&data[..n]);
1320 }
1321 _ => {
1322 let buf = self.arena.slice_mut(id);
1323 let n_floats = data.len() / 4;
1324 let n = buf.len().min(n_floats);
1325 for i in 0..n {
1326 let bytes = [
1327 data[i * 4],
1328 data[i * 4 + 1],
1329 data[i * 4 + 2],
1330 data[i * 4 + 3],
1331 ];
1332 buf[i] = f32::from_le_bytes(bytes);
1333 }
1334 }
1335 }
1336 }
1337
1338 fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1344 self.typed_params
1345 .insert(name.to_string(), (data.to_vec(), dtype));
1346 self.params.remove(name);
1347 self.write_param_bytes_to_arena(name, data);
1348 }
1349
1350 fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1351 if let Some(&id) = self.param_ids.get(name)
1352 && self.arena.has_buffer(id)
1353 {
1354 let off = self.arena.byte_offset(id);
1355 let buf = self.arena.raw_buf_mut();
1356 debug_assert!(
1357 off + data.len() <= buf.len(),
1358 "set_param_bytes: '{name}' would overflow arena slot"
1359 );
1360 buf[off..off + data.len()].copy_from_slice(data);
1361 }
1362 }
1363 }
1364}
1365
1366#[cfg(feature = "gpu")]
1371pub mod wgpu_backend {
1372 use super::*;
1373 use rlx_ir::OpKind;
1374 use rlx_wgpu::backend::WgpuExecutable;
1375
1376 pub struct WgpuBackend;
1377
1378 const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1384 OpKind::Input,
1385 OpKind::Param,
1386 OpKind::Constant,
1387 OpKind::Activation,
1388 OpKind::Cast,
1389 OpKind::StopGradient,
1390 OpKind::Binary,
1391 OpKind::Compare,
1392 OpKind::Where,
1393 OpKind::ElementwiseRegion,
1394 OpKind::TransformRegion,
1395 OpKind::BatchElementwiseRegion,
1396 OpKind::MatMul,
1397 OpKind::DotGeneral,
1398 OpKind::LayerNorm,
1399 OpKind::RmsNorm,
1400 OpKind::Attention,
1401 OpKind::AttentionBackward,
1402 OpKind::RmsNormBackwardInput,
1403 OpKind::RmsNormBackwardGamma,
1404 OpKind::RmsNormBackwardBeta,
1405 OpKind::LayerNormBackwardInput,
1412 OpKind::LayerNormBackwardGamma,
1413 OpKind::RopeBackward,
1414 OpKind::CumsumBackward,
1415 OpKind::GatherBackward,
1416 OpKind::Rope,
1417 OpKind::Reshape,
1418 OpKind::Transpose,
1419 OpKind::Narrow,
1420 OpKind::Concat,
1421 OpKind::Expand,
1422 OpKind::Gather,
1423 OpKind::Reduce,
1424 OpKind::Softmax,
1425 OpKind::Cumsum,
1426 OpKind::TopK,
1427 OpKind::Sample,
1428 OpKind::Conv,
1429 OpKind::Im2Col,
1430 OpKind::Pool,
1431 OpKind::GroupedMatMul,
1432 OpKind::DequantGroupedMatMul,
1433 OpKind::DequantMoEWeights,
1434 OpKind::ScatterAdd,
1435 OpKind::SelectiveScan,
1436 OpKind::DequantMatMul,
1437 OpKind::FusedMatMulBiasAct,
1438 OpKind::FusedResidualLN,
1439 OpKind::FusedResidualRmsNorm,
1440 OpKind::FusedSwiGLU,
1441 OpKind::FusedAttentionBlock,
1442 OpKind::FusedTransformerLayer,
1443 OpKind::Fft,
1449 OpKind::LogMel,
1450 OpKind::LogMelBackward,
1451 OpKind::WelchPeaks,
1452 OpKind::GaussianSplatRender,
1454 OpKind::GaussianSplatRenderBackward,
1455 OpKind::GaussianSplatPrepare,
1456 OpKind::GaussianSplatRasterize,
1457 OpKind::Custom,
1458 ];
1460
1461 impl Backend for WgpuBackend {
1462 fn supported_ops(&self) -> &'static [OpKind] {
1463 WGPU_SUPPORTED_OPS
1464 }
1465
1466 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1467 use rlx_opt::pass::Pass as _;
1468 let graph = rlx_opt::LowerControlFlow.run(graph);
1469 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1470 .unwrap_or_else(|errors| {
1471 panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1472 });
1473 let graph = crate::precompile::precompile_cleanup(graph, options);
1474 let graph = rlx_opt::LegalizeBroadcast.run(graph);
1478 let compile_result = crate::stages::compile_graph_stages_for_backend(
1487 rlx_driver::Device::Gpu,
1488 graph,
1489 options,
1490 WGPU_SUPPORTED_OPS,
1491 );
1492 crate::stages::maybe_log_fusion(&compile_result.fusion);
1493 let graph = compile_result.lir.into_graph();
1494 let graph = match options.policy.clone() {
1495 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1496 None => graph,
1497 };
1498 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1499 Box::new(WgpuExecutableWrapper {
1500 inner: WgpuExecutable::compile(graph),
1501 io_manifest,
1502 })
1503 }
1504
1505 fn compile_lir(
1506 &self,
1507 lir: LirModule,
1508 options: &CompileOptions,
1509 ) -> Box<dyn ExecutableGraph> {
1510 use rlx_opt::pass::Pass as _;
1511 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1514 let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1515 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1516 Box::new(WgpuExecutableWrapper {
1517 inner: WgpuExecutable::compile(graph),
1518 io_manifest,
1519 })
1520 }
1521 }
1522
1523 struct WgpuExecutableWrapper {
1524 inner: WgpuExecutable,
1525 io_manifest: cpu_low_precision::IoDtypeManifest,
1526 }
1527
1528 unsafe impl Send for WgpuExecutableWrapper {}
1529
1530 impl ExecutableGraph for WgpuExecutableWrapper {
1531 fn set_param(&mut self, name: &str, data: &[f32]) {
1532 self.inner.set_param(name, data);
1533 }
1534 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1535 self.inner.run(inputs)
1536 }
1537 fn run_read_outputs(
1538 &mut self,
1539 inputs: &[(&str, &[f32])],
1540 read_indices: Option<&[usize]>,
1541 ) -> Vec<Vec<f32>> {
1542 self.inner.run_read_outputs(inputs, read_indices)
1543 }
1544 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1545 self.inner.bind_gpu_handle(name, data)
1546 }
1547 fn has_gpu_handle(&self, name: &str) -> bool {
1548 self.inner.has_gpu_handle(name)
1549 }
1550 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1551 self.inner.set_gpu_handle_feed(handle_name, output_index);
1552 true
1553 }
1554 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1555 self.inner.read_gpu_handle(name)
1556 }
1557 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1558 self.inner.set_active_extent(extent);
1559 }
1560
1561 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1564 match dtype {
1565 rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1566 self.inner.set_param_bytes(name, data);
1567 }
1568 rlx_ir::DType::F32 => {
1569 let n = data.len() / 4;
1570 let f32_slice =
1571 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1572 self.inner.set_param(name, f32_slice);
1573 }
1574 rlx_ir::DType::F16 => {
1575 let n = data.len() / 2;
1576 let f16_slice =
1577 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1578 let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1579 self.inner.set_param(name, &f32);
1580 }
1581 rlx_ir::DType::BF16 => {
1582 let n = data.len() / 2;
1583 let bf16_slice = unsafe {
1584 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1585 };
1586 let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1587 self.inner.set_param(name, &f32);
1588 }
1589 other => panic!(
1590 "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1591 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1592 ),
1593 }
1594 }
1595
1596 fn run_typed(
1599 &mut self,
1600 inputs: &[(&str, &[u8], rlx_ir::DType)],
1601 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1602 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1603 for (name, data, dt) in inputs {
1604 let v: Vec<f32> = match *dt {
1605 rlx_ir::DType::F32 => {
1606 let n = data.len() / 4;
1607 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1608 .to_vec()
1609 }
1610 rlx_ir::DType::F16 => {
1611 let n = data.len() / 2;
1612 let s = unsafe {
1613 std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1614 };
1615 s.iter().map(|h| h.to_f32()).collect()
1616 }
1617 rlx_ir::DType::BF16 => {
1618 let n = data.len() / 2;
1619 let s = unsafe {
1620 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1621 };
1622 s.iter().map(|h| h.to_f32()).collect()
1623 }
1624 other => {
1625 panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1626 }
1627 };
1628 owned.push((name.to_string(), v));
1629 }
1630 let refs: Vec<(&str, &[f32])> = owned
1631 .iter()
1632 .map(|(n, d)| (n.as_str(), d.as_slice()))
1633 .collect();
1634 let dtypes =
1635 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1636 let outs = self.inner.run(&refs);
1637 outs.into_iter()
1638 .zip(
1639 dtypes
1640 .into_iter()
1641 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1642 )
1643 .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1644 .collect()
1645 }
1646 }
1647
1648 fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1654 use rlx_ir::DType;
1655 match dt {
1656 DType::F32 => {
1657 let mut bytes = Vec::with_capacity(v.len() * 4);
1658 for &x in v {
1659 bytes.extend_from_slice(&x.to_le_bytes());
1660 }
1661 bytes
1662 }
1663 DType::F16 => {
1664 let mut bytes = Vec::with_capacity(v.len() * 2);
1665 for &x in v {
1666 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1667 }
1668 bytes
1669 }
1670 DType::BF16 => {
1671 let mut bytes = Vec::with_capacity(v.len() * 2);
1672 for &x in v {
1673 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1674 }
1675 bytes
1676 }
1677 DType::F64 => {
1678 let mut bytes = Vec::with_capacity(v.len() * 8);
1679 for &x in v {
1680 bytes.extend_from_slice(&(x as f64).to_le_bytes());
1681 }
1682 bytes
1683 }
1684 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1685 DType::U8 => v.iter().map(|&x| x as u8).collect(),
1686 DType::I16 => {
1687 let mut bytes = Vec::with_capacity(v.len() * 2);
1688 for &x in v {
1689 bytes.extend_from_slice(&(x as i16).to_le_bytes());
1690 }
1691 bytes
1692 }
1693 DType::I32 => {
1694 let mut bytes = Vec::with_capacity(v.len() * 4);
1695 for &x in v {
1696 bytes.extend_from_slice(&(x as i32).to_le_bytes());
1697 }
1698 bytes
1699 }
1700 DType::U32 => {
1701 let mut bytes = Vec::with_capacity(v.len() * 4);
1702 for &x in v {
1703 bytes.extend_from_slice(&(x as u32).to_le_bytes());
1704 }
1705 bytes
1706 }
1707 DType::I64 => {
1708 let mut bytes = Vec::with_capacity(v.len() * 8);
1709 for &x in v {
1710 bytes.extend_from_slice(&(x as i64).to_le_bytes());
1711 }
1712 bytes
1713 }
1714 DType::Bool => v
1715 .iter()
1716 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1717 .collect(),
1718 DType::C64 => {
1725 let mut bytes = Vec::with_capacity(v.len() * 4);
1726 for &x in v {
1727 bytes.extend_from_slice(&x.to_le_bytes());
1728 }
1729 bytes
1730 }
1731 }
1732 }
1733}
1734
1735#[cfg(all(feature = "mlx", rlx_mlx_host))]
1738pub mod mlx_backend {
1739 use super::*;
1740 use rlx_mlx::MlxExecutable;
1741
1742 pub struct MlxBackend;
1743
1744 const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1754 use rlx_ir::OpKind::*;
1755 &[
1756 Input,
1757 Param,
1758 Constant,
1759 Activation,
1760 Cast,
1761 StopGradient,
1762 Binary,
1763 Compare,
1764 Where,
1765 ElementwiseRegion,
1766 TransformRegion,
1767 BatchElementwiseRegion,
1768 MatMul,
1769 DotGeneral,
1770 DenseSolve,
1771 BatchedDenseSolve,
1772 LayerNorm,
1773 LayerNorm2d,
1774 ResizeNearest2x,
1775 RmsNorm,
1776 Attention,
1777 Rope,
1778 Reshape,
1779 Transpose,
1780 Narrow,
1781 Concat,
1782 Expand,
1783 Gather,
1784 Reduce,
1785 Softmax,
1786 Cumsum,
1787 TopK,
1788 Sample,
1789 Conv,
1790 ConvTranspose2d,
1791 Pool,
1792 GroupedMatMul,
1793 DequantGroupedMatMul,
1794 DequantMoEWeights,
1795 ScatterAdd,
1796 LoraMatMul,
1797 DequantMatMul,
1798 SelectiveScan,
1799 GatedDeltaNet,
1800 FusedSwiGLU,
1801 FusedMatMulBiasAct,
1802 FusedResidualLN,
1803 FusedResidualRmsNorm,
1804 FusedAttentionBlock,
1805 FusedTransformerLayer,
1806 If,
1807 While,
1808 Scan,
1813 ScanBackward,
1814 ScanBackwardXs,
1815 ReluBackward,
1818 ActivationBackward,
1819 SoftmaxCrossEntropyWithLogits,
1820 SoftmaxCrossEntropyBackward,
1821 AttentionBackward,
1822 LayerNormBackwardInput,
1823 LayerNormBackwardGamma,
1824 Conv2dBackwardInput,
1829 Conv2dBackwardWeight,
1830 MaxPool2dBackward,
1834 FakeQuantize,
1839 FakeQuantizeBackward,
1840 Custom,
1845 Fft,
1846 LogMel,
1847 LogMelBackward,
1848 WelchPeaks,
1849 GaussianSplatRender,
1850 GaussianSplatRenderBackward,
1851 ]
1854 };
1855
1856 impl Backend for MlxBackend {
1857 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1858 MLX_SUPPORTED_OPS
1859 }
1860
1861 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1862 let compile_result = crate::stages::compile_graph_stages_for_backend(
1863 rlx_driver::Device::Mlx,
1864 graph,
1865 options,
1866 MLX_SUPPORTED_OPS,
1867 );
1868 crate::stages::maybe_log_fusion(&compile_result.fusion);
1869 self.compile_lir(compile_result.lir, options)
1870 }
1871
1872 fn compile_lir(
1873 &self,
1874 lir: LirModule,
1875 options: &CompileOptions,
1876 ) -> Box<dyn ExecutableGraph> {
1877 use rlx_opt::pass::Pass as _;
1878 let mut graph = lir.into_graph();
1879 graph = rlx_opt::LowerControlFlow.run(graph);
1880 let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1881 Box::new(build_mlx_executable(graph))
1882 }
1883 }
1884
1885 fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1886 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1887 let mode = mlx_mode_from_env();
1888 let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1889 if mode == rlx_mlx::lower::MlxMode::Compiled {
1890 if let Err(e) = exe.warm_compile() {
1891 eprintln!(
1892 "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1893 );
1894 }
1895 }
1896 MlxExecutableWrapper {
1897 inner: exe,
1898 io_manifest,
1899 }
1900 }
1901
1902 fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1903 match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1904 Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1905 Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1906 Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1907 _ => rlx_mlx::lower::MlxMode::Compiled,
1908 }
1909 }
1910
1911 struct MlxExecutableWrapper {
1912 inner: MlxExecutable,
1913 io_manifest: cpu_low_precision::IoDtypeManifest,
1914 }
1915
1916 unsafe impl Send for MlxExecutableWrapper {}
1917
1918 impl ExecutableGraph for MlxExecutableWrapper {
1919 fn set_param(&mut self, name: &str, data: &[f32]) {
1920 self.inner.set_param(name, data);
1921 }
1922 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1923 self.inner.run(inputs)
1924 }
1925 fn run_read_outputs(
1926 &mut self,
1927 inputs: &[(&str, &[f32])],
1928 read_indices: Option<&[usize]>,
1929 ) -> Vec<Vec<f32>> {
1930 self.inner
1931 .run_read_outputs(inputs, read_indices)
1932 .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
1933 }
1934 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1935 self.inner.run_slots(inputs)
1936 }
1937 fn arena_ptr(&self) -> *const u8 {
1938 self.inner.arena_ptr()
1939 }
1940 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1941 self.inner.commit_no_wait(inputs);
1942 }
1943 fn sync_pending(&mut self) {
1944 self.inner.sync_pending();
1945 }
1946 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1947 self.inner.run_pipelined(input_sets)
1948 }
1949 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1950 self.inner.bind_handle(name, data)
1951 }
1952 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1953 self.inner.read_handle(name)
1954 }
1955 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1956 self.inner.bind_gpu_handle(name, data).is_ok()
1957 }
1958 fn has_gpu_handle(&self, name: &str) -> bool {
1959 self.inner.has_gpu_handle(name)
1960 }
1961 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1962 self.inner.set_gpu_handle_feed(handle_name, output_index);
1963 true
1964 }
1965 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1966 self.inner.read_gpu_handle(name).ok()
1967 }
1968 fn run_feed_gpu_handle(
1969 &mut self,
1970 inputs: &[(&str, &[f32])],
1971 handle_name: &str,
1972 output_index: usize,
1973 ) -> Option<Vec<f32>> {
1974 self.inner
1975 .run_feed_gpu(inputs, handle_name, output_index)
1976 .ok()
1977 }
1978 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1979 self.inner.set_param_typed(name, data, dtype);
1980 }
1981 fn run_typed(
1982 &mut self,
1983 inputs: &[(&str, &[u8], rlx_ir::DType)],
1984 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1985 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1986 for (name, data, dt) in inputs {
1987 let v = super::widen_bytes_to_f32(data, *dt);
1988 owned.push((name.to_string(), v));
1989 }
1990 let refs: Vec<(&str, &[f32])> = owned
1991 .iter()
1992 .map(|(n, d)| (n.as_str(), d.as_slice()))
1993 .collect();
1994 let f32_outs = self.inner.run(&refs);
1995 let declared = super::declared_output_dtypes(
1996 &self.io_manifest,
1997 (0..f32_outs.len()).map(|_| rlx_ir::DType::F32).collect(),
1998 );
1999 f32_outs
2000 .into_iter()
2001 .zip(
2002 declared
2003 .into_iter()
2004 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2005 )
2006 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2007 .collect()
2008 }
2009 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2010 self.inner.set_active_extent(extent);
2011 }
2012 }
2013}
2014
2015#[cfg(all(feature = "metal", target_os = "macos"))]
2016pub mod metal_backend {
2017 use super::*;
2018 use rlx_metal::backend::MetalExecutable;
2019
2020 pub struct MetalBackend;
2021
2022 const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2032 use rlx_ir::OpKind::*;
2033 &[
2034 Input,
2035 Param,
2036 Constant,
2037 Activation,
2038 Cast,
2039 StopGradient,
2040 Binary,
2041 Compare,
2042 Where,
2043 ElementwiseRegion,
2044 TransformRegion,
2045 BatchElementwiseRegion,
2046 MatMul,
2047 DotGeneral,
2048 LayerNorm,
2049 LayerNorm2d,
2050 GroupNorm,
2051 RmsNorm,
2052 ResizeNearest2x,
2053 AxialRope2d,
2054 Attention,
2055 AttentionBackward,
2056 RmsNormBackwardInput,
2057 RmsNormBackwardGamma,
2058 RmsNormBackwardBeta,
2059 RopeBackward,
2060 CumsumBackward,
2061 GatherBackward,
2062 Conv2dBackwardInput,
2063 Conv2dBackwardWeight,
2064 MaxPool2dBackward,
2065 Rope,
2066 Reshape,
2067 Transpose,
2068 Narrow,
2069 Concat,
2070 Expand,
2071 Gather,
2072 Reduce,
2073 Softmax,
2074 TopK,
2075 Conv,
2076 Im2Col,
2077 ConvTranspose2d,
2078 Pool,
2079 GroupedMatMul,
2080 DequantGroupedMatMul,
2081 DequantMoEWeights,
2082 ScatterAdd,
2083 DequantMatMul,
2084 GatedDeltaNet,
2085 FusedSwiGLU,
2086 FusedMatMulBiasAct,
2087 FusedResidualLN,
2088 FusedResidualRmsNorm,
2089 Custom,
2095 Fft,
2101 LogMel,
2102 LogMelBackward,
2103 WelchPeaks,
2104 GaussianSplatRender,
2106 GaussianSplatRenderBackward,
2107 GaussianSplatPrepare,
2108 GaussianSplatRasterize,
2109 ]
2110 };
2111
2112 impl Backend for MetalBackend {
2113 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2114 METAL_SUPPORTED_OPS
2115 }
2116
2117 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2118 use rlx_opt::pass::Pass as _;
2119 let graph = rlx_opt::LowerControlFlow.run(graph);
2123 let dispatch = options.kernel_dispatch;
2124 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2125 graph,
2126 METAL_SUPPORTED_OPS,
2127 dispatch,
2128 )
2129 .unwrap_or_else(|errors| {
2130 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2131 });
2132 let graph = crate::precompile::precompile_cleanup(graph, options);
2133
2134 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2137 Box::new(MetalExecutableWrapper {
2138 inner: MetalExecutable::compile_with_policy(
2139 graph,
2140 options.policy.clone(),
2141 Some(METAL_SUPPORTED_OPS),
2142 ),
2143 io_manifest,
2144 })
2145 }
2146
2147 fn compile_lir(
2148 &self,
2149 lir: LirModule,
2150 options: &CompileOptions,
2151 ) -> Box<dyn ExecutableGraph> {
2152 use rlx_opt::pass::Pass as _;
2153 let mut graph = lir.into_graph();
2154 graph = rlx_opt::LowerControlFlow.run(graph);
2155 let dispatch = options.kernel_dispatch;
2156 let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2157 graph,
2158 METAL_SUPPORTED_OPS,
2159 dispatch,
2160 )
2161 .unwrap_or_else(|errors| {
2162 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2163 });
2164 graph = crate::precompile::precompile_cleanup(graph, options);
2165 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2166 Box::new(MetalExecutableWrapper {
2167 inner: MetalExecutable::compile_from_fused(
2168 graph,
2169 options.policy.clone(),
2170 Some(METAL_SUPPORTED_OPS),
2171 ),
2172 io_manifest,
2173 })
2174 }
2175 }
2176
2177 struct MetalExecutableWrapper {
2178 inner: MetalExecutable,
2179 io_manifest: cpu_low_precision::IoDtypeManifest,
2180 }
2181
2182 unsafe impl Send for MetalExecutableWrapper {}
2183
2184 impl ExecutableGraph for MetalExecutableWrapper {
2185 fn set_param(&mut self, name: &str, data: &[f32]) {
2186 self.inner.set_param(name, data);
2187 }
2188 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2189 self.inner.run(inputs)
2190 }
2191 fn run_read_outputs(
2192 &mut self,
2193 inputs: &[(&str, &[f32])],
2194 read_indices: Option<&[usize]>,
2195 ) -> Vec<Vec<f32>> {
2196 self.inner.run_read_outputs(inputs, read_indices)
2197 }
2198 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2199 self.inner.bind_gpu_handle(name, data)
2200 }
2201 fn has_gpu_handle(&self, name: &str) -> bool {
2202 self.inner.has_gpu_handle(name)
2203 }
2204 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2205 self.inner.set_gpu_handle_feed(handle_name, output_index);
2206 true
2207 }
2208 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2209 self.inner.read_gpu_handle(name)
2210 }
2211 fn read_output_row(
2212 &self,
2213 out_idx: usize,
2214 row: usize,
2215 row_inner: usize,
2216 ) -> Option<Vec<f32>> {
2217 Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2218 }
2219 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2220 self.inner.run_slots(inputs)
2221 }
2222 fn arena_ptr(&self) -> *const u8 {
2223 self.inner.arena_ptr()
2224 }
2225 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2226 self.inner.commit_no_wait(inputs);
2227 }
2228 fn sync_pending(&mut self) {
2229 self.inner.sync_pending();
2230 }
2231 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2232 self.inner.run_pipelined(input_sets)
2233 }
2234 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2235 self.inner.set_active_extent(extent);
2236 }
2237
2238 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2244 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2245 self.inner.set_param_bytes(name, data);
2246 return;
2247 }
2248 if dtype == rlx_ir::DType::F32 {
2249 let n = data.len() / 4;
2250 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2251 self.inner.set_param(name, s);
2252 } else {
2253 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2254 self.inner.set_param(name, &f32_buf);
2255 }
2256 }
2257
2258 fn run_typed(
2266 &mut self,
2267 inputs: &[(&str, &[u8], rlx_ir::DType)],
2268 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2269 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2270 for (name, data, dt) in inputs {
2271 let v = super::widen_bytes_to_f32(data, *dt);
2272 owned.push((name.to_string(), v));
2273 }
2274 let refs: Vec<(&str, &[f32])> = owned
2275 .iter()
2276 .map(|(n, d)| (n.as_str(), d.as_slice()))
2277 .collect();
2278 let dtypes =
2279 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2280 let f32_outs = self.inner.run(&refs);
2281 let byte_outs = self.inner.output_bytes_per_node();
2282 f32_outs
2283 .into_iter()
2284 .zip(byte_outs.into_iter())
2285 .zip(
2286 dtypes
2287 .into_iter()
2288 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2289 )
2290 .map(|((f32_v, byte_v), dt)| match dt {
2291 rlx_ir::DType::F64 => (byte_v, dt),
2292 _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2293 })
2294 .collect()
2295 }
2296 }
2297}
2298
2299#[cfg(feature = "cuda")]
2302pub mod cuda_backend {
2303 use super::*;
2304 use rlx_cuda::backend::CudaExecutable;
2305
2306 pub struct CudaBackend;
2307
2308 const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2314 use rlx_ir::OpKind::*;
2315 &[
2316 Input,
2317 Param,
2318 Constant,
2319 Activation,
2320 Cast,
2321 Binary,
2322 Compare,
2323 Where,
2324 ElementwiseRegion,
2325 TransformRegion,
2326 BatchElementwiseRegion,
2327 MatMul,
2328 DotGeneral,
2329 LayerNorm,
2330 LayerNorm2d,
2331 GroupNorm,
2332 ResizeNearest2x,
2333 RmsNorm,
2334 Attention,
2335 AttentionBackward,
2336 RmsNormBackwardInput,
2337 RmsNormBackwardGamma,
2338 RmsNormBackwardBeta,
2339 RopeBackward,
2340 CumsumBackward,
2341 GatherBackward,
2342 Conv2dBackwardInput,
2343 Conv2dBackwardWeight,
2344 MaxPool2dBackward,
2345 Rope,
2346 Reshape,
2347 Transpose,
2348 Narrow,
2349 Concat,
2350 Expand,
2351 Gather,
2352 Reduce,
2353 Softmax,
2354 Cumsum,
2355 TopK,
2356 Sample,
2357 Conv,
2358 ConvTranspose2d,
2359 Pool,
2360 GroupedMatMul,
2361 DequantGroupedMatMul,
2362 DequantMoEWeights,
2363 ScatterAdd,
2364 DequantMatMul,
2365 SelectiveScan,
2366 FusedMatMulBiasAct,
2367 FusedResidualLN,
2368 FusedResidualRmsNorm,
2369 GaussianSplatRender,
2370 GaussianSplatRenderBackward,
2371 GaussianSplatPrepare,
2372 GaussianSplatRasterize,
2373 Custom,
2374 Fft,
2375 LogMel,
2376 LogMelBackward,
2377 WelchPeaks,
2378 Im2Col,
2379 ]
2380 };
2381
2382 impl Backend for CudaBackend {
2383 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2384 CUDA_SUPPORTED_OPS
2385 }
2386
2387 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2388 use rlx_opt::pass::Pass as _;
2389 let graph = rlx_cuda::unfuse::unfuse(graph);
2392 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2393 .unwrap_or_else(|errors| {
2394 panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2395 });
2396 let graph = crate::precompile::precompile_cleanup(graph, options);
2397 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2399 let compile_result = crate::stages::compile_graph_stages_for_backend(
2401 rlx_driver::Device::Cuda,
2402 graph,
2403 options,
2404 CUDA_SUPPORTED_OPS,
2405 );
2406 crate::stages::maybe_log_fusion(&compile_result.fusion);
2407 let graph = compile_result.lir.into_graph();
2408 let graph = match options.policy.clone() {
2409 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2410 None => graph,
2411 };
2412 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2413 Box::new(CudaExecutableWrapper {
2414 inner: CudaExecutable::compile(graph),
2415 io_manifest,
2416 })
2417 }
2418
2419 fn compile_lir(
2420 &self,
2421 lir: LirModule,
2422 options: &CompileOptions,
2423 ) -> Box<dyn ExecutableGraph> {
2424 use rlx_opt::pass::Pass as _;
2425 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2426 let (graph, io_manifest) =
2427 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2428 rlx_cuda::unfuse::unfuse(graph),
2429 options,
2430 CUDA_SUPPORTED_OPS,
2431 "cuda",
2432 ));
2433 Box::new(CudaExecutableWrapper {
2434 inner: CudaExecutable::compile(graph),
2435 io_manifest,
2436 })
2437 }
2438 }
2439
2440 struct CudaExecutableWrapper {
2441 inner: CudaExecutable,
2442 io_manifest: cpu_low_precision::IoDtypeManifest,
2443 }
2444
2445 unsafe impl Send for CudaExecutableWrapper {}
2450
2451 impl ExecutableGraph for CudaExecutableWrapper {
2452 fn set_param(&mut self, name: &str, data: &[f32]) {
2453 self.inner.set_param(name, data);
2454 }
2455 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2456 self.inner.run(inputs)
2457 }
2458 fn run_read_outputs(
2459 &mut self,
2460 inputs: &[(&str, &[f32])],
2461 read_indices: Option<&[usize]>,
2462 ) -> Vec<Vec<f32>> {
2463 self.inner.run_read_outputs(inputs, read_indices)
2464 }
2465 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2466 self.inner.bind_gpu_handle(name, data)
2467 }
2468 fn has_gpu_handle(&self, name: &str) -> bool {
2469 self.inner.has_gpu_handle(name)
2470 }
2471 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2472 self.inner.set_gpu_handle_feed(handle_name, output_index);
2473 true
2474 }
2475 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2476 self.inner.read_gpu_handle(name)
2477 }
2478 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2479 self.inner.set_active_extent(extent);
2480 }
2481
2482 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2483 self.inner.run_slots(inputs)
2484 }
2485
2486 fn arena_ptr(&self) -> *const u8 {
2487 self.inner.arena_ptr()
2488 }
2489
2490 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2495 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2496 self.inner.set_param_bytes(name, data);
2497 return;
2498 }
2499 if dtype == rlx_ir::DType::F32 {
2500 let n = data.len() / 4;
2501 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2502 self.inner.set_param(name, s);
2503 } else {
2504 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2505 self.inner.set_param(name, &f32_buf);
2506 }
2507 }
2508
2509 fn run_typed(
2512 &mut self,
2513 inputs: &[(&str, &[u8], rlx_ir::DType)],
2514 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2515 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2516 for (name, data, dt) in inputs {
2517 let v = super::widen_bytes_to_f32(data, *dt);
2518 owned.push((name.to_string(), v));
2519 }
2520 let refs: Vec<(&str, &[f32])> = owned
2521 .iter()
2522 .map(|(n, d)| (n.as_str(), d.as_slice()))
2523 .collect();
2524 let dtypes =
2525 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2526 let outs = self.inner.run(&refs);
2527 outs.into_iter()
2528 .zip(
2529 dtypes
2530 .into_iter()
2531 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2532 )
2533 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2534 .collect()
2535 }
2536 }
2537}
2538
2539#[cfg(feature = "rocm")]
2542pub mod rocm_backend {
2543 use super::*;
2544 use rlx_rocm::backend::RocmExecutable;
2545
2546 pub struct RocmBackend;
2547
2548 const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2551 use rlx_ir::OpKind::*;
2552 &[
2553 Input,
2554 Param,
2555 Constant,
2556 Activation,
2557 Cast,
2558 Binary,
2559 Compare,
2560 Where,
2561 ElementwiseRegion,
2562 TransformRegion,
2563 BatchElementwiseRegion,
2564 MatMul,
2565 DotGeneral,
2566 LayerNorm,
2567 LayerNorm2d,
2568 GroupNorm,
2569 ResizeNearest2x,
2570 RmsNorm,
2571 Attention,
2572 AttentionBackward,
2573 RmsNormBackwardInput,
2574 RmsNormBackwardGamma,
2575 RmsNormBackwardBeta,
2576 RopeBackward,
2577 CumsumBackward,
2578 GatherBackward,
2579 Rope,
2580 Reshape,
2581 Transpose,
2582 Narrow,
2583 Concat,
2584 Expand,
2585 Gather,
2586 Reduce,
2587 Softmax,
2588 Cumsum,
2589 TopK,
2590 Sample,
2591 Conv,
2592 ConvTranspose2d,
2593 Pool,
2594 GroupedMatMul,
2595 DequantGroupedMatMul,
2596 DequantMoEWeights,
2597 ScatterAdd,
2598 DequantMatMul,
2599 SelectiveScan,
2600 FusedMatMulBiasAct,
2601 FusedResidualLN,
2602 FusedResidualRmsNorm,
2603 GaussianSplatRender,
2604 GaussianSplatRenderBackward,
2605 GaussianSplatPrepare,
2606 GaussianSplatRasterize,
2607 Custom,
2608 Fft,
2609 LogMel,
2610 LogMelBackward,
2611 WelchPeaks,
2612 Im2Col,
2613 ]
2614 };
2615
2616 impl Backend for RocmBackend {
2617 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2618 ROCM_SUPPORTED_OPS
2619 }
2620
2621 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2622 use rlx_opt::pass::Pass as _;
2623 let graph = rlx_rocm::unfuse::unfuse(graph);
2624 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2625 .unwrap_or_else(|errors| {
2626 panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2627 });
2628 let graph = crate::precompile::precompile_cleanup(graph, options);
2629 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2630 let compile_result = crate::stages::compile_graph_stages_for_backend(
2631 rlx_driver::Device::Rocm,
2632 graph,
2633 options,
2634 ROCM_SUPPORTED_OPS,
2635 );
2636 crate::stages::maybe_log_fusion(&compile_result.fusion);
2637 let graph = compile_result.lir.into_graph();
2638 let graph = match options.policy.clone() {
2639 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2640 None => graph,
2641 };
2642 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2643 Box::new(RocmExecutableWrapper {
2644 inner: RocmExecutable::compile(graph),
2645 io_manifest,
2646 })
2647 }
2648
2649 fn compile_lir(
2650 &self,
2651 lir: LirModule,
2652 options: &CompileOptions,
2653 ) -> Box<dyn ExecutableGraph> {
2654 let (graph, io_manifest) =
2655 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2656 rlx_rocm::unfuse::unfuse(lir.into_graph()),
2657 options,
2658 ROCM_SUPPORTED_OPS,
2659 "rocm",
2660 ));
2661 Box::new(RocmExecutableWrapper {
2662 inner: RocmExecutable::compile(graph),
2663 io_manifest,
2664 })
2665 }
2666 }
2667
2668 struct RocmExecutableWrapper {
2669 inner: RocmExecutable,
2670 io_manifest: cpu_low_precision::IoDtypeManifest,
2671 }
2672
2673 unsafe impl Send for RocmExecutableWrapper {}
2677
2678 impl ExecutableGraph for RocmExecutableWrapper {
2679 fn set_param(&mut self, name: &str, data: &[f32]) {
2680 self.inner.set_param(name, data);
2681 }
2682 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2683 self.inner.run(inputs)
2684 }
2685 fn run_read_outputs(
2686 &mut self,
2687 inputs: &[(&str, &[f32])],
2688 read_indices: Option<&[usize]>,
2689 ) -> Vec<Vec<f32>> {
2690 self.inner.run_read_outputs(inputs, read_indices)
2691 }
2692 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2693 self.inner.bind_gpu_handle(name, data)
2694 }
2695 fn has_gpu_handle(&self, name: &str) -> bool {
2696 self.inner.has_gpu_handle(name)
2697 }
2698 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2699 self.inner.set_gpu_handle_feed(handle_name, output_index);
2700 true
2701 }
2702 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2703 self.inner.read_gpu_handle(name)
2704 }
2705 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2706 self.inner.run_slots(inputs)
2707 }
2708 fn arena_ptr(&self) -> *const u8 {
2709 self.inner.arena_ptr()
2710 }
2711 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2712 self.inner.set_active_extent(extent);
2713 }
2714
2715 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2720 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2721 self.inner.set_param_bytes(name, data);
2722 return;
2723 }
2724 if dtype == rlx_ir::DType::F32 {
2725 let n = data.len() / 4;
2726 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2727 self.inner.set_param(name, s);
2728 } else {
2729 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2730 self.inner.set_param(name, &f32_buf);
2731 }
2732 }
2733
2734 fn run_typed(
2737 &mut self,
2738 inputs: &[(&str, &[u8], rlx_ir::DType)],
2739 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2740 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2741 for (name, data, dt) in inputs {
2742 let v = super::widen_bytes_to_f32(data, *dt);
2743 owned.push((name.to_string(), v));
2744 }
2745 let refs: Vec<(&str, &[f32])> = owned
2746 .iter()
2747 .map(|(n, d)| (n.as_str(), d.as_slice()))
2748 .collect();
2749 let dtypes =
2750 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2751 let outs = self.inner.run(&refs);
2752 outs.into_iter()
2753 .zip(
2754 dtypes
2755 .into_iter()
2756 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2757 )
2758 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2759 .collect()
2760 }
2761 }
2762}
2763
2764#[cfg(feature = "tpu")]
2767pub mod tpu_backend {
2768 use super::*;
2769 use rlx_tpu::TpuExecutable;
2770
2771 pub struct TpuBackend;
2772
2773 const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2779 use rlx_ir::OpKind::*;
2780 &[
2781 Input,
2782 Param,
2783 Constant,
2784 Activation,
2785 Cast,
2786 Binary,
2787 Compare,
2788 Where,
2789 ElementwiseRegion,
2790 TransformRegion,
2791 BatchElementwiseRegion,
2792 MatMul,
2793 DotGeneral,
2794 LayerNorm,
2795 RmsNorm,
2796 Attention,
2797 Rope,
2798 Reshape,
2799 Transpose,
2800 Narrow,
2801 Concat,
2802 Expand,
2803 Gather,
2804 Reduce,
2805 Softmax,
2806 Cumsum,
2807 TopK,
2808 Sample,
2809 Conv,
2810 Pool,
2811 GroupedMatMul,
2812 DequantGroupedMatMul,
2813 DequantMoEWeights,
2814 ScatterAdd,
2815 DequantMatMul,
2816 SelectiveScan,
2817 QMatMul,
2819 QConv2d,
2820 Quantize,
2821 Dequantize,
2822 FusedMatMulBiasAct,
2823 FusedResidualLN,
2824 FusedResidualRmsNorm,
2825 Fft,
2826 LogMel,
2827 LogMelBackward,
2828 WelchPeaks,
2829 ]
2831 };
2832
2833 impl Backend for TpuBackend {
2834 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2835 TPU_SUPPORTED_OPS
2836 }
2837
2838 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2839 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2840 graph,
2841 TPU_SUPPORTED_OPS,
2842 options.kernel_dispatch,
2843 )
2844 .unwrap_or_else(|errors| {
2845 panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2846 });
2847 use rlx_opt::pass::Pass as _;
2863 let policy = options
2864 .policy
2865 .clone()
2866 .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2867 let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2868 let _ = options.dce;
2869 let _ = options.constant_folding;
2870 Box::new(TpuExecutableWrapper {
2871 inner: TpuExecutable::compile(graph),
2872 })
2873 }
2874 }
2875
2876 struct TpuExecutableWrapper {
2877 inner: TpuExecutable,
2878 }
2879
2880 unsafe impl Send for TpuExecutableWrapper {}
2884
2885 impl ExecutableGraph for TpuExecutableWrapper {
2886 fn set_param(&mut self, name: &str, data: &[f32]) {
2887 self.inner.set_param(name, data);
2888 }
2889 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2890 self.inner.run(inputs)
2891 }
2892
2893 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2898 if dtype == rlx_ir::DType::F32 {
2899 let n = data.len() / 4;
2900 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2901 self.inner.set_param(name, s);
2902 } else {
2903 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2904 self.inner.set_param(name, &f32_buf);
2905 }
2906 }
2907
2908 fn run_typed(
2909 &mut self,
2910 inputs: &[(&str, &[u8], rlx_ir::DType)],
2911 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2912 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2913 for (name, data, dt) in inputs {
2914 let v = super::widen_bytes_to_f32(data, *dt);
2915 owned.push((name.to_string(), v));
2916 }
2917 let refs: Vec<(&str, &[f32])> = owned
2918 .iter()
2919 .map(|(n, d)| (n.as_str(), d.as_slice()))
2920 .collect();
2921 let dtypes = self.inner.output_dtypes();
2922 let outs = self.inner.run(&refs);
2923 outs.into_iter()
2924 .zip(
2925 dtypes
2926 .into_iter()
2927 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2928 )
2929 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2930 .collect()
2931 }
2932 }
2933}