1use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29#[allow(dead_code)]
36pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
37 use rlx_ir::DType;
38 match dtype {
39 DType::F32 => {
40 let n = data.len() / 4;
41 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
42 s.to_vec()
43 }
44 DType::F16 => {
45 let n = data.len() / 2;
46 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
47 s.iter().map(|h| h.to_f32()).collect()
48 }
49 DType::BF16 => {
50 let n = data.len() / 2;
51 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
52 s.iter().map(|h| h.to_f32()).collect()
53 }
54 other => panic!(
55 "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
56 (only F32/F16/BF16 are accepted on the host I/O surface)"
57 ),
58 }
59}
60
61#[allow(dead_code)]
66pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
67 use rlx_ir::DType;
68 match dt {
69 DType::F32 => {
70 let mut bytes = Vec::with_capacity(v.len() * 4);
71 for &x in v {
72 bytes.extend_from_slice(&x.to_le_bytes());
73 }
74 bytes
75 }
76 DType::F16 => {
77 let mut bytes = Vec::with_capacity(v.len() * 2);
78 for &x in v {
79 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
80 }
81 bytes
82 }
83 DType::BF16 => {
84 let mut bytes = Vec::with_capacity(v.len() * 2);
85 for &x in v {
86 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
87 }
88 bytes
89 }
90 DType::F64 => {
91 let mut bytes = Vec::with_capacity(v.len() * 8);
92 for &x in v {
93 bytes.extend_from_slice(&(x as f64).to_le_bytes());
94 }
95 bytes
96 }
97 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
98 DType::U8 => v.iter().map(|&x| x as u8).collect(),
99 DType::I16 => {
100 let mut bytes = Vec::with_capacity(v.len() * 2);
101 for &x in v {
102 bytes.extend_from_slice(&(x as i16).to_le_bytes());
103 }
104 bytes
105 }
106 DType::I32 => {
107 let mut bytes = Vec::with_capacity(v.len() * 4);
108 for &x in v {
109 bytes.extend_from_slice(&(x as i32).to_le_bytes());
110 }
111 bytes
112 }
113 DType::U32 => {
114 let mut bytes = Vec::with_capacity(v.len() * 4);
115 for &x in v {
116 bytes.extend_from_slice(&(x as u32).to_le_bytes());
117 }
118 bytes
119 }
120 DType::I64 => {
121 let mut bytes = Vec::with_capacity(v.len() * 8);
122 for &x in v {
123 bytes.extend_from_slice(&(x as i64).to_le_bytes());
124 }
125 bytes
126 }
127 DType::Bool => v
128 .iter()
129 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
130 .collect(),
131 DType::C64 => {
132 let mut bytes = Vec::with_capacity(v.len() * 8);
136 for &x in v {
137 bytes.extend_from_slice(&x.to_le_bytes());
138 bytes.extend_from_slice(&0.0_f32.to_le_bytes());
139 }
140 bytes
141 }
142 }
143}
144
145pub trait ExecutableGraph: Send {
147 fn set_param(&mut self, name: &str, data: &[f32]);
149
150 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
157 panic!("clone_box not implemented for this backend");
158 }
159
160 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
162
163 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
165 let vecs = self.run(inputs);
166 vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
167 }
168
169 fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
172 &[] }
174
175 fn arena_ptr(&self) -> *const u8 {
177 std::ptr::null()
178 }
179
180 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
197 let _ = extent;
198 }
199
200 fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
202
203 fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
205
206 fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
208 false
209 }
210
211 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
213 None
214 }
215
216 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
218 None
219 }
220
221 fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
225 false
226 }
227
228 fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
230 None
231 }
232
233 fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
235 false
236 }
237
238 fn has_gpu_handle(&self, _name: &str) -> bool {
239 false
240 }
241
242 fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
243 false
244 }
245
246 fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
247 None
248 }
249
250 fn run_feed_gpu_handle(
252 &mut self,
253 inputs: &[(&str, &[f32])],
254 _handle_name: &str,
255 _output_index: usize,
256 ) -> Option<Vec<f32>> {
257 let _ = inputs;
258 None
259 }
260
261 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
276 let _ = self.run(inputs);
277 }
278
279 fn sync_pending(&mut self) {}
282
283 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
292 input_sets.iter().map(|inputs| self.run(inputs)).collect()
293 }
294
295 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
308 if dtype != rlx_ir::DType::F32 {
309 panic!(
310 "backend's default set_param_typed only handles F32; \
311 got {dtype:?}. Override on the backend for typed support."
312 );
313 }
314 if !data.len().is_multiple_of(4) {
315 panic!(
316 "set_param_typed F32: data length {} not a multiple of 4",
317 data.len()
318 );
319 }
320 let n = data.len() / 4;
325 let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
326 self.set_param(name, f32_slice);
327 }
328
329 fn run_typed(
333 &mut self,
334 inputs: &[(&str, &[u8], rlx_ir::DType)],
335 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
336 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
339 for (name, data, dt) in inputs {
340 if *dt != rlx_ir::DType::F32 {
341 panic!(
342 "backend's default run_typed only handles F32 inputs; \
343 got {dt:?} for input '{name}'"
344 );
345 }
346 if data.len() % 4 != 0 {
347 panic!(
348 "run_typed F32 input '{name}': len {} not multiple of 4",
349 data.len()
350 );
351 }
352 let n = data.len() / 4;
353 let v: Vec<f32> =
354 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
355 owned.push((name.to_string(), v));
356 }
357 let refs: Vec<(&str, &[f32])> = owned
358 .iter()
359 .map(|(n, d)| (n.as_str(), d.as_slice()))
360 .collect();
361 let outs = self.run(&refs);
362 outs.into_iter()
363 .map(|v| {
364 let bytes =
365 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
366 .to_vec();
367 (bytes, rlx_ir::DType::F32)
368 })
369 .collect()
370 }
371}
372
373pub trait Backend: Send + Sync {
383 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
385
386 fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
390 self.compile(lir.into_graph(), options)
391 }
392
393 fn compile_hir(
395 &self,
396 hir: HirModule,
397 device: rlx_driver::Device,
398 options: &CompileOptions,
399 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
400 let result = crate::stages::compile_hir_stages(device, hir, options)?;
401 crate::stages::maybe_log_fusion(&result.fusion);
402 Ok(self.compile_lir(result.lir, options))
403 }
404
405 fn compile_module(
407 &self,
408 module: rlx_ir::GraphModule,
409 device: rlx_driver::Device,
410 options: &CompileOptions,
411 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
412 let result = crate::stages::compile_module_stages(device, module, options)?;
413 crate::stages::maybe_log_fusion(&result.fusion);
414 Ok(self.compile_lir(result.lir, options))
415 }
416
417 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
424 &[]
425 }
426}
427
428#[allow(dead_code)]
431fn prepare_fused_graph(
432 graph: Graph,
433 options: &CompileOptions,
434 supported_ops: &[rlx_ir::OpKind],
435 backend_name: &str,
436) -> Graph {
437 let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
438 graph,
439 backend_name,
440 supported_ops,
441 options.kernel_dispatch,
442 );
443 rlx_opt::maybe_log_dispatch_report(&report);
444 if !report.compile_ready {
445 panic!(
446 "{}\n{}",
447 rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
448 rlx_opt::format_dispatch_report(&report)
449 );
450 }
451 use rlx_opt::pass::Pass as _;
452 if options.dce {
453 graph = rlx_opt::DeadCodeElimination.run(graph);
454 }
455 if options.constant_folding {
456 graph = rlx_opt::ConstantFolding.run(graph);
457 }
458 if let Some(p) = options.policy.clone() {
459 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
460 }
461 graph
462}
463
464pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
472 backend.compile(graph, &CompileOptions::default())
473}
474
475pub fn compile_hir(
477 backend: &dyn Backend,
478 hir: HirModule,
479 device: rlx_driver::Device,
480 options: &CompileOptions,
481) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
482 backend.compile_hir(hir, device, options)
483}
484
485pub fn compile_module(
487 backend: &dyn Backend,
488 module: rlx_ir::GraphModule,
489 device: rlx_driver::Device,
490 options: &CompileOptions,
491) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
492 backend.compile_module(module, device, options)
493}
494
495pub fn compile_with_precision(
497 backend: &dyn Backend,
498 graph: Graph,
499 precision: crate::Precision,
500) -> Box<dyn ExecutableGraph> {
501 backend.compile(graph, &CompileOptions::new().precision(precision))
502}
503
504fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
509 use rlx_opt::pass::Pass as _;
510 match policy {
511 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
512 None => graph,
513 }
514}
515
516#[cfg(feature = "cpu")]
519pub mod cpu_backend {
520 use super::*;
521 use rlx_cpu::{arena::Arena, thunk};
522 use rlx_ir::{DType, NodeId, Op};
523 use rlx_opt::memory::{self, MemoryPlan};
524 use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
527
528 pub struct CpuBackend;
529
530 const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
537 use rlx_ir::OpKind::*;
538 &[
539 Input,
540 Param,
541 Constant,
542 Activation,
543 Cast,
544 Binary,
545 Compare,
546 Where,
547 ElementwiseRegion,
548 MatMul,
549 DotGeneral,
550 DenseSolve,
551 BatchedDenseSolve,
552 Scan,
553 ScanBackward,
554 ScanBackwardXs,
555 LayerNorm,
556 LayerNorm2d,
557 GroupNorm,
558 RmsNorm,
559 ResizeNearest2x,
560 AxialRope2d,
561 Attention,
562 Rope,
563 Reshape,
564 Transpose,
565 Narrow,
566 Concat,
567 Expand,
568 Gather,
569 Reduce,
570 Softmax,
571 Cumsum,
572 TopK,
573 Sample,
574 Conv,
575 ConvTranspose2d,
576 Pool,
577 GroupedMatMul,
578 DequantGroupedMatMul,
579 DequantMoEWeights,
580 ScatterAdd,
581 LoraMatMul,
582 DequantMatMul,
583 SelectiveScan,
584 GatedDeltaNet,
585 FusedSwiGLU,
586 FusedMatMulBiasAct,
587 FusedResidualLN,
588 FusedResidualRmsNorm,
589 FusedAttentionBlock,
590 ReluBackward,
595 ActivationBackward,
596 FakeQuantize,
597 FakeQuantizeBackward,
598 MaxPool2dBackward,
599 Conv2dBackwardInput,
600 Conv2dBackwardWeight,
601 SoftmaxCrossEntropyWithLogits,
602 SoftmaxCrossEntropyBackward,
603 AttentionBackward,
604 LayerNormBackwardInput,
605 LayerNormBackwardGamma,
606 RmsNormBackwardInput,
607 RmsNormBackwardGamma,
608 RmsNormBackwardBeta,
609 RopeBackward,
610 CumsumBackward,
611 GatherBackward,
612 GaussianSplatRender,
614 GaussianSplatRenderBackward,
615 GaussianSplatPrepare,
616 GaussianSplatRasterize,
617 Custom,
621 CustomFn,
625 Fft,
629 ComplexNormSq,
634 ComplexNormSqBackward,
635 Conjugate,
636 ]
637 };
638
639 impl Backend for CpuBackend {
640 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
641 CPU_SUPPORTED_OPS
642 }
643
644 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
645 use rlx_opt::pass::Pass as _;
646 let graph = rlx_opt::LowerControlFlow.run(graph);
652 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
656 panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
657 }
658 let policy = options.policy.clone();
659 let _precision = options.precision;
660 let cfg = rlx_cpu::config::RuntimeConfig::global();
661
662 let graph = if options.dce {
664 rlx_opt::DeadCodeElimination.run(graph)
665 } else {
666 graph
667 };
668 let graph = if options.constant_folding {
669 rlx_opt::ConstantFolding.run(graph)
670 } else {
671 graph
672 };
673
674 let mut compile_opts = options.clone();
676 compile_opts.arena_alignment = cfg.arena_alignment;
677 let compile_result = crate::stages::compile_graph_stages_for_backend(
678 rlx_driver::Device::Cpu,
679 graph,
680 &compile_opts,
681 CPU_SUPPORTED_OPS,
682 );
683 crate::stages::maybe_log_fusion(&compile_result.fusion);
684 let fused = compile_result.lir.into_graph();
685
686 let fused = match policy {
689 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
690 None => fused,
691 };
692
693 let plan = memory::plan_memory_aligned(&fused, cfg.arena_alignment);
695 if cfg.verbose >= 1 {
696 eprintln!(
697 "[rlx] arena: {} bytes, {} buffers, alignment: {}",
698 plan.arena_size,
699 plan.assignments.len(),
700 cfg.arena_alignment
701 );
702 }
703 Box::new(build_cpu_executable(fused, plan))
704 }
705
706 fn compile_lir(
707 &self,
708 lir: LirModule,
709 options: &CompileOptions,
710 ) -> Box<dyn ExecutableGraph> {
711 let alignment = lir.buffers.alignment.max(options.arena_alignment);
712 let embedded: MemoryPlan = (&lir.buffers).into();
713 let mut graph = lir.into_graph();
714 if let Some(p) = options.policy.clone() {
715 use rlx_opt::pass::Pass;
716 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
717 }
718 let plan = if options.policy.is_some() {
719 memory::plan_memory_aligned(&graph, alignment)
720 } else {
721 embedded
722 };
723 let cfg = rlx_cpu::config::RuntimeConfig::global();
724 if cfg.verbose >= 1 {
725 eprintln!(
726 "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
727 plan.arena_size,
728 plan.assignments.len(),
729 alignment,
730 );
731 }
732 Box::new(build_cpu_executable(graph, plan))
733 }
734 }
735
736 fn build_cpu_executable(graph: Graph, plan: MemoryPlan) -> CpuExecutable {
737 let mut arena = Arena::from_plan(plan);
738 let mut input_ids = HashMap::new();
739 let mut param_ids = HashMap::new();
740 let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
741 for node in graph.nodes() {
742 node_dtypes.insert(node.id, node.shape.dtype());
743 match &node.op {
744 Op::Input { name } => {
745 input_ids.insert(name.clone(), node.id);
746 }
747 Op::Param { name } => {
748 param_ids.insert(name.clone(), node.id);
749 }
750 _ => {}
751 }
752 }
753
754 let schedule = thunk::compile_thunks(&graph, &arena);
755
756 let mut input_slots = Vec::new();
757 for node in graph.nodes() {
758 if let Op::Input { name } = &node.op {
759 let off = arena.byte_offset(node.id);
760 let len = node.shape.num_elements().unwrap_or(0);
761 input_slots.push((name.clone(), off, len, node.shape.dtype()));
762 }
763 }
764
765 let output_slots: Vec<(usize, usize)> = graph
766 .outputs
767 .iter()
768 .map(|&id| {
769 let off = arena.byte_offset(id);
770 let len = graph.node(id).shape.num_elements().unwrap_or(0);
771 (off, len)
772 })
773 .collect();
774
775 for node in graph.nodes() {
776 if let Op::Constant { data } = &node.op
777 && arena.has_buffer(node.id)
778 && !data.is_empty()
779 {
780 match node.shape.dtype() {
781 DType::F64 => {
782 let off = arena.byte_offset(node.id);
783 let buf = arena.raw_buf_mut();
784 let n = buf.len().saturating_sub(off).min(data.len());
785 buf[off..off + n].copy_from_slice(&data[..n]);
786 }
787 _ => {
788 let buf = arena.slice_mut(node.id);
789 let n_floats = data.len() / 4;
790 let n = buf.len().min(n_floats);
791 for i in 0..n {
792 let bytes = [
793 data[i * 4],
794 data[i * 4 + 1],
795 data[i * 4 + 2],
796 data[i * 4 + 3],
797 ];
798 buf[i] = f32::from_le_bytes(bytes);
799 }
800 }
801 }
802 }
803 }
804
805 CpuExecutable {
806 graph,
807 arena,
808 params: HashMap::new(),
809 input_ids,
810 param_ids,
811 node_dtypes,
812 schedule,
813 input_slots,
814 output_slots,
815 handles: HashMap::new(),
816 active_extent: None,
817 moe_resident: None,
818 moe_resident_layers: None,
819 moe_topk_capture: None,
820 }
821 }
822
823 #[derive(Clone)]
824 struct CpuExecutable {
825 graph: Graph,
826 arena: Arena,
827 params: HashMap<String, Vec<f32>>,
828 input_ids: HashMap<String, NodeId>,
829 param_ids: HashMap<String, NodeId>,
830 node_dtypes: HashMap<NodeId, DType>,
833 schedule: thunk::ThunkSchedule,
834 input_slots: Vec<(String, usize, usize, DType)>,
836 output_slots: Vec<(usize, usize)>,
838 handles: HashMap<String, Vec<f32>>,
843 active_extent: Option<(usize, usize)>,
849 moe_resident: Option<std::sync::Arc<[bool]>>,
850 moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
851 moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
852 }
853
854 unsafe impl Send for CpuExecutable {}
855
856 impl CpuExecutable {
857 fn write_input(&mut self, id: NodeId, data: &[f32]) {
859 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
860 let off = self.arena.byte_offset(id);
861 let buf = self.arena.raw_buf_mut();
862 let elem_size = dtype.size_bytes();
863 let max_elems = (buf.len() - off) / elem_size;
864 unsafe {
865 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
866 }
867 }
868
869 fn read_output(&self, id: NodeId) -> Vec<f32> {
871 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
872 let off = self.arena.byte_offset(id);
873 let buf = self.arena.raw_buf();
874 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
875 unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
876 }
877 }
878
879 impl ExecutableGraph for CpuExecutable {
880 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
881 Box::new(self.clone())
882 }
883 fn set_param(&mut self, name: &str, data: &[f32]) {
884 if let Some(&id) = self.param_ids.get(name)
887 && self.arena.has_buffer(id)
888 {
889 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
890 let off = self.arena.byte_offset(id);
891 let buf = self.arena.raw_buf_mut();
892 let elem_size = dtype.size_bytes();
893 let max_elems = (buf.len() - off) / elem_size;
894 unsafe {
895 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
896 }
897 return;
898 }
899 self.params.insert(name.to_string(), data.to_vec());
901 }
902
903 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
904 let handle_names: Vec<String> = self.handles.keys().cloned().collect();
907 for name in &handle_names {
908 if let Some(&id) = self.input_ids.get(name)
909 && self.arena.has_buffer(id)
910 {
911 let data = self.handles.get(name).cloned().unwrap_or_default();
912 self.write_input(id, &data);
913 }
914 }
915 for &(name, data) in inputs {
917 if let Some(&id) = self.input_ids.get(name)
918 && self.arena.has_buffer(id)
919 {
920 self.write_input(id, data);
921 }
922 }
923
924 let active_used = if let Some((actual, upper)) = self.active_extent {
929 thunk::execute_thunks_active(
930 &self.schedule,
931 self.arena.raw_buf_mut(),
932 actual,
933 upper,
934 )
935 } else {
936 false
937 };
938 if !active_used {
939 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
941 }
942
943 for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
947 let name = format!("out{idx}");
948 if self.handles.contains_key(&name) {
949 let v = self.read_output(out_id);
950 self.handles.insert(name, v);
951 }
952 }
953
954 self.graph
955 .outputs
956 .iter()
957 .map(|&out_id| self.read_output(out_id))
958 .collect()
959 }
960
961 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
962 for &(name, data) in inputs {
964 if let Some(&id) = self.input_ids.get(name)
965 && self.arena.has_buffer(id)
966 {
967 self.write_input(id, data);
968 }
969 }
970 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
971 self.graph
975 .outputs
976 .iter()
977 .map(|&out_id| {
978 let (ptr, len) = self.arena.raw_ptr(out_id);
979 (ptr as *const f32, len)
980 })
981 .collect()
982 }
983
984 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
988 let buf = self.arena.raw_buf_mut();
989 for (i, &data) in inputs.iter().enumerate() {
990 if i < self.input_slots.len() {
991 let (_, off, max_len, dtype) = &self.input_slots[i];
992 unsafe {
993 write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
994 }
995 }
996 }
997 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998 &self.output_slots
999 }
1000
1001 fn arena_ptr(&self) -> *const u8 {
1002 self.arena.raw_buf_mut_ptr()
1003 }
1004
1005 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1006 self.handles.insert(name.to_string(), data.to_vec());
1011 true
1012 }
1013
1014 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1015 self.handles.get(name).cloned()
1016 }
1017
1018 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1019 self.active_extent = extent;
1020 }
1021
1022 fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1023 self.moe_resident_layers = None;
1024 self.schedule.moe_resident_layers = None;
1025 self.moe_resident = Some(Arc::from(mask));
1026 self.schedule.moe_resident = self.moe_resident.clone();
1027 }
1028
1029 fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1030 self.moe_resident = None;
1031 self.schedule.moe_resident = None;
1032 let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1033 let arc = Arc::new(layers);
1034 self.moe_resident_layers = Some(arc.clone());
1035 self.schedule.moe_resident_layers = Some(arc);
1036 }
1037
1038 fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1039 let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1040 self.moe_topk_capture = Some(cap.clone());
1041 self.schedule.moe_topk_capture = Some(cap);
1042 true
1043 }
1044
1045 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1046 let cap = self.moe_topk_capture.as_ref()?;
1047 let layers = cap.take_layers();
1048 if layers.is_empty() {
1049 None
1050 } else {
1051 Some(layers)
1052 }
1053 }
1054
1055 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1056 rlx_cpu::moe_residency::take_last_forward_stats()
1057 }
1058
1059 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1065 if dtype == DType::F64 {
1066 self.set_param_bytes(name, data, dtype);
1067 return;
1068 }
1069 if matches!(dtype, DType::U8 | DType::I8) {
1073 self.set_param_bytes(name, data, dtype);
1074 return;
1075 }
1076 if dtype == DType::F32 {
1077 let n = data.len() / 4;
1078 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1079 self.set_param(name, s);
1080 } else {
1081 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1082 self.set_param(name, &f32_buf);
1083 }
1084 }
1085
1086 fn run_typed(
1098 &mut self,
1099 inputs: &[(&str, &[u8], rlx_ir::DType)],
1100 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1101 let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1106
1107 if all_f64 {
1108 for (name, data, _) in inputs {
1109 if let Some(&id) = self.input_ids.get(*name) {
1110 if !self.arena.has_buffer(id) {
1111 continue;
1112 }
1113 let off = self.arena.byte_offset(id);
1114 let buf = self.arena.raw_buf_mut();
1115 let n = data.len();
1116 debug_assert!(
1117 off + n <= buf.len(),
1118 "run_typed: input '{name}' overflows arena slot"
1119 );
1120 buf[off..off + n].copy_from_slice(data);
1121 }
1122 }
1123 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1124 } else {
1125 let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1130 for (name, data, dt) in inputs {
1131 let direct = matches!(*dt, DType::F64 | DType::I32 | DType::I64 | DType::U32,);
1132 if direct {
1133 if let Some(&id) = self.input_ids.get(*name) {
1134 if !self.arena.has_buffer(id) {
1135 continue;
1136 }
1137 let off = self.arena.byte_offset(id);
1138 let buf = self.arena.raw_buf_mut();
1139 buf[off..off + data.len()].copy_from_slice(data);
1140 }
1141 } else {
1142 let v = super::widen_bytes_to_f32(data, *dt);
1143 f32_owned.push((name.to_string(), v));
1144 }
1145 }
1146 let refs: Vec<(&str, &[f32])> = f32_owned
1147 .iter()
1148 .map(|(n, d)| (n.as_str(), d.as_slice()))
1149 .collect();
1150 let _ = self.run(&refs);
1151 }
1152
1153 self.graph
1155 .outputs
1156 .iter()
1157 .map(|&id| {
1158 let dtype = self.graph.node(id).shape.dtype();
1159 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1160 let n_bytes = n_elems * dtype.size_bytes();
1161 let off = self.arena.byte_offset(id);
1162 let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1163 (bytes, dtype)
1164 })
1165 .collect()
1166 }
1167 }
1168
1169 impl CpuExecutable {
1170 fn set_param_bytes(&mut self, name: &str, data: &[u8], _dtype: rlx_ir::DType) {
1176 if let Some(&id) = self.param_ids.get(name)
1177 && self.arena.has_buffer(id)
1178 {
1179 let off = self.arena.byte_offset(id);
1180 let buf = self.arena.raw_buf_mut();
1181 debug_assert!(
1182 off + data.len() <= buf.len(),
1183 "set_param_bytes: '{name}' would overflow arena slot"
1184 );
1185 buf[off..off + data.len()].copy_from_slice(data);
1186 }
1187 }
1188 }
1189}
1190
1191#[cfg(feature = "gpu")]
1196pub mod wgpu_backend {
1197 use super::*;
1198 use rlx_ir::OpKind;
1199 use rlx_wgpu::backend::WgpuExecutable;
1200
1201 pub struct WgpuBackend;
1202
1203 const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1209 OpKind::Input,
1210 OpKind::Param,
1211 OpKind::Constant,
1212 OpKind::Activation,
1213 OpKind::Cast,
1214 OpKind::Binary,
1215 OpKind::Compare,
1216 OpKind::Where,
1217 OpKind::ElementwiseRegion,
1218 OpKind::MatMul,
1219 OpKind::DotGeneral,
1220 OpKind::LayerNorm,
1221 OpKind::RmsNorm,
1222 OpKind::Attention,
1223 OpKind::AttentionBackward,
1224 OpKind::RmsNormBackwardInput,
1225 OpKind::RmsNormBackwardGamma,
1226 OpKind::RmsNormBackwardBeta,
1227 OpKind::RopeBackward,
1228 OpKind::CumsumBackward,
1229 OpKind::GatherBackward,
1230 OpKind::Rope,
1231 OpKind::Reshape,
1232 OpKind::Transpose,
1233 OpKind::Narrow,
1234 OpKind::Concat,
1235 OpKind::Expand,
1236 OpKind::Gather,
1237 OpKind::Reduce,
1238 OpKind::Softmax,
1239 OpKind::Cumsum,
1240 OpKind::TopK,
1241 OpKind::Sample,
1242 OpKind::Conv,
1243 OpKind::Pool,
1244 OpKind::GroupedMatMul,
1245 OpKind::DequantGroupedMatMul,
1246 OpKind::DequantMoEWeights,
1247 OpKind::ScatterAdd,
1248 OpKind::SelectiveScan,
1249 OpKind::DequantMatMul,
1250 OpKind::FusedMatMulBiasAct,
1251 OpKind::FusedResidualLN,
1252 OpKind::FusedResidualRmsNorm,
1253 OpKind::FusedSwiGLU,
1254 OpKind::FusedAttentionBlock,
1255 OpKind::FusedTransformerLayer,
1256 OpKind::Fft,
1262 OpKind::GaussianSplatRender,
1264 OpKind::GaussianSplatRenderBackward,
1265 OpKind::GaussianSplatPrepare,
1266 OpKind::GaussianSplatRasterize,
1267 OpKind::Custom,
1268 ];
1270
1271 impl Backend for WgpuBackend {
1272 fn supported_ops(&self) -> &'static [OpKind] {
1273 WGPU_SUPPORTED_OPS
1274 }
1275
1276 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1277 use rlx_opt::pass::Pass as _;
1278 let graph = rlx_opt::LowerControlFlow.run(graph);
1279 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1280 .unwrap_or_else(|errors| {
1281 panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1282 });
1283 let graph = if options.dce {
1285 rlx_opt::DeadCodeElimination.run(graph)
1286 } else {
1287 graph
1288 };
1289 let graph = if options.constant_folding {
1290 rlx_opt::ConstantFolding.run(graph)
1291 } else {
1292 graph
1293 };
1294 let compile_result = crate::stages::compile_graph_stages_for_backend(
1303 rlx_driver::Device::Gpu,
1304 graph,
1305 options,
1306 WGPU_SUPPORTED_OPS,
1307 );
1308 crate::stages::maybe_log_fusion(&compile_result.fusion);
1309 let graph = compile_result.lir.into_graph();
1310 let graph = match options.policy.clone() {
1311 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1312 None => graph,
1313 };
1314 Box::new(WgpuExecutableWrapper {
1315 inner: WgpuExecutable::compile(graph),
1316 })
1317 }
1318
1319 fn compile_lir(
1320 &self,
1321 lir: LirModule,
1322 options: &CompileOptions,
1323 ) -> Box<dyn ExecutableGraph> {
1324 let graph = prepare_fused_graph(lir.into_graph(), options, WGPU_SUPPORTED_OPS, "wgpu");
1325 Box::new(WgpuExecutableWrapper {
1326 inner: WgpuExecutable::compile(graph),
1327 })
1328 }
1329 }
1330
1331 struct WgpuExecutableWrapper {
1332 inner: WgpuExecutable,
1333 }
1334
1335 unsafe impl Send for WgpuExecutableWrapper {}
1336
1337 impl ExecutableGraph for WgpuExecutableWrapper {
1338 fn set_param(&mut self, name: &str, data: &[f32]) {
1339 self.inner.set_param(name, data);
1340 }
1341 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1342 self.inner.run(inputs)
1343 }
1344 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1345 self.inner.set_active_extent(extent);
1346 }
1347
1348 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1351 match dtype {
1352 rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1353 self.inner.set_param_bytes(name, data);
1354 }
1355 rlx_ir::DType::F32 => {
1356 let n = data.len() / 4;
1357 let f32_slice =
1358 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1359 self.inner.set_param(name, f32_slice);
1360 }
1361 rlx_ir::DType::F16 => {
1362 let n = data.len() / 2;
1363 let f16_slice =
1364 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1365 let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1366 self.inner.set_param(name, &f32);
1367 }
1368 rlx_ir::DType::BF16 => {
1369 let n = data.len() / 2;
1370 let bf16_slice = unsafe {
1371 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1372 };
1373 let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1374 self.inner.set_param(name, &f32);
1375 }
1376 other => panic!(
1377 "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1378 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1379 ),
1380 }
1381 }
1382
1383 fn run_typed(
1386 &mut self,
1387 inputs: &[(&str, &[u8], rlx_ir::DType)],
1388 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1389 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1390 for (name, data, dt) in inputs {
1391 let v: Vec<f32> = match *dt {
1392 rlx_ir::DType::F32 => {
1393 let n = data.len() / 4;
1394 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1395 .to_vec()
1396 }
1397 rlx_ir::DType::F16 => {
1398 let n = data.len() / 2;
1399 let s = unsafe {
1400 std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1401 };
1402 s.iter().map(|h| h.to_f32()).collect()
1403 }
1404 rlx_ir::DType::BF16 => {
1405 let n = data.len() / 2;
1406 let s = unsafe {
1407 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1408 };
1409 s.iter().map(|h| h.to_f32()).collect()
1410 }
1411 other => {
1412 panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1413 }
1414 };
1415 owned.push((name.to_string(), v));
1416 }
1417 let refs: Vec<(&str, &[f32])> = owned
1418 .iter()
1419 .map(|(n, d)| (n.as_str(), d.as_slice()))
1420 .collect();
1421 let dtypes = self.inner.output_dtypes();
1422 let outs = self.inner.run(&refs);
1423 outs.into_iter()
1424 .zip(
1425 dtypes
1426 .into_iter()
1427 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1428 )
1429 .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1430 .collect()
1431 }
1432 }
1433
1434 fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1440 use rlx_ir::DType;
1441 match dt {
1442 DType::F32 => {
1443 let mut bytes = Vec::with_capacity(v.len() * 4);
1444 for &x in v {
1445 bytes.extend_from_slice(&x.to_le_bytes());
1446 }
1447 bytes
1448 }
1449 DType::F16 => {
1450 let mut bytes = Vec::with_capacity(v.len() * 2);
1451 for &x in v {
1452 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1453 }
1454 bytes
1455 }
1456 DType::BF16 => {
1457 let mut bytes = Vec::with_capacity(v.len() * 2);
1458 for &x in v {
1459 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1460 }
1461 bytes
1462 }
1463 DType::F64 => {
1464 let mut bytes = Vec::with_capacity(v.len() * 8);
1465 for &x in v {
1466 bytes.extend_from_slice(&(x as f64).to_le_bytes());
1467 }
1468 bytes
1469 }
1470 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1471 DType::U8 => v.iter().map(|&x| x as u8).collect(),
1472 DType::I16 => {
1473 let mut bytes = Vec::with_capacity(v.len() * 2);
1474 for &x in v {
1475 bytes.extend_from_slice(&(x as i16).to_le_bytes());
1476 }
1477 bytes
1478 }
1479 DType::I32 => {
1480 let mut bytes = Vec::with_capacity(v.len() * 4);
1481 for &x in v {
1482 bytes.extend_from_slice(&(x as i32).to_le_bytes());
1483 }
1484 bytes
1485 }
1486 DType::U32 => {
1487 let mut bytes = Vec::with_capacity(v.len() * 4);
1488 for &x in v {
1489 bytes.extend_from_slice(&(x as u32).to_le_bytes());
1490 }
1491 bytes
1492 }
1493 DType::I64 => {
1494 let mut bytes = Vec::with_capacity(v.len() * 8);
1495 for &x in v {
1496 bytes.extend_from_slice(&(x as i64).to_le_bytes());
1497 }
1498 bytes
1499 }
1500 DType::Bool => v
1501 .iter()
1502 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1503 .collect(),
1504 DType::C64 => {
1511 let mut bytes = Vec::with_capacity(v.len() * 4);
1512 for &x in v {
1513 bytes.extend_from_slice(&x.to_le_bytes());
1514 }
1515 bytes
1516 }
1517 }
1518 }
1519}
1520
1521#[cfg(all(feature = "mlx", target_os = "macos"))]
1524pub mod mlx_backend {
1525 use super::*;
1526 use rlx_mlx::MlxExecutable;
1527
1528 pub struct MlxBackend;
1529
1530 const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1536 use rlx_ir::OpKind::*;
1537 &[
1538 Input,
1539 Param,
1540 Constant,
1541 Activation,
1542 Cast,
1543 Binary,
1544 Compare,
1545 Where,
1546 ElementwiseRegion,
1547 MatMul,
1548 DotGeneral,
1549 DenseSolve,
1550 BatchedDenseSolve,
1551 LayerNorm,
1552 LayerNorm2d,
1553 RmsNorm,
1554 Attention,
1555 Rope,
1556 Reshape,
1557 Transpose,
1558 Narrow,
1559 Concat,
1560 Expand,
1561 Gather,
1562 Reduce,
1563 Softmax,
1564 Cumsum,
1565 TopK,
1566 Sample,
1567 Conv,
1568 ConvTranspose2d,
1569 Pool,
1570 GroupedMatMul,
1571 DequantGroupedMatMul,
1572 DequantMoEWeights,
1573 ScatterAdd,
1574 LoraMatMul,
1575 DequantMatMul,
1576 SelectiveScan,
1577 GatedDeltaNet,
1578 FusedSwiGLU,
1579 FusedMatMulBiasAct,
1580 FusedResidualLN,
1581 FusedResidualRmsNorm,
1582 FusedAttentionBlock,
1583 FusedTransformerLayer,
1584 If,
1585 While,
1586 Scan,
1591 ScanBackward,
1592 ScanBackwardXs,
1593 ReluBackward,
1596 ActivationBackward,
1597 SoftmaxCrossEntropyWithLogits,
1598 SoftmaxCrossEntropyBackward,
1599 AttentionBackward,
1600 LayerNormBackwardInput,
1601 LayerNormBackwardGamma,
1602 Conv2dBackwardInput,
1607 Conv2dBackwardWeight,
1608 MaxPool2dBackward,
1612 FakeQuantize,
1617 FakeQuantizeBackward,
1618 Custom,
1623 Fft,
1624 GaussianSplatRender,
1625 GaussianSplatRenderBackward,
1626 ]
1629 };
1630
1631 impl Backend for MlxBackend {
1632 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1633 MLX_SUPPORTED_OPS
1634 }
1635
1636 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1637 let compile_result = crate::stages::compile_graph_stages_for_backend(
1638 rlx_driver::Device::Mlx,
1639 graph,
1640 options,
1641 MLX_SUPPORTED_OPS,
1642 );
1643 crate::stages::maybe_log_fusion(&compile_result.fusion);
1644 self.compile_lir(compile_result.lir, options)
1645 }
1646
1647 fn compile_lir(
1648 &self,
1649 lir: LirModule,
1650 options: &CompileOptions,
1651 ) -> Box<dyn ExecutableGraph> {
1652 use rlx_opt::pass::Pass as _;
1653 let mut graph = lir.into_graph();
1654 graph = rlx_opt::LowerControlFlow.run(graph);
1655 let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1656 Box::new(build_mlx_executable(graph))
1657 }
1658 }
1659
1660 fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1661 let mode = mlx_mode_from_env();
1662 let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1663 if mode == rlx_mlx::lower::MlxMode::Compiled {
1664 if let Err(e) = exe.warm_compile() {
1665 eprintln!(
1666 "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1667 );
1668 }
1669 }
1670 MlxExecutableWrapper { inner: exe }
1671 }
1672
1673 fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1674 match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1675 Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1676 Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1677 Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1678 _ => rlx_mlx::lower::MlxMode::Compiled,
1679 }
1680 }
1681
1682 struct MlxExecutableWrapper {
1683 inner: MlxExecutable,
1684 }
1685
1686 unsafe impl Send for MlxExecutableWrapper {}
1687
1688 impl ExecutableGraph for MlxExecutableWrapper {
1689 fn set_param(&mut self, name: &str, data: &[f32]) {
1690 self.inner.set_param(name, data);
1691 }
1692 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1693 self.inner.run(inputs)
1694 }
1695 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1696 self.inner.run_slots(inputs)
1697 }
1698 fn arena_ptr(&self) -> *const u8 {
1699 self.inner.arena_ptr()
1700 }
1701 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1702 self.inner.commit_no_wait(inputs);
1703 }
1704 fn sync_pending(&mut self) {
1705 self.inner.sync_pending();
1706 }
1707 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1708 self.inner.run_pipelined(input_sets)
1709 }
1710 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1711 self.inner.bind_handle(name, data)
1712 }
1713 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1714 self.inner.read_handle(name)
1715 }
1716 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1717 self.inner.bind_gpu_handle(name, data).is_ok()
1718 }
1719 fn has_gpu_handle(&self, name: &str) -> bool {
1720 self.inner.has_gpu_handle(name)
1721 }
1722 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1723 self.inner.set_gpu_handle_feed(handle_name, output_index);
1724 true
1725 }
1726 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1727 self.inner.read_gpu_handle(name).ok()
1728 }
1729 fn run_feed_gpu_handle(
1730 &mut self,
1731 inputs: &[(&str, &[f32])],
1732 handle_name: &str,
1733 output_index: usize,
1734 ) -> Option<Vec<f32>> {
1735 self.inner
1736 .run_feed_gpu(inputs, handle_name, output_index)
1737 .ok()
1738 }
1739 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1740 self.inner.set_param_typed(name, data, dtype);
1741 }
1742 fn run_typed(
1743 &mut self,
1744 inputs: &[(&str, &[u8], rlx_ir::DType)],
1745 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1746 self.inner.run_typed(inputs)
1747 }
1748 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1749 self.inner.set_active_extent(extent);
1750 }
1751 }
1752}
1753
1754#[cfg(all(feature = "metal", target_os = "macos"))]
1755pub mod metal_backend {
1756 use super::*;
1757 use rlx_metal::backend::MetalExecutable;
1758
1759 pub struct MetalBackend;
1760
1761 const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1770 use rlx_ir::OpKind::*;
1771 &[
1772 Input,
1773 Param,
1774 Constant,
1775 Activation,
1776 Cast,
1777 Binary,
1778 Compare,
1779 Where,
1780 ElementwiseRegion,
1781 MatMul,
1782 DotGeneral,
1783 LayerNorm,
1784 LayerNorm2d,
1785 GroupNorm,
1786 RmsNorm,
1787 ResizeNearest2x,
1788 AxialRope2d,
1789 Attention,
1790 AttentionBackward,
1791 RmsNormBackwardInput,
1792 RmsNormBackwardGamma,
1793 RmsNormBackwardBeta,
1794 RopeBackward,
1795 CumsumBackward,
1796 GatherBackward,
1797 Rope,
1798 Reshape,
1799 Transpose,
1800 Narrow,
1801 Concat,
1802 Expand,
1803 Gather,
1804 Reduce,
1805 Softmax,
1806 TopK,
1807 Conv,
1808 ConvTranspose2d,
1809 Pool,
1810 GroupedMatMul,
1811 DequantGroupedMatMul,
1812 DequantMoEWeights,
1813 ScatterAdd,
1814 DequantMatMul,
1815 GatedDeltaNet,
1816 FusedSwiGLU,
1817 FusedMatMulBiasAct,
1818 FusedResidualLN,
1819 FusedResidualRmsNorm,
1820 Custom,
1826 Fft,
1832 GaussianSplatRender,
1834 GaussianSplatRenderBackward,
1835 GaussianSplatPrepare,
1836 GaussianSplatRasterize,
1837 ]
1838 };
1839
1840 impl Backend for MetalBackend {
1841 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1842 METAL_SUPPORTED_OPS
1843 }
1844
1845 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1846 use rlx_opt::pass::Pass as _;
1847 let graph = rlx_opt::LowerControlFlow.run(graph);
1851 let mut dispatch = options.kernel_dispatch;
1852 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1853 graph,
1854 METAL_SUPPORTED_OPS,
1855 dispatch,
1856 )
1857 .unwrap_or_else(|errors| {
1858 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1859 });
1860 let graph = if options.dce {
1862 rlx_opt::DeadCodeElimination.run(graph)
1863 } else {
1864 graph
1865 };
1866 let graph = if options.constant_folding {
1867 rlx_opt::ConstantFolding.run(graph)
1868 } else {
1869 graph
1870 };
1871
1872 Box::new(MetalExecutableWrapper {
1875 inner: MetalExecutable::compile_with_policy(
1876 graph,
1877 options.policy.clone(),
1878 Some(METAL_SUPPORTED_OPS),
1879 ),
1880 })
1881 }
1882
1883 fn compile_lir(
1884 &self,
1885 lir: LirModule,
1886 options: &CompileOptions,
1887 ) -> Box<dyn ExecutableGraph> {
1888 use rlx_opt::pass::Pass as _;
1889 let mut graph = lir.into_graph();
1890 graph = rlx_opt::LowerControlFlow.run(graph);
1891 let mut dispatch = options.kernel_dispatch;
1892 let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1893 graph,
1894 METAL_SUPPORTED_OPS,
1895 dispatch,
1896 )
1897 .unwrap_or_else(|errors| {
1898 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1899 });
1900 if options.dce {
1901 graph = rlx_opt::DeadCodeElimination.run(graph);
1902 }
1903 if options.constant_folding {
1904 graph = rlx_opt::ConstantFolding.run(graph);
1905 }
1906 Box::new(MetalExecutableWrapper {
1907 inner: MetalExecutable::compile_from_fused(
1908 graph,
1909 options.policy.clone(),
1910 Some(METAL_SUPPORTED_OPS),
1911 ),
1912 })
1913 }
1914 }
1915
1916 struct MetalExecutableWrapper {
1917 inner: MetalExecutable,
1918 }
1919
1920 unsafe impl Send for MetalExecutableWrapper {}
1921
1922 impl ExecutableGraph for MetalExecutableWrapper {
1923 fn set_param(&mut self, name: &str, data: &[f32]) {
1924 self.inner.set_param(name, data);
1925 }
1926 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1927 self.inner.run(inputs)
1928 }
1929 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1930 self.inner.run_slots(inputs)
1931 }
1932 fn arena_ptr(&self) -> *const u8 {
1933 self.inner.arena_ptr()
1934 }
1935 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1936 self.inner.commit_no_wait(inputs);
1937 }
1938 fn sync_pending(&mut self) {
1939 self.inner.sync_pending();
1940 }
1941 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1942 self.inner.run_pipelined(input_sets)
1943 }
1944 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1945 self.inner.set_active_extent(extent);
1946 }
1947
1948 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1954 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
1955 self.inner.set_param_bytes(name, data);
1956 return;
1957 }
1958 if dtype == rlx_ir::DType::F32 {
1959 let n = data.len() / 4;
1960 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1961 self.inner.set_param(name, s);
1962 } else {
1963 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1964 self.inner.set_param(name, &f32_buf);
1965 }
1966 }
1967
1968 fn run_typed(
1976 &mut self,
1977 inputs: &[(&str, &[u8], rlx_ir::DType)],
1978 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1979 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1980 for (name, data, dt) in inputs {
1981 let v = super::widen_bytes_to_f32(data, *dt);
1982 owned.push((name.to_string(), v));
1983 }
1984 let refs: Vec<(&str, &[f32])> = owned
1985 .iter()
1986 .map(|(n, d)| (n.as_str(), d.as_slice()))
1987 .collect();
1988 let dtypes = self.inner.output_dtypes();
1989 let f32_outs = self.inner.run(&refs);
1990 let byte_outs = self.inner.output_bytes_per_node();
1991 f32_outs
1992 .into_iter()
1993 .zip(byte_outs.into_iter())
1994 .zip(
1995 dtypes
1996 .into_iter()
1997 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1998 )
1999 .map(|((f32_v, byte_v), dt)| match dt {
2000 rlx_ir::DType::F64 => (byte_v, dt),
2001 _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2002 })
2003 .collect()
2004 }
2005 }
2006}
2007
2008#[cfg(feature = "cuda")]
2011pub mod cuda_backend {
2012 use super::*;
2013 use rlx_cuda::backend::CudaExecutable;
2014
2015 pub struct CudaBackend;
2016
2017 const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2023 use rlx_ir::OpKind::*;
2024 &[
2025 Input,
2026 Param,
2027 Constant,
2028 Activation,
2029 Cast,
2030 Binary,
2031 Compare,
2032 Where,
2033 ElementwiseRegion,
2034 MatMul,
2035 DotGeneral,
2036 LayerNorm,
2037 LayerNorm2d,
2038 RmsNorm,
2039 Attention,
2040 AttentionBackward,
2041 RmsNormBackwardInput,
2042 RmsNormBackwardGamma,
2043 RmsNormBackwardBeta,
2044 RopeBackward,
2045 CumsumBackward,
2046 GatherBackward,
2047 Rope,
2048 Reshape,
2049 Transpose,
2050 Narrow,
2051 Concat,
2052 Expand,
2053 Gather,
2054 Reduce,
2055 Softmax,
2056 Cumsum,
2057 TopK,
2058 Sample,
2059 Conv,
2060 ConvTranspose2d,
2061 Pool,
2062 GroupedMatMul,
2063 DequantGroupedMatMul,
2064 DequantMoEWeights,
2065 ScatterAdd,
2066 DequantMatMul,
2067 SelectiveScan,
2068 FusedMatMulBiasAct,
2069 FusedResidualLN,
2070 FusedResidualRmsNorm,
2071 GaussianSplatRender,
2072 GaussianSplatRenderBackward,
2073 GaussianSplatPrepare,
2074 GaussianSplatRasterize,
2075 Custom,
2076 Fft,
2077 ]
2078 };
2079
2080 impl Backend for CudaBackend {
2081 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2082 CUDA_SUPPORTED_OPS
2083 }
2084
2085 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2086 let graph = rlx_cuda::unfuse::unfuse(graph);
2089 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2090 .unwrap_or_else(|errors| {
2091 panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2092 });
2093 use rlx_opt::pass::Pass as _;
2094 let graph = if options.dce {
2095 rlx_opt::DeadCodeElimination.run(graph)
2096 } else {
2097 graph
2098 };
2099 let graph = if options.constant_folding {
2100 rlx_opt::ConstantFolding.run(graph)
2101 } else {
2102 graph
2103 };
2104 let compile_result = crate::stages::compile_graph_stages_for_backend(
2106 rlx_driver::Device::Cuda,
2107 graph,
2108 options,
2109 CUDA_SUPPORTED_OPS,
2110 );
2111 crate::stages::maybe_log_fusion(&compile_result.fusion);
2112 let graph = compile_result.lir.into_graph();
2113 let graph = match options.policy.clone() {
2114 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2115 None => graph,
2116 };
2117 Box::new(CudaExecutableWrapper {
2118 inner: CudaExecutable::compile(graph),
2119 })
2120 }
2121
2122 fn compile_lir(
2123 &self,
2124 lir: LirModule,
2125 options: &CompileOptions,
2126 ) -> Box<dyn ExecutableGraph> {
2127 let graph = prepare_fused_graph(
2128 rlx_cuda::unfuse::unfuse(lir.into_graph()),
2129 options,
2130 CUDA_SUPPORTED_OPS,
2131 "cuda",
2132 );
2133 Box::new(CudaExecutableWrapper {
2134 inner: CudaExecutable::compile(graph),
2135 })
2136 }
2137 }
2138
2139 struct CudaExecutableWrapper {
2140 inner: CudaExecutable,
2141 }
2142
2143 unsafe impl Send for CudaExecutableWrapper {}
2148
2149 impl ExecutableGraph for CudaExecutableWrapper {
2150 fn set_param(&mut self, name: &str, data: &[f32]) {
2151 self.inner.set_param(name, data);
2152 }
2153 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2154 self.inner.run(inputs)
2155 }
2156 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2157 self.inner.set_active_extent(extent);
2158 }
2159
2160 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2165 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2166 self.inner.set_param_bytes(name, data);
2167 return;
2168 }
2169 if dtype == rlx_ir::DType::F32 {
2170 let n = data.len() / 4;
2171 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2172 self.inner.set_param(name, s);
2173 } else {
2174 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2175 self.inner.set_param(name, &f32_buf);
2176 }
2177 }
2178
2179 fn run_typed(
2182 &mut self,
2183 inputs: &[(&str, &[u8], rlx_ir::DType)],
2184 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2185 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2186 for (name, data, dt) in inputs {
2187 let v = super::widen_bytes_to_f32(data, *dt);
2188 owned.push((name.to_string(), v));
2189 }
2190 let refs: Vec<(&str, &[f32])> = owned
2191 .iter()
2192 .map(|(n, d)| (n.as_str(), d.as_slice()))
2193 .collect();
2194 let dtypes = self.inner.output_dtypes();
2195 let outs = self.inner.run(&refs);
2196 outs.into_iter()
2197 .zip(
2198 dtypes
2199 .into_iter()
2200 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2201 )
2202 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2203 .collect()
2204 }
2205 }
2206}
2207
2208#[cfg(feature = "rocm")]
2211pub mod rocm_backend {
2212 use super::*;
2213 use rlx_rocm::backend::RocmExecutable;
2214
2215 pub struct RocmBackend;
2216
2217 const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2220 use rlx_ir::OpKind::*;
2221 &[
2222 Input,
2223 Param,
2224 Constant,
2225 Activation,
2226 Cast,
2227 Binary,
2228 Compare,
2229 Where,
2230 ElementwiseRegion,
2231 MatMul,
2232 DotGeneral,
2233 LayerNorm,
2234 RmsNorm,
2235 Attention,
2236 AttentionBackward,
2237 Rope,
2238 Reshape,
2239 Transpose,
2240 Narrow,
2241 Concat,
2242 Expand,
2243 Gather,
2244 Reduce,
2245 Softmax,
2246 Cumsum,
2247 TopK,
2248 Sample,
2249 Conv,
2250 Pool,
2251 GroupedMatMul,
2252 DequantGroupedMatMul,
2253 DequantMoEWeights,
2254 ScatterAdd,
2255 DequantMatMul,
2256 SelectiveScan,
2257 FusedMatMulBiasAct,
2258 FusedResidualLN,
2259 FusedResidualRmsNorm,
2260 GaussianSplatRender,
2261 GaussianSplatRenderBackward,
2262 GaussianSplatPrepare,
2263 GaussianSplatRasterize,
2264 Custom,
2265 Fft,
2266 ]
2267 };
2268
2269 impl Backend for RocmBackend {
2270 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2271 ROCM_SUPPORTED_OPS
2272 }
2273
2274 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2275 let graph = rlx_opt::rewrite_for_backend(graph, ROCM_SUPPORTED_OPS);
2276 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, ROCM_SUPPORTED_OPS) {
2277 panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2278 }
2279 use rlx_opt::pass::Pass as _;
2280 let graph = if options.dce {
2281 rlx_opt::DeadCodeElimination.run(graph)
2282 } else {
2283 graph
2284 };
2285 let graph = if options.constant_folding {
2286 rlx_opt::ConstantFolding.run(graph)
2287 } else {
2288 graph
2289 };
2290 let compile_result = crate::stages::compile_graph_stages_for_backend(
2291 rlx_driver::Device::Rocm,
2292 graph,
2293 options,
2294 ROCM_SUPPORTED_OPS,
2295 );
2296 crate::stages::maybe_log_fusion(&compile_result.fusion);
2297 let graph = compile_result.lir.into_graph();
2298 let graph = match options.policy.clone() {
2299 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2300 None => graph,
2301 };
2302 Box::new(RocmExecutableWrapper {
2303 inner: RocmExecutable::compile(graph),
2304 })
2305 }
2306
2307 fn compile_lir(
2308 &self,
2309 lir: LirModule,
2310 options: &CompileOptions,
2311 ) -> Box<dyn ExecutableGraph> {
2312 let graph = prepare_fused_graph(lir.into_graph(), options, ROCM_SUPPORTED_OPS, "rocm");
2313 Box::new(RocmExecutableWrapper {
2314 inner: RocmExecutable::compile(graph),
2315 })
2316 }
2317 }
2318
2319 struct RocmExecutableWrapper {
2320 inner: RocmExecutable,
2321 }
2322
2323 unsafe impl Send for RocmExecutableWrapper {}
2327
2328 impl ExecutableGraph for RocmExecutableWrapper {
2329 fn set_param(&mut self, name: &str, data: &[f32]) {
2330 self.inner.set_param(name, data);
2331 }
2332 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2333 self.inner.run(inputs)
2334 }
2335 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2336 self.inner.set_active_extent(extent);
2337 }
2338
2339 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2344 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2345 self.inner.set_param_bytes(name, data);
2346 return;
2347 }
2348 if dtype == rlx_ir::DType::F32 {
2349 let n = data.len() / 4;
2350 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2351 self.inner.set_param(name, s);
2352 } else {
2353 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2354 self.inner.set_param(name, &f32_buf);
2355 }
2356 }
2357
2358 fn run_typed(
2361 &mut self,
2362 inputs: &[(&str, &[u8], rlx_ir::DType)],
2363 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2364 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2365 for (name, data, dt) in inputs {
2366 let v = super::widen_bytes_to_f32(data, *dt);
2367 owned.push((name.to_string(), v));
2368 }
2369 let refs: Vec<(&str, &[f32])> = owned
2370 .iter()
2371 .map(|(n, d)| (n.as_str(), d.as_slice()))
2372 .collect();
2373 let dtypes = self.inner.output_dtypes();
2374 let outs = self.inner.run(&refs);
2375 outs.into_iter()
2376 .zip(
2377 dtypes
2378 .into_iter()
2379 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2380 )
2381 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2382 .collect()
2383 }
2384 }
2385}
2386
2387#[cfg(feature = "tpu")]
2390pub mod tpu_backend {
2391 use super::*;
2392 use rlx_tpu::TpuExecutable;
2393
2394 pub struct TpuBackend;
2395
2396 const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2402 use rlx_ir::OpKind::*;
2403 &[
2404 Input,
2405 Param,
2406 Constant,
2407 Activation,
2408 Cast,
2409 Binary,
2410 Compare,
2411 Where,
2412 ElementwiseRegion,
2413 MatMul,
2414 DotGeneral,
2415 LayerNorm,
2416 RmsNorm,
2417 Attention,
2418 Rope,
2419 Reshape,
2420 Transpose,
2421 Narrow,
2422 Concat,
2423 Expand,
2424 Gather,
2425 Reduce,
2426 Softmax,
2427 Cumsum,
2428 TopK,
2429 Sample,
2430 Conv,
2431 Pool,
2432 GroupedMatMul,
2433 DequantGroupedMatMul,
2434 DequantMoEWeights,
2435 ScatterAdd,
2436 DequantMatMul,
2437 SelectiveScan,
2438 QMatMul,
2440 QConv2d,
2441 Quantize,
2442 Dequantize,
2443 FusedMatMulBiasAct,
2444 FusedResidualLN,
2445 FusedResidualRmsNorm,
2446 Fft,
2447 ]
2449 };
2450
2451 impl Backend for TpuBackend {
2452 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2453 TPU_SUPPORTED_OPS
2454 }
2455
2456 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2457 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2458 graph,
2459 TPU_SUPPORTED_OPS,
2460 options.kernel_dispatch,
2461 )
2462 .unwrap_or_else(|errors| {
2463 panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2464 });
2465 use rlx_opt::pass::Pass as _;
2481 let policy = options
2482 .policy
2483 .clone()
2484 .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2485 let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2486 let _ = options.dce;
2487 let _ = options.constant_folding;
2488 Box::new(TpuExecutableWrapper {
2489 inner: TpuExecutable::compile(graph),
2490 })
2491 }
2492 }
2493
2494 struct TpuExecutableWrapper {
2495 inner: TpuExecutable,
2496 }
2497
2498 unsafe impl Send for TpuExecutableWrapper {}
2502
2503 impl ExecutableGraph for TpuExecutableWrapper {
2504 fn set_param(&mut self, name: &str, data: &[f32]) {
2505 self.inner.set_param(name, data);
2506 }
2507 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2508 self.inner.run(inputs)
2509 }
2510
2511 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2516 if dtype == rlx_ir::DType::F32 {
2517 let n = data.len() / 4;
2518 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2519 self.inner.set_param(name, s);
2520 } else {
2521 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2522 self.inner.set_param(name, &f32_buf);
2523 }
2524 }
2525
2526 fn run_typed(
2527 &mut self,
2528 inputs: &[(&str, &[u8], rlx_ir::DType)],
2529 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2530 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2531 for (name, data, dt) in inputs {
2532 let v = super::widen_bytes_to_f32(data, *dt);
2533 owned.push((name.to_string(), v));
2534 }
2535 let refs: Vec<(&str, &[f32])> = owned
2536 .iter()
2537 .map(|(n, d)| (n.as_str(), d.as_slice()))
2538 .collect();
2539 let dtypes = self.inner.output_dtypes();
2540 let outs = self.inner.run(&refs);
2541 outs.into_iter()
2542 .zip(
2543 dtypes
2544 .into_iter()
2545 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2546 )
2547 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2548 .collect()
2549 }
2550 }
2551}