1use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39 use rlx_ir::DType;
40 match dtype {
41 DType::F32 => {
42 let n = data.len() / 4;
43 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44 s.to_vec()
45 }
46 DType::F16 => {
47 let n = data.len() / 2;
48 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49 s.iter().map(|h| h.to_f32()).collect()
50 }
51 DType::BF16 => {
52 let n = data.len() / 2;
53 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54 s.iter().map(|h| h.to_f32()).collect()
55 }
56 other => panic!(
57 "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58 (only F32/F16/BF16 are accepted on the host I/O surface)"
59 ),
60 }
61}
62
63#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69 use rlx_ir::DType;
70 match dt {
71 DType::F32 => {
72 let mut bytes = Vec::with_capacity(v.len() * 4);
73 for &x in v {
74 bytes.extend_from_slice(&x.to_le_bytes());
75 }
76 bytes
77 }
78 DType::F16 => {
79 let mut bytes = Vec::with_capacity(v.len() * 2);
80 for &x in v {
81 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82 }
83 bytes
84 }
85 DType::BF16 => {
86 let mut bytes = Vec::with_capacity(v.len() * 2);
87 for &x in v {
88 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89 }
90 bytes
91 }
92 DType::F64 => {
93 let mut bytes = Vec::with_capacity(v.len() * 8);
94 for &x in v {
95 bytes.extend_from_slice(&(x as f64).to_le_bytes());
96 }
97 bytes
98 }
99 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100 DType::U8 => v.iter().map(|&x| x as u8).collect(),
101 DType::I16 => {
102 let mut bytes = Vec::with_capacity(v.len() * 2);
103 for &x in v {
104 bytes.extend_from_slice(&(x as i16).to_le_bytes());
105 }
106 bytes
107 }
108 DType::I32 => {
109 let mut bytes = Vec::with_capacity(v.len() * 4);
110 for &x in v {
111 bytes.extend_from_slice(&(x as i32).to_le_bytes());
112 }
113 bytes
114 }
115 DType::U32 => {
116 let mut bytes = Vec::with_capacity(v.len() * 4);
117 for &x in v {
118 bytes.extend_from_slice(&(x as u32).to_le_bytes());
119 }
120 bytes
121 }
122 DType::I64 => {
123 let mut bytes = Vec::with_capacity(v.len() * 8);
124 for &x in v {
125 bytes.extend_from_slice(&(x as i64).to_le_bytes());
126 }
127 bytes
128 }
129 DType::Bool => v
130 .iter()
131 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132 .collect(),
133 DType::C64 => {
134 let mut bytes = Vec::with_capacity(v.len() * 8);
138 for &x in v {
139 bytes.extend_from_slice(&x.to_le_bytes());
140 bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141 }
142 bytes
143 }
144 }
145}
146
147pub trait ExecutableGraph: Send {
149 fn set_param(&mut self, name: &str, data: &[f32]);
151
152 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
159 panic!("clone_box not implemented for this backend");
160 }
161
162 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
164
165 fn run_read_outputs(
168 &mut self,
169 inputs: &[(&str, &[f32])],
170 read_indices: Option<&[usize]>,
171 ) -> Vec<Vec<f32>> {
172 match read_indices {
173 None => self.run(inputs),
174 Some(ix) => {
175 let all = self.run(inputs);
178 ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
179 }
180 }
181 }
182
183 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
185 let vecs = self.run(inputs);
186 vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
187 }
188
189 fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
192 &[] }
194
195 fn arena_ptr(&self) -> *const u8 {
197 std::ptr::null()
198 }
199
200 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
217 let _ = extent;
218 }
219
220 fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
222
223 fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
225
226 fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
228 false
229 }
230
231 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
233 None
234 }
235
236 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
238 None
239 }
240
241 fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
245 false
246 }
247
248 fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
250 None
251 }
252
253 fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
255 false
256 }
257
258 fn has_gpu_handle(&self, _name: &str) -> bool {
259 false
260 }
261
262 fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
263 false
264 }
265
266 fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
267 None
268 }
269
270 fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
273 None
274 }
275
276 fn run_feed_gpu_handle(
278 &mut self,
279 inputs: &[(&str, &[f32])],
280 _handle_name: &str,
281 _output_index: usize,
282 ) -> Option<Vec<f32>> {
283 let _ = inputs;
284 None
285 }
286
287 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
302 let _ = self.run(inputs);
303 }
304
305 fn sync_pending(&mut self) {}
308
309 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
318 input_sets.iter().map(|inputs| self.run(inputs)).collect()
319 }
320
321 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
334 if dtype != rlx_ir::DType::F32 {
335 panic!(
336 "backend's default set_param_typed only handles F32; \
337 got {dtype:?}. Override on the backend for typed support."
338 );
339 }
340 if !data.len().is_multiple_of(4) {
341 panic!(
342 "set_param_typed F32: data length {} not a multiple of 4",
343 data.len()
344 );
345 }
346 let n = data.len() / 4;
351 let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
352 self.set_param(name, f32_slice);
353 }
354
355 fn run_typed(
359 &mut self,
360 inputs: &[(&str, &[u8], rlx_ir::DType)],
361 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
362 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
365 for (name, data, dt) in inputs {
366 if *dt != rlx_ir::DType::F32 {
367 panic!(
368 "backend's default run_typed only handles F32 inputs; \
369 got {dt:?} for input '{name}'"
370 );
371 }
372 if data.len() % 4 != 0 {
373 panic!(
374 "run_typed F32 input '{name}': len {} not multiple of 4",
375 data.len()
376 );
377 }
378 let n = data.len() / 4;
379 let v: Vec<f32> =
380 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
381 owned.push((name.to_string(), v));
382 }
383 let refs: Vec<(&str, &[f32])> = owned
384 .iter()
385 .map(|(n, d)| (n.as_str(), d.as_slice()))
386 .collect();
387 let outs = self.run(&refs);
388 outs.into_iter()
389 .map(|v| {
390 let bytes =
391 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
392 .to_vec();
393 (bytes, rlx_ir::DType::F32)
394 })
395 .collect()
396 }
397}
398
399pub trait Backend: Send + Sync {
409 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
411
412 fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
416 self.compile(lir.into_graph(), options)
417 }
418
419 fn compile_hir(
421 &self,
422 hir: HirModule,
423 device: rlx_driver::Device,
424 options: &CompileOptions,
425 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
426 let result = crate::stages::compile_hir_stages(device, hir, options)?;
427 crate::stages::maybe_log_fusion(&result.fusion);
428 Ok(self.compile_lir(result.lir, options))
429 }
430
431 fn compile_module(
433 &self,
434 module: rlx_ir::GraphModule,
435 device: rlx_driver::Device,
436 options: &CompileOptions,
437 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
438 let result = crate::stages::compile_module_stages(device, module, options)?;
439 crate::stages::maybe_log_fusion(&result.fusion);
440 Ok(self.compile_lir(result.lir, options))
441 }
442
443 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
450 &[]
451 }
452}
453
454#[allow(dead_code)]
457fn prepare_fused_graph(
458 graph: Graph,
459 options: &CompileOptions,
460 supported_ops: &[rlx_ir::OpKind],
461 backend_name: &str,
462) -> Graph {
463 let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
464 graph,
465 backend_name,
466 supported_ops,
467 options.kernel_dispatch,
468 );
469 rlx_opt::maybe_log_dispatch_report(&report);
470 if !report.compile_ready {
471 panic!(
472 "{}\n{}",
473 rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
474 rlx_opt::format_dispatch_report(&report)
475 );
476 }
477 graph = crate::precompile::post_fusion_cleanup(graph, options);
478 if let Some(p) = options.policy.clone() {
479 use rlx_opt::pass::Pass as _;
480 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
481 }
482 graph
483}
484
485#[allow(dead_code)]
486fn declared_output_dtypes(
487 manifest: &cpu_low_precision::IoDtypeManifest,
488 exec_dtypes: Vec<rlx_ir::DType>,
489) -> Vec<rlx_ir::DType> {
490 exec_dtypes
491 .into_iter()
492 .enumerate()
493 .map(|(i, exec)| manifest.output_dtype(i, exec))
494 .collect()
495}
496
497pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
505 backend.compile(graph, &CompileOptions::default())
506}
507
508pub fn compile_hir(
510 backend: &dyn Backend,
511 hir: HirModule,
512 device: rlx_driver::Device,
513 options: &CompileOptions,
514) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
515 backend.compile_hir(hir, device, options)
516}
517
518pub fn compile_module(
520 backend: &dyn Backend,
521 module: rlx_ir::GraphModule,
522 device: rlx_driver::Device,
523 options: &CompileOptions,
524) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
525 backend.compile_module(module, device, options)
526}
527
528pub fn compile_with_precision(
530 backend: &dyn Backend,
531 graph: Graph,
532 precision: crate::Precision,
533) -> Box<dyn ExecutableGraph> {
534 backend.compile(graph, &CompileOptions::new().precision(precision))
535}
536
537fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
542 use rlx_opt::pass::Pass as _;
543 match policy {
544 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
545 None => graph,
546 }
547}
548
549#[cfg(feature = "cpu")]
552pub mod cpu_backend {
553 use super::*;
554 use rlx_cpu::{arena::Arena, thunk};
555 use rlx_ir::{DType, NodeId, Op};
556 use rlx_opt::memory::{self, MemoryPlan};
557 use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
560
561 pub struct CpuBackend;
562
563 const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
570 use rlx_ir::OpKind::*;
571 &[
572 Input,
573 Param,
574 Constant,
575 Activation,
576 Cast,
577 StopGradient,
578 Binary,
579 Compare,
580 Where,
581 ElementwiseRegion,
582 MatMul,
583 DotGeneral,
584 DenseSolve,
585 BatchedDenseSolve,
586 Scan,
587 ScanBackward,
588 ScanBackwardXs,
589 LayerNorm,
590 LayerNorm2d,
591 GroupNorm,
592 BatchNormInference,
593 RmsNorm,
594 ResizeNearest2x,
595 AxialRope2d,
596 Attention,
597 Rope,
598 Reshape,
599 Transpose,
600 Narrow,
601 Concat,
602 Expand,
603 Gather,
604 Reduce,
605 Softmax,
606 Cumsum,
607 TopK,
608 Sample,
609 Conv,
610 Im2Col,
611 ConvTranspose2d,
612 Pool,
613 GroupedMatMul,
614 DequantGroupedMatMul,
615 DequantMoEWeights,
616 ScatterAdd,
617 LoraMatMul,
618 DequantMatMul,
619 SelectiveScan,
620 GatedDeltaNet,
621 FusedSwiGLU,
622 FusedMatMulBiasAct,
623 FusedResidualLN,
624 FusedResidualRmsNorm,
625 FusedAttentionBlock,
626 ReluBackward,
631 ActivationBackward,
632 FakeQuantize,
633 FakeQuantizeBackward,
634 MaxPool2dBackward,
635 Conv2dBackwardInput,
636 Conv2dBackwardWeight,
637 SoftmaxCrossEntropyWithLogits,
638 SoftmaxCrossEntropyBackward,
639 AttentionBackward,
640 LayerNormBackwardInput,
641 LayerNormBackwardGamma,
642 BatchNormInferenceBackwardInput,
643 BatchNormInferenceBackwardGamma,
644 BatchNormInferenceBackwardBeta,
645 RmsNormBackwardInput,
646 RmsNormBackwardGamma,
647 RmsNormBackwardBeta,
648 RopeBackward,
649 CumsumBackward,
650 GatherBackward,
651 GaussianSplatRender,
653 GaussianSplatRenderBackward,
654 GaussianSplatPrepare,
655 GaussianSplatRasterize,
656 Custom,
660 CustomFn,
664 Fft,
668 FftButterflyStage,
669 LogMel,
670 LogMelBackward,
671 ComplexNormSq,
676 ComplexNormSqBackward,
677 Conjugate,
678 ]
679 };
680
681 impl Backend for CpuBackend {
682 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
683 CPU_SUPPORTED_OPS
684 }
685
686 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
687 use rlx_opt::pass::Pass as _;
688 let graph = rlx_opt::LowerControlFlow.run(graph);
694 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
698 panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
699 }
700 let policy = options.policy.clone();
701 let _precision = options.precision;
702 let cfg = rlx_cpu::config::RuntimeConfig::global();
703
704 let graph = crate::precompile::precompile_cleanup(graph, options);
705
706 let mut compile_opts = options.clone();
708 compile_opts.arena_alignment = cfg.arena_alignment;
709 let compile_result = crate::stages::compile_graph_stages_for_backend(
710 rlx_driver::Device::Cpu,
711 graph,
712 &compile_opts,
713 CPU_SUPPORTED_OPS,
714 );
715 crate::stages::maybe_log_fusion(&compile_result.fusion);
716 let fused = compile_result.lir.into_graph();
717
718 let fused = match policy {
721 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
722 None => fused,
723 };
724
725 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
726 let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
727 cpu_low_precision::promote_to_f32(fused)
728 } else {
729 fused
730 };
731
732 let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
734 if cfg.verbose >= 1 {
735 eprintln!(
736 "[rlx] arena: {} bytes, {} buffers, alignment: {}",
737 plan.arena_size,
738 plan.assignments.len(),
739 cfg.arena_alignment
740 );
741 }
742 Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
743 }
744
745 fn compile_lir(
746 &self,
747 lir: LirModule,
748 options: &CompileOptions,
749 ) -> Box<dyn ExecutableGraph> {
750 let alignment = lir.buffers.alignment.max(options.arena_alignment);
751 let mut graph = lir.into_graph();
752 {
753 use rlx_opt::pass::Pass as _;
754 graph = rlx_opt::LegalizeBroadcast.run(graph);
755 }
756 if let Some(p) = options.policy.clone() {
757 use rlx_opt::pass::Pass;
758 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
759 }
760 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
761 let promote = cpu_low_precision::needs_f32_exec(&graph);
762 let exec_graph = if promote {
763 cpu_low_precision::promote_to_f32(graph)
764 } else {
765 graph
766 };
767 let plan = memory::plan_memory_aligned(&exec_graph, alignment);
770 let cfg = rlx_cpu::config::RuntimeConfig::global();
771 if cfg.verbose >= 1 {
772 eprintln!(
773 "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
774 plan.arena_size,
775 plan.assignments.len(),
776 alignment,
777 );
778 }
779 Box::new(build_cpu_executable(exec_graph, plan, io_manifest))
780 }
781 }
782
783 fn build_cpu_executable(
784 graph: Graph,
785 plan: MemoryPlan,
786 io_manifest: cpu_low_precision::IoDtypeManifest,
787 ) -> CpuExecutable {
788 let mut arena = Arena::from_plan(plan);
789 let mut input_ids = HashMap::new();
790 let mut param_ids = HashMap::new();
791 let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
792 for node in graph.nodes() {
793 node_dtypes.insert(node.id, node.shape.dtype());
794 match &node.op {
795 Op::Input { name } => {
796 input_ids.insert(name.clone(), node.id);
797 }
798 Op::Param { name } => {
799 param_ids.insert(name.clone(), node.id);
800 }
801 _ => {}
802 }
803 }
804
805 let schedule = thunk::compile_thunks(&graph, &arena);
806
807 let mut input_slots = Vec::new();
808 for node in graph.nodes() {
809 if let Op::Input { name } = &node.op {
810 let off = arena.byte_offset(node.id);
811 let len = node.shape.num_elements().unwrap_or(0);
812 input_slots.push((name.clone(), off, len, node.shape.dtype()));
813 }
814 }
815
816 let output_slots: Vec<(usize, usize)> = graph
817 .outputs
818 .iter()
819 .map(|&id| {
820 let off = arena.byte_offset(id);
821 let len = graph.node(id).shape.num_elements().unwrap_or(0);
822 (off, len)
823 })
824 .collect();
825
826 for node in graph.nodes() {
827 if let Op::Constant { data } = &node.op
828 && arena.has_buffer(node.id)
829 && !data.is_empty()
830 {
831 match node.shape.dtype() {
832 DType::F64 | DType::F16 | DType::BF16 => {
833 let off = arena.byte_offset(node.id);
834 let buf = arena.raw_buf_mut();
835 let n = buf.len().saturating_sub(off).min(data.len());
836 buf[off..off + n].copy_from_slice(&data[..n]);
837 }
838 _ => {
839 let buf = arena.slice_mut(node.id);
840 let n_floats = data.len() / 4;
841 let n = buf.len().min(n_floats);
842 for i in 0..n {
843 let bytes = [
844 data[i * 4],
845 data[i * 4 + 1],
846 data[i * 4 + 2],
847 data[i * 4 + 3],
848 ];
849 buf[i] = f32::from_le_bytes(bytes);
850 }
851 }
852 }
853 }
854 }
855
856 CpuExecutable {
857 graph,
858 arena,
859 params: HashMap::new(),
860 typed_params: HashMap::new(),
861 input_ids,
862 param_ids,
863 node_dtypes,
864 io_manifest,
865 schedule,
866 input_slots,
867 output_slots,
868 handles: HashMap::new(),
869 active_extent: None,
870 moe_resident: None,
871 moe_resident_layers: None,
872 moe_topk_capture: None,
873 }
874 }
875
876 #[derive(Clone)]
877 struct CpuExecutable {
878 graph: Graph,
879 arena: Arena,
880 params: HashMap<String, Vec<f32>>,
881 typed_params: HashMap<String, (Vec<u8>, DType)>,
883 input_ids: HashMap<String, NodeId>,
884 param_ids: HashMap<String, NodeId>,
885 node_dtypes: HashMap<NodeId, DType>,
888 io_manifest: cpu_low_precision::IoDtypeManifest,
890 schedule: thunk::ThunkSchedule,
891 input_slots: Vec<(String, usize, usize, DType)>,
893 output_slots: Vec<(usize, usize)>,
895 handles: HashMap<String, Vec<f32>>,
900 active_extent: Option<(usize, usize)>,
906 moe_resident: Option<std::sync::Arc<[bool]>>,
907 moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
908 moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
909 }
910
911 unsafe impl Send for CpuExecutable {}
912
913 impl CpuExecutable {
914 fn write_input(&mut self, id: NodeId, data: &[f32]) {
916 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
917 let off = self.arena.byte_offset(id);
918 let buf = self.arena.raw_buf_mut();
919 let elem_size = dtype.size_bytes();
920 let max_elems = (buf.len() - off) / elem_size;
921 unsafe {
922 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
923 }
924 }
925
926 fn read_output(&self, id: NodeId) -> Vec<f32> {
928 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
929 let off = self.arena.byte_offset(id);
930 let buf = self.arena.raw_buf();
931 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
932 unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
933 }
934 }
935
936 impl ExecutableGraph for CpuExecutable {
937 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
938 Box::new(self.clone())
939 }
940 fn set_param(&mut self, name: &str, data: &[f32]) {
941 self.params.insert(name.to_string(), data.to_vec());
942 self.typed_params.remove(name);
943 if let Some(&id) = self.param_ids.get(name)
946 && self.arena.has_buffer(id)
947 {
948 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
949 let off = self.arena.byte_offset(id);
950 let buf = self.arena.raw_buf_mut();
951 let elem_size = dtype.size_bytes();
952 let max_elems = (buf.len() - off) / elem_size;
953 unsafe {
954 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
955 }
956 }
957 }
958
959 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
960 self.restore_arena_baseline();
961 let handle_names: Vec<String> = self.handles.keys().cloned().collect();
964 for name in &handle_names {
965 if let Some(&id) = self.input_ids.get(name)
966 && self.arena.has_buffer(id)
967 {
968 let data = self.handles.get(name).cloned().unwrap_or_default();
969 self.write_input(id, &data);
970 }
971 }
972 for &(name, data) in inputs {
974 if let Some(&id) = self.input_ids.get(name)
975 && self.arena.has_buffer(id)
976 {
977 self.write_input(id, data);
978 }
979 }
980
981 let active_used = if let Some((actual, upper)) = self.active_extent {
986 thunk::execute_thunks_active(
987 &self.schedule,
988 self.arena.raw_buf_mut(),
989 actual,
990 upper,
991 )
992 } else {
993 false
994 };
995 if !active_used {
996 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998 }
999
1000 for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1004 let name = format!("out{idx}");
1005 if self.handles.contains_key(&name) {
1006 let v = self.read_output(out_id);
1007 self.handles.insert(name, v);
1008 }
1009 }
1010
1011 self.graph
1012 .outputs
1013 .iter()
1014 .map(|&out_id| self.read_output(out_id))
1015 .collect()
1016 }
1017
1018 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1019 self.restore_arena_baseline();
1020 for &(name, data) in inputs {
1022 if let Some(&id) = self.input_ids.get(name)
1023 && self.arena.has_buffer(id)
1024 {
1025 self.write_input(id, data);
1026 }
1027 }
1028 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1029 self.graph
1033 .outputs
1034 .iter()
1035 .map(|&out_id| {
1036 let (ptr, len) = self.arena.raw_ptr(out_id);
1037 (ptr as *const f32, len)
1038 })
1039 .collect()
1040 }
1041
1042 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1046 self.restore_arena_baseline();
1047 let buf = self.arena.raw_buf_mut();
1048 for (i, &data) in inputs.iter().enumerate() {
1049 if i < self.input_slots.len() {
1050 let (_, off, max_len, dtype) = &self.input_slots[i];
1051 unsafe {
1052 write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1053 }
1054 }
1055 }
1056 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1057 &self.output_slots
1058 }
1059
1060 fn arena_ptr(&self) -> *const u8 {
1061 self.arena.raw_buf_mut_ptr()
1062 }
1063
1064 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1065 self.handles.insert(name.to_string(), data.to_vec());
1070 true
1071 }
1072
1073 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1074 self.handles.get(name).cloned()
1075 }
1076
1077 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1078 self.active_extent = extent;
1079 }
1080
1081 fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1082 self.moe_resident_layers = None;
1083 self.schedule.moe_resident_layers = None;
1084 self.moe_resident = Some(Arc::from(mask));
1085 self.schedule.moe_resident = self.moe_resident.clone();
1086 }
1087
1088 fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1089 self.moe_resident = None;
1090 self.schedule.moe_resident = None;
1091 let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1092 let arc = Arc::new(layers);
1093 self.moe_resident_layers = Some(arc.clone());
1094 self.schedule.moe_resident_layers = Some(arc);
1095 }
1096
1097 fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1098 let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1099 self.moe_topk_capture = Some(cap.clone());
1100 self.schedule.moe_topk_capture = Some(cap);
1101 true
1102 }
1103
1104 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1105 let cap = self.moe_topk_capture.as_ref()?;
1106 let layers = cap.take_layers();
1107 if layers.is_empty() {
1108 None
1109 } else {
1110 Some(layers)
1111 }
1112 }
1113
1114 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1115 rlx_cpu::moe_residency::take_last_forward_stats()
1116 }
1117
1118 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1124 if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1125 self.set_param_bytes(name, data, dtype);
1126 return;
1127 }
1128 if matches!(dtype, DType::U8 | DType::I8) {
1132 self.set_param_bytes(name, data, dtype);
1133 return;
1134 }
1135 if dtype == DType::F32 {
1136 let n = data.len() / 4;
1137 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1138 self.set_param(name, s);
1139 } else {
1140 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1141 self.set_param(name, &f32_buf);
1142 }
1143 }
1144
1145 fn run_typed(
1157 &mut self,
1158 inputs: &[(&str, &[u8], rlx_ir::DType)],
1159 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1160 let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1165
1166 if all_f64 {
1167 for (name, data, _) in inputs {
1168 if let Some(&id) = self.input_ids.get(*name) {
1169 if !self.arena.has_buffer(id) {
1170 continue;
1171 }
1172 let off = self.arena.byte_offset(id);
1173 let buf = self.arena.raw_buf_mut();
1174 let n = data.len();
1175 debug_assert!(
1176 off + n <= buf.len(),
1177 "run_typed: input '{name}' overflows arena slot"
1178 );
1179 buf[off..off + n].copy_from_slice(data);
1180 }
1181 }
1182 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1183 } else {
1184 let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1189 for (name, data, dt) in inputs {
1190 let direct = matches!(
1191 *dt,
1192 DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1193 );
1194 if direct {
1195 if let Some(&id) = self.input_ids.get(*name) {
1196 if !self.arena.has_buffer(id) {
1197 continue;
1198 }
1199 let off = self.arena.byte_offset(id);
1200 let buf = self.arena.raw_buf_mut();
1201 buf[off..off + data.len()].copy_from_slice(data);
1202 }
1203 } else {
1204 let v = super::widen_bytes_to_f32(data, *dt);
1205 f32_owned.push((name.to_string(), v));
1206 }
1207 }
1208 for (name, data) in &f32_owned {
1209 if let Some(&id) = self.input_ids.get(name.as_str()) {
1210 if self.arena.has_buffer(id) {
1211 self.write_input(id, data);
1212 }
1213 }
1214 }
1215 let active_used = if let Some((actual, upper)) = self.active_extent {
1216 thunk::execute_thunks_active(
1217 &self.schedule,
1218 self.arena.raw_buf_mut(),
1219 actual,
1220 upper,
1221 )
1222 } else {
1223 false
1224 };
1225 if !active_used {
1226 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1227 }
1228 }
1229
1230 self.graph
1232 .outputs
1233 .iter()
1234 .enumerate()
1235 .map(|(idx, &id)| {
1236 let exec_dtype = self.graph.node(id).shape.dtype();
1237 let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1238 if matches!(
1239 exec_dtype,
1240 DType::F64
1241 | DType::F16
1242 | DType::BF16
1243 | DType::I32
1244 | DType::I64
1245 | DType::U32
1246 | DType::C64
1247 ) {
1248 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1249 let n_bytes = n_elems * exec_dtype.size_bytes();
1250 let off = self.arena.byte_offset(id);
1251 let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1252 return (bytes, declared);
1253 }
1254 let f32_vals = self.read_output(id);
1255 if declared != exec_dtype {
1256 return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1257 }
1258 let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1259 (bytes, declared)
1260 })
1261 .collect()
1262 }
1263 }
1264
1265 impl CpuExecutable {
1266 fn restore_arena_baseline(&mut self) {
1271 self.arena.raw_buf_mut().fill(0);
1272 let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1273 .graph
1274 .nodes()
1275 .iter()
1276 .filter_map(|node| {
1277 if let Op::Constant { data } = &node.op
1278 && self.arena.has_buffer(node.id)
1279 && !data.is_empty()
1280 {
1281 Some((node.id, node.shape.dtype(), data.clone()))
1282 } else {
1283 None
1284 }
1285 })
1286 .collect();
1287 for (id, dtype, data) in constants {
1288 self.write_constant_to_arena(id, dtype, &data);
1289 }
1290 let params = self.params.clone();
1291 for (name, data) in params {
1292 if let Some(&id) = self.param_ids.get(&name)
1293 && self.arena.has_buffer(id)
1294 {
1295 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1296 let off = self.arena.byte_offset(id);
1297 let buf = self.arena.raw_buf_mut();
1298 let elem_size = dtype.size_bytes();
1299 let max_elems = (buf.len() - off) / elem_size;
1300 unsafe {
1301 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1302 }
1303 }
1304 }
1305 let typed = self.typed_params.clone();
1306 for (name, (data, dtype)) in typed {
1307 self.write_param_bytes_to_arena(&name, &data);
1308 let _ = dtype;
1309 }
1310 }
1311
1312 fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1313 match dtype {
1314 DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1315 let off = self.arena.byte_offset(id);
1316 let buf = self.arena.raw_buf_mut();
1317 let n = buf.len().saturating_sub(off).min(data.len());
1318 buf[off..off + n].copy_from_slice(&data[..n]);
1319 }
1320 _ => {
1321 let buf = self.arena.slice_mut(id);
1322 let n_floats = data.len() / 4;
1323 let n = buf.len().min(n_floats);
1324 for i in 0..n {
1325 let bytes = [
1326 data[i * 4],
1327 data[i * 4 + 1],
1328 data[i * 4 + 2],
1329 data[i * 4 + 3],
1330 ];
1331 buf[i] = f32::from_le_bytes(bytes);
1332 }
1333 }
1334 }
1335 }
1336
1337 fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1343 self.typed_params
1344 .insert(name.to_string(), (data.to_vec(), dtype));
1345 self.params.remove(name);
1346 self.write_param_bytes_to_arena(name, data);
1347 }
1348
1349 fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1350 if let Some(&id) = self.param_ids.get(name)
1351 && self.arena.has_buffer(id)
1352 {
1353 let off = self.arena.byte_offset(id);
1354 let buf = self.arena.raw_buf_mut();
1355 debug_assert!(
1356 off + data.len() <= buf.len(),
1357 "set_param_bytes: '{name}' would overflow arena slot"
1358 );
1359 buf[off..off + data.len()].copy_from_slice(data);
1360 }
1361 }
1362 }
1363}
1364
1365#[cfg(feature = "gpu")]
1370pub mod wgpu_backend {
1371 use super::*;
1372 use rlx_ir::OpKind;
1373 use rlx_wgpu::backend::WgpuExecutable;
1374
1375 pub struct WgpuBackend;
1376
1377 const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1383 OpKind::Input,
1384 OpKind::Param,
1385 OpKind::Constant,
1386 OpKind::Activation,
1387 OpKind::Cast,
1388 OpKind::StopGradient,
1389 OpKind::Binary,
1390 OpKind::Compare,
1391 OpKind::Where,
1392 OpKind::ElementwiseRegion,
1393 OpKind::TransformRegion,
1394 OpKind::BatchElementwiseRegion,
1395 OpKind::MatMul,
1396 OpKind::DotGeneral,
1397 OpKind::LayerNorm,
1398 OpKind::RmsNorm,
1399 OpKind::Attention,
1400 OpKind::AttentionBackward,
1401 OpKind::RmsNormBackwardInput,
1402 OpKind::RmsNormBackwardGamma,
1403 OpKind::RmsNormBackwardBeta,
1404 OpKind::LayerNormBackwardInput,
1411 OpKind::LayerNormBackwardGamma,
1412 OpKind::RopeBackward,
1413 OpKind::CumsumBackward,
1414 OpKind::GatherBackward,
1415 OpKind::Rope,
1416 OpKind::Reshape,
1417 OpKind::Transpose,
1418 OpKind::Narrow,
1419 OpKind::Concat,
1420 OpKind::Expand,
1421 OpKind::Gather,
1422 OpKind::Reduce,
1423 OpKind::Softmax,
1424 OpKind::Cumsum,
1425 OpKind::TopK,
1426 OpKind::Sample,
1427 OpKind::Conv,
1428 OpKind::Im2Col,
1429 OpKind::Pool,
1430 OpKind::GroupedMatMul,
1431 OpKind::DequantGroupedMatMul,
1432 OpKind::DequantMoEWeights,
1433 OpKind::ScatterAdd,
1434 OpKind::SelectiveScan,
1435 OpKind::DequantMatMul,
1436 OpKind::FusedMatMulBiasAct,
1437 OpKind::FusedResidualLN,
1438 OpKind::FusedResidualRmsNorm,
1439 OpKind::FusedSwiGLU,
1440 OpKind::FusedAttentionBlock,
1441 OpKind::FusedTransformerLayer,
1442 OpKind::Fft,
1448 OpKind::LogMel,
1449 OpKind::LogMelBackward,
1450 OpKind::GaussianSplatRender,
1452 OpKind::GaussianSplatRenderBackward,
1453 OpKind::GaussianSplatPrepare,
1454 OpKind::GaussianSplatRasterize,
1455 OpKind::Custom,
1456 ];
1458
1459 impl Backend for WgpuBackend {
1460 fn supported_ops(&self) -> &'static [OpKind] {
1461 WGPU_SUPPORTED_OPS
1462 }
1463
1464 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1465 use rlx_opt::pass::Pass as _;
1466 let graph = rlx_opt::LowerControlFlow.run(graph);
1467 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1468 .unwrap_or_else(|errors| {
1469 panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1470 });
1471 let graph = crate::precompile::precompile_cleanup(graph, options);
1472 let graph = rlx_opt::LegalizeBroadcast.run(graph);
1476 let compile_result = crate::stages::compile_graph_stages_for_backend(
1485 rlx_driver::Device::Gpu,
1486 graph,
1487 options,
1488 WGPU_SUPPORTED_OPS,
1489 );
1490 crate::stages::maybe_log_fusion(&compile_result.fusion);
1491 let graph = compile_result.lir.into_graph();
1492 let graph = match options.policy.clone() {
1493 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1494 None => graph,
1495 };
1496 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1497 Box::new(WgpuExecutableWrapper {
1498 inner: WgpuExecutable::compile(graph),
1499 io_manifest,
1500 })
1501 }
1502
1503 fn compile_lir(
1504 &self,
1505 lir: LirModule,
1506 options: &CompileOptions,
1507 ) -> Box<dyn ExecutableGraph> {
1508 use rlx_opt::pass::Pass as _;
1509 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1512 let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1513 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1514 Box::new(WgpuExecutableWrapper {
1515 inner: WgpuExecutable::compile(graph),
1516 io_manifest,
1517 })
1518 }
1519 }
1520
1521 struct WgpuExecutableWrapper {
1522 inner: WgpuExecutable,
1523 io_manifest: cpu_low_precision::IoDtypeManifest,
1524 }
1525
1526 unsafe impl Send for WgpuExecutableWrapper {}
1527
1528 impl ExecutableGraph for WgpuExecutableWrapper {
1529 fn set_param(&mut self, name: &str, data: &[f32]) {
1530 self.inner.set_param(name, data);
1531 }
1532 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1533 self.inner.run(inputs)
1534 }
1535 fn run_read_outputs(
1536 &mut self,
1537 inputs: &[(&str, &[f32])],
1538 read_indices: Option<&[usize]>,
1539 ) -> Vec<Vec<f32>> {
1540 self.inner.run_read_outputs(inputs, read_indices)
1541 }
1542 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1543 self.inner.bind_gpu_handle(name, data)
1544 }
1545 fn has_gpu_handle(&self, name: &str) -> bool {
1546 self.inner.has_gpu_handle(name)
1547 }
1548 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1549 self.inner.set_gpu_handle_feed(handle_name, output_index);
1550 true
1551 }
1552 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1553 self.inner.read_gpu_handle(name)
1554 }
1555 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1556 self.inner.set_active_extent(extent);
1557 }
1558
1559 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1562 match dtype {
1563 rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1564 self.inner.set_param_bytes(name, data);
1565 }
1566 rlx_ir::DType::F32 => {
1567 let n = data.len() / 4;
1568 let f32_slice =
1569 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1570 self.inner.set_param(name, f32_slice);
1571 }
1572 rlx_ir::DType::F16 => {
1573 let n = data.len() / 2;
1574 let f16_slice =
1575 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1576 let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1577 self.inner.set_param(name, &f32);
1578 }
1579 rlx_ir::DType::BF16 => {
1580 let n = data.len() / 2;
1581 let bf16_slice = unsafe {
1582 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1583 };
1584 let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1585 self.inner.set_param(name, &f32);
1586 }
1587 other => panic!(
1588 "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1589 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1590 ),
1591 }
1592 }
1593
1594 fn run_typed(
1597 &mut self,
1598 inputs: &[(&str, &[u8], rlx_ir::DType)],
1599 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1600 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1601 for (name, data, dt) in inputs {
1602 let v: Vec<f32> = match *dt {
1603 rlx_ir::DType::F32 => {
1604 let n = data.len() / 4;
1605 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1606 .to_vec()
1607 }
1608 rlx_ir::DType::F16 => {
1609 let n = data.len() / 2;
1610 let s = unsafe {
1611 std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1612 };
1613 s.iter().map(|h| h.to_f32()).collect()
1614 }
1615 rlx_ir::DType::BF16 => {
1616 let n = data.len() / 2;
1617 let s = unsafe {
1618 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1619 };
1620 s.iter().map(|h| h.to_f32()).collect()
1621 }
1622 other => {
1623 panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1624 }
1625 };
1626 owned.push((name.to_string(), v));
1627 }
1628 let refs: Vec<(&str, &[f32])> = owned
1629 .iter()
1630 .map(|(n, d)| (n.as_str(), d.as_slice()))
1631 .collect();
1632 let dtypes =
1633 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1634 let outs = self.inner.run(&refs);
1635 outs.into_iter()
1636 .zip(
1637 dtypes
1638 .into_iter()
1639 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1640 )
1641 .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1642 .collect()
1643 }
1644 }
1645
1646 fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1652 use rlx_ir::DType;
1653 match dt {
1654 DType::F32 => {
1655 let mut bytes = Vec::with_capacity(v.len() * 4);
1656 for &x in v {
1657 bytes.extend_from_slice(&x.to_le_bytes());
1658 }
1659 bytes
1660 }
1661 DType::F16 => {
1662 let mut bytes = Vec::with_capacity(v.len() * 2);
1663 for &x in v {
1664 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1665 }
1666 bytes
1667 }
1668 DType::BF16 => {
1669 let mut bytes = Vec::with_capacity(v.len() * 2);
1670 for &x in v {
1671 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1672 }
1673 bytes
1674 }
1675 DType::F64 => {
1676 let mut bytes = Vec::with_capacity(v.len() * 8);
1677 for &x in v {
1678 bytes.extend_from_slice(&(x as f64).to_le_bytes());
1679 }
1680 bytes
1681 }
1682 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1683 DType::U8 => v.iter().map(|&x| x as u8).collect(),
1684 DType::I16 => {
1685 let mut bytes = Vec::with_capacity(v.len() * 2);
1686 for &x in v {
1687 bytes.extend_from_slice(&(x as i16).to_le_bytes());
1688 }
1689 bytes
1690 }
1691 DType::I32 => {
1692 let mut bytes = Vec::with_capacity(v.len() * 4);
1693 for &x in v {
1694 bytes.extend_from_slice(&(x as i32).to_le_bytes());
1695 }
1696 bytes
1697 }
1698 DType::U32 => {
1699 let mut bytes = Vec::with_capacity(v.len() * 4);
1700 for &x in v {
1701 bytes.extend_from_slice(&(x as u32).to_le_bytes());
1702 }
1703 bytes
1704 }
1705 DType::I64 => {
1706 let mut bytes = Vec::with_capacity(v.len() * 8);
1707 for &x in v {
1708 bytes.extend_from_slice(&(x as i64).to_le_bytes());
1709 }
1710 bytes
1711 }
1712 DType::Bool => v
1713 .iter()
1714 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1715 .collect(),
1716 DType::C64 => {
1723 let mut bytes = Vec::with_capacity(v.len() * 4);
1724 for &x in v {
1725 bytes.extend_from_slice(&x.to_le_bytes());
1726 }
1727 bytes
1728 }
1729 }
1730 }
1731}
1732
1733#[cfg(all(feature = "mlx", rlx_mlx_host))]
1736pub mod mlx_backend {
1737 use super::*;
1738 use rlx_mlx::MlxExecutable;
1739
1740 pub struct MlxBackend;
1741
1742 const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1752 use rlx_ir::OpKind::*;
1753 &[
1754 Input,
1755 Param,
1756 Constant,
1757 Activation,
1758 Cast,
1759 StopGradient,
1760 Binary,
1761 Compare,
1762 Where,
1763 ElementwiseRegion,
1764 TransformRegion,
1765 BatchElementwiseRegion,
1766 MatMul,
1767 DotGeneral,
1768 DenseSolve,
1769 BatchedDenseSolve,
1770 LayerNorm,
1771 LayerNorm2d,
1772 ResizeNearest2x,
1773 RmsNorm,
1774 Attention,
1775 Rope,
1776 Reshape,
1777 Transpose,
1778 Narrow,
1779 Concat,
1780 Expand,
1781 Gather,
1782 Reduce,
1783 Softmax,
1784 Cumsum,
1785 TopK,
1786 Sample,
1787 Conv,
1788 ConvTranspose2d,
1789 Pool,
1790 GroupedMatMul,
1791 DequantGroupedMatMul,
1792 DequantMoEWeights,
1793 ScatterAdd,
1794 LoraMatMul,
1795 DequantMatMul,
1796 SelectiveScan,
1797 GatedDeltaNet,
1798 FusedSwiGLU,
1799 FusedMatMulBiasAct,
1800 FusedResidualLN,
1801 FusedResidualRmsNorm,
1802 FusedAttentionBlock,
1803 FusedTransformerLayer,
1804 If,
1805 While,
1806 Scan,
1811 ScanBackward,
1812 ScanBackwardXs,
1813 ReluBackward,
1816 ActivationBackward,
1817 SoftmaxCrossEntropyWithLogits,
1818 SoftmaxCrossEntropyBackward,
1819 AttentionBackward,
1820 LayerNormBackwardInput,
1821 LayerNormBackwardGamma,
1822 Conv2dBackwardInput,
1827 Conv2dBackwardWeight,
1828 MaxPool2dBackward,
1832 FakeQuantize,
1837 FakeQuantizeBackward,
1838 Custom,
1843 Fft,
1844 LogMel,
1845 LogMelBackward,
1846 GaussianSplatRender,
1847 GaussianSplatRenderBackward,
1848 ]
1851 };
1852
1853 impl Backend for MlxBackend {
1854 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1855 MLX_SUPPORTED_OPS
1856 }
1857
1858 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1859 let compile_result = crate::stages::compile_graph_stages_for_backend(
1860 rlx_driver::Device::Mlx,
1861 graph,
1862 options,
1863 MLX_SUPPORTED_OPS,
1864 );
1865 crate::stages::maybe_log_fusion(&compile_result.fusion);
1866 self.compile_lir(compile_result.lir, options)
1867 }
1868
1869 fn compile_lir(
1870 &self,
1871 lir: LirModule,
1872 options: &CompileOptions,
1873 ) -> Box<dyn ExecutableGraph> {
1874 use rlx_opt::pass::Pass as _;
1875 let mut graph = lir.into_graph();
1876 graph = rlx_opt::LowerControlFlow.run(graph);
1877 let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1878 Box::new(build_mlx_executable(graph))
1879 }
1880 }
1881
1882 fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1883 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1884 let mode = mlx_mode_from_env();
1885 let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1886 if mode == rlx_mlx::lower::MlxMode::Compiled {
1887 if let Err(e) = exe.warm_compile() {
1888 eprintln!(
1889 "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1890 );
1891 }
1892 }
1893 MlxExecutableWrapper {
1894 inner: exe,
1895 io_manifest,
1896 }
1897 }
1898
1899 fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1900 match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1901 Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1902 Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1903 Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1904 _ => rlx_mlx::lower::MlxMode::Compiled,
1905 }
1906 }
1907
1908 struct MlxExecutableWrapper {
1909 inner: MlxExecutable,
1910 io_manifest: cpu_low_precision::IoDtypeManifest,
1911 }
1912
1913 unsafe impl Send for MlxExecutableWrapper {}
1914
1915 impl ExecutableGraph for MlxExecutableWrapper {
1916 fn set_param(&mut self, name: &str, data: &[f32]) {
1917 self.inner.set_param(name, data);
1918 }
1919 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1920 self.inner.run(inputs)
1921 }
1922 fn run_read_outputs(
1923 &mut self,
1924 inputs: &[(&str, &[f32])],
1925 read_indices: Option<&[usize]>,
1926 ) -> Vec<Vec<f32>> {
1927 self.inner
1928 .run_read_outputs(inputs, read_indices)
1929 .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
1930 }
1931 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1932 self.inner.run_slots(inputs)
1933 }
1934 fn arena_ptr(&self) -> *const u8 {
1935 self.inner.arena_ptr()
1936 }
1937 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1938 self.inner.commit_no_wait(inputs);
1939 }
1940 fn sync_pending(&mut self) {
1941 self.inner.sync_pending();
1942 }
1943 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1944 self.inner.run_pipelined(input_sets)
1945 }
1946 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1947 self.inner.bind_handle(name, data)
1948 }
1949 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1950 self.inner.read_handle(name)
1951 }
1952 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1953 self.inner.bind_gpu_handle(name, data).is_ok()
1954 }
1955 fn has_gpu_handle(&self, name: &str) -> bool {
1956 self.inner.has_gpu_handle(name)
1957 }
1958 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1959 self.inner.set_gpu_handle_feed(handle_name, output_index);
1960 true
1961 }
1962 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1963 self.inner.read_gpu_handle(name).ok()
1964 }
1965 fn run_feed_gpu_handle(
1966 &mut self,
1967 inputs: &[(&str, &[f32])],
1968 handle_name: &str,
1969 output_index: usize,
1970 ) -> Option<Vec<f32>> {
1971 self.inner
1972 .run_feed_gpu(inputs, handle_name, output_index)
1973 .ok()
1974 }
1975 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1976 self.inner.set_param_typed(name, data, dtype);
1977 }
1978 fn run_typed(
1979 &mut self,
1980 inputs: &[(&str, &[u8], rlx_ir::DType)],
1981 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1982 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1983 for (name, data, dt) in inputs {
1984 let v = super::widen_bytes_to_f32(data, *dt);
1985 owned.push((name.to_string(), v));
1986 }
1987 let refs: Vec<(&str, &[f32])> = owned
1988 .iter()
1989 .map(|(n, d)| (n.as_str(), d.as_slice()))
1990 .collect();
1991 let f32_outs = self.inner.run(&refs);
1992 let declared = super::declared_output_dtypes(
1993 &self.io_manifest,
1994 (0..f32_outs.len()).map(|_| rlx_ir::DType::F32).collect(),
1995 );
1996 f32_outs
1997 .into_iter()
1998 .zip(
1999 declared
2000 .into_iter()
2001 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2002 )
2003 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2004 .collect()
2005 }
2006 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2007 self.inner.set_active_extent(extent);
2008 }
2009 }
2010}
2011
2012#[cfg(all(feature = "metal", target_os = "macos"))]
2013pub mod metal_backend {
2014 use super::*;
2015 use rlx_metal::backend::MetalExecutable;
2016
2017 pub struct MetalBackend;
2018
2019 const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2029 use rlx_ir::OpKind::*;
2030 &[
2031 Input,
2032 Param,
2033 Constant,
2034 Activation,
2035 Cast,
2036 StopGradient,
2037 Binary,
2038 Compare,
2039 Where,
2040 ElementwiseRegion,
2041 TransformRegion,
2042 BatchElementwiseRegion,
2043 MatMul,
2044 DotGeneral,
2045 LayerNorm,
2046 LayerNorm2d,
2047 GroupNorm,
2048 RmsNorm,
2049 ResizeNearest2x,
2050 AxialRope2d,
2051 Attention,
2052 AttentionBackward,
2053 RmsNormBackwardInput,
2054 RmsNormBackwardGamma,
2055 RmsNormBackwardBeta,
2056 RopeBackward,
2057 CumsumBackward,
2058 GatherBackward,
2059 Conv2dBackwardInput,
2060 Conv2dBackwardWeight,
2061 MaxPool2dBackward,
2062 Rope,
2063 Reshape,
2064 Transpose,
2065 Narrow,
2066 Concat,
2067 Expand,
2068 Gather,
2069 Reduce,
2070 Softmax,
2071 TopK,
2072 Conv,
2073 Im2Col,
2074 ConvTranspose2d,
2075 Pool,
2076 GroupedMatMul,
2077 DequantGroupedMatMul,
2078 DequantMoEWeights,
2079 ScatterAdd,
2080 DequantMatMul,
2081 GatedDeltaNet,
2082 FusedSwiGLU,
2083 FusedMatMulBiasAct,
2084 FusedResidualLN,
2085 FusedResidualRmsNorm,
2086 Custom,
2092 Fft,
2098 LogMel,
2099 LogMelBackward,
2100 GaussianSplatRender,
2102 GaussianSplatRenderBackward,
2103 GaussianSplatPrepare,
2104 GaussianSplatRasterize,
2105 ]
2106 };
2107
2108 impl Backend for MetalBackend {
2109 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2110 METAL_SUPPORTED_OPS
2111 }
2112
2113 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2114 use rlx_opt::pass::Pass as _;
2115 let graph = rlx_opt::LowerControlFlow.run(graph);
2119 let dispatch = options.kernel_dispatch;
2120 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2121 graph,
2122 METAL_SUPPORTED_OPS,
2123 dispatch,
2124 )
2125 .unwrap_or_else(|errors| {
2126 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2127 });
2128 let graph = crate::precompile::precompile_cleanup(graph, options);
2129
2130 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2133 Box::new(MetalExecutableWrapper {
2134 inner: MetalExecutable::compile_with_policy(
2135 graph,
2136 options.policy.clone(),
2137 Some(METAL_SUPPORTED_OPS),
2138 ),
2139 io_manifest,
2140 })
2141 }
2142
2143 fn compile_lir(
2144 &self,
2145 lir: LirModule,
2146 options: &CompileOptions,
2147 ) -> Box<dyn ExecutableGraph> {
2148 use rlx_opt::pass::Pass as _;
2149 let mut graph = lir.into_graph();
2150 graph = rlx_opt::LowerControlFlow.run(graph);
2151 let dispatch = options.kernel_dispatch;
2152 let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2153 graph,
2154 METAL_SUPPORTED_OPS,
2155 dispatch,
2156 )
2157 .unwrap_or_else(|errors| {
2158 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2159 });
2160 graph = crate::precompile::precompile_cleanup(graph, options);
2161 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2162 Box::new(MetalExecutableWrapper {
2163 inner: MetalExecutable::compile_from_fused(
2164 graph,
2165 options.policy.clone(),
2166 Some(METAL_SUPPORTED_OPS),
2167 ),
2168 io_manifest,
2169 })
2170 }
2171 }
2172
2173 struct MetalExecutableWrapper {
2174 inner: MetalExecutable,
2175 io_manifest: cpu_low_precision::IoDtypeManifest,
2176 }
2177
2178 unsafe impl Send for MetalExecutableWrapper {}
2179
2180 impl ExecutableGraph for MetalExecutableWrapper {
2181 fn set_param(&mut self, name: &str, data: &[f32]) {
2182 self.inner.set_param(name, data);
2183 }
2184 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2185 self.inner.run(inputs)
2186 }
2187 fn run_read_outputs(
2188 &mut self,
2189 inputs: &[(&str, &[f32])],
2190 read_indices: Option<&[usize]>,
2191 ) -> Vec<Vec<f32>> {
2192 self.inner.run_read_outputs(inputs, read_indices)
2193 }
2194 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2195 self.inner.bind_gpu_handle(name, data)
2196 }
2197 fn has_gpu_handle(&self, name: &str) -> bool {
2198 self.inner.has_gpu_handle(name)
2199 }
2200 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2201 self.inner.set_gpu_handle_feed(handle_name, output_index);
2202 true
2203 }
2204 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2205 self.inner.read_gpu_handle(name)
2206 }
2207 fn read_output_row(
2208 &self,
2209 out_idx: usize,
2210 row: usize,
2211 row_inner: usize,
2212 ) -> Option<Vec<f32>> {
2213 Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2214 }
2215 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2216 self.inner.run_slots(inputs)
2217 }
2218 fn arena_ptr(&self) -> *const u8 {
2219 self.inner.arena_ptr()
2220 }
2221 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2222 self.inner.commit_no_wait(inputs);
2223 }
2224 fn sync_pending(&mut self) {
2225 self.inner.sync_pending();
2226 }
2227 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2228 self.inner.run_pipelined(input_sets)
2229 }
2230 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2231 self.inner.set_active_extent(extent);
2232 }
2233
2234 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2240 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2241 self.inner.set_param_bytes(name, data);
2242 return;
2243 }
2244 if dtype == rlx_ir::DType::F32 {
2245 let n = data.len() / 4;
2246 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2247 self.inner.set_param(name, s);
2248 } else {
2249 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2250 self.inner.set_param(name, &f32_buf);
2251 }
2252 }
2253
2254 fn run_typed(
2262 &mut self,
2263 inputs: &[(&str, &[u8], rlx_ir::DType)],
2264 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2265 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2266 for (name, data, dt) in inputs {
2267 let v = super::widen_bytes_to_f32(data, *dt);
2268 owned.push((name.to_string(), v));
2269 }
2270 let refs: Vec<(&str, &[f32])> = owned
2271 .iter()
2272 .map(|(n, d)| (n.as_str(), d.as_slice()))
2273 .collect();
2274 let dtypes =
2275 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2276 let f32_outs = self.inner.run(&refs);
2277 let byte_outs = self.inner.output_bytes_per_node();
2278 f32_outs
2279 .into_iter()
2280 .zip(byte_outs.into_iter())
2281 .zip(
2282 dtypes
2283 .into_iter()
2284 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2285 )
2286 .map(|((f32_v, byte_v), dt)| match dt {
2287 rlx_ir::DType::F64 => (byte_v, dt),
2288 _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2289 })
2290 .collect()
2291 }
2292 }
2293}
2294
2295#[cfg(feature = "cuda")]
2298pub mod cuda_backend {
2299 use super::*;
2300 use rlx_cuda::backend::CudaExecutable;
2301
2302 pub struct CudaBackend;
2303
2304 const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2310 use rlx_ir::OpKind::*;
2311 &[
2312 Input,
2313 Param,
2314 Constant,
2315 Activation,
2316 Cast,
2317 Binary,
2318 Compare,
2319 Where,
2320 ElementwiseRegion,
2321 TransformRegion,
2322 BatchElementwiseRegion,
2323 MatMul,
2324 DotGeneral,
2325 LayerNorm,
2326 LayerNorm2d,
2327 GroupNorm,
2328 ResizeNearest2x,
2329 RmsNorm,
2330 Attention,
2331 AttentionBackward,
2332 RmsNormBackwardInput,
2333 RmsNormBackwardGamma,
2334 RmsNormBackwardBeta,
2335 RopeBackward,
2336 CumsumBackward,
2337 GatherBackward,
2338 Conv2dBackwardInput,
2339 Conv2dBackwardWeight,
2340 MaxPool2dBackward,
2341 Rope,
2342 Reshape,
2343 Transpose,
2344 Narrow,
2345 Concat,
2346 Expand,
2347 Gather,
2348 Reduce,
2349 Softmax,
2350 Cumsum,
2351 TopK,
2352 Sample,
2353 Conv,
2354 ConvTranspose2d,
2355 Pool,
2356 GroupedMatMul,
2357 DequantGroupedMatMul,
2358 DequantMoEWeights,
2359 ScatterAdd,
2360 DequantMatMul,
2361 SelectiveScan,
2362 FusedMatMulBiasAct,
2363 FusedResidualLN,
2364 FusedResidualRmsNorm,
2365 GaussianSplatRender,
2366 GaussianSplatRenderBackward,
2367 GaussianSplatPrepare,
2368 GaussianSplatRasterize,
2369 Custom,
2370 Fft,
2371 LogMel,
2372 LogMelBackward,
2373 Im2Col,
2374 ]
2375 };
2376
2377 impl Backend for CudaBackend {
2378 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2379 CUDA_SUPPORTED_OPS
2380 }
2381
2382 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2383 let graph = rlx_cuda::unfuse::unfuse(graph);
2386 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2387 .unwrap_or_else(|errors| {
2388 panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2389 });
2390 let graph = crate::precompile::precompile_cleanup(graph, options);
2391 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2393 let compile_result = crate::stages::compile_graph_stages_for_backend(
2395 rlx_driver::Device::Cuda,
2396 graph,
2397 options,
2398 CUDA_SUPPORTED_OPS,
2399 );
2400 crate::stages::maybe_log_fusion(&compile_result.fusion);
2401 let graph = compile_result.lir.into_graph();
2402 let graph = match options.policy.clone() {
2403 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2404 None => graph,
2405 };
2406 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2407 Box::new(CudaExecutableWrapper {
2408 inner: CudaExecutable::compile(graph),
2409 io_manifest,
2410 })
2411 }
2412
2413 fn compile_lir(
2414 &self,
2415 lir: LirModule,
2416 options: &CompileOptions,
2417 ) -> Box<dyn ExecutableGraph> {
2418 use rlx_opt::pass::Pass as _;
2419 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2420 let (graph, io_manifest) =
2421 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2422 rlx_cuda::unfuse::unfuse(graph),
2423 options,
2424 CUDA_SUPPORTED_OPS,
2425 "cuda",
2426 ));
2427 Box::new(CudaExecutableWrapper {
2428 inner: CudaExecutable::compile(graph),
2429 io_manifest,
2430 })
2431 }
2432 }
2433
2434 struct CudaExecutableWrapper {
2435 inner: CudaExecutable,
2436 io_manifest: cpu_low_precision::IoDtypeManifest,
2437 }
2438
2439 unsafe impl Send for CudaExecutableWrapper {}
2444
2445 impl ExecutableGraph for CudaExecutableWrapper {
2446 fn set_param(&mut self, name: &str, data: &[f32]) {
2447 self.inner.set_param(name, data);
2448 }
2449 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2450 self.inner.run(inputs)
2451 }
2452 fn run_read_outputs(
2453 &mut self,
2454 inputs: &[(&str, &[f32])],
2455 read_indices: Option<&[usize]>,
2456 ) -> Vec<Vec<f32>> {
2457 self.inner.run_read_outputs(inputs, read_indices)
2458 }
2459 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2460 self.inner.bind_gpu_handle(name, data)
2461 }
2462 fn has_gpu_handle(&self, name: &str) -> bool {
2463 self.inner.has_gpu_handle(name)
2464 }
2465 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2466 self.inner.set_gpu_handle_feed(handle_name, output_index);
2467 true
2468 }
2469 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2470 self.inner.read_gpu_handle(name)
2471 }
2472 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2473 self.inner.set_active_extent(extent);
2474 }
2475
2476 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2477 self.inner.run_slots(inputs)
2478 }
2479
2480 fn arena_ptr(&self) -> *const u8 {
2481 self.inner.arena_ptr()
2482 }
2483
2484 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2489 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2490 self.inner.set_param_bytes(name, data);
2491 return;
2492 }
2493 if dtype == rlx_ir::DType::F32 {
2494 let n = data.len() / 4;
2495 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2496 self.inner.set_param(name, s);
2497 } else {
2498 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2499 self.inner.set_param(name, &f32_buf);
2500 }
2501 }
2502
2503 fn run_typed(
2506 &mut self,
2507 inputs: &[(&str, &[u8], rlx_ir::DType)],
2508 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2509 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2510 for (name, data, dt) in inputs {
2511 let v = super::widen_bytes_to_f32(data, *dt);
2512 owned.push((name.to_string(), v));
2513 }
2514 let refs: Vec<(&str, &[f32])> = owned
2515 .iter()
2516 .map(|(n, d)| (n.as_str(), d.as_slice()))
2517 .collect();
2518 let dtypes =
2519 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2520 let outs = self.inner.run(&refs);
2521 outs.into_iter()
2522 .zip(
2523 dtypes
2524 .into_iter()
2525 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2526 )
2527 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2528 .collect()
2529 }
2530 }
2531}
2532
2533#[cfg(feature = "rocm")]
2536pub mod rocm_backend {
2537 use super::*;
2538 use rlx_rocm::backend::RocmExecutable;
2539
2540 pub struct RocmBackend;
2541
2542 const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2545 use rlx_ir::OpKind::*;
2546 &[
2547 Input,
2548 Param,
2549 Constant,
2550 Activation,
2551 Cast,
2552 Binary,
2553 Compare,
2554 Where,
2555 ElementwiseRegion,
2556 TransformRegion,
2557 BatchElementwiseRegion,
2558 MatMul,
2559 DotGeneral,
2560 LayerNorm,
2561 LayerNorm2d,
2562 GroupNorm,
2563 ResizeNearest2x,
2564 RmsNorm,
2565 Attention,
2566 AttentionBackward,
2567 RmsNormBackwardInput,
2568 RmsNormBackwardGamma,
2569 RmsNormBackwardBeta,
2570 RopeBackward,
2571 CumsumBackward,
2572 GatherBackward,
2573 Rope,
2574 Reshape,
2575 Transpose,
2576 Narrow,
2577 Concat,
2578 Expand,
2579 Gather,
2580 Reduce,
2581 Softmax,
2582 Cumsum,
2583 TopK,
2584 Sample,
2585 Conv,
2586 ConvTranspose2d,
2587 Pool,
2588 GroupedMatMul,
2589 DequantGroupedMatMul,
2590 DequantMoEWeights,
2591 ScatterAdd,
2592 DequantMatMul,
2593 SelectiveScan,
2594 FusedMatMulBiasAct,
2595 FusedResidualLN,
2596 FusedResidualRmsNorm,
2597 GaussianSplatRender,
2598 GaussianSplatRenderBackward,
2599 GaussianSplatPrepare,
2600 GaussianSplatRasterize,
2601 Custom,
2602 Fft,
2603 LogMel,
2604 LogMelBackward,
2605 Im2Col,
2606 ]
2607 };
2608
2609 impl Backend for RocmBackend {
2610 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2611 ROCM_SUPPORTED_OPS
2612 }
2613
2614 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2615 let graph = rlx_rocm::unfuse::unfuse(graph);
2616 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2617 .unwrap_or_else(|errors| {
2618 panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2619 });
2620 let graph = crate::precompile::precompile_cleanup(graph, options);
2621 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2622 let compile_result = crate::stages::compile_graph_stages_for_backend(
2623 rlx_driver::Device::Rocm,
2624 graph,
2625 options,
2626 ROCM_SUPPORTED_OPS,
2627 );
2628 crate::stages::maybe_log_fusion(&compile_result.fusion);
2629 let graph = compile_result.lir.into_graph();
2630 let graph = match options.policy.clone() {
2631 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2632 None => graph,
2633 };
2634 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2635 Box::new(RocmExecutableWrapper {
2636 inner: RocmExecutable::compile(graph),
2637 io_manifest,
2638 })
2639 }
2640
2641 fn compile_lir(
2642 &self,
2643 lir: LirModule,
2644 options: &CompileOptions,
2645 ) -> Box<dyn ExecutableGraph> {
2646 let (graph, io_manifest) =
2647 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2648 rlx_rocm::unfuse::unfuse(lir.into_graph()),
2649 options,
2650 ROCM_SUPPORTED_OPS,
2651 "rocm",
2652 ));
2653 Box::new(RocmExecutableWrapper {
2654 inner: RocmExecutable::compile(graph),
2655 io_manifest,
2656 })
2657 }
2658 }
2659
2660 struct RocmExecutableWrapper {
2661 inner: RocmExecutable,
2662 io_manifest: cpu_low_precision::IoDtypeManifest,
2663 }
2664
2665 unsafe impl Send for RocmExecutableWrapper {}
2669
2670 impl ExecutableGraph for RocmExecutableWrapper {
2671 fn set_param(&mut self, name: &str, data: &[f32]) {
2672 self.inner.set_param(name, data);
2673 }
2674 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2675 self.inner.run(inputs)
2676 }
2677 fn run_read_outputs(
2678 &mut self,
2679 inputs: &[(&str, &[f32])],
2680 read_indices: Option<&[usize]>,
2681 ) -> Vec<Vec<f32>> {
2682 self.inner.run_read_outputs(inputs, read_indices)
2683 }
2684 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2685 self.inner.bind_gpu_handle(name, data)
2686 }
2687 fn has_gpu_handle(&self, name: &str) -> bool {
2688 self.inner.has_gpu_handle(name)
2689 }
2690 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2691 self.inner.set_gpu_handle_feed(handle_name, output_index);
2692 true
2693 }
2694 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2695 self.inner.read_gpu_handle(name)
2696 }
2697 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2698 self.inner.run_slots(inputs)
2699 }
2700 fn arena_ptr(&self) -> *const u8 {
2701 self.inner.arena_ptr()
2702 }
2703 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2704 self.inner.set_active_extent(extent);
2705 }
2706
2707 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2712 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2713 self.inner.set_param_bytes(name, data);
2714 return;
2715 }
2716 if dtype == rlx_ir::DType::F32 {
2717 let n = data.len() / 4;
2718 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2719 self.inner.set_param(name, s);
2720 } else {
2721 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2722 self.inner.set_param(name, &f32_buf);
2723 }
2724 }
2725
2726 fn run_typed(
2729 &mut self,
2730 inputs: &[(&str, &[u8], rlx_ir::DType)],
2731 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2732 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2733 for (name, data, dt) in inputs {
2734 let v = super::widen_bytes_to_f32(data, *dt);
2735 owned.push((name.to_string(), v));
2736 }
2737 let refs: Vec<(&str, &[f32])> = owned
2738 .iter()
2739 .map(|(n, d)| (n.as_str(), d.as_slice()))
2740 .collect();
2741 let dtypes =
2742 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2743 let outs = self.inner.run(&refs);
2744 outs.into_iter()
2745 .zip(
2746 dtypes
2747 .into_iter()
2748 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2749 )
2750 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2751 .collect()
2752 }
2753 }
2754}
2755
2756#[cfg(feature = "tpu")]
2759pub mod tpu_backend {
2760 use super::*;
2761 use rlx_tpu::TpuExecutable;
2762
2763 pub struct TpuBackend;
2764
2765 const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2771 use rlx_ir::OpKind::*;
2772 &[
2773 Input,
2774 Param,
2775 Constant,
2776 Activation,
2777 Cast,
2778 Binary,
2779 Compare,
2780 Where,
2781 ElementwiseRegion,
2782 TransformRegion,
2783 BatchElementwiseRegion,
2784 MatMul,
2785 DotGeneral,
2786 LayerNorm,
2787 RmsNorm,
2788 Attention,
2789 Rope,
2790 Reshape,
2791 Transpose,
2792 Narrow,
2793 Concat,
2794 Expand,
2795 Gather,
2796 Reduce,
2797 Softmax,
2798 Cumsum,
2799 TopK,
2800 Sample,
2801 Conv,
2802 Pool,
2803 GroupedMatMul,
2804 DequantGroupedMatMul,
2805 DequantMoEWeights,
2806 ScatterAdd,
2807 DequantMatMul,
2808 SelectiveScan,
2809 QMatMul,
2811 QConv2d,
2812 Quantize,
2813 Dequantize,
2814 FusedMatMulBiasAct,
2815 FusedResidualLN,
2816 FusedResidualRmsNorm,
2817 Fft,
2818 LogMel,
2819 LogMelBackward,
2820 ]
2822 };
2823
2824 impl Backend for TpuBackend {
2825 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2826 TPU_SUPPORTED_OPS
2827 }
2828
2829 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2830 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2831 graph,
2832 TPU_SUPPORTED_OPS,
2833 options.kernel_dispatch,
2834 )
2835 .unwrap_or_else(|errors| {
2836 panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2837 });
2838 use rlx_opt::pass::Pass as _;
2854 let policy = options
2855 .policy
2856 .clone()
2857 .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2858 let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2859 let _ = options.dce;
2860 let _ = options.constant_folding;
2861 Box::new(TpuExecutableWrapper {
2862 inner: TpuExecutable::compile(graph),
2863 })
2864 }
2865 }
2866
2867 struct TpuExecutableWrapper {
2868 inner: TpuExecutable,
2869 }
2870
2871 unsafe impl Send for TpuExecutableWrapper {}
2875
2876 impl ExecutableGraph for TpuExecutableWrapper {
2877 fn set_param(&mut self, name: &str, data: &[f32]) {
2878 self.inner.set_param(name, data);
2879 }
2880 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2881 self.inner.run(inputs)
2882 }
2883
2884 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2889 if dtype == rlx_ir::DType::F32 {
2890 let n = data.len() / 4;
2891 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2892 self.inner.set_param(name, s);
2893 } else {
2894 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2895 self.inner.set_param(name, &f32_buf);
2896 }
2897 }
2898
2899 fn run_typed(
2900 &mut self,
2901 inputs: &[(&str, &[u8], rlx_ir::DType)],
2902 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2903 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2904 for (name, data, dt) in inputs {
2905 let v = super::widen_bytes_to_f32(data, *dt);
2906 owned.push((name.to_string(), v));
2907 }
2908 let refs: Vec<(&str, &[f32])> = owned
2909 .iter()
2910 .map(|(n, d)| (n.as_str(), d.as_slice()))
2911 .collect();
2912 let dtypes = self.inner.output_dtypes();
2913 let outs = self.inner.run(&refs);
2914 outs.into_iter()
2915 .zip(
2916 dtypes
2917 .into_iter()
2918 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2919 )
2920 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2921 .collect()
2922 }
2923 }
2924}