1use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29#[allow(dead_code)]
36pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
37 use rlx_ir::DType;
38 match dtype {
39 DType::F32 => {
40 let n = data.len() / 4;
41 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
42 s.to_vec()
43 }
44 DType::F16 => {
45 let n = data.len() / 2;
46 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
47 s.iter().map(|h| h.to_f32()).collect()
48 }
49 DType::BF16 => {
50 let n = data.len() / 2;
51 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
52 s.iter().map(|h| h.to_f32()).collect()
53 }
54 other => panic!(
55 "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
56 (only F32/F16/BF16 are accepted on the host I/O surface)"
57 ),
58 }
59}
60
61#[allow(dead_code)]
66pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
67 use rlx_ir::DType;
68 match dt {
69 DType::F32 => {
70 let mut bytes = Vec::with_capacity(v.len() * 4);
71 for &x in v {
72 bytes.extend_from_slice(&x.to_le_bytes());
73 }
74 bytes
75 }
76 DType::F16 => {
77 let mut bytes = Vec::with_capacity(v.len() * 2);
78 for &x in v {
79 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
80 }
81 bytes
82 }
83 DType::BF16 => {
84 let mut bytes = Vec::with_capacity(v.len() * 2);
85 for &x in v {
86 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
87 }
88 bytes
89 }
90 DType::F64 => {
91 let mut bytes = Vec::with_capacity(v.len() * 8);
92 for &x in v {
93 bytes.extend_from_slice(&(x as f64).to_le_bytes());
94 }
95 bytes
96 }
97 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
98 DType::U8 => v.iter().map(|&x| x as u8).collect(),
99 DType::I16 => {
100 let mut bytes = Vec::with_capacity(v.len() * 2);
101 for &x in v {
102 bytes.extend_from_slice(&(x as i16).to_le_bytes());
103 }
104 bytes
105 }
106 DType::I32 => {
107 let mut bytes = Vec::with_capacity(v.len() * 4);
108 for &x in v {
109 bytes.extend_from_slice(&(x as i32).to_le_bytes());
110 }
111 bytes
112 }
113 DType::U32 => {
114 let mut bytes = Vec::with_capacity(v.len() * 4);
115 for &x in v {
116 bytes.extend_from_slice(&(x as u32).to_le_bytes());
117 }
118 bytes
119 }
120 DType::I64 => {
121 let mut bytes = Vec::with_capacity(v.len() * 8);
122 for &x in v {
123 bytes.extend_from_slice(&(x as i64).to_le_bytes());
124 }
125 bytes
126 }
127 DType::Bool => v
128 .iter()
129 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
130 .collect(),
131 DType::C64 => {
132 let mut bytes = Vec::with_capacity(v.len() * 8);
136 for &x in v {
137 bytes.extend_from_slice(&x.to_le_bytes());
138 bytes.extend_from_slice(&0.0_f32.to_le_bytes());
139 }
140 bytes
141 }
142 }
143}
144
145pub trait ExecutableGraph: Send {
147 fn set_param(&mut self, name: &str, data: &[f32]);
149
150 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
157 panic!("clone_box not implemented for this backend");
158 }
159
160 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
162
163 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
165 let vecs = self.run(inputs);
166 vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
167 }
168
169 fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
172 &[] }
174
175 fn arena_ptr(&self) -> *const u8 {
177 std::ptr::null()
178 }
179
180 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
197 let _ = extent;
198 }
199
200 fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
202
203 fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
205
206 fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
208 false
209 }
210
211 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
213 None
214 }
215
216 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
218 None
219 }
220
221 fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
225 false
226 }
227
228 fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
230 None
231 }
232
233 fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
235 false
236 }
237
238 fn has_gpu_handle(&self, _name: &str) -> bool {
239 false
240 }
241
242 fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
243 false
244 }
245
246 fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
247 None
248 }
249
250 fn run_feed_gpu_handle(
252 &mut self,
253 inputs: &[(&str, &[f32])],
254 _handle_name: &str,
255 _output_index: usize,
256 ) -> Option<Vec<f32>> {
257 let _ = inputs;
258 None
259 }
260
261 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
276 let _ = self.run(inputs);
277 }
278
279 fn sync_pending(&mut self) {}
282
283 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
292 input_sets.iter().map(|inputs| self.run(inputs)).collect()
293 }
294
295 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
308 if dtype != rlx_ir::DType::F32 {
309 panic!(
310 "backend's default set_param_typed only handles F32; \
311 got {dtype:?}. Override on the backend for typed support."
312 );
313 }
314 if !data.len().is_multiple_of(4) {
315 panic!(
316 "set_param_typed F32: data length {} not a multiple of 4",
317 data.len()
318 );
319 }
320 let n = data.len() / 4;
325 let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
326 self.set_param(name, f32_slice);
327 }
328
329 fn run_typed(
333 &mut self,
334 inputs: &[(&str, &[u8], rlx_ir::DType)],
335 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
336 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
339 for (name, data, dt) in inputs {
340 if *dt != rlx_ir::DType::F32 {
341 panic!(
342 "backend's default run_typed only handles F32 inputs; \
343 got {dt:?} for input '{name}'"
344 );
345 }
346 if data.len() % 4 != 0 {
347 panic!(
348 "run_typed F32 input '{name}': len {} not multiple of 4",
349 data.len()
350 );
351 }
352 let n = data.len() / 4;
353 let v: Vec<f32> =
354 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
355 owned.push((name.to_string(), v));
356 }
357 let refs: Vec<(&str, &[f32])> = owned
358 .iter()
359 .map(|(n, d)| (n.as_str(), d.as_slice()))
360 .collect();
361 let outs = self.run(&refs);
362 outs.into_iter()
363 .map(|v| {
364 let bytes =
365 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
366 .to_vec();
367 (bytes, rlx_ir::DType::F32)
368 })
369 .collect()
370 }
371}
372
373pub trait Backend: Send + Sync {
383 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
385
386 fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
390 self.compile(lir.into_graph(), options)
391 }
392
393 fn compile_hir(
395 &self,
396 hir: HirModule,
397 device: rlx_driver::Device,
398 options: &CompileOptions,
399 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
400 let result = crate::stages::compile_hir_stages(device, hir, options)?;
401 crate::stages::maybe_log_fusion(&result.fusion);
402 Ok(self.compile_lir(result.lir, options))
403 }
404
405 fn compile_module(
407 &self,
408 module: rlx_ir::GraphModule,
409 device: rlx_driver::Device,
410 options: &CompileOptions,
411 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
412 let result = crate::stages::compile_module_stages(device, module, options)?;
413 crate::stages::maybe_log_fusion(&result.fusion);
414 Ok(self.compile_lir(result.lir, options))
415 }
416
417 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
424 &[]
425 }
426}
427
428#[allow(dead_code)]
431fn prepare_fused_graph(
432 graph: Graph,
433 options: &CompileOptions,
434 supported_ops: &[rlx_ir::OpKind],
435 backend_name: &str,
436) -> Graph {
437 let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
438 graph,
439 backend_name,
440 supported_ops,
441 options.kernel_dispatch,
442 );
443 rlx_opt::maybe_log_dispatch_report(&report);
444 if !report.compile_ready {
445 panic!(
446 "{}\n{}",
447 rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
448 rlx_opt::format_dispatch_report(&report)
449 );
450 }
451 use rlx_opt::pass::Pass as _;
452 if options.dce {
453 graph = rlx_opt::DeadCodeElimination.run(graph);
454 }
455 if options.constant_folding {
456 graph = rlx_opt::ConstantFolding.run(graph);
457 }
458 if let Some(p) = options.policy.clone() {
459 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
460 }
461 graph
462}
463
464pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
472 backend.compile(graph, &CompileOptions::default())
473}
474
475pub fn compile_hir(
477 backend: &dyn Backend,
478 hir: HirModule,
479 device: rlx_driver::Device,
480 options: &CompileOptions,
481) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
482 backend.compile_hir(hir, device, options)
483}
484
485pub fn compile_module(
487 backend: &dyn Backend,
488 module: rlx_ir::GraphModule,
489 device: rlx_driver::Device,
490 options: &CompileOptions,
491) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
492 backend.compile_module(module, device, options)
493}
494
495pub fn compile_with_precision(
497 backend: &dyn Backend,
498 graph: Graph,
499 precision: crate::Precision,
500) -> Box<dyn ExecutableGraph> {
501 backend.compile(graph, &CompileOptions::new().precision(precision))
502}
503
504fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
509 use rlx_opt::pass::Pass as _;
510 match policy {
511 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
512 None => graph,
513 }
514}
515
516#[cfg(feature = "cpu")]
519pub mod cpu_backend {
520 use super::*;
521 use rlx_cpu::{arena::Arena, thunk};
522 use rlx_ir::{DType, NodeId, Op};
523 use rlx_opt::memory::{self, MemoryPlan};
524 use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
527
528 pub struct CpuBackend;
529
530 const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
537 use rlx_ir::OpKind::*;
538 &[
539 Input,
540 Param,
541 Constant,
542 Activation,
543 Cast,
544 Binary,
545 Compare,
546 Where,
547 ElementwiseRegion,
548 MatMul,
549 DotGeneral,
550 DenseSolve,
551 BatchedDenseSolve,
552 Scan,
553 ScanBackward,
554 ScanBackwardXs,
555 LayerNorm,
556 LayerNorm2d,
557 GroupNorm,
558 RmsNorm,
559 ResizeNearest2x,
560 AxialRope2d,
561 Attention,
562 Rope,
563 Reshape,
564 Transpose,
565 Narrow,
566 Concat,
567 Expand,
568 Gather,
569 Reduce,
570 Softmax,
571 Cumsum,
572 TopK,
573 Sample,
574 Conv,
575 ConvTranspose2d,
576 Pool,
577 GroupedMatMul,
578 DequantGroupedMatMul,
579 DequantMoEWeights,
580 ScatterAdd,
581 LoraMatMul,
582 DequantMatMul,
583 SelectiveScan,
584 GatedDeltaNet,
585 FusedSwiGLU,
586 FusedMatMulBiasAct,
587 FusedResidualLN,
588 FusedResidualRmsNorm,
589 FusedAttentionBlock,
590 ReluBackward,
595 ActivationBackward,
596 FakeQuantize,
597 FakeQuantizeBackward,
598 MaxPool2dBackward,
599 Conv2dBackwardInput,
600 Conv2dBackwardWeight,
601 SoftmaxCrossEntropyWithLogits,
602 SoftmaxCrossEntropyBackward,
603 AttentionBackward,
604 LayerNormBackwardInput,
605 LayerNormBackwardGamma,
606 RmsNormBackwardInput,
607 RmsNormBackwardGamma,
608 RmsNormBackwardBeta,
609 RopeBackward,
610 CumsumBackward,
611 GatherBackward,
612 GaussianSplatRender,
614 GaussianSplatRenderBackward,
615 GaussianSplatPrepare,
616 GaussianSplatRasterize,
617 Custom,
621 CustomFn,
625 Fft,
629 ComplexNormSq,
634 ComplexNormSqBackward,
635 Conjugate,
636 ]
637 };
638
639 impl Backend for CpuBackend {
640 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
641 CPU_SUPPORTED_OPS
642 }
643
644 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
645 use rlx_opt::pass::Pass as _;
646 let graph = rlx_opt::LowerControlFlow.run(graph);
652 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
656 panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
657 }
658 let policy = options.policy.clone();
659 let _precision = options.precision;
660 let cfg = rlx_cpu::config::RuntimeConfig::global();
661
662 let graph = if options.dce {
664 rlx_opt::DeadCodeElimination.run(graph)
665 } else {
666 graph
667 };
668 let graph = if options.constant_folding {
669 rlx_opt::ConstantFolding.run(graph)
670 } else {
671 graph
672 };
673
674 let mut compile_opts = options.clone();
676 compile_opts.arena_alignment = cfg.arena_alignment;
677 let compile_result = crate::stages::compile_graph_stages_for_backend(
678 rlx_driver::Device::Cpu,
679 graph,
680 &compile_opts,
681 CPU_SUPPORTED_OPS,
682 );
683 crate::stages::maybe_log_fusion(&compile_result.fusion);
684 let fused = compile_result.lir.into_graph();
685
686 let fused = match policy {
689 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
690 None => fused,
691 };
692
693 let plan = memory::plan_memory_aligned(&fused, cfg.arena_alignment);
695 if cfg.verbose >= 1 {
696 eprintln!(
697 "[rlx] arena: {} bytes, {} buffers, alignment: {}",
698 plan.arena_size,
699 plan.assignments.len(),
700 cfg.arena_alignment
701 );
702 }
703 Box::new(build_cpu_executable(fused, plan))
704 }
705
706 fn compile_lir(
707 &self,
708 lir: LirModule,
709 options: &CompileOptions,
710 ) -> Box<dyn ExecutableGraph> {
711 let alignment = lir.buffers.alignment.max(options.arena_alignment);
712 let embedded: MemoryPlan = (&lir.buffers).into();
713 let mut graph = lir.into_graph();
714 if let Some(p) = options.policy.clone() {
715 use rlx_opt::pass::Pass;
716 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
717 }
718 let plan = if options.policy.is_some() {
719 memory::plan_memory_aligned(&graph, alignment)
720 } else {
721 embedded
722 };
723 let cfg = rlx_cpu::config::RuntimeConfig::global();
724 if cfg.verbose >= 1 {
725 eprintln!(
726 "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
727 plan.arena_size,
728 plan.assignments.len(),
729 alignment,
730 );
731 }
732 Box::new(build_cpu_executable(graph, plan))
733 }
734 }
735
736 fn build_cpu_executable(graph: Graph, plan: MemoryPlan) -> CpuExecutable {
737 let mut arena = Arena::from_plan(plan);
738 let mut input_ids = HashMap::new();
739 let mut param_ids = HashMap::new();
740 let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
741 for node in graph.nodes() {
742 node_dtypes.insert(node.id, node.shape.dtype());
743 match &node.op {
744 Op::Input { name } => {
745 input_ids.insert(name.clone(), node.id);
746 }
747 Op::Param { name } => {
748 param_ids.insert(name.clone(), node.id);
749 }
750 _ => {}
751 }
752 }
753
754 let schedule = thunk::compile_thunks(&graph, &arena);
755
756 let mut input_slots = Vec::new();
757 for node in graph.nodes() {
758 if let Op::Input { name } = &node.op {
759 let off = arena.byte_offset(node.id);
760 let len = node.shape.num_elements().unwrap_or(0);
761 input_slots.push((name.clone(), off, len, node.shape.dtype()));
762 }
763 }
764
765 let output_slots: Vec<(usize, usize)> = graph
766 .outputs
767 .iter()
768 .map(|&id| {
769 let off = arena.byte_offset(id);
770 let len = graph.node(id).shape.num_elements().unwrap_or(0);
771 (off, len)
772 })
773 .collect();
774
775 for node in graph.nodes() {
776 if let Op::Constant { data } = &node.op
777 && arena.has_buffer(node.id)
778 && !data.is_empty()
779 {
780 match node.shape.dtype() {
781 DType::F64 => {
782 let off = arena.byte_offset(node.id);
783 let buf = arena.raw_buf_mut();
784 let n = buf.len().saturating_sub(off).min(data.len());
785 buf[off..off + n].copy_from_slice(&data[..n]);
786 }
787 _ => {
788 let buf = arena.slice_mut(node.id);
789 let n_floats = data.len() / 4;
790 let n = buf.len().min(n_floats);
791 for i in 0..n {
792 let bytes = [
793 data[i * 4],
794 data[i * 4 + 1],
795 data[i * 4 + 2],
796 data[i * 4 + 3],
797 ];
798 buf[i] = f32::from_le_bytes(bytes);
799 }
800 }
801 }
802 }
803 }
804
805 CpuExecutable {
806 graph,
807 arena,
808 params: HashMap::new(),
809 input_ids,
810 param_ids,
811 node_dtypes,
812 schedule,
813 input_slots,
814 output_slots,
815 handles: HashMap::new(),
816 active_extent: None,
817 moe_resident: None,
818 moe_resident_layers: None,
819 moe_topk_capture: None,
820 }
821 }
822
823 #[derive(Clone)]
824 struct CpuExecutable {
825 graph: Graph,
826 arena: Arena,
827 params: HashMap<String, Vec<f32>>,
828 input_ids: HashMap<String, NodeId>,
829 param_ids: HashMap<String, NodeId>,
830 node_dtypes: HashMap<NodeId, DType>,
833 schedule: thunk::ThunkSchedule,
834 input_slots: Vec<(String, usize, usize, DType)>,
836 output_slots: Vec<(usize, usize)>,
838 handles: HashMap<String, Vec<f32>>,
843 active_extent: Option<(usize, usize)>,
849 moe_resident: Option<std::sync::Arc<[bool]>>,
850 moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
851 moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
852 }
853
854 unsafe impl Send for CpuExecutable {}
855
856 impl CpuExecutable {
857 fn write_input(&mut self, id: NodeId, data: &[f32]) {
859 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
860 let off = self.arena.byte_offset(id);
861 let buf = self.arena.raw_buf_mut();
862 let elem_size = dtype.size_bytes();
863 let max_elems = (buf.len() - off) / elem_size;
864 unsafe {
865 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
866 }
867 }
868
869 fn read_output(&self, id: NodeId) -> Vec<f32> {
871 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
872 let off = self.arena.byte_offset(id);
873 let buf = self.arena.raw_buf();
874 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
875 unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
876 }
877 }
878
879 impl ExecutableGraph for CpuExecutable {
880 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
881 Box::new(self.clone())
882 }
883 fn set_param(&mut self, name: &str, data: &[f32]) {
884 if let Some(&id) = self.param_ids.get(name)
887 && self.arena.has_buffer(id)
888 {
889 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
890 let off = self.arena.byte_offset(id);
891 let buf = self.arena.raw_buf_mut();
892 let elem_size = dtype.size_bytes();
893 let max_elems = (buf.len() - off) / elem_size;
894 unsafe {
895 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
896 }
897 return;
898 }
899 self.params.insert(name.to_string(), data.to_vec());
901 }
902
903 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
904 let handle_names: Vec<String> = self.handles.keys().cloned().collect();
907 for name in &handle_names {
908 if let Some(&id) = self.input_ids.get(name)
909 && self.arena.has_buffer(id)
910 {
911 let data = self.handles.get(name).cloned().unwrap_or_default();
912 self.write_input(id, &data);
913 }
914 }
915 for &(name, data) in inputs {
917 if let Some(&id) = self.input_ids.get(name)
918 && self.arena.has_buffer(id)
919 {
920 self.write_input(id, data);
921 }
922 }
923
924 let active_used = if let Some((actual, upper)) = self.active_extent {
929 thunk::execute_thunks_active(
930 &self.schedule,
931 self.arena.raw_buf_mut(),
932 actual,
933 upper,
934 )
935 } else {
936 false
937 };
938 if !active_used {
939 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
941 }
942
943 for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
947 let name = format!("out{idx}");
948 if self.handles.contains_key(&name) {
949 let v = self.read_output(out_id);
950 self.handles.insert(name, v);
951 }
952 }
953
954 self.graph
955 .outputs
956 .iter()
957 .map(|&out_id| self.read_output(out_id))
958 .collect()
959 }
960
961 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
962 for &(name, data) in inputs {
964 if let Some(&id) = self.input_ids.get(name)
965 && self.arena.has_buffer(id)
966 {
967 self.write_input(id, data);
968 }
969 }
970 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
971 self.graph
975 .outputs
976 .iter()
977 .map(|&out_id| {
978 let (ptr, len) = self.arena.raw_ptr(out_id);
979 (ptr as *const f32, len)
980 })
981 .collect()
982 }
983
984 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
988 let buf = self.arena.raw_buf_mut();
989 for (i, &data) in inputs.iter().enumerate() {
990 if i < self.input_slots.len() {
991 let (_, off, max_len, dtype) = &self.input_slots[i];
992 unsafe {
993 write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
994 }
995 }
996 }
997 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
998 &self.output_slots
999 }
1000
1001 fn arena_ptr(&self) -> *const u8 {
1002 self.arena.raw_buf_mut_ptr()
1003 }
1004
1005 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1006 self.handles.insert(name.to_string(), data.to_vec());
1011 true
1012 }
1013
1014 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1015 self.handles.get(name).cloned()
1016 }
1017
1018 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1019 self.active_extent = extent;
1020 }
1021
1022 fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1023 self.moe_resident_layers = None;
1024 self.schedule.moe_resident_layers = None;
1025 self.moe_resident = Some(Arc::from(mask));
1026 self.schedule.moe_resident = self.moe_resident.clone();
1027 }
1028
1029 fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1030 self.moe_resident = None;
1031 self.schedule.moe_resident = None;
1032 let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1033 let arc = Arc::new(layers);
1034 self.moe_resident_layers = Some(arc.clone());
1035 self.schedule.moe_resident_layers = Some(arc);
1036 }
1037
1038 fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1039 let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1040 self.moe_topk_capture = Some(cap.clone());
1041 self.schedule.moe_topk_capture = Some(cap);
1042 true
1043 }
1044
1045 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1046 let cap = self.moe_topk_capture.as_ref()?;
1047 let layers = cap.take_layers();
1048 if layers.is_empty() {
1049 None
1050 } else {
1051 Some(layers)
1052 }
1053 }
1054
1055 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1056 rlx_cpu::moe_residency::take_last_forward_stats()
1057 }
1058
1059 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1065 if dtype == DType::F64 {
1066 self.set_param_bytes(name, data, dtype);
1067 return;
1068 }
1069 if matches!(dtype, DType::U8 | DType::I8) {
1073 self.set_param_bytes(name, data, dtype);
1074 return;
1075 }
1076 if dtype == DType::F32 {
1077 let n = data.len() / 4;
1078 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1079 self.set_param(name, s);
1080 } else {
1081 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1082 self.set_param(name, &f32_buf);
1083 }
1084 }
1085
1086 fn run_typed(
1098 &mut self,
1099 inputs: &[(&str, &[u8], rlx_ir::DType)],
1100 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1101 let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1106
1107 if all_f64 {
1108 for (name, data, _) in inputs {
1109 if let Some(&id) = self.input_ids.get(*name) {
1110 if !self.arena.has_buffer(id) {
1111 continue;
1112 }
1113 let off = self.arena.byte_offset(id);
1114 let buf = self.arena.raw_buf_mut();
1115 let n = data.len();
1116 debug_assert!(
1117 off + n <= buf.len(),
1118 "run_typed: input '{name}' overflows arena slot"
1119 );
1120 buf[off..off + n].copy_from_slice(data);
1121 }
1122 }
1123 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1124 } else {
1125 let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1130 for (name, data, dt) in inputs {
1131 let direct = matches!(*dt, DType::F64 | DType::I32 | DType::I64 | DType::U32,);
1132 if direct {
1133 if let Some(&id) = self.input_ids.get(*name) {
1134 if !self.arena.has_buffer(id) {
1135 continue;
1136 }
1137 let off = self.arena.byte_offset(id);
1138 let buf = self.arena.raw_buf_mut();
1139 buf[off..off + data.len()].copy_from_slice(data);
1140 }
1141 } else {
1142 let v = super::widen_bytes_to_f32(data, *dt);
1143 f32_owned.push((name.to_string(), v));
1144 }
1145 }
1146 let refs: Vec<(&str, &[f32])> = f32_owned
1147 .iter()
1148 .map(|(n, d)| (n.as_str(), d.as_slice()))
1149 .collect();
1150 let _ = self.run(&refs);
1151 }
1152
1153 self.graph
1155 .outputs
1156 .iter()
1157 .map(|&id| {
1158 let dtype = self.graph.node(id).shape.dtype();
1159 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1160 let n_bytes = n_elems * dtype.size_bytes();
1161 let off = self.arena.byte_offset(id);
1162 let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1163 (bytes, dtype)
1164 })
1165 .collect()
1166 }
1167 }
1168
1169 impl CpuExecutable {
1170 fn set_param_bytes(&mut self, name: &str, data: &[u8], _dtype: rlx_ir::DType) {
1176 if let Some(&id) = self.param_ids.get(name)
1177 && self.arena.has_buffer(id)
1178 {
1179 let off = self.arena.byte_offset(id);
1180 let buf = self.arena.raw_buf_mut();
1181 debug_assert!(
1182 off + data.len() <= buf.len(),
1183 "set_param_bytes: '{name}' would overflow arena slot"
1184 );
1185 buf[off..off + data.len()].copy_from_slice(data);
1186 }
1187 }
1188 }
1189}
1190
1191#[cfg(feature = "gpu")]
1196pub mod wgpu_backend {
1197 use super::*;
1198 use rlx_ir::OpKind;
1199 use rlx_wgpu::backend::WgpuExecutable;
1200
1201 pub struct WgpuBackend;
1202
1203 const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1209 OpKind::Input,
1210 OpKind::Param,
1211 OpKind::Constant,
1212 OpKind::Activation,
1213 OpKind::Cast,
1214 OpKind::Binary,
1215 OpKind::Compare,
1216 OpKind::Where,
1217 OpKind::ElementwiseRegion,
1218 OpKind::MatMul,
1219 OpKind::DotGeneral,
1220 OpKind::LayerNorm,
1221 OpKind::RmsNorm,
1222 OpKind::Attention,
1223 OpKind::AttentionBackward,
1224 OpKind::RmsNormBackwardInput,
1225 OpKind::RmsNormBackwardGamma,
1226 OpKind::RmsNormBackwardBeta,
1227 OpKind::RopeBackward,
1228 OpKind::CumsumBackward,
1229 OpKind::GatherBackward,
1230 OpKind::Rope,
1231 OpKind::Reshape,
1232 OpKind::Transpose,
1233 OpKind::Narrow,
1234 OpKind::Concat,
1235 OpKind::Expand,
1236 OpKind::Gather,
1237 OpKind::Reduce,
1238 OpKind::Softmax,
1239 OpKind::Cumsum,
1240 OpKind::TopK,
1241 OpKind::Sample,
1242 OpKind::Conv,
1243 OpKind::Pool,
1244 OpKind::GroupedMatMul,
1245 OpKind::DequantGroupedMatMul,
1246 OpKind::DequantMoEWeights,
1247 OpKind::ScatterAdd,
1248 OpKind::SelectiveScan,
1249 OpKind::DequantMatMul,
1250 OpKind::FusedMatMulBiasAct,
1251 OpKind::FusedResidualLN,
1252 OpKind::FusedResidualRmsNorm,
1253 OpKind::FusedSwiGLU,
1254 OpKind::FusedAttentionBlock,
1255 OpKind::FusedTransformerLayer,
1256 OpKind::Fft,
1262 OpKind::GaussianSplatRender,
1264 OpKind::GaussianSplatRenderBackward,
1265 OpKind::GaussianSplatPrepare,
1266 OpKind::GaussianSplatRasterize,
1267 OpKind::Custom,
1268 ];
1270
1271 impl Backend for WgpuBackend {
1272 fn supported_ops(&self) -> &'static [OpKind] {
1273 WGPU_SUPPORTED_OPS
1274 }
1275
1276 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1277 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, WGPU_SUPPORTED_OPS) {
1281 panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1282 }
1283 use rlx_opt::pass::Pass as _;
1284 let graph = if options.dce {
1286 rlx_opt::DeadCodeElimination.run(graph)
1287 } else {
1288 graph
1289 };
1290 let graph = if options.constant_folding {
1291 rlx_opt::ConstantFolding.run(graph)
1292 } else {
1293 graph
1294 };
1295 let compile_result = crate::stages::compile_graph_stages_for_backend(
1304 rlx_driver::Device::Gpu,
1305 graph,
1306 options,
1307 WGPU_SUPPORTED_OPS,
1308 );
1309 crate::stages::maybe_log_fusion(&compile_result.fusion);
1310 let graph = compile_result.lir.into_graph();
1311 let graph = match options.policy.clone() {
1312 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1313 None => graph,
1314 };
1315 Box::new(WgpuExecutableWrapper {
1316 inner: WgpuExecutable::compile(graph),
1317 })
1318 }
1319
1320 fn compile_lir(
1321 &self,
1322 lir: LirModule,
1323 options: &CompileOptions,
1324 ) -> Box<dyn ExecutableGraph> {
1325 let graph = prepare_fused_graph(lir.into_graph(), options, WGPU_SUPPORTED_OPS, "wgpu");
1326 Box::new(WgpuExecutableWrapper {
1327 inner: WgpuExecutable::compile(graph),
1328 })
1329 }
1330 }
1331
1332 struct WgpuExecutableWrapper {
1333 inner: WgpuExecutable,
1334 }
1335
1336 unsafe impl Send for WgpuExecutableWrapper {}
1337
1338 impl ExecutableGraph for WgpuExecutableWrapper {
1339 fn set_param(&mut self, name: &str, data: &[f32]) {
1340 self.inner.set_param(name, data);
1341 }
1342 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1343 self.inner.run(inputs)
1344 }
1345 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1346 self.inner.set_active_extent(extent);
1347 }
1348
1349 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1352 match dtype {
1353 rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1354 self.inner.set_param_bytes(name, data);
1355 }
1356 rlx_ir::DType::F32 => {
1357 let n = data.len() / 4;
1358 let f32_slice =
1359 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1360 self.inner.set_param(name, f32_slice);
1361 }
1362 rlx_ir::DType::F16 => {
1363 let n = data.len() / 2;
1364 let f16_slice =
1365 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1366 let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1367 self.inner.set_param(name, &f32);
1368 }
1369 rlx_ir::DType::BF16 => {
1370 let n = data.len() / 2;
1371 let bf16_slice = unsafe {
1372 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1373 };
1374 let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1375 self.inner.set_param(name, &f32);
1376 }
1377 other => panic!(
1378 "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1379 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1380 ),
1381 }
1382 }
1383
1384 fn run_typed(
1387 &mut self,
1388 inputs: &[(&str, &[u8], rlx_ir::DType)],
1389 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1390 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1391 for (name, data, dt) in inputs {
1392 let v: Vec<f32> = match *dt {
1393 rlx_ir::DType::F32 => {
1394 let n = data.len() / 4;
1395 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1396 .to_vec()
1397 }
1398 rlx_ir::DType::F16 => {
1399 let n = data.len() / 2;
1400 let s = unsafe {
1401 std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1402 };
1403 s.iter().map(|h| h.to_f32()).collect()
1404 }
1405 rlx_ir::DType::BF16 => {
1406 let n = data.len() / 2;
1407 let s = unsafe {
1408 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1409 };
1410 s.iter().map(|h| h.to_f32()).collect()
1411 }
1412 other => {
1413 panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1414 }
1415 };
1416 owned.push((name.to_string(), v));
1417 }
1418 let refs: Vec<(&str, &[f32])> = owned
1419 .iter()
1420 .map(|(n, d)| (n.as_str(), d.as_slice()))
1421 .collect();
1422 let dtypes = self.inner.output_dtypes();
1423 let outs = self.inner.run(&refs);
1424 outs.into_iter()
1425 .zip(
1426 dtypes
1427 .into_iter()
1428 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1429 )
1430 .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1431 .collect()
1432 }
1433 }
1434
1435 fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1441 use rlx_ir::DType;
1442 match dt {
1443 DType::F32 => {
1444 let mut bytes = Vec::with_capacity(v.len() * 4);
1445 for &x in v {
1446 bytes.extend_from_slice(&x.to_le_bytes());
1447 }
1448 bytes
1449 }
1450 DType::F16 => {
1451 let mut bytes = Vec::with_capacity(v.len() * 2);
1452 for &x in v {
1453 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1454 }
1455 bytes
1456 }
1457 DType::BF16 => {
1458 let mut bytes = Vec::with_capacity(v.len() * 2);
1459 for &x in v {
1460 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1461 }
1462 bytes
1463 }
1464 DType::F64 => {
1465 let mut bytes = Vec::with_capacity(v.len() * 8);
1466 for &x in v {
1467 bytes.extend_from_slice(&(x as f64).to_le_bytes());
1468 }
1469 bytes
1470 }
1471 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1472 DType::U8 => v.iter().map(|&x| x as u8).collect(),
1473 DType::I16 => {
1474 let mut bytes = Vec::with_capacity(v.len() * 2);
1475 for &x in v {
1476 bytes.extend_from_slice(&(x as i16).to_le_bytes());
1477 }
1478 bytes
1479 }
1480 DType::I32 => {
1481 let mut bytes = Vec::with_capacity(v.len() * 4);
1482 for &x in v {
1483 bytes.extend_from_slice(&(x as i32).to_le_bytes());
1484 }
1485 bytes
1486 }
1487 DType::U32 => {
1488 let mut bytes = Vec::with_capacity(v.len() * 4);
1489 for &x in v {
1490 bytes.extend_from_slice(&(x as u32).to_le_bytes());
1491 }
1492 bytes
1493 }
1494 DType::I64 => {
1495 let mut bytes = Vec::with_capacity(v.len() * 8);
1496 for &x in v {
1497 bytes.extend_from_slice(&(x as i64).to_le_bytes());
1498 }
1499 bytes
1500 }
1501 DType::Bool => v
1502 .iter()
1503 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1504 .collect(),
1505 DType::C64 => {
1512 let mut bytes = Vec::with_capacity(v.len() * 4);
1513 for &x in v {
1514 bytes.extend_from_slice(&x.to_le_bytes());
1515 }
1516 bytes
1517 }
1518 }
1519 }
1520}
1521
1522#[cfg(all(feature = "mlx", target_os = "macos"))]
1525pub mod mlx_backend {
1526 use super::*;
1527 use rlx_mlx::MlxExecutable;
1528
1529 pub struct MlxBackend;
1530
1531 const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1537 use rlx_ir::OpKind::*;
1538 &[
1539 Input,
1540 Param,
1541 Constant,
1542 Activation,
1543 Cast,
1544 Binary,
1545 Compare,
1546 Where,
1547 ElementwiseRegion,
1548 MatMul,
1549 DotGeneral,
1550 DenseSolve,
1551 BatchedDenseSolve,
1552 LayerNorm,
1553 LayerNorm2d,
1554 RmsNorm,
1555 Attention,
1556 Rope,
1557 Reshape,
1558 Transpose,
1559 Narrow,
1560 Concat,
1561 Expand,
1562 Gather,
1563 Reduce,
1564 Softmax,
1565 Cumsum,
1566 TopK,
1567 Sample,
1568 Conv,
1569 ConvTranspose2d,
1570 Pool,
1571 GroupedMatMul,
1572 DequantGroupedMatMul,
1573 DequantMoEWeights,
1574 ScatterAdd,
1575 LoraMatMul,
1576 DequantMatMul,
1577 SelectiveScan,
1578 GatedDeltaNet,
1579 FusedSwiGLU,
1580 FusedMatMulBiasAct,
1581 FusedResidualLN,
1582 FusedResidualRmsNorm,
1583 FusedAttentionBlock,
1584 FusedTransformerLayer,
1585 If,
1586 While,
1587 Scan,
1592 ScanBackward,
1593 ScanBackwardXs,
1594 ReluBackward,
1597 ActivationBackward,
1598 SoftmaxCrossEntropyWithLogits,
1599 SoftmaxCrossEntropyBackward,
1600 AttentionBackward,
1601 LayerNormBackwardInput,
1602 LayerNormBackwardGamma,
1603 Conv2dBackwardInput,
1608 Conv2dBackwardWeight,
1609 MaxPool2dBackward,
1613 FakeQuantize,
1618 FakeQuantizeBackward,
1619 Custom,
1624 GaussianSplatRender,
1625 GaussianSplatRenderBackward,
1626 ]
1635 };
1636
1637 impl Backend for MlxBackend {
1638 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1639 MLX_SUPPORTED_OPS
1640 }
1641
1642 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1643 let compile_result = crate::stages::compile_graph_stages_for_backend(
1644 rlx_driver::Device::Mlx,
1645 graph,
1646 options,
1647 MLX_SUPPORTED_OPS,
1648 );
1649 crate::stages::maybe_log_fusion(&compile_result.fusion);
1650 self.compile_lir(compile_result.lir, options)
1651 }
1652
1653 fn compile_lir(
1654 &self,
1655 lir: LirModule,
1656 options: &CompileOptions,
1657 ) -> Box<dyn ExecutableGraph> {
1658 use rlx_opt::pass::Pass as _;
1659 let mut graph = lir.into_graph();
1660 graph = rlx_opt::LowerControlFlow.run(graph);
1661 let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1662 Box::new(build_mlx_executable(graph))
1663 }
1664 }
1665
1666 fn build_mlx_executable(graph: Graph) -> MlxExecutableWrapper {
1667 let mode = mlx_mode_from_env();
1668 let mut exe = MlxExecutable::compile_from_fused(graph, mode);
1669 if mode == rlx_mlx::lower::MlxMode::Compiled {
1670 if let Err(e) = exe.warm_compile() {
1671 eprintln!(
1672 "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1673 );
1674 }
1675 }
1676 MlxExecutableWrapper { inner: exe }
1677 }
1678
1679 fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1680 match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1681 Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1682 Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1683 Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1684 _ => rlx_mlx::lower::MlxMode::Compiled,
1685 }
1686 }
1687
1688 struct MlxExecutableWrapper {
1689 inner: MlxExecutable,
1690 }
1691
1692 unsafe impl Send for MlxExecutableWrapper {}
1693
1694 impl ExecutableGraph for MlxExecutableWrapper {
1695 fn set_param(&mut self, name: &str, data: &[f32]) {
1696 self.inner.set_param(name, data);
1697 }
1698 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1699 self.inner.run(inputs)
1700 }
1701 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1702 self.inner.run_slots(inputs)
1703 }
1704 fn arena_ptr(&self) -> *const u8 {
1705 self.inner.arena_ptr()
1706 }
1707 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1708 self.inner.commit_no_wait(inputs);
1709 }
1710 fn sync_pending(&mut self) {
1711 self.inner.sync_pending();
1712 }
1713 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1714 self.inner.run_pipelined(input_sets)
1715 }
1716 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1717 self.inner.bind_handle(name, data)
1718 }
1719 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1720 self.inner.read_handle(name)
1721 }
1722 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1723 self.inner.bind_gpu_handle(name, data).is_ok()
1724 }
1725 fn has_gpu_handle(&self, name: &str) -> bool {
1726 self.inner.has_gpu_handle(name)
1727 }
1728 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1729 self.inner.set_gpu_handle_feed(handle_name, output_index);
1730 true
1731 }
1732 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1733 self.inner.read_gpu_handle(name).ok()
1734 }
1735 fn run_feed_gpu_handle(
1736 &mut self,
1737 inputs: &[(&str, &[f32])],
1738 handle_name: &str,
1739 output_index: usize,
1740 ) -> Option<Vec<f32>> {
1741 self.inner
1742 .run_feed_gpu(inputs, handle_name, output_index)
1743 .ok()
1744 }
1745 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1746 self.inner.set_param_typed(name, data, dtype);
1747 }
1748 fn run_typed(
1749 &mut self,
1750 inputs: &[(&str, &[u8], rlx_ir::DType)],
1751 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1752 self.inner.run_typed(inputs)
1753 }
1754 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1755 self.inner.set_active_extent(extent);
1756 }
1757 }
1758}
1759
1760#[cfg(all(feature = "metal", target_os = "macos"))]
1761pub mod metal_backend {
1762 use super::*;
1763 use rlx_metal::backend::MetalExecutable;
1764
1765 pub struct MetalBackend;
1766
1767 const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1776 use rlx_ir::OpKind::*;
1777 &[
1778 Input,
1779 Param,
1780 Constant,
1781 Activation,
1782 Cast,
1783 Binary,
1784 Compare,
1785 Where,
1786 ElementwiseRegion,
1787 MatMul,
1788 DotGeneral,
1789 LayerNorm,
1790 LayerNorm2d,
1791 GroupNorm,
1792 RmsNorm,
1793 ResizeNearest2x,
1794 AxialRope2d,
1795 Attention,
1796 AttentionBackward,
1797 RmsNormBackwardInput,
1798 RmsNormBackwardGamma,
1799 RmsNormBackwardBeta,
1800 RopeBackward,
1801 CumsumBackward,
1802 GatherBackward,
1803 Rope,
1804 Reshape,
1805 Transpose,
1806 Narrow,
1807 Concat,
1808 Expand,
1809 Gather,
1810 Reduce,
1811 Softmax,
1812 TopK,
1813 Conv,
1814 ConvTranspose2d,
1815 Pool,
1816 GroupedMatMul,
1817 DequantGroupedMatMul,
1818 DequantMoEWeights,
1819 ScatterAdd,
1820 DequantMatMul,
1821 GatedDeltaNet,
1822 FusedSwiGLU,
1823 FusedMatMulBiasAct,
1824 FusedResidualLN,
1825 FusedResidualRmsNorm,
1826 Custom,
1832 Fft,
1838 GaussianSplatRender,
1840 GaussianSplatRenderBackward,
1841 GaussianSplatPrepare,
1842 GaussianSplatRasterize,
1843 ]
1844 };
1845
1846 impl Backend for MetalBackend {
1847 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1848 METAL_SUPPORTED_OPS
1849 }
1850
1851 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1852 use rlx_opt::pass::Pass as _;
1853 let graph = rlx_opt::LowerControlFlow.run(graph);
1857 let mut dispatch = options.kernel_dispatch;
1858 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1859 graph,
1860 METAL_SUPPORTED_OPS,
1861 dispatch,
1862 )
1863 .unwrap_or_else(|errors| {
1864 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1865 });
1866 let graph = if options.dce {
1868 rlx_opt::DeadCodeElimination.run(graph)
1869 } else {
1870 graph
1871 };
1872 let graph = if options.constant_folding {
1873 rlx_opt::ConstantFolding.run(graph)
1874 } else {
1875 graph
1876 };
1877
1878 Box::new(MetalExecutableWrapper {
1881 inner: MetalExecutable::compile_with_policy(
1882 graph,
1883 options.policy.clone(),
1884 Some(METAL_SUPPORTED_OPS),
1885 ),
1886 })
1887 }
1888
1889 fn compile_lir(
1890 &self,
1891 lir: LirModule,
1892 options: &CompileOptions,
1893 ) -> Box<dyn ExecutableGraph> {
1894 use rlx_opt::pass::Pass as _;
1895 let mut graph = lir.into_graph();
1896 graph = rlx_opt::LowerControlFlow.run(graph);
1897 let mut dispatch = options.kernel_dispatch;
1898 let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
1899 graph,
1900 METAL_SUPPORTED_OPS,
1901 dispatch,
1902 )
1903 .unwrap_or_else(|errors| {
1904 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
1905 });
1906 if options.dce {
1907 graph = rlx_opt::DeadCodeElimination.run(graph);
1908 }
1909 if options.constant_folding {
1910 graph = rlx_opt::ConstantFolding.run(graph);
1911 }
1912 Box::new(MetalExecutableWrapper {
1913 inner: MetalExecutable::compile_from_fused(
1914 graph,
1915 options.policy.clone(),
1916 Some(METAL_SUPPORTED_OPS),
1917 ),
1918 })
1919 }
1920 }
1921
1922 struct MetalExecutableWrapper {
1923 inner: MetalExecutable,
1924 }
1925
1926 unsafe impl Send for MetalExecutableWrapper {}
1927
1928 impl ExecutableGraph for MetalExecutableWrapper {
1929 fn set_param(&mut self, name: &str, data: &[f32]) {
1930 self.inner.set_param(name, data);
1931 }
1932 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1933 self.inner.run(inputs)
1934 }
1935 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1936 self.inner.run_slots(inputs)
1937 }
1938 fn arena_ptr(&self) -> *const u8 {
1939 self.inner.arena_ptr()
1940 }
1941 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
1942 self.inner.commit_no_wait(inputs);
1943 }
1944 fn sync_pending(&mut self) {
1945 self.inner.sync_pending();
1946 }
1947 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
1948 self.inner.run_pipelined(input_sets)
1949 }
1950 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1951 self.inner.set_active_extent(extent);
1952 }
1953
1954 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1960 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
1961 self.inner.set_param_bytes(name, data);
1962 return;
1963 }
1964 if dtype == rlx_ir::DType::F32 {
1965 let n = data.len() / 4;
1966 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1967 self.inner.set_param(name, s);
1968 } else {
1969 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1970 self.inner.set_param(name, &f32_buf);
1971 }
1972 }
1973
1974 fn run_typed(
1982 &mut self,
1983 inputs: &[(&str, &[u8], rlx_ir::DType)],
1984 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1985 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1986 for (name, data, dt) in inputs {
1987 let v = super::widen_bytes_to_f32(data, *dt);
1988 owned.push((name.to_string(), v));
1989 }
1990 let refs: Vec<(&str, &[f32])> = owned
1991 .iter()
1992 .map(|(n, d)| (n.as_str(), d.as_slice()))
1993 .collect();
1994 let dtypes = self.inner.output_dtypes();
1995 let f32_outs = self.inner.run(&refs);
1996 let byte_outs = self.inner.output_bytes_per_node();
1997 f32_outs
1998 .into_iter()
1999 .zip(byte_outs.into_iter())
2000 .zip(
2001 dtypes
2002 .into_iter()
2003 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2004 )
2005 .map(|((f32_v, byte_v), dt)| match dt {
2006 rlx_ir::DType::F64 => (byte_v, dt),
2007 _ => (super::narrow_f32_to_bytes(&f32_v, dt), dt),
2008 })
2009 .collect()
2010 }
2011 }
2012}
2013
2014#[cfg(feature = "cuda")]
2017pub mod cuda_backend {
2018 use super::*;
2019 use rlx_cuda::backend::CudaExecutable;
2020
2021 pub struct CudaBackend;
2022
2023 const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2029 use rlx_ir::OpKind::*;
2030 &[
2031 Input,
2032 Param,
2033 Constant,
2034 Activation,
2035 Cast,
2036 Binary,
2037 Compare,
2038 Where,
2039 ElementwiseRegion,
2040 MatMul,
2041 DotGeneral,
2042 LayerNorm,
2043 LayerNorm2d,
2044 RmsNorm,
2045 Attention,
2046 AttentionBackward,
2047 RmsNormBackwardInput,
2048 RmsNormBackwardGamma,
2049 RmsNormBackwardBeta,
2050 RopeBackward,
2051 CumsumBackward,
2052 GatherBackward,
2053 Rope,
2054 Reshape,
2055 Transpose,
2056 Narrow,
2057 Concat,
2058 Expand,
2059 Gather,
2060 Reduce,
2061 Softmax,
2062 Cumsum,
2063 TopK,
2064 Sample,
2065 Conv,
2066 ConvTranspose2d,
2067 Pool,
2068 GroupedMatMul,
2069 DequantGroupedMatMul,
2070 DequantMoEWeights,
2071 ScatterAdd,
2072 DequantMatMul,
2073 SelectiveScan,
2074 FusedMatMulBiasAct,
2075 FusedResidualLN,
2076 FusedResidualRmsNorm,
2077 GaussianSplatRender,
2078 GaussianSplatRenderBackward,
2079 GaussianSplatPrepare,
2080 GaussianSplatRasterize,
2081 Custom,
2082 ]
2083 };
2084
2085 impl Backend for CudaBackend {
2086 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2087 CUDA_SUPPORTED_OPS
2088 }
2089
2090 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2091 let graph = rlx_cuda::unfuse::unfuse(graph);
2094 let graph = rlx_opt::rewrite_for_backend(graph, CUDA_SUPPORTED_OPS);
2095 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CUDA_SUPPORTED_OPS) {
2096 panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2097 }
2098 use rlx_opt::pass::Pass as _;
2099 let graph = if options.dce {
2100 rlx_opt::DeadCodeElimination.run(graph)
2101 } else {
2102 graph
2103 };
2104 let graph = if options.constant_folding {
2105 rlx_opt::ConstantFolding.run(graph)
2106 } else {
2107 graph
2108 };
2109 let compile_result = crate::stages::compile_graph_stages_for_backend(
2111 rlx_driver::Device::Cuda,
2112 graph,
2113 options,
2114 CUDA_SUPPORTED_OPS,
2115 );
2116 crate::stages::maybe_log_fusion(&compile_result.fusion);
2117 let graph = compile_result.lir.into_graph();
2118 let graph = match options.policy.clone() {
2119 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2120 None => graph,
2121 };
2122 Box::new(CudaExecutableWrapper {
2123 inner: CudaExecutable::compile(graph),
2124 })
2125 }
2126
2127 fn compile_lir(
2128 &self,
2129 lir: LirModule,
2130 options: &CompileOptions,
2131 ) -> Box<dyn ExecutableGraph> {
2132 let graph = prepare_fused_graph(
2133 rlx_cuda::unfuse::unfuse(lir.into_graph()),
2134 options,
2135 CUDA_SUPPORTED_OPS,
2136 "cuda",
2137 );
2138 Box::new(CudaExecutableWrapper {
2139 inner: CudaExecutable::compile(graph),
2140 })
2141 }
2142 }
2143
2144 struct CudaExecutableWrapper {
2145 inner: CudaExecutable,
2146 }
2147
2148 unsafe impl Send for CudaExecutableWrapper {}
2153
2154 impl ExecutableGraph for CudaExecutableWrapper {
2155 fn set_param(&mut self, name: &str, data: &[f32]) {
2156 self.inner.set_param(name, data);
2157 }
2158 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2159 self.inner.run(inputs)
2160 }
2161 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2162 self.inner.set_active_extent(extent);
2163 }
2164
2165 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2170 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2171 self.inner.set_param_bytes(name, data);
2172 return;
2173 }
2174 if dtype == rlx_ir::DType::F32 {
2175 let n = data.len() / 4;
2176 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2177 self.inner.set_param(name, s);
2178 } else {
2179 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2180 self.inner.set_param(name, &f32_buf);
2181 }
2182 }
2183
2184 fn run_typed(
2187 &mut self,
2188 inputs: &[(&str, &[u8], rlx_ir::DType)],
2189 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2190 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2191 for (name, data, dt) in inputs {
2192 let v = super::widen_bytes_to_f32(data, *dt);
2193 owned.push((name.to_string(), v));
2194 }
2195 let refs: Vec<(&str, &[f32])> = owned
2196 .iter()
2197 .map(|(n, d)| (n.as_str(), d.as_slice()))
2198 .collect();
2199 let dtypes = self.inner.output_dtypes();
2200 let outs = self.inner.run(&refs);
2201 outs.into_iter()
2202 .zip(
2203 dtypes
2204 .into_iter()
2205 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2206 )
2207 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2208 .collect()
2209 }
2210 }
2211}
2212
2213#[cfg(feature = "rocm")]
2216pub mod rocm_backend {
2217 use super::*;
2218 use rlx_rocm::backend::RocmExecutable;
2219
2220 pub struct RocmBackend;
2221
2222 const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2225 use rlx_ir::OpKind::*;
2226 &[
2227 Input,
2228 Param,
2229 Constant,
2230 Activation,
2231 Cast,
2232 Binary,
2233 Compare,
2234 Where,
2235 ElementwiseRegion,
2236 MatMul,
2237 DotGeneral,
2238 LayerNorm,
2239 RmsNorm,
2240 Attention,
2241 AttentionBackward,
2242 Rope,
2243 Reshape,
2244 Transpose,
2245 Narrow,
2246 Concat,
2247 Expand,
2248 Gather,
2249 Reduce,
2250 Softmax,
2251 Cumsum,
2252 TopK,
2253 Sample,
2254 Conv,
2255 Pool,
2256 GroupedMatMul,
2257 DequantGroupedMatMul,
2258 DequantMoEWeights,
2259 ScatterAdd,
2260 DequantMatMul,
2261 SelectiveScan,
2262 FusedMatMulBiasAct,
2263 FusedResidualLN,
2264 FusedResidualRmsNorm,
2265 GaussianSplatRender,
2266 GaussianSplatRenderBackward,
2267 GaussianSplatPrepare,
2268 GaussianSplatRasterize,
2269 Custom,
2270 ]
2271 };
2272
2273 impl Backend for RocmBackend {
2274 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2275 ROCM_SUPPORTED_OPS
2276 }
2277
2278 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2279 let graph = rlx_opt::rewrite_for_backend(graph, ROCM_SUPPORTED_OPS);
2280 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, ROCM_SUPPORTED_OPS) {
2281 panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2282 }
2283 use rlx_opt::pass::Pass as _;
2284 let graph = if options.dce {
2285 rlx_opt::DeadCodeElimination.run(graph)
2286 } else {
2287 graph
2288 };
2289 let graph = if options.constant_folding {
2290 rlx_opt::ConstantFolding.run(graph)
2291 } else {
2292 graph
2293 };
2294 let compile_result = crate::stages::compile_graph_stages_for_backend(
2295 rlx_driver::Device::Rocm,
2296 graph,
2297 options,
2298 ROCM_SUPPORTED_OPS,
2299 );
2300 crate::stages::maybe_log_fusion(&compile_result.fusion);
2301 let graph = compile_result.lir.into_graph();
2302 let graph = match options.policy.clone() {
2303 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2304 None => graph,
2305 };
2306 Box::new(RocmExecutableWrapper {
2307 inner: RocmExecutable::compile(graph),
2308 })
2309 }
2310
2311 fn compile_lir(
2312 &self,
2313 lir: LirModule,
2314 options: &CompileOptions,
2315 ) -> Box<dyn ExecutableGraph> {
2316 let graph = prepare_fused_graph(lir.into_graph(), options, ROCM_SUPPORTED_OPS, "rocm");
2317 Box::new(RocmExecutableWrapper {
2318 inner: RocmExecutable::compile(graph),
2319 })
2320 }
2321 }
2322
2323 struct RocmExecutableWrapper {
2324 inner: RocmExecutable,
2325 }
2326
2327 unsafe impl Send for RocmExecutableWrapper {}
2331
2332 impl ExecutableGraph for RocmExecutableWrapper {
2333 fn set_param(&mut self, name: &str, data: &[f32]) {
2334 self.inner.set_param(name, data);
2335 }
2336 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2337 self.inner.run(inputs)
2338 }
2339 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2340 self.inner.set_active_extent(extent);
2341 }
2342
2343 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2348 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2349 self.inner.set_param_bytes(name, data);
2350 return;
2351 }
2352 if dtype == rlx_ir::DType::F32 {
2353 let n = data.len() / 4;
2354 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2355 self.inner.set_param(name, s);
2356 } else {
2357 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2358 self.inner.set_param(name, &f32_buf);
2359 }
2360 }
2361
2362 fn run_typed(
2365 &mut self,
2366 inputs: &[(&str, &[u8], rlx_ir::DType)],
2367 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2368 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2369 for (name, data, dt) in inputs {
2370 let v = super::widen_bytes_to_f32(data, *dt);
2371 owned.push((name.to_string(), v));
2372 }
2373 let refs: Vec<(&str, &[f32])> = owned
2374 .iter()
2375 .map(|(n, d)| (n.as_str(), d.as_slice()))
2376 .collect();
2377 let dtypes = self.inner.output_dtypes();
2378 let outs = self.inner.run(&refs);
2379 outs.into_iter()
2380 .zip(
2381 dtypes
2382 .into_iter()
2383 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2384 )
2385 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2386 .collect()
2387 }
2388 }
2389}
2390
2391#[cfg(feature = "tpu")]
2394pub mod tpu_backend {
2395 use super::*;
2396 use rlx_tpu::TpuExecutable;
2397
2398 pub struct TpuBackend;
2399
2400 const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2406 use rlx_ir::OpKind::*;
2407 &[
2408 Input,
2409 Param,
2410 Constant,
2411 Activation,
2412 Cast,
2413 Binary,
2414 Compare,
2415 Where,
2416 ElementwiseRegion,
2417 MatMul,
2418 DotGeneral,
2419 LayerNorm,
2420 RmsNorm,
2421 Attention,
2422 Rope,
2423 Reshape,
2424 Transpose,
2425 Narrow,
2426 Concat,
2427 Expand,
2428 Gather,
2429 Reduce,
2430 Softmax,
2431 Cumsum,
2432 TopK,
2433 Sample,
2434 Conv,
2435 Pool,
2436 GroupedMatMul,
2437 DequantGroupedMatMul,
2438 DequantMoEWeights,
2439 ScatterAdd,
2440 DequantMatMul,
2441 SelectiveScan,
2442 QMatMul,
2444 QConv2d,
2445 Quantize,
2446 Dequantize,
2447 FusedMatMulBiasAct,
2448 FusedResidualLN,
2449 FusedResidualRmsNorm,
2450 ]
2452 };
2453
2454 impl Backend for TpuBackend {
2455 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2456 TPU_SUPPORTED_OPS
2457 }
2458
2459 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2460 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2461 graph,
2462 TPU_SUPPORTED_OPS,
2463 options.kernel_dispatch,
2464 )
2465 .unwrap_or_else(|errors| {
2466 panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
2467 });
2468 use rlx_opt::pass::Pass as _;
2484 let policy = options
2485 .policy
2486 .clone()
2487 .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
2488 let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
2489 let _ = options.dce;
2490 let _ = options.constant_folding;
2491 Box::new(TpuExecutableWrapper {
2492 inner: TpuExecutable::compile(graph),
2493 })
2494 }
2495 }
2496
2497 struct TpuExecutableWrapper {
2498 inner: TpuExecutable,
2499 }
2500
2501 unsafe impl Send for TpuExecutableWrapper {}
2505
2506 impl ExecutableGraph for TpuExecutableWrapper {
2507 fn set_param(&mut self, name: &str, data: &[f32]) {
2508 self.inner.set_param(name, data);
2509 }
2510 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2511 self.inner.run(inputs)
2512 }
2513
2514 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2519 if dtype == rlx_ir::DType::F32 {
2520 let n = data.len() / 4;
2521 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2522 self.inner.set_param(name, s);
2523 } else {
2524 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2525 self.inner.set_param(name, &f32_buf);
2526 }
2527 }
2528
2529 fn run_typed(
2530 &mut self,
2531 inputs: &[(&str, &[u8], rlx_ir::DType)],
2532 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2533 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2534 for (name, data, dt) in inputs {
2535 let v = super::widen_bytes_to_f32(data, *dt);
2536 owned.push((name.to_string(), v));
2537 }
2538 let refs: Vec<(&str, &[f32])> = owned
2539 .iter()
2540 .map(|(n, d)| (n.as_str(), d.as_slice()))
2541 .collect();
2542 let dtypes = self.inner.output_dtypes();
2543 let outs = self.inner.run(&refs);
2544 outs.into_iter()
2545 .zip(
2546 dtypes
2547 .into_iter()
2548 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2549 )
2550 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2551 .collect()
2552 }
2553 }
2554}