1use crate::CompileOptions;
23use rlx_ir::Graph;
24use rlx_ir::hir::HirModule;
25use rlx_ir::lir::LirModule;
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use crate::cpu_low_precision;
30
31#[allow(dead_code)]
38pub(crate) fn widen_bytes_to_f32(data: &[u8], dtype: rlx_ir::DType) -> Vec<f32> {
39 use rlx_ir::DType;
40 match dtype {
41 DType::F32 => {
42 let n = data.len() / 4;
43 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
44 s.to_vec()
45 }
46 DType::F16 => {
47 let n = data.len() / 2;
48 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
49 s.iter().map(|h| h.to_f32()).collect()
50 }
51 DType::BF16 => {
52 let n = data.len() / 2;
53 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n) };
54 s.iter().map(|h| h.to_f32()).collect()
55 }
56 other => panic!(
57 "widen_bytes_to_f32: dtype {other:?} unsupported on f32-arena backends \
58 (only F32/F16/BF16 are accepted on the host I/O surface)"
59 ),
60 }
61}
62
63#[allow(dead_code)]
68pub(crate) fn narrow_f32_to_bytes(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
69 use rlx_ir::DType;
70 match dt {
71 DType::F32 => {
72 let mut bytes = Vec::with_capacity(v.len() * 4);
73 for &x in v {
74 bytes.extend_from_slice(&x.to_le_bytes());
75 }
76 bytes
77 }
78 DType::F16 => {
79 let mut bytes = Vec::with_capacity(v.len() * 2);
80 for &x in v {
81 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
82 }
83 bytes
84 }
85 DType::BF16 => {
86 let mut bytes = Vec::with_capacity(v.len() * 2);
87 for &x in v {
88 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
89 }
90 bytes
91 }
92 DType::F64 => {
93 let mut bytes = Vec::with_capacity(v.len() * 8);
94 for &x in v {
95 bytes.extend_from_slice(&(x as f64).to_le_bytes());
96 }
97 bytes
98 }
99 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
100 DType::U8 => v.iter().map(|&x| x as u8).collect(),
101 DType::I16 => {
102 let mut bytes = Vec::with_capacity(v.len() * 2);
103 for &x in v {
104 bytes.extend_from_slice(&(x as i16).to_le_bytes());
105 }
106 bytes
107 }
108 DType::I32 => {
109 let mut bytes = Vec::with_capacity(v.len() * 4);
110 for &x in v {
111 bytes.extend_from_slice(&(x as i32).to_le_bytes());
112 }
113 bytes
114 }
115 DType::U32 => {
116 let mut bytes = Vec::with_capacity(v.len() * 4);
117 for &x in v {
118 bytes.extend_from_slice(&(x as u32).to_le_bytes());
119 }
120 bytes
121 }
122 DType::I64 => {
123 let mut bytes = Vec::with_capacity(v.len() * 8);
124 for &x in v {
125 bytes.extend_from_slice(&(x as i64).to_le_bytes());
126 }
127 bytes
128 }
129 DType::Bool => v
130 .iter()
131 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
132 .collect(),
133 DType::C64 => {
134 let mut bytes = Vec::with_capacity(v.len() * 8);
138 for &x in v {
139 bytes.extend_from_slice(&x.to_le_bytes());
140 bytes.extend_from_slice(&0.0_f32.to_le_bytes());
141 }
142 bytes
143 }
144 }
145}
146
147pub trait ExecutableGraph: Send {
149 fn set_param(&mut self, name: &str, data: &[f32]);
151
152 fn finalize_params(&mut self) {}
155
156 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
163 panic!("clone_box not implemented for this backend");
164 }
165
166 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>>;
168
169 fn run_read_outputs(
172 &mut self,
173 inputs: &[(&str, &[f32])],
174 read_indices: Option<&[usize]>,
175 ) -> Vec<Vec<f32>> {
176 match read_indices {
177 None => self.run(inputs),
178 Some(ix) => {
179 let all = self.run(inputs);
182 ix.iter().filter_map(|&i| all.get(i).cloned()).collect()
183 }
184 }
185 }
186
187 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
189 let vecs = self.run(inputs);
190 vecs.iter().map(|v| (v.as_ptr(), v.len())).collect()
191 }
192
193 fn run_slots(&mut self, _inputs: &[&[f32]]) -> &[(usize, usize)] {
196 &[] }
198
199 fn arena_ptr(&self) -> *const u8 {
201 std::ptr::null()
202 }
203
204 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
221 let _ = extent;
222 }
223
224 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
226 let _ = rng;
227 }
228
229 fn rng(&self) -> rlx_ir::RngOptions {
231 rlx_ir::RngOptions::default()
232 }
233
234 fn set_moe_resident_experts(&mut self, _mask: &[bool]) {}
236
237 fn set_moe_resident_experts_per_layer(&mut self, _masks: &[&[bool]]) {}
239
240 fn enable_moe_topk_capture(&mut self, _num_experts: usize) -> bool {
242 false
243 }
244
245 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
247 None
248 }
249
250 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
252 None
253 }
254
255 fn bind_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
259 false
260 }
261
262 fn read_handle(&self, _name: &str) -> Option<Vec<f32>> {
264 None
265 }
266
267 fn bind_gpu_handle(&mut self, _name: &str, _data: &[f32]) -> bool {
269 false
270 }
271
272 fn has_gpu_handle(&self, _name: &str) -> bool {
273 false
274 }
275
276 fn set_gpu_handle_feed(&mut self, _handle_name: &str, _output_index: usize) -> bool {
277 false
278 }
279
280 fn read_gpu_handle(&self, _name: &str) -> Option<Vec<f32>> {
281 None
282 }
283
284 fn read_output_row(&self, _out_idx: usize, _row: usize, _row_inner: usize) -> Option<Vec<f32>> {
287 None
288 }
289
290 fn run_feed_gpu_handle(
292 &mut self,
293 inputs: &[(&str, &[f32])],
294 _handle_name: &str,
295 _output_index: usize,
296 ) -> Option<Vec<f32>> {
297 let _ = inputs;
298 None
299 }
300
301 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
316 let _ = self.run(inputs);
317 }
318
319 fn sync_pending(&mut self) {}
322
323 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
332 input_sets.iter().map(|inputs| self.run(inputs)).collect()
333 }
334
335 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
348 if dtype != rlx_ir::DType::F32 {
349 panic!(
350 "backend's default set_param_typed only handles F32; \
351 got {dtype:?}. Override on the backend for typed support."
352 );
353 }
354 if !data.len().is_multiple_of(4) {
355 panic!(
356 "set_param_typed F32: data length {} not a multiple of 4",
357 data.len()
358 );
359 }
360 let n = data.len() / 4;
365 let f32_slice = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
366 self.set_param(name, f32_slice);
367 }
368
369 fn run_typed(
373 &mut self,
374 inputs: &[(&str, &[u8], rlx_ir::DType)],
375 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
376 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
379 for (name, data, dt) in inputs {
380 if *dt != rlx_ir::DType::F32 {
381 panic!(
382 "backend's default run_typed only handles F32 inputs; \
383 got {dt:?} for input '{name}'"
384 );
385 }
386 if data.len() % 4 != 0 {
387 panic!(
388 "run_typed F32 input '{name}': len {} not multiple of 4",
389 data.len()
390 );
391 }
392 let n = data.len() / 4;
393 let v: Vec<f32> =
394 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }.to_vec();
395 owned.push((name.to_string(), v));
396 }
397 let refs: Vec<(&str, &[f32])> = owned
398 .iter()
399 .map(|(n, d)| (n.as_str(), d.as_slice()))
400 .collect();
401 let outs = self.run(&refs);
402 outs.into_iter()
403 .map(|v| {
404 let bytes =
405 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, v.len() * 4) }
406 .to_vec();
407 (bytes, rlx_ir::DType::F32)
408 })
409 .collect()
410 }
411}
412
413pub trait Backend: Send + Sync {
423 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph>;
425
426 fn compile_lir(&self, lir: LirModule, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
430 self.compile(lir.into_graph(), options)
431 }
432
433 fn compile_hir(
435 &self,
436 hir: HirModule,
437 device: rlx_driver::Device,
438 options: &CompileOptions,
439 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
440 let result = crate::stages::compile_hir_stages(device, hir, options)?;
441 crate::stages::maybe_log_fusion(&result.fusion);
442 Ok(self.compile_lir(result.lir, options))
443 }
444
445 fn compile_module(
447 &self,
448 module: rlx_ir::GraphModule,
449 device: rlx_driver::Device,
450 options: &CompileOptions,
451 ) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
452 let result = crate::stages::compile_module_stages(device, module, options)?;
453 crate::stages::maybe_log_fusion(&result.fusion);
454 Ok(self.compile_lir(result.lir, options))
455 }
456
457 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
464 &[]
465 }
466}
467
468#[allow(dead_code)]
471fn prepare_fused_graph(
472 graph: Graph,
473 options: &CompileOptions,
474 supported_ops: &[rlx_ir::OpKind],
475 backend_name: &str,
476) -> Graph {
477 let (mut graph, report) = rlx_opt::prepare_graph_for_backend_with_report(
478 graph,
479 backend_name,
480 supported_ops,
481 options.kernel_dispatch,
482 );
483 rlx_opt::maybe_log_dispatch_report(&report);
484 if !report.compile_ready {
485 panic!(
486 "{}\n{}",
487 rlx_opt::format_legalize_error(backend_name, &report.still_unsupported),
488 rlx_opt::format_dispatch_report(&report)
489 );
490 }
491 graph = crate::precompile::post_fusion_cleanup(graph, options);
492 if let Some(p) = options.policy.clone() {
493 use rlx_opt::pass::Pass as _;
494 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
495 }
496 graph
497}
498
499#[allow(dead_code)]
500fn declared_output_dtypes(
501 manifest: &cpu_low_precision::IoDtypeManifest,
502 exec_dtypes: Vec<rlx_ir::DType>,
503) -> Vec<rlx_ir::DType> {
504 exec_dtypes
505 .into_iter()
506 .enumerate()
507 .map(|(i, exec)| manifest.output_dtype(i, exec))
508 .collect()
509}
510
511pub fn compile(backend: &dyn Backend, graph: Graph) -> Box<dyn ExecutableGraph> {
519 backend.compile(graph, &CompileOptions::default())
520}
521
522pub fn compile_hir(
524 backend: &dyn Backend,
525 hir: HirModule,
526 device: rlx_driver::Device,
527 options: &CompileOptions,
528) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
529 backend.compile_hir(hir, device, options)
530}
531
532pub fn compile_module(
534 backend: &dyn Backend,
535 module: rlx_ir::GraphModule,
536 device: rlx_driver::Device,
537 options: &CompileOptions,
538) -> Result<Box<dyn ExecutableGraph>, rlx_ir::hir::LowerError> {
539 backend.compile_module(module, device, options)
540}
541
542pub fn compile_with_precision(
544 backend: &dyn Backend,
545 graph: Graph,
546 precision: crate::Precision,
547) -> Box<dyn ExecutableGraph> {
548 backend.compile(graph, &CompileOptions::new().precision(precision))
549}
550
551fn _legacy_apply_policy(graph: Graph, policy: Option<rlx_opt::PrecisionPolicy>) -> Graph {
556 use rlx_opt::pass::Pass as _;
557 match policy {
558 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
559 None => graph,
560 }
561}
562
563#[cfg(feature = "cpu")]
566pub mod cpu_backend {
567 use super::*;
568 use rlx_cpu::{arena::Arena, thunk};
569 use rlx_ir::{DType, NodeId, Op};
570 use rlx_opt::memory::{self, MemoryPlan};
571 use rlx_driver::arena::{read_typed_to_f32, write_typed_from_f32};
574
575 pub struct CpuBackend;
576
577 const CPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
584 use rlx_ir::OpKind::*;
585 &[
586 Input,
587 Param,
588 Constant,
589 Activation,
590 Cast,
591 StopGradient,
592 Binary,
593 Compare,
594 Where,
595 ElementwiseRegion,
596 MatMul,
597 DotGeneral,
598 DenseSolve,
599 BatchedDenseSolve,
600 Scan,
601 ScanBackward,
602 ScanBackwardXs,
603 LayerNorm,
604 LayerNorm2d,
605 GroupNorm,
606 BatchNormInference,
607 RmsNorm,
608 ResizeNearest2x,
609 AxialRope2d,
610 Attention,
611 Rope,
612 Reshape,
613 Transpose,
614 Narrow,
615 Concat,
616 Expand,
617 Gather,
618 Reduce,
619 Softmax,
620 Cumsum,
621 ArgMax,
622 ArgMin,
623 TopK,
624 Sample,
625 RngNormal,
626 RngUniform,
627 Conv,
628 Im2Col,
629 ConvTranspose2d,
630 Pool,
631 GroupedMatMul,
632 DequantGroupedMatMul,
633 DequantMoEWeights,
634 ScatterAdd,
635 LoraMatMul,
636 DequantMatMul,
637 SelectiveScan,
638 GatedDeltaNet,
639 Lstm,
640 FusedSwiGLU,
641 FusedMatMulBiasAct,
642 FusedResidualLN,
643 FusedResidualRmsNorm,
644 FusedAttentionBlock,
645 ReluBackward,
650 ActivationBackward,
651 FakeQuantize,
652 FakeQuantizeBackward,
653 MaxPool2dBackward,
654 Conv2dBackwardInput,
655 Conv2dBackwardWeight,
656 SoftmaxCrossEntropyWithLogits,
657 SoftmaxCrossEntropyBackward,
658 AttentionBackward,
659 LayerNormBackwardInput,
660 LayerNormBackwardGamma,
661 BatchNormInferenceBackwardInput,
662 BatchNormInferenceBackwardGamma,
663 BatchNormInferenceBackwardBeta,
664 RmsNormBackwardInput,
665 RmsNormBackwardGamma,
666 RmsNormBackwardBeta,
667 RopeBackward,
668 CumsumBackward,
669 GatherBackward,
670 GaussianSplatRender,
672 GaussianSplatRenderBackward,
673 GaussianSplatPrepare,
674 GaussianSplatRasterize,
675 Custom,
679 CustomFn,
683 Fft,
687 FftButterflyStage,
688 LogMel,
689 LogMelBackward,
690 WelchPeaks,
691 ComplexNormSq,
696 ComplexNormSqBackward,
697 Conjugate,
698 ]
699 };
700
701 impl Backend for CpuBackend {
702 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
703 CPU_SUPPORTED_OPS
704 }
705
706 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
707 use rlx_opt::pass::Pass as _;
708 static ONNX_KERNELS: std::sync::Once = std::sync::Once::new();
709 ONNX_KERNELS.call_once(rlx_cpu::onnx_ref::register_onnx_reference_kernels);
710 let graph = rlx_opt::LowerControlFlow.run(graph);
716 if let Err(errors) = rlx_opt::legalize_for_backend(&graph, CPU_SUPPORTED_OPS) {
720 panic!("{}", rlx_opt::format_legalize_error("cpu", &errors));
721 }
722 let policy = options.policy.clone();
723 let _precision = options.precision;
724 let cfg = rlx_cpu::config::RuntimeConfig::global();
725
726 let graph = crate::precompile::precompile_cleanup(graph, options);
727
728 let mut compile_opts = options.clone();
730 compile_opts.arena_alignment = cfg.arena_alignment;
731 let compile_result = crate::stages::compile_graph_stages_for_backend(
732 rlx_driver::Device::Cpu,
733 graph,
734 &compile_opts,
735 CPU_SUPPORTED_OPS,
736 );
737 crate::stages::maybe_log_fusion(&compile_result.fusion);
738 let fused = compile_result.lir.into_graph();
739
740 let fused = match policy {
743 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(fused),
744 None => fused,
745 };
746
747 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&fused);
748 let exec_graph = if cpu_low_precision::needs_f32_exec(&fused) {
749 cpu_low_precision::promote_to_f32(fused)
750 } else {
751 fused
752 };
753
754 let plan = memory::plan_memory_aligned(&exec_graph, cfg.arena_alignment);
756 if cfg.verbose >= 1 {
757 eprintln!(
758 "[rlx] arena: {} bytes, {} buffers, alignment: {}",
759 plan.arena_size,
760 plan.assignments.len(),
761 cfg.arena_alignment
762 );
763 }
764 Box::new(build_cpu_executable(
765 exec_graph,
766 plan,
767 io_manifest,
768 options.rng,
769 ))
770 }
771
772 fn compile_lir(
773 &self,
774 lir: LirModule,
775 options: &CompileOptions,
776 ) -> Box<dyn ExecutableGraph> {
777 let alignment = lir.buffers.alignment.max(options.arena_alignment);
778 let mut graph = lir.into_graph();
779 {
780 use rlx_opt::pass::Pass as _;
781 graph = rlx_opt::LegalizeBroadcast.run(graph);
782 }
783 if let Some(p) = options.policy.clone() {
784 use rlx_opt::pass::Pass;
785 graph = rlx_opt::AutoMixedPrecision::new(p).run(graph);
786 }
787 let io_manifest = cpu_low_precision::IoDtypeManifest::from_graph(&graph);
788 let promote = cpu_low_precision::needs_f32_exec(&graph);
789 let exec_graph = if promote {
790 cpu_low_precision::promote_to_f32(graph)
791 } else {
792 graph
793 };
794 let plan = memory::plan_memory_aligned(&exec_graph, alignment);
797 let cfg = rlx_cpu::config::RuntimeConfig::global();
798 if cfg.verbose >= 1 {
799 eprintln!(
800 "[rlx] compile_lir: arena {} bytes ({} buffers, alignment {})",
801 plan.arena_size,
802 plan.assignments.len(),
803 alignment,
804 );
805 }
806 Box::new(build_cpu_executable(
807 exec_graph,
808 plan,
809 io_manifest,
810 options.rng,
811 ))
812 }
813 }
814
815 fn build_cpu_executable(
816 graph: Graph,
817 plan: MemoryPlan,
818 io_manifest: cpu_low_precision::IoDtypeManifest,
819 rng: rlx_ir::RngOptions,
820 ) -> CpuExecutable {
821 let mut arena = Arena::from_plan(plan);
822 let mut input_ids = HashMap::new();
823 let mut param_ids = HashMap::new();
824 let mut node_dtypes: HashMap<NodeId, DType> = HashMap::new();
825 for node in graph.nodes() {
826 node_dtypes.insert(node.id, node.shape.dtype());
827 match &node.op {
828 Op::Input { name } => {
829 input_ids.insert(name.clone(), node.id);
830 }
831 Op::Param { name } => {
832 param_ids.insert(name.clone(), node.id);
833 }
834 _ => {}
835 }
836 }
837
838 let schedule = thunk::compile_thunks_with_rng(&graph, &arena, rng);
839
840 let mut input_slots = Vec::new();
841 for node in graph.nodes() {
842 if let Op::Input { name } = &node.op {
843 let off = arena.byte_offset(node.id);
844 let len = node.shape.num_elements().unwrap_or(0);
845 input_slots.push((name.clone(), off, len, node.shape.dtype()));
846 }
847 }
848
849 let output_slots: Vec<(usize, usize)> = graph
850 .outputs
851 .iter()
852 .map(|&id| {
853 let off = arena.byte_offset(id);
854 let len = graph.node(id).shape.num_elements().unwrap_or(0);
855 (off, len)
856 })
857 .collect();
858
859 for node in graph.nodes() {
860 if let Op::Constant { data } = &node.op
861 && arena.has_buffer(node.id)
862 && !data.is_empty()
863 {
864 match node.shape.dtype() {
865 DType::F64
872 | DType::F16
873 | DType::BF16
874 | DType::I64
875 | DType::I32
876 | DType::U32 => {
877 let off = arena.byte_offset(node.id);
878 let buf = arena.raw_buf_mut();
879 let n = buf.len().saturating_sub(off).min(data.len());
880 buf[off..off + n].copy_from_slice(&data[..n]);
881 }
882 _ => {
883 let buf = arena.slice_mut(node.id);
884 let n_floats = data.len() / 4;
885 let n = buf.len().min(n_floats);
886 for i in 0..n {
887 let bytes = [
888 data[i * 4],
889 data[i * 4 + 1],
890 data[i * 4 + 2],
891 data[i * 4 + 3],
892 ];
893 buf[i] = f32::from_le_bytes(bytes);
894 }
895 }
896 }
897 }
898 }
899
900 CpuExecutable {
901 graph,
902 arena,
903 params: HashMap::new(),
904 typed_params: HashMap::new(),
905 input_ids,
906 param_ids,
907 node_dtypes,
908 io_manifest,
909 schedule,
910 input_slots,
911 output_slots,
912 handles: HashMap::new(),
913 active_extent: None,
914 moe_resident: None,
915 moe_resident_layers: None,
916 moe_topk_capture: None,
917 }
918 }
919
920 #[derive(Clone)]
921 struct CpuExecutable {
922 graph: Graph,
923 arena: Arena,
924 params: HashMap<String, Vec<f32>>,
925 typed_params: HashMap<String, (Vec<u8>, DType)>,
927 input_ids: HashMap<String, NodeId>,
928 param_ids: HashMap<String, NodeId>,
929 node_dtypes: HashMap<NodeId, DType>,
932 io_manifest: cpu_low_precision::IoDtypeManifest,
934 schedule: thunk::ThunkSchedule,
935 input_slots: Vec<(String, usize, usize, DType)>,
937 output_slots: Vec<(usize, usize)>,
939 handles: HashMap<String, Vec<f32>>,
944 active_extent: Option<(usize, usize)>,
950 moe_resident: Option<std::sync::Arc<[bool]>>,
951 moe_resident_layers: Option<std::sync::Arc<Vec<std::sync::Arc<[bool]>>>>,
952 moe_topk_capture: Option<std::sync::Arc<rlx_cpu::moe_topk_capture::MoeTopkCapture>>,
953 }
954
955 unsafe impl Send for CpuExecutable {}
956
957 impl CpuExecutable {
958 fn write_input(&mut self, id: NodeId, data: &[f32]) {
960 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
961 let off = self.arena.byte_offset(id);
962 let buf = self.arena.raw_buf_mut();
963 let elem_size = dtype.size_bytes();
964 let max_elems = (buf.len() - off) / elem_size;
965 unsafe {
966 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
967 }
968 }
969
970 fn read_output(&self, id: NodeId) -> Vec<f32> {
972 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
973 let off = self.arena.byte_offset(id);
974 let buf = self.arena.raw_buf();
975 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
976 unsafe { read_typed_to_f32(buf.as_ptr().add(off), dtype, n_elems) }
977 }
978 }
979
980 impl ExecutableGraph for CpuExecutable {
981 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
982 Box::new(self.clone())
983 }
984 fn set_param(&mut self, name: &str, data: &[f32]) {
985 self.params.insert(name.to_string(), data.to_vec());
986 self.typed_params.remove(name);
987 if let Some(&id) = self.param_ids.get(name)
990 && self.arena.has_buffer(id)
991 {
992 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
993 let off = self.arena.byte_offset(id);
994 let buf = self.arena.raw_buf_mut();
995 let elem_size = dtype.size_bytes();
996 let max_elems = (buf.len() - off) / elem_size;
997 unsafe {
998 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, data, max_elems);
999 }
1000 }
1001 }
1002
1003 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1004 self.restore_arena_baseline();
1005 let handle_names: Vec<String> = self.handles.keys().cloned().collect();
1008 for name in &handle_names {
1009 if let Some(&id) = self.input_ids.get(name)
1010 && self.arena.has_buffer(id)
1011 {
1012 let data = self.handles.get(name).cloned().unwrap_or_default();
1013 self.write_input(id, &data);
1014 }
1015 }
1016 for &(name, data) in inputs {
1018 if let Some(&id) = self.input_ids.get(name)
1019 && self.arena.has_buffer(id)
1020 {
1021 self.write_input(id, data);
1022 }
1023 }
1024
1025 let active_used = if let Some((actual, upper)) = self.active_extent {
1030 thunk::execute_thunks_active(
1031 &self.schedule,
1032 self.arena.raw_buf_mut(),
1033 actual,
1034 upper,
1035 )
1036 } else {
1037 false
1038 };
1039 if !active_used {
1040 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1042 }
1043
1044 for (idx, &out_id) in self.graph.outputs.iter().enumerate() {
1048 let name = format!("out{idx}");
1049 if self.handles.contains_key(&name) {
1050 let v = self.read_output(out_id);
1051 self.handles.insert(name, v);
1052 }
1053 }
1054
1055 self.graph
1056 .outputs
1057 .iter()
1058 .map(|&out_id| self.read_output(out_id))
1059 .collect()
1060 }
1061
1062 fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
1063 self.restore_arena_baseline();
1064 for &(name, data) in inputs {
1066 if let Some(&id) = self.input_ids.get(name)
1067 && self.arena.has_buffer(id)
1068 {
1069 self.write_input(id, data);
1070 }
1071 }
1072 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1073 self.graph
1077 .outputs
1078 .iter()
1079 .map(|&out_id| {
1080 let (ptr, len) = self.arena.raw_ptr(out_id);
1081 (ptr as *const f32, len)
1082 })
1083 .collect()
1084 }
1085
1086 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
1090 self.restore_arena_baseline();
1091 let buf = self.arena.raw_buf_mut();
1092 for (i, &data) in inputs.iter().enumerate() {
1093 if i < self.input_slots.len() {
1094 let (_, off, max_len, dtype) = &self.input_slots[i];
1095 unsafe {
1096 write_typed_from_f32(buf.as_mut_ptr().add(*off), *dtype, data, *max_len);
1097 }
1098 }
1099 }
1100 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1101 &self.output_slots
1102 }
1103
1104 fn arena_ptr(&self) -> *const u8 {
1105 self.arena.raw_buf_mut_ptr()
1106 }
1107
1108 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
1109 self.handles.insert(name.to_string(), data.to_vec());
1114 true
1115 }
1116
1117 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
1118 self.handles.get(name).cloned()
1119 }
1120
1121 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1122 self.active_extent = extent;
1123 }
1124
1125 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
1126 *self.schedule.rng.write().unwrap() = rng;
1127 }
1128
1129 fn rng(&self) -> rlx_ir::RngOptions {
1130 *self.schedule.rng.read().unwrap()
1131 }
1132
1133 fn set_moe_resident_experts(&mut self, mask: &[bool]) {
1134 self.moe_resident_layers = None;
1135 self.schedule.moe_resident_layers = None;
1136 self.moe_resident = Some(Arc::from(mask));
1137 self.schedule.moe_resident = self.moe_resident.clone();
1138 }
1139
1140 fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
1141 self.moe_resident = None;
1142 self.schedule.moe_resident = None;
1143 let layers: Vec<Arc<[bool]>> = masks.iter().map(|m| Arc::from(*m)).collect();
1144 let arc = Arc::new(layers);
1145 self.moe_resident_layers = Some(arc.clone());
1146 self.schedule.moe_resident_layers = Some(arc);
1147 }
1148
1149 fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
1150 let cap = rlx_cpu::moe_topk_capture::MoeTopkCapture::new(num_experts);
1151 self.moe_topk_capture = Some(cap.clone());
1152 self.schedule.moe_topk_capture = Some(cap);
1153 true
1154 }
1155
1156 fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
1157 let cap = self.moe_topk_capture.as_ref()?;
1158 let layers = cap.take_layers();
1159 if layers.is_empty() {
1160 None
1161 } else {
1162 Some(layers)
1163 }
1164 }
1165
1166 fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
1167 rlx_cpu::moe_residency::take_last_forward_stats()
1168 }
1169
1170 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1176 if matches!(dtype, DType::F64 | DType::I64 | DType::I32 | DType::U32) {
1177 self.set_param_bytes(name, data, dtype);
1178 return;
1179 }
1180 if matches!(dtype, DType::U8 | DType::I8) {
1184 self.set_param_bytes(name, data, dtype);
1185 return;
1186 }
1187 if dtype == DType::F32 {
1188 let n = data.len() / 4;
1189 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1190 self.set_param(name, s);
1191 } else {
1192 let f32_buf = super::widen_bytes_to_f32(data, dtype);
1193 self.set_param(name, &f32_buf);
1194 }
1195 }
1196
1197 fn run_typed(
1209 &mut self,
1210 inputs: &[(&str, &[u8], rlx_ir::DType)],
1211 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1212 let all_f64 = !inputs.is_empty() && inputs.iter().all(|(_, _, dt)| *dt == DType::F64);
1217
1218 if all_f64 {
1219 for (name, data, _) in inputs {
1220 if let Some(&id) = self.input_ids.get(*name) {
1221 if !self.arena.has_buffer(id) {
1222 continue;
1223 }
1224 let off = self.arena.byte_offset(id);
1225 let buf = self.arena.raw_buf_mut();
1226 let n = data.len();
1227 debug_assert!(
1228 off + n <= buf.len(),
1229 "run_typed: input '{name}' overflows arena slot"
1230 );
1231 buf[off..off + n].copy_from_slice(data);
1232 }
1233 }
1234 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1235 } else {
1236 let mut f32_owned: Vec<(String, Vec<f32>)> = Vec::new();
1241 for (name, data, dt) in inputs {
1242 let direct = matches!(
1243 *dt,
1244 DType::F64 | DType::I32 | DType::I64 | DType::U32 | DType::C64
1245 );
1246 if direct {
1247 if let Some(&id) = self.input_ids.get(*name) {
1248 if !self.arena.has_buffer(id) {
1249 continue;
1250 }
1251 let off = self.arena.byte_offset(id);
1252 let buf = self.arena.raw_buf_mut();
1253 buf[off..off + data.len()].copy_from_slice(data);
1254 }
1255 } else {
1256 let v = super::widen_bytes_to_f32(data, *dt);
1257 f32_owned.push((name.to_string(), v));
1258 }
1259 }
1260 for (name, data) in &f32_owned {
1261 if let Some(&id) = self.input_ids.get(name.as_str()) {
1262 if self.arena.has_buffer(id) {
1263 self.write_input(id, data);
1264 }
1265 }
1266 }
1267 let active_used = if let Some((actual, upper)) = self.active_extent {
1268 thunk::execute_thunks_active(
1269 &self.schedule,
1270 self.arena.raw_buf_mut(),
1271 actual,
1272 upper,
1273 )
1274 } else {
1275 false
1276 };
1277 if !active_used {
1278 thunk::execute_thunks(&self.schedule, self.arena.raw_buf_mut());
1279 }
1280 }
1281
1282 self.graph
1284 .outputs
1285 .iter()
1286 .enumerate()
1287 .map(|(idx, &id)| {
1288 let exec_dtype = self.graph.node(id).shape.dtype();
1289 let declared = self.io_manifest.output_dtype(idx, exec_dtype);
1290 if matches!(
1291 exec_dtype,
1292 DType::F64
1293 | DType::F16
1294 | DType::BF16
1295 | DType::I32
1296 | DType::I64
1297 | DType::U32
1298 | DType::C64
1299 ) {
1300 let n_elems = self.graph.node(id).shape.num_elements().unwrap_or(0);
1301 let n_bytes = n_elems * exec_dtype.size_bytes();
1302 let off = self.arena.byte_offset(id);
1303 let bytes = self.arena.raw_buf()[off..off + n_bytes].to_vec();
1304 return (bytes, declared);
1305 }
1306 let f32_vals = self.read_output(id);
1307 if declared != exec_dtype {
1308 return (super::narrow_f32_to_bytes(&f32_vals, declared), declared);
1309 }
1310 let bytes = f32_vals.iter().flat_map(|v| v.to_le_bytes()).collect();
1311 (bytes, declared)
1312 })
1313 .collect()
1314 }
1315 }
1316
1317 impl CpuExecutable {
1318 fn restore_arena_baseline(&mut self) {
1323 self.arena.raw_buf_mut().fill(0);
1324 let constants: Vec<(NodeId, DType, Vec<u8>)> = self
1325 .graph
1326 .nodes()
1327 .iter()
1328 .filter_map(|node| {
1329 if let Op::Constant { data } = &node.op
1330 && self.arena.has_buffer(node.id)
1331 && !data.is_empty()
1332 {
1333 Some((node.id, node.shape.dtype(), data.clone()))
1334 } else {
1335 None
1336 }
1337 })
1338 .collect();
1339 for (id, dtype, data) in constants {
1340 self.write_constant_to_arena(id, dtype, &data);
1341 }
1342 let params = self.params.clone();
1343 for (name, data) in params {
1344 if let Some(&id) = self.param_ids.get(&name)
1345 && self.arena.has_buffer(id)
1346 {
1347 let dtype = self.node_dtypes.get(&id).copied().unwrap_or(DType::F32);
1348 let off = self.arena.byte_offset(id);
1349 let buf = self.arena.raw_buf_mut();
1350 let elem_size = dtype.size_bytes();
1351 let max_elems = (buf.len() - off) / elem_size;
1352 unsafe {
1353 write_typed_from_f32(buf.as_mut_ptr().add(off), dtype, &data, max_elems);
1354 }
1355 }
1356 }
1357 let typed = self.typed_params.clone();
1358 for (name, (data, dtype)) in typed {
1359 self.write_param_bytes_to_arena(&name, &data);
1360 let _ = dtype;
1361 }
1362 }
1363
1364 fn write_constant_to_arena(&mut self, id: NodeId, dtype: DType, data: &[u8]) {
1365 match dtype {
1366 DType::F64 | DType::F16 | DType::BF16 | DType::U8 | DType::I8 => {
1367 let off = self.arena.byte_offset(id);
1368 let buf = self.arena.raw_buf_mut();
1369 let n = buf.len().saturating_sub(off).min(data.len());
1370 buf[off..off + n].copy_from_slice(&data[..n]);
1371 }
1372 _ => {
1373 let buf = self.arena.slice_mut(id);
1374 let n_floats = data.len() / 4;
1375 let n = buf.len().min(n_floats);
1376 for i in 0..n {
1377 let bytes = [
1378 data[i * 4],
1379 data[i * 4 + 1],
1380 data[i * 4 + 2],
1381 data[i * 4 + 3],
1382 ];
1383 buf[i] = f32::from_le_bytes(bytes);
1384 }
1385 }
1386 }
1387 }
1388
1389 fn set_param_bytes(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1395 self.typed_params
1396 .insert(name.to_string(), (data.to_vec(), dtype));
1397 self.params.remove(name);
1398 self.write_param_bytes_to_arena(name, data);
1399 }
1400
1401 fn write_param_bytes_to_arena(&mut self, name: &str, data: &[u8]) {
1402 if let Some(&id) = self.param_ids.get(name)
1403 && self.arena.has_buffer(id)
1404 {
1405 let off = self.arena.byte_offset(id);
1406 let buf = self.arena.raw_buf_mut();
1407 debug_assert!(
1408 off + data.len() <= buf.len(),
1409 "set_param_bytes: '{name}' would overflow arena slot"
1410 );
1411 buf[off..off + data.len()].copy_from_slice(data);
1412 }
1413 }
1414 }
1415}
1416
1417#[cfg(feature = "gpu")]
1422pub mod wgpu_backend {
1423 use super::*;
1424 use rlx_ir::OpKind;
1425 use rlx_wgpu::backend::WgpuExecutable;
1426
1427 pub struct WgpuBackend;
1428
1429 const WGPU_SUPPORTED_OPS: &[OpKind] = &[
1435 OpKind::Input,
1436 OpKind::Param,
1437 OpKind::Constant,
1438 OpKind::Activation,
1439 OpKind::Cast,
1440 OpKind::StopGradient,
1441 OpKind::Binary,
1442 OpKind::Compare,
1443 OpKind::Where,
1444 OpKind::ElementwiseRegion,
1445 OpKind::TransformRegion,
1446 OpKind::BatchElementwiseRegion,
1447 OpKind::MatMul,
1448 OpKind::DotGeneral,
1449 OpKind::LayerNorm,
1450 OpKind::RmsNorm,
1451 OpKind::Attention,
1452 OpKind::AttentionBackward,
1453 OpKind::RmsNormBackwardInput,
1454 OpKind::RmsNormBackwardGamma,
1455 OpKind::RmsNormBackwardBeta,
1456 OpKind::LayerNormBackwardInput,
1463 OpKind::LayerNormBackwardGamma,
1464 OpKind::RopeBackward,
1465 OpKind::CumsumBackward,
1466 OpKind::GatherBackward,
1467 OpKind::Rope,
1468 OpKind::Reshape,
1469 OpKind::Transpose,
1470 OpKind::Narrow,
1471 OpKind::Concat,
1472 OpKind::Expand,
1473 OpKind::Gather,
1474 OpKind::Reduce,
1475 OpKind::Softmax,
1476 OpKind::Cumsum,
1477 OpKind::TopK,
1478 OpKind::Sample,
1479 OpKind::Conv,
1480 OpKind::Im2Col,
1481 OpKind::Pool,
1482 OpKind::GroupedMatMul,
1483 OpKind::DequantGroupedMatMul,
1484 OpKind::DequantMoEWeights,
1485 OpKind::ScatterAdd,
1486 OpKind::SelectiveScan,
1487 OpKind::Lstm,
1488 OpKind::DequantMatMul,
1489 OpKind::FusedMatMulBiasAct,
1490 OpKind::FusedResidualLN,
1491 OpKind::FusedResidualRmsNorm,
1492 OpKind::FusedSwiGLU,
1493 OpKind::FusedAttentionBlock,
1494 OpKind::FusedTransformerLayer,
1495 OpKind::Fft,
1501 OpKind::LogMel,
1502 OpKind::LogMelBackward,
1503 OpKind::WelchPeaks,
1504 OpKind::GaussianSplatRender,
1506 OpKind::GaussianSplatRenderBackward,
1507 OpKind::GaussianSplatPrepare,
1508 OpKind::GaussianSplatRasterize,
1509 OpKind::Custom,
1510 OpKind::RngNormal,
1511 OpKind::RngUniform,
1512 ];
1514
1515 impl Backend for WgpuBackend {
1516 fn supported_ops(&self) -> &'static [OpKind] {
1517 WGPU_SUPPORTED_OPS
1518 }
1519
1520 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1521 use rlx_opt::pass::Pass as _;
1522 let graph = rlx_opt::LowerControlFlow.run(graph);
1523 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, WGPU_SUPPORTED_OPS)
1524 .unwrap_or_else(|errors| {
1525 panic!("{}", rlx_opt::format_legalize_error("wgpu", &errors));
1526 });
1527 let graph = crate::precompile::precompile_cleanup(graph, options);
1528 let graph = rlx_opt::LegalizeBroadcast.run(graph);
1532 let compile_result = crate::stages::compile_graph_stages_for_backend(
1541 rlx_driver::Device::Gpu,
1542 graph,
1543 options,
1544 WGPU_SUPPORTED_OPS,
1545 );
1546 crate::stages::maybe_log_fusion(&compile_result.fusion);
1547 let graph = compile_result.lir.into_graph();
1548 let graph = match options.policy.clone() {
1549 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
1550 None => graph,
1551 };
1552 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1553 Box::new(WgpuExecutableWrapper {
1554 inner: WgpuExecutable::compile_rng(graph, options.rng),
1555 io_manifest,
1556 })
1557 }
1558
1559 fn compile_lir(
1560 &self,
1561 lir: LirModule,
1562 options: &CompileOptions,
1563 ) -> Box<dyn ExecutableGraph> {
1564 use rlx_opt::pass::Pass as _;
1565 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
1568 let graph = prepare_fused_graph(graph, options, WGPU_SUPPORTED_OPS, "wgpu");
1569 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1570 Box::new(WgpuExecutableWrapper {
1571 inner: WgpuExecutable::compile_rng(graph, options.rng),
1572 io_manifest,
1573 })
1574 }
1575 }
1576
1577 struct WgpuExecutableWrapper {
1578 inner: WgpuExecutable,
1579 io_manifest: cpu_low_precision::IoDtypeManifest,
1580 }
1581
1582 unsafe impl Send for WgpuExecutableWrapper {}
1583
1584 impl ExecutableGraph for WgpuExecutableWrapper {
1585 fn set_param(&mut self, name: &str, data: &[f32]) {
1586 self.inner.set_param(name, data);
1587 }
1588 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
1589 self.inner.run(inputs)
1590 }
1591 fn run_read_outputs(
1592 &mut self,
1593 inputs: &[(&str, &[f32])],
1594 read_indices: Option<&[usize]>,
1595 ) -> Vec<Vec<f32>> {
1596 self.inner.run_read_outputs(inputs, read_indices)
1597 }
1598 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
1599 self.inner.bind_gpu_handle(name, data)
1600 }
1601 fn has_gpu_handle(&self, name: &str) -> bool {
1602 self.inner.has_gpu_handle(name)
1603 }
1604 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
1605 self.inner.set_gpu_handle_feed(handle_name, output_index);
1606 true
1607 }
1608 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
1609 self.inner.read_gpu_handle(name)
1610 }
1611 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
1612 self.inner.set_active_extent(extent);
1613 }
1614
1615 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
1616 self.inner.set_rng(rng);
1617 }
1618
1619 fn rng(&self) -> rlx_ir::RngOptions {
1620 self.inner.rng()
1621 }
1622
1623 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
1626 match dtype {
1627 rlx_ir::DType::U8 | rlx_ir::DType::I8 => {
1628 self.inner.set_param_bytes(name, data);
1629 }
1630 rlx_ir::DType::F32 => {
1631 let n = data.len() / 4;
1632 let f32_slice =
1633 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
1634 self.inner.set_param(name, f32_slice);
1635 }
1636 rlx_ir::DType::F16 => {
1637 let n = data.len() / 2;
1638 let f16_slice =
1639 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n) };
1640 let f32: Vec<f32> = f16_slice.iter().map(|h| h.to_f32()).collect();
1641 self.inner.set_param(name, &f32);
1642 }
1643 rlx_ir::DType::BF16 => {
1644 let n = data.len() / 2;
1645 let bf16_slice = unsafe {
1646 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1647 };
1648 let f32: Vec<f32> = bf16_slice.iter().map(|h| h.to_f32()).collect();
1649 self.inner.set_param(name, &f32);
1650 }
1651 other => panic!(
1652 "rlx-wgpu set_param_typed: dtype {other:?} unsupported \
1653 (F32, F16, BF16 only — wgpu arena is f32-uniform)"
1654 ),
1655 }
1656 }
1657
1658 fn run_typed(
1661 &mut self,
1662 inputs: &[(&str, &[u8], rlx_ir::DType)],
1663 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
1664 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
1665 for (name, data, dt) in inputs {
1666 let v: Vec<f32> = match *dt {
1667 rlx_ir::DType::F32 => {
1668 let n = data.len() / 4;
1669 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) }
1670 .to_vec()
1671 }
1672 rlx_ir::DType::F16 => {
1673 let n = data.len() / 2;
1674 let s = unsafe {
1675 std::slice::from_raw_parts(data.as_ptr() as *const half::f16, n)
1676 };
1677 s.iter().map(|h| h.to_f32()).collect()
1678 }
1679 rlx_ir::DType::BF16 => {
1680 let n = data.len() / 2;
1681 let s = unsafe {
1682 std::slice::from_raw_parts(data.as_ptr() as *const half::bf16, n)
1683 };
1684 s.iter().map(|h| h.to_f32()).collect()
1685 }
1686 rlx_ir::DType::I64 => {
1690 let n = data.len() / 8;
1691 let s =
1692 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i64, n) };
1693 s.iter().map(|&x| x as f32).collect()
1694 }
1695 rlx_ir::DType::I32 => {
1696 let n = data.len() / 4;
1697 let s =
1698 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const i32, n) };
1699 s.iter().map(|&x| x as f32).collect()
1700 }
1701 rlx_ir::DType::U8 | rlx_ir::DType::I8 | rlx_ir::DType::Bool => {
1702 data.iter().map(|&b| b as f32).collect()
1703 }
1704 other => {
1705 panic!("rlx-wgpu run_typed: input '{name}' dtype {other:?} unsupported")
1706 }
1707 };
1708 owned.push((name.to_string(), v));
1709 }
1710 let refs: Vec<(&str, &[f32])> = owned
1711 .iter()
1712 .map(|(n, d)| (n.as_str(), d.as_slice()))
1713 .collect();
1714 let dtypes =
1715 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
1716 let outs = self.inner.run(&refs);
1717 outs.into_iter()
1718 .zip(
1719 dtypes
1720 .into_iter()
1721 .chain(std::iter::repeat(rlx_ir::DType::F32)),
1722 )
1723 .map(|(v, dt)| (narrow_to_dtype(&v, dt), dt))
1724 .collect()
1725 }
1726
1727 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
1728 Box::new(WgpuExecutableWrapper {
1729 inner: self.inner.clone_for_cache(),
1730 io_manifest: self.io_manifest.clone(),
1731 })
1732 }
1733 }
1734
1735 fn narrow_to_dtype(v: &[f32], dt: rlx_ir::DType) -> Vec<u8> {
1741 use rlx_ir::DType;
1742 match dt {
1743 DType::F32 => {
1744 let mut bytes = Vec::with_capacity(v.len() * 4);
1745 for &x in v {
1746 bytes.extend_from_slice(&x.to_le_bytes());
1747 }
1748 bytes
1749 }
1750 DType::F16 => {
1751 let mut bytes = Vec::with_capacity(v.len() * 2);
1752 for &x in v {
1753 bytes.extend_from_slice(&half::f16::from_f32(x).to_le_bytes());
1754 }
1755 bytes
1756 }
1757 DType::BF16 => {
1758 let mut bytes = Vec::with_capacity(v.len() * 2);
1759 for &x in v {
1760 bytes.extend_from_slice(&half::bf16::from_f32(x).to_le_bytes());
1761 }
1762 bytes
1763 }
1764 DType::F64 => {
1765 let mut bytes = Vec::with_capacity(v.len() * 8);
1766 for &x in v {
1767 bytes.extend_from_slice(&(x as f64).to_le_bytes());
1768 }
1769 bytes
1770 }
1771 DType::I8 => v.iter().map(|&x| x as i8 as u8).collect(),
1772 DType::U8 => v.iter().map(|&x| x as u8).collect(),
1773 DType::I16 => {
1774 let mut bytes = Vec::with_capacity(v.len() * 2);
1775 for &x in v {
1776 bytes.extend_from_slice(&(x as i16).to_le_bytes());
1777 }
1778 bytes
1779 }
1780 DType::I32 => {
1781 let mut bytes = Vec::with_capacity(v.len() * 4);
1782 for &x in v {
1783 bytes.extend_from_slice(&(x as i32).to_le_bytes());
1784 }
1785 bytes
1786 }
1787 DType::U32 => {
1788 let mut bytes = Vec::with_capacity(v.len() * 4);
1789 for &x in v {
1790 bytes.extend_from_slice(&(x as u32).to_le_bytes());
1791 }
1792 bytes
1793 }
1794 DType::I64 => {
1795 let mut bytes = Vec::with_capacity(v.len() * 8);
1796 for &x in v {
1797 bytes.extend_from_slice(&(x as i64).to_le_bytes());
1798 }
1799 bytes
1800 }
1801 DType::Bool => v
1802 .iter()
1803 .map(|&x| if x != 0.0 { 1u8 } else { 0u8 })
1804 .collect(),
1805 DType::C64 => {
1812 let mut bytes = Vec::with_capacity(v.len() * 4);
1813 for &x in v {
1814 bytes.extend_from_slice(&x.to_le_bytes());
1815 }
1816 bytes
1817 }
1818 }
1819 }
1820}
1821
1822#[cfg(all(feature = "mlx", rlx_mlx_host))]
1825pub mod mlx_backend {
1826 use super::*;
1827 use rlx_mlx::MlxExecutable;
1828
1829 pub struct MlxBackend;
1830
1831 const MLX_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
1841 use rlx_ir::OpKind::*;
1842 &[
1843 Input,
1844 Param,
1845 Constant,
1846 Activation,
1847 Cast,
1848 StopGradient,
1849 Binary,
1850 Compare,
1851 Where,
1852 ElementwiseRegion,
1853 TransformRegion,
1854 BatchElementwiseRegion,
1855 MatMul,
1856 DotGeneral,
1857 DenseSolve,
1858 BatchedDenseSolve,
1859 LayerNorm,
1860 LayerNorm2d,
1861 ResizeNearest2x,
1862 RmsNorm,
1863 Attention,
1864 Rope,
1865 Reshape,
1866 Transpose,
1867 Narrow,
1868 Concat,
1869 Expand,
1870 Gather,
1871 Reduce,
1872 Softmax,
1873 Cumsum,
1874 TopK,
1875 RngNormal,
1876 RngUniform,
1877 Sample,
1878 Conv,
1879 ConvTranspose2d,
1880 Pool,
1881 GroupedMatMul,
1882 DequantGroupedMatMul,
1883 DequantMoEWeights,
1884 ScatterAdd,
1885 LoraMatMul,
1886 DequantMatMul,
1887 SelectiveScan,
1888 GatedDeltaNet,
1889 FusedSwiGLU,
1890 FusedMatMulBiasAct,
1891 FusedResidualLN,
1892 FusedResidualRmsNorm,
1893 FusedAttentionBlock,
1894 FusedTransformerLayer,
1895 If,
1896 While,
1897 Scan,
1902 ScanBackward,
1903 ScanBackwardXs,
1904 ReluBackward,
1907 ActivationBackward,
1908 SoftmaxCrossEntropyWithLogits,
1909 SoftmaxCrossEntropyBackward,
1910 AttentionBackward,
1911 LayerNormBackwardInput,
1912 LayerNormBackwardGamma,
1913 Conv2dBackwardInput,
1918 Conv2dBackwardWeight,
1919 MaxPool2dBackward,
1923 FakeQuantize,
1928 FakeQuantizeBackward,
1929 Custom,
1934 Fft,
1935 LogMel,
1936 LogMelBackward,
1937 WelchPeaks,
1938 GaussianSplatRender,
1939 GaussianSplatRenderBackward,
1940 ]
1943 };
1944
1945 impl Backend for MlxBackend {
1946 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
1947 MLX_SUPPORTED_OPS
1948 }
1949
1950 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
1951 let compile_result = crate::stages::compile_graph_stages_for_backend(
1952 rlx_driver::Device::Mlx,
1953 graph,
1954 options,
1955 MLX_SUPPORTED_OPS,
1956 );
1957 crate::stages::maybe_log_fusion(&compile_result.fusion);
1958 self.compile_lir(compile_result.lir, options)
1959 }
1960
1961 fn compile_lir(
1962 &self,
1963 lir: LirModule,
1964 options: &CompileOptions,
1965 ) -> Box<dyn ExecutableGraph> {
1966 use rlx_opt::pass::Pass as _;
1967 let mut graph = lir.into_graph();
1968 graph = rlx_opt::LowerControlFlow.run(graph);
1969 let graph = prepare_fused_graph(graph, options, MLX_SUPPORTED_OPS, "mlx");
1970 Box::new(build_mlx_executable(graph, options.rng))
1971 }
1972 }
1973
1974 fn build_mlx_executable(graph: Graph, rng: rlx_ir::RngOptions) -> MlxExecutableWrapper {
1975 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
1976 let mode = mlx_mode_from_env();
1977 let mut exe = MlxExecutable::compile_from_fused_with_rng(graph, mode, rng);
1978 if mode == rlx_mlx::lower::MlxMode::Compiled {
1979 if let Err(e) = exe.warm_compile() {
1980 eprintln!(
1981 "[rlx-runtime] MLX warm_compile failed ({e}); first run will pay the trace cost"
1982 );
1983 }
1984 }
1985 MlxExecutableWrapper {
1986 inner: exe,
1987 io_manifest,
1988 }
1989 }
1990
1991 fn mlx_mode_from_env() -> rlx_mlx::lower::MlxMode {
1992 match rlx_ir::env::var("RLX_MLX_MODE").as_deref() {
1993 Some(s) if s.eq_ignore_ascii_case("eager") => rlx_mlx::lower::MlxMode::Eager,
1994 Some(s) if s.eq_ignore_ascii_case("lazy") => rlx_mlx::lower::MlxMode::Lazy,
1995 Some(s) if s.eq_ignore_ascii_case("compiled") => rlx_mlx::lower::MlxMode::Compiled,
1996 _ => rlx_mlx::lower::MlxMode::Compiled,
1997 }
1998 }
1999
2000 struct MlxExecutableWrapper {
2001 inner: MlxExecutable,
2002 io_manifest: cpu_low_precision::IoDtypeManifest,
2003 }
2004
2005 unsafe impl Send for MlxExecutableWrapper {}
2006
2007 impl ExecutableGraph for MlxExecutableWrapper {
2008 fn set_param(&mut self, name: &str, data: &[f32]) {
2009 self.inner.set_param(name, data);
2010 }
2011 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2012 self.inner.run(inputs)
2013 }
2014 fn run_read_outputs(
2015 &mut self,
2016 inputs: &[(&str, &[f32])],
2017 read_indices: Option<&[usize]>,
2018 ) -> Vec<Vec<f32>> {
2019 self.inner
2020 .run_read_outputs(inputs, read_indices)
2021 .unwrap_or_else(|e| panic!("MLX run_read_outputs failed: {e}"))
2022 }
2023 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2024 self.inner.run_slots(inputs)
2025 }
2026 fn arena_ptr(&self) -> *const u8 {
2027 self.inner.arena_ptr()
2028 }
2029 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2030 self.inner.commit_no_wait(inputs);
2031 }
2032 fn sync_pending(&mut self) {
2033 self.inner.sync_pending();
2034 }
2035 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2036 self.inner.run_pipelined(input_sets)
2037 }
2038 fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
2039 self.inner.bind_handle(name, data)
2040 }
2041 fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
2042 self.inner.read_handle(name)
2043 }
2044 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2045 self.inner.bind_gpu_handle(name, data).is_ok()
2046 }
2047 fn has_gpu_handle(&self, name: &str) -> bool {
2048 self.inner.has_gpu_handle(name)
2049 }
2050 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2051 self.inner.set_gpu_handle_feed(handle_name, output_index);
2052 true
2053 }
2054 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2055 self.inner.read_gpu_handle(name).ok()
2056 }
2057 fn run_feed_gpu_handle(
2058 &mut self,
2059 inputs: &[(&str, &[f32])],
2060 handle_name: &str,
2061 output_index: usize,
2062 ) -> Option<Vec<f32>> {
2063 self.inner
2064 .run_feed_gpu(inputs, handle_name, output_index)
2065 .ok()
2066 }
2067 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2068 self.inner.set_param_typed(name, data, dtype);
2069 }
2070 fn run_typed(
2071 &mut self,
2072 inputs: &[(&str, &[u8], rlx_ir::DType)],
2073 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2074 self.inner.run_typed(inputs)
2075 }
2076 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2077 self.inner.set_active_extent(extent);
2078 }
2079
2080 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2081 self.inner.set_rng(rng);
2082 }
2083
2084 fn rng(&self) -> rlx_ir::RngOptions {
2085 self.inner.rng()
2086 }
2087
2088 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2089 Box::new(MlxExecutableWrapper {
2090 inner: self.inner.clone_for_cache(),
2091 io_manifest: self.io_manifest.clone(),
2092 })
2093 }
2094 }
2095}
2096
2097pub(crate) const COREML_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2104 use rlx_ir::OpKind::*;
2105 &[
2106 Input,
2107 Param,
2108 Constant,
2109 Activation,
2110 Cast,
2111 Binary,
2112 MatMul,
2113 LayerNorm,
2114 RmsNorm,
2115 Reduce,
2116 Softmax,
2117 Reshape,
2118 Transpose,
2119 Narrow,
2120 Concat,
2121 Gather,
2122 Rope,
2123 Attention,
2124 Compare,
2125 Where,
2126 Expand,
2127 Cumsum,
2128 ScatterAdd,
2129 BatchNormInference,
2130 GroupNorm,
2131 LayerNorm2d,
2132 LoraMatMul,
2133 Conv,
2134 ConvTranspose2d,
2135 Pool,
2136 TopK,
2137 AxialRope2d,
2138 ResizeNearest2x,
2139 StopGradient,
2140 GroupedMatMul,
2141 DequantMatMul,
2142 DequantMoEWeights,
2143 DequantGroupedMatMul,
2144 Quantize,
2145 Dequantize,
2146 SelectiveScan,
2147 GatedDeltaNet,
2148 ]
2149};
2150
2151#[cfg(all(feature = "coreml", any(target_os = "macos", target_os = "ios")))]
2159pub mod coreml_backend {
2160 use super::*;
2161 use rlx_coreml::CoremlExecutable;
2162
2163 pub struct CoremlBackend;
2164
2165 impl Backend for CoremlBackend {
2166 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2167 super::COREML_SUPPORTED_OPS
2168 }
2169
2170 fn compile(&self, graph: Graph, _options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2171 Box::new(CoremlExecutableWrapper {
2174 inner: CoremlExecutable::compile(graph),
2175 })
2176 }
2177
2178 fn compile_lir(
2179 &self,
2180 lir: LirModule,
2181 options: &CompileOptions,
2182 ) -> Box<dyn ExecutableGraph> {
2183 self.compile(lir.into_graph(), options)
2186 }
2187 }
2188
2189 struct CoremlExecutableWrapper {
2190 inner: CoremlExecutable,
2191 }
2192
2193 unsafe impl Send for CoremlExecutableWrapper {}
2195
2196 impl ExecutableGraph for CoremlExecutableWrapper {
2197 fn set_param(&mut self, name: &str, data: &[f32]) {
2198 self.inner.set_param(name, data);
2199 }
2200
2201 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2202 self.inner.set_param_typed(name, data, dtype);
2205 }
2206
2207 fn finalize_params(&mut self) {
2208 self.inner
2209 .finalize()
2210 .unwrap_or_else(|e| panic!("CoreML finalize failed: {e}"));
2211 }
2212
2213 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2214 self.inner
2215 .run(inputs)
2216 .unwrap_or_else(|e| panic!("CoreML run failed: {e}"))
2217 }
2218
2219 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2220 Box::new(CoremlExecutableWrapper {
2221 inner: self.inner.clone_for_cache(),
2222 })
2223 }
2224
2225 fn run_typed(
2226 &mut self,
2227 inputs: &[(&str, &[u8], rlx_ir::DType)],
2228 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2229 use rlx_ir::DType;
2230 let owned: Vec<(String, Vec<f32>)> = inputs
2233 .iter()
2234 .map(|(name, data, dt)| {
2235 let v: Vec<f32> = match dt {
2236 DType::I64 => data
2237 .chunks_exact(8)
2238 .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
2239 .collect(),
2240 DType::I32 => data
2241 .chunks_exact(4)
2242 .map(|c| i32::from_le_bytes(c.try_into().unwrap()) as f32)
2243 .collect(),
2244 DType::U8 | DType::Bool => data.iter().map(|&b| b as f32).collect(),
2245 _ => super::widen_bytes_to_f32(data, *dt),
2246 };
2247 (name.to_string(), v)
2248 })
2249 .collect();
2250 let refs: Vec<(&str, &[f32])> = owned
2251 .iter()
2252 .map(|(n, d)| (n.as_str(), d.as_slice()))
2253 .collect();
2254 self.run(&refs)
2255 .into_iter()
2256 .map(|v| {
2257 let bytes: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
2258 (bytes, DType::F32)
2259 })
2260 .collect()
2261 }
2262 }
2263}
2264
2265#[cfg(all(feature = "metal", target_os = "macos"))]
2266pub mod metal_backend {
2267 use super::*;
2268 use rlx_metal::backend::MetalExecutable;
2269
2270 pub struct MetalBackend;
2271
2272 const METAL_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2282 use rlx_ir::OpKind::*;
2283 &[
2284 Input,
2285 Param,
2286 Constant,
2287 Activation,
2288 Cast,
2289 StopGradient,
2290 Binary,
2291 Compare,
2292 Where,
2293 ElementwiseRegion,
2294 TransformRegion,
2295 BatchElementwiseRegion,
2296 MatMul,
2297 DotGeneral,
2298 LayerNorm,
2299 LayerNorm2d,
2300 GroupNorm,
2301 RmsNorm,
2302 ResizeNearest2x,
2303 AxialRope2d,
2304 Attention,
2305 AttentionBackward,
2306 RmsNormBackwardInput,
2307 RmsNormBackwardGamma,
2308 RmsNormBackwardBeta,
2309 RopeBackward,
2310 Cumsum,
2311 CumsumBackward,
2312 GatherBackward,
2313 Conv2dBackwardInput,
2314 Conv2dBackwardWeight,
2315 MaxPool2dBackward,
2316 Rope,
2317 Reshape,
2318 Transpose,
2319 Narrow,
2320 Concat,
2321 Expand,
2322 Gather,
2323 Reduce,
2324 Softmax,
2325 TopK,
2326 RngNormal,
2327 RngUniform,
2328 Conv,
2329 Im2Col,
2330 ConvTranspose2d,
2331 Pool,
2332 GroupedMatMul,
2333 DequantGroupedMatMul,
2334 DequantMoEWeights,
2335 ScatterAdd,
2336 DequantMatMul,
2337 GatedDeltaNet,
2338 Lstm,
2339 FusedSwiGLU,
2340 FusedMatMulBiasAct,
2341 FusedResidualLN,
2342 FusedResidualRmsNorm,
2343 Custom,
2349 Fft,
2355 LogMel,
2356 LogMelBackward,
2357 WelchPeaks,
2358 GaussianSplatRender,
2360 GaussianSplatRenderBackward,
2361 GaussianSplatPrepare,
2362 GaussianSplatRasterize,
2363 ]
2364 };
2365
2366 impl Backend for MetalBackend {
2367 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2368 METAL_SUPPORTED_OPS
2369 }
2370
2371 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2372 use rlx_opt::pass::Pass as _;
2373 let graph = rlx_opt::LowerControlFlow.run(graph);
2377 let dispatch = options.kernel_dispatch;
2378 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2379 graph,
2380 METAL_SUPPORTED_OPS,
2381 dispatch,
2382 )
2383 .unwrap_or_else(|errors| {
2384 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2385 });
2386 let graph = crate::precompile::precompile_cleanup(graph, options);
2387
2388 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2391 Box::new(MetalExecutableWrapper {
2392 inner: MetalExecutable::compile_with_policy(
2393 graph,
2394 options.policy.clone(),
2395 Some(METAL_SUPPORTED_OPS),
2396 options.rng,
2397 ),
2398 io_manifest,
2399 })
2400 }
2401
2402 fn compile_lir(
2403 &self,
2404 lir: LirModule,
2405 options: &CompileOptions,
2406 ) -> Box<dyn ExecutableGraph> {
2407 use rlx_opt::pass::Pass as _;
2408 let mut graph = lir.into_graph();
2409 graph = rlx_opt::LowerControlFlow.run(graph);
2410 let dispatch = options.kernel_dispatch;
2411 let mut graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
2412 graph,
2413 METAL_SUPPORTED_OPS,
2414 dispatch,
2415 )
2416 .unwrap_or_else(|errors| {
2417 panic!("{}", rlx_opt::format_legalize_error("metal", &errors));
2418 });
2419 graph = crate::precompile::precompile_cleanup(graph, options);
2420 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2421 Box::new(MetalExecutableWrapper {
2422 inner: MetalExecutable::compile_from_fused(
2423 graph,
2424 options.policy.clone(),
2425 Some(METAL_SUPPORTED_OPS),
2426 options.rng,
2427 ),
2428 io_manifest,
2429 })
2430 }
2431 }
2432
2433 struct MetalExecutableWrapper {
2434 inner: MetalExecutable,
2435 io_manifest: cpu_low_precision::IoDtypeManifest,
2436 }
2437
2438 unsafe impl Send for MetalExecutableWrapper {}
2439
2440 impl ExecutableGraph for MetalExecutableWrapper {
2441 fn set_param(&mut self, name: &str, data: &[f32]) {
2442 self.inner.set_param(name, data);
2443 }
2444
2445 fn finalize_params(&mut self) {
2446 self.inner.preload_qmatmul_weights();
2447 }
2448
2449 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2450 self.inner.run(inputs)
2451 }
2452 fn run_read_outputs(
2453 &mut self,
2454 inputs: &[(&str, &[f32])],
2455 read_indices: Option<&[usize]>,
2456 ) -> Vec<Vec<f32>> {
2457 self.inner.run_read_outputs(inputs, read_indices)
2458 }
2459 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2460 self.inner.bind_gpu_handle(name, data)
2461 }
2462 fn has_gpu_handle(&self, name: &str) -> bool {
2463 self.inner.has_gpu_handle(name)
2464 }
2465 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2466 self.inner.set_gpu_handle_feed(handle_name, output_index);
2467 true
2468 }
2469 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2470 self.inner.read_gpu_handle(name)
2471 }
2472 fn read_output_row(
2473 &self,
2474 out_idx: usize,
2475 row: usize,
2476 row_inner: usize,
2477 ) -> Option<Vec<f32>> {
2478 Some(self.inner.read_graph_output_row(out_idx, row, row_inner))
2479 }
2480 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2481 self.inner.run_slots(inputs)
2482 }
2483 fn arena_ptr(&self) -> *const u8 {
2484 self.inner.arena_ptr()
2485 }
2486 fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
2487 self.inner.commit_no_wait(inputs);
2488 }
2489 fn sync_pending(&mut self) {
2490 self.inner.sync_pending();
2491 }
2492 fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
2493 self.inner.run_pipelined(input_sets)
2494 }
2495 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2496 self.inner.set_active_extent(extent);
2497 }
2498
2499 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2500 self.inner.set_rng(rng);
2501 }
2502
2503 fn rng(&self) -> rlx_ir::RngOptions {
2504 self.inner.rng()
2505 }
2506
2507 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2513 if matches!(
2514 dtype,
2515 rlx_ir::DType::U8
2516 | rlx_ir::DType::I8
2517 | rlx_ir::DType::I32
2518 | rlx_ir::DType::I64
2519 | rlx_ir::DType::U32
2520 | rlx_ir::DType::F64
2521 ) {
2522 self.inner.set_param_bytes(name, data);
2523 return;
2524 }
2525 if dtype == rlx_ir::DType::F32 {
2526 let n = data.len() / 4;
2527 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2528 self.inner.set_param(name, s);
2529 } else {
2530 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2531 self.inner.set_param(name, &f32_buf);
2532 }
2533 }
2534
2535 fn run_typed(
2539 &mut self,
2540 inputs: &[(&str, &[u8], rlx_ir::DType)],
2541 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2542 self.inner.run_typed(inputs)
2543 }
2544
2545 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2546 Box::new(MetalExecutableWrapper {
2547 inner: self.inner.clone_for_cache(),
2548 io_manifest: self.io_manifest.clone(),
2549 })
2550 }
2551 }
2552}
2553
2554#[cfg(feature = "cuda")]
2557pub mod cuda_backend {
2558 use super::*;
2559 use rlx_cuda::backend::CudaExecutable;
2560
2561 pub struct CudaBackend;
2562
2563 const CUDA_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2569 use rlx_ir::OpKind::*;
2570 &[
2571 Input,
2572 Param,
2573 Constant,
2574 Activation,
2575 Cast,
2576 Binary,
2577 Compare,
2578 Where,
2579 ElementwiseRegion,
2580 TransformRegion,
2581 BatchElementwiseRegion,
2582 MatMul,
2583 DotGeneral,
2584 LayerNorm,
2585 LayerNorm2d,
2586 GroupNorm,
2587 ResizeNearest2x,
2588 RmsNorm,
2589 Attention,
2590 AttentionBackward,
2591 RmsNormBackwardInput,
2592 RmsNormBackwardGamma,
2593 RmsNormBackwardBeta,
2594 RopeBackward,
2595 CumsumBackward,
2596 GatherBackward,
2597 Conv2dBackwardInput,
2598 Conv2dBackwardWeight,
2599 MaxPool2dBackward,
2600 Rope,
2601 Reshape,
2602 Transpose,
2603 Narrow,
2604 Concat,
2605 Expand,
2606 Gather,
2607 Reduce,
2608 Softmax,
2609 Cumsum,
2610 TopK,
2611 Sample,
2612 Conv,
2613 ConvTranspose2d,
2614 Pool,
2615 GroupedMatMul,
2616 DequantGroupedMatMul,
2617 DequantMoEWeights,
2618 ScatterAdd,
2619 DequantMatMul,
2620 SelectiveScan,
2621 Lstm,
2622 FusedMatMulBiasAct,
2623 FusedResidualLN,
2624 FusedResidualRmsNorm,
2625 GaussianSplatRender,
2626 GaussianSplatRenderBackward,
2627 GaussianSplatPrepare,
2628 GaussianSplatRasterize,
2629 Custom,
2630 Fft,
2631 LogMel,
2632 LogMelBackward,
2633 WelchPeaks,
2634 Im2Col,
2635 RngNormal,
2636 RngUniform,
2637 ]
2638 };
2639
2640 impl Backend for CudaBackend {
2641 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2642 CUDA_SUPPORTED_OPS
2643 }
2644
2645 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2646 use rlx_opt::pass::Pass as _;
2647 let graph = rlx_cuda::unfuse::unfuse(graph);
2650 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, CUDA_SUPPORTED_OPS)
2651 .unwrap_or_else(|errors| {
2652 panic!("{}", rlx_opt::format_legalize_error("cuda", &errors));
2653 });
2654 let graph = crate::precompile::precompile_cleanup(graph, options);
2655 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2657 let compile_result = crate::stages::compile_graph_stages_for_backend(
2659 rlx_driver::Device::Cuda,
2660 graph,
2661 options,
2662 CUDA_SUPPORTED_OPS,
2663 );
2664 crate::stages::maybe_log_fusion(&compile_result.fusion);
2665 let graph = compile_result.lir.into_graph();
2666 let graph = match options.policy.clone() {
2667 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2668 None => graph,
2669 };
2670 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2671 Box::new(CudaExecutableWrapper {
2672 inner: CudaExecutable::compile_rng(graph, options.rng),
2673 io_manifest,
2674 })
2675 }
2676
2677 fn compile_lir(
2678 &self,
2679 lir: LirModule,
2680 options: &CompileOptions,
2681 ) -> Box<dyn ExecutableGraph> {
2682 use rlx_opt::pass::Pass as _;
2683 let graph = rlx_opt::LegalizeBroadcast.run(lir.into_graph());
2684 let (graph, io_manifest) =
2685 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2686 rlx_cuda::unfuse::unfuse(graph),
2687 options,
2688 CUDA_SUPPORTED_OPS,
2689 "cuda",
2690 ));
2691 Box::new(CudaExecutableWrapper {
2692 inner: CudaExecutable::compile_rng(graph, options.rng),
2693 io_manifest,
2694 })
2695 }
2696 }
2697
2698 struct CudaExecutableWrapper {
2699 inner: CudaExecutable,
2700 io_manifest: cpu_low_precision::IoDtypeManifest,
2701 }
2702
2703 unsafe impl Send for CudaExecutableWrapper {}
2708
2709 impl ExecutableGraph for CudaExecutableWrapper {
2710 fn set_param(&mut self, name: &str, data: &[f32]) {
2711 self.inner.set_param(name, data);
2712 }
2713 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2714 self.inner.run(inputs)
2715 }
2716 fn run_read_outputs(
2717 &mut self,
2718 inputs: &[(&str, &[f32])],
2719 read_indices: Option<&[usize]>,
2720 ) -> Vec<Vec<f32>> {
2721 self.inner.run_read_outputs(inputs, read_indices)
2722 }
2723 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2724 self.inner.bind_gpu_handle(name, data)
2725 }
2726 fn has_gpu_handle(&self, name: &str) -> bool {
2727 self.inner.has_gpu_handle(name)
2728 }
2729 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2730 self.inner.set_gpu_handle_feed(handle_name, output_index);
2731 true
2732 }
2733 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2734 self.inner.read_gpu_handle(name)
2735 }
2736 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2737 self.inner.set_active_extent(extent);
2738 }
2739
2740 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2741 self.inner.set_rng(rng);
2742 }
2743
2744 fn rng(&self) -> rlx_ir::RngOptions {
2745 self.inner.rng()
2746 }
2747
2748 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2749 self.inner.run_slots(inputs)
2750 }
2751
2752 fn arena_ptr(&self) -> *const u8 {
2753 self.inner.arena_ptr()
2754 }
2755
2756 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
2761 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
2762 self.inner.set_param_bytes(name, data);
2763 return;
2764 }
2765 if dtype == rlx_ir::DType::F32 {
2766 let n = data.len() / 4;
2767 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
2768 self.inner.set_param(name, s);
2769 } else {
2770 let f32_buf = super::widen_bytes_to_f32(data, dtype);
2771 self.inner.set_param(name, &f32_buf);
2772 }
2773 }
2774
2775 fn run_typed(
2778 &mut self,
2779 inputs: &[(&str, &[u8], rlx_ir::DType)],
2780 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
2781 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
2782 for (name, data, dt) in inputs {
2783 let v = super::widen_bytes_to_f32(data, *dt);
2784 owned.push((name.to_string(), v));
2785 }
2786 let refs: Vec<(&str, &[f32])> = owned
2787 .iter()
2788 .map(|(n, d)| (n.as_str(), d.as_slice()))
2789 .collect();
2790 let dtypes =
2791 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
2792 let outs = self.inner.run(&refs);
2793 outs.into_iter()
2794 .zip(
2795 dtypes
2796 .into_iter()
2797 .chain(std::iter::repeat(rlx_ir::DType::F32)),
2798 )
2799 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
2800 .collect()
2801 }
2802
2803 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
2804 Box::new(CudaExecutableWrapper {
2805 inner: self.inner.clone_for_cache(),
2806 io_manifest: self.io_manifest.clone(),
2807 })
2808 }
2809 }
2810}
2811
2812#[cfg(feature = "rocm")]
2815pub mod rocm_backend {
2816 use super::*;
2817 use rlx_rocm::backend::RocmExecutable;
2818
2819 pub struct RocmBackend;
2820
2821 const ROCM_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
2824 use rlx_ir::OpKind::*;
2825 &[
2826 Input,
2827 Param,
2828 Constant,
2829 Activation,
2830 Cast,
2831 Binary,
2832 Compare,
2833 Where,
2834 ElementwiseRegion,
2835 TransformRegion,
2836 BatchElementwiseRegion,
2837 MatMul,
2838 DotGeneral,
2839 LayerNorm,
2840 LayerNorm2d,
2841 GroupNorm,
2842 ResizeNearest2x,
2843 RmsNorm,
2844 Attention,
2845 AttentionBackward,
2846 RmsNormBackwardInput,
2847 RmsNormBackwardGamma,
2848 RmsNormBackwardBeta,
2849 RopeBackward,
2850 CumsumBackward,
2851 GatherBackward,
2852 Rope,
2853 Reshape,
2854 Transpose,
2855 Narrow,
2856 Concat,
2857 Expand,
2858 Gather,
2859 Reduce,
2860 Softmax,
2861 Cumsum,
2862 TopK,
2863 Sample,
2864 Conv,
2865 ConvTranspose2d,
2866 Pool,
2867 GroupedMatMul,
2868 DequantGroupedMatMul,
2869 DequantMoEWeights,
2870 ScatterAdd,
2871 DequantMatMul,
2872 SelectiveScan,
2873 Lstm,
2874 FusedMatMulBiasAct,
2875 FusedResidualLN,
2876 FusedResidualRmsNorm,
2877 GaussianSplatRender,
2878 GaussianSplatRenderBackward,
2879 GaussianSplatPrepare,
2880 GaussianSplatRasterize,
2881 Custom,
2882 Fft,
2883 LogMel,
2884 LogMelBackward,
2885 WelchPeaks,
2886 Im2Col,
2887 RngNormal,
2888 RngUniform,
2889 ]
2890 };
2891
2892 impl Backend for RocmBackend {
2893 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
2894 ROCM_SUPPORTED_OPS
2895 }
2896
2897 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
2898 use rlx_opt::pass::Pass as _;
2899 let graph = rlx_rocm::unfuse::unfuse(graph);
2900 let graph = rlx_opt::legalize_or_rewrite_for_backend(graph, ROCM_SUPPORTED_OPS)
2901 .unwrap_or_else(|errors| {
2902 panic!("{}", rlx_opt::format_legalize_error("rocm", &errors));
2903 });
2904 let graph = crate::precompile::precompile_cleanup(graph, options);
2905 let graph = rlx_opt::LegalizeBroadcast.run(graph);
2906 let compile_result = crate::stages::compile_graph_stages_for_backend(
2907 rlx_driver::Device::Rocm,
2908 graph,
2909 options,
2910 ROCM_SUPPORTED_OPS,
2911 );
2912 crate::stages::maybe_log_fusion(&compile_result.fusion);
2913 let graph = compile_result.lir.into_graph();
2914 let graph = match options.policy.clone() {
2915 Some(p) => rlx_opt::AutoMixedPrecision::new(p).run(graph),
2916 None => graph,
2917 };
2918 let (graph, io_manifest) = cpu_low_precision::prepare_f32_exec_graph(graph);
2919 Box::new(RocmExecutableWrapper {
2920 inner: RocmExecutable::compile_rng(graph, options.rng),
2921 io_manifest,
2922 })
2923 }
2924
2925 fn compile_lir(
2926 &self,
2927 lir: LirModule,
2928 options: &CompileOptions,
2929 ) -> Box<dyn ExecutableGraph> {
2930 let (graph, io_manifest) =
2931 cpu_low_precision::prepare_f32_exec_graph(prepare_fused_graph(
2932 rlx_rocm::unfuse::unfuse(lir.into_graph()),
2933 options,
2934 ROCM_SUPPORTED_OPS,
2935 "rocm",
2936 ));
2937 Box::new(RocmExecutableWrapper {
2938 inner: RocmExecutable::compile_rng(graph, options.rng),
2939 io_manifest,
2940 })
2941 }
2942 }
2943
2944 struct RocmExecutableWrapper {
2945 inner: RocmExecutable,
2946 io_manifest: cpu_low_precision::IoDtypeManifest,
2947 }
2948
2949 unsafe impl Send for RocmExecutableWrapper {}
2953
2954 impl ExecutableGraph for RocmExecutableWrapper {
2955 fn set_param(&mut self, name: &str, data: &[f32]) {
2956 self.inner.set_param(name, data);
2957 }
2958 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
2959 self.inner.run(inputs)
2960 }
2961 fn run_read_outputs(
2962 &mut self,
2963 inputs: &[(&str, &[f32])],
2964 read_indices: Option<&[usize]>,
2965 ) -> Vec<Vec<f32>> {
2966 self.inner.run_read_outputs(inputs, read_indices)
2967 }
2968 fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
2969 self.inner.bind_gpu_handle(name, data)
2970 }
2971 fn has_gpu_handle(&self, name: &str) -> bool {
2972 self.inner.has_gpu_handle(name)
2973 }
2974 fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
2975 self.inner.set_gpu_handle_feed(handle_name, output_index);
2976 true
2977 }
2978 fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
2979 self.inner.read_gpu_handle(name)
2980 }
2981 fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
2982 self.inner.run_slots(inputs)
2983 }
2984 fn arena_ptr(&self) -> *const u8 {
2985 self.inner.arena_ptr()
2986 }
2987 fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
2988 self.inner.set_active_extent(extent);
2989 }
2990
2991 fn set_rng(&mut self, rng: rlx_ir::RngOptions) {
2992 self.inner.set_rng(rng);
2993 }
2994
2995 fn rng(&self) -> rlx_ir::RngOptions {
2996 self.inner.rng()
2997 }
2998
2999 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
3004 if matches!(dtype, rlx_ir::DType::U8 | rlx_ir::DType::I8) {
3005 self.inner.set_param_bytes(name, data);
3006 return;
3007 }
3008 if dtype == rlx_ir::DType::F32 {
3009 let n = data.len() / 4;
3010 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
3011 self.inner.set_param(name, s);
3012 } else {
3013 let f32_buf = super::widen_bytes_to_f32(data, dtype);
3014 self.inner.set_param(name, &f32_buf);
3015 }
3016 }
3017
3018 fn run_typed(
3021 &mut self,
3022 inputs: &[(&str, &[u8], rlx_ir::DType)],
3023 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
3024 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
3025 for (name, data, dt) in inputs {
3026 let v = super::widen_bytes_to_f32(data, *dt);
3027 owned.push((name.to_string(), v));
3028 }
3029 let refs: Vec<(&str, &[f32])> = owned
3030 .iter()
3031 .map(|(n, d)| (n.as_str(), d.as_slice()))
3032 .collect();
3033 let dtypes =
3034 super::declared_output_dtypes(&self.io_manifest, self.inner.output_dtypes());
3035 let outs = self.inner.run(&refs);
3036 outs.into_iter()
3037 .zip(
3038 dtypes
3039 .into_iter()
3040 .chain(std::iter::repeat(rlx_ir::DType::F32)),
3041 )
3042 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
3043 .collect()
3044 }
3045
3046 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
3047 Box::new(RocmExecutableWrapper {
3048 inner: self.inner.clone_for_cache(),
3049 io_manifest: self.io_manifest.clone(),
3050 })
3051 }
3052 }
3053}
3054
3055#[cfg(feature = "tpu")]
3058pub mod tpu_backend {
3059 use super::*;
3060 use rlx_tpu::TpuExecutable;
3061
3062 pub struct TpuBackend;
3063
3064 const TPU_SUPPORTED_OPS: &[rlx_ir::OpKind] = {
3070 use rlx_ir::OpKind::*;
3071 &[
3072 Input,
3073 Param,
3074 Constant,
3075 Activation,
3076 Cast,
3077 Binary,
3078 Compare,
3079 Where,
3080 ElementwiseRegion,
3081 TransformRegion,
3082 BatchElementwiseRegion,
3083 MatMul,
3084 DotGeneral,
3085 LayerNorm,
3086 RmsNorm,
3087 Attention,
3088 Rope,
3089 Reshape,
3090 Transpose,
3091 Narrow,
3092 Concat,
3093 Expand,
3094 Gather,
3095 Reduce,
3096 Softmax,
3097 Cumsum,
3098 TopK,
3099 Sample,
3100 Conv,
3101 Pool,
3102 GroupedMatMul,
3103 DequantGroupedMatMul,
3104 DequantMoEWeights,
3105 ScatterAdd,
3106 DequantMatMul,
3107 SelectiveScan,
3108 QMatMul,
3110 QConv2d,
3111 Quantize,
3112 Dequantize,
3113 FusedMatMulBiasAct,
3114 FusedResidualLN,
3115 FusedResidualRmsNorm,
3116 Fft,
3117 LogMel,
3118 LogMelBackward,
3119 WelchPeaks,
3120 RngNormal,
3121 RngUniform,
3122 ]
3124 };
3125
3126 impl Backend for TpuBackend {
3127 fn supported_ops(&self) -> &'static [rlx_ir::OpKind] {
3128 TPU_SUPPORTED_OPS
3129 }
3130
3131 fn compile(&self, graph: Graph, options: &CompileOptions) -> Box<dyn ExecutableGraph> {
3132 let graph = rlx_opt::legalize_or_rewrite_for_backend_with_config(
3133 graph,
3134 TPU_SUPPORTED_OPS,
3135 options.kernel_dispatch,
3136 )
3137 .unwrap_or_else(|errors| {
3138 panic!("{}", rlx_opt::format_legalize_error("tpu", &errors));
3139 });
3140 use rlx_opt::pass::Pass as _;
3156 let policy = options
3157 .policy
3158 .clone()
3159 .unwrap_or(rlx_opt::PrecisionPolicy::AutoMixedBf16);
3160 let graph = rlx_opt::AutoMixedPrecision::new(policy).run(graph);
3161 let _ = options.dce;
3162 let _ = options.constant_folding;
3163 Box::new(TpuExecutableWrapper {
3164 inner: TpuExecutable::compile_rng(graph, options.rng),
3165 })
3166 }
3167 }
3168
3169 struct TpuExecutableWrapper {
3170 inner: TpuExecutable,
3171 }
3172
3173 unsafe impl Send for TpuExecutableWrapper {}
3177
3178 impl ExecutableGraph for TpuExecutableWrapper {
3179 fn set_param(&mut self, name: &str, data: &[f32]) {
3180 self.inner.set_param(name, data);
3181 }
3182 fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
3183 self.inner.run(inputs)
3184 }
3185
3186 fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
3191 if dtype == rlx_ir::DType::F32 {
3192 let n = data.len() / 4;
3193 let s = unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, n) };
3194 self.inner.set_param(name, s);
3195 } else {
3196 let f32_buf = super::widen_bytes_to_f32(data, dtype);
3197 self.inner.set_param(name, &f32_buf);
3198 }
3199 }
3200
3201 fn run_typed(
3202 &mut self,
3203 inputs: &[(&str, &[u8], rlx_ir::DType)],
3204 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
3205 let mut owned: Vec<(String, Vec<f32>)> = Vec::with_capacity(inputs.len());
3206 for (name, data, dt) in inputs {
3207 let v = super::widen_bytes_to_f32(data, *dt);
3208 owned.push((name.to_string(), v));
3209 }
3210 let refs: Vec<(&str, &[f32])> = owned
3211 .iter()
3212 .map(|(n, d)| (n.as_str(), d.as_slice()))
3213 .collect();
3214 let dtypes = self.inner.output_dtypes();
3215 let outs = self.inner.run(&refs);
3216 outs.into_iter()
3217 .zip(
3218 dtypes
3219 .into_iter()
3220 .chain(std::iter::repeat(rlx_ir::DType::F32)),
3221 )
3222 .map(|(v, dt)| (super::narrow_f32_to_bytes(&v, dt), dt))
3223 .collect()
3224 }
3225
3226 fn clone_box(&self) -> Box<dyn ExecutableGraph> {
3227 Box::new(TpuExecutableWrapper {
3228 inner: self.inner.clone_for_cache(),
3229 })
3230 }
3231 }
3232}