rlx_runtime/compiled.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled graph — the hot-path execution object.
17
18use crate::backend::ExecutableGraph;
19use rlx_driver::Device;
20
21/// A compiled graph ready for execution.
22///
23/// Created by [`crate::Session::compile`]. Holds the fused + memory-planned
24/// graph and all pre-allocated execution state. Call
25/// [`CompiledGraph::run`] repeatedly with different inputs — zero
26/// allocation per call.
27pub struct CompiledGraph {
28 inner: Box<dyn ExecutableGraph>,
29 device: Device,
30}
31
32impl Clone for CompiledGraph {
33 /// Deep-clones the underlying executable via `ExecutableGraph::clone_box`.
34 /// Backends that don't support cloning will panic at this point.
35 fn clone(&self) -> Self {
36 Self {
37 inner: self.inner.clone_box(),
38 device: self.device,
39 }
40 }
41}
42
43impl CompiledGraph {
44 pub(crate) fn new(inner: Box<dyn ExecutableGraph>, device: Device) -> Self {
45 Self { inner, device }
46 }
47
48 /// Which device this graph runs on.
49 pub fn device(&self) -> Device {
50 self.device
51 }
52
53 /// Set a named parameter (model weight).
54 /// Call once per parameter after compilation.
55 pub fn set_param(&mut self, name: &str, data: &[f32]) {
56 self.inner.set_param(name, data);
57 }
58
59 /// Execute the graph with named inputs.
60 /// Returns one `Vec<f32>` per graph output (copies from arena).
61 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
62 self.inner.run(inputs)
63 }
64
65 /// Execute and return raw pointers to output data (zero-copy).
66 /// Data is valid until the next `run`/`run_raw` call.
67 ///
68 /// # Safety
69 /// The returned pointers point into the arena. Do not use after
70 /// the next call to run/run_raw (arena data will be overwritten).
71 pub fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
72 self.inner.run_raw(inputs)
73 }
74
75 /// Fastest execution: inputs by slot index (order matches graph input declaration).
76 /// Returns output (offset, len) pairs. Read data via `arena_ptr().add(offset)`.
77 /// Zero HashMap lookup, zero Vec allocation, zero name matching.
78 pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
79 self.inner.run_slots(inputs)
80 }
81
82 /// Arena pointer for reading output data after `run_slots`.
83 pub fn arena_ptr(&self) -> *const u8 {
84 self.inner.arena_ptr()
85 }
86
87 /// Bind a persistent buffer (KV-cache, optimizer state, etc.).
88 /// Stays alive across `run()` calls; the backend uses it as the
89 /// graph input with the matching name.
90 /// Returns true if the backend supports persistent handles.
91 pub fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
92 self.inner.bind_handle(name, data)
93 }
94
95 /// Read the current contents of a persistent buffer.
96 pub fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
97 self.inner.read_handle(name)
98 }
99
100 /// GPU-resident MLX input (no-op on non-MLX backends).
101 pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
102 self.inner.bind_gpu_handle(name, data)
103 }
104
105 pub fn has_gpu_handle(&self, name: &str) -> bool {
106 self.inner.has_gpu_handle(name)
107 }
108
109 pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
110 self.inner.set_gpu_handle_feed(handle_name, output_index)
111 }
112
113 pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
114 self.inner.read_gpu_handle(name)
115 }
116
117 /// Run, refresh GPU handle from output, return that output vector.
118 pub fn run_feed_gpu_handle(
119 &mut self,
120 inputs: &[(&str, &[f32])],
121 handle_name: &str,
122 output_index: usize,
123 ) -> Option<Vec<f32>> {
124 self.inner
125 .run_feed_gpu_handle(inputs, handle_name, output_index)
126 }
127
128 /// Hint subsequent `run` calls to process only the first `actual`
129 /// rows along the bucket axis (out of `upper`, the compile extent).
130 /// Backends that support per-kernel active-extent dispatch honor
131 /// this; others ignore it. Pass `None` to clear.
132 ///
133 /// See `BucketedCompileCache::run_padded` for the canonical caller.
134 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
135 self.inner.set_active_extent(extent);
136 }
137
138 /// TIDE merged MoE placement (`mask[expert]` device-resident if any layer has it).
139 pub fn set_moe_resident_experts(&mut self, mask: &[bool]) {
140 self.inner.set_moe_resident_experts(mask);
141 }
142
143 /// Per MoE layer placement (forward order). Preferred on CPU over merged mask.
144 pub fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
145 self.inner.set_moe_resident_experts_per_layer(masks);
146 }
147
148 /// Capture MoE router TopK on next forward (CPU). Returns false if unsupported.
149 pub fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
150 self.inner.enable_moe_topk_capture(num_experts)
151 }
152
153 /// Per-layer expert indices from the last forward (MoE router TopK order).
154 pub fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
155 self.inner.take_moe_topk_capture()
156 }
157
158 /// GroupedMatMul GPU/CPU token accounting from the last forward (CPU).
159 pub fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
160 self.inner.take_moe_residency_stats()
161 }
162
163 // ── Pipelined / async execution (Phase C) ─────────────────────────
164
165 /// Encode + commit a forward pass without waiting for the device.
166 ///
167 /// Outputs of intermediate calls are stomped — use `run_pipelined`
168 /// when you need each call's outputs back. Pair with `sync_pending`
169 /// to drain. CPU is synchronous, so this falls back to `run`.
170 pub fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
171 self.inner.commit_no_wait(inputs);
172 }
173
174 /// Wait for every command queued by `commit_no_wait`. CPU is a no-op.
175 pub fn sync_pending(&mut self) {
176 self.inner.sync_pending();
177 }
178
179 /// Pipelined batch run. Issues one commit per input set, syncs once
180 /// at the end. On Metal, each commit gets its own output snapshot
181 /// (allocated + blit-copied), so subsequent commits stomping the
182 /// shared arena don't corrupt earlier runs' outputs.
183 /// Returns `out[run_idx][output_idx][element_idx]`.
184 pub fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
185 self.inner.run_pipelined(input_sets)
186 }
187
188 /// Set a named parameter from raw bytes in the given dtype. The
189 /// backend handles the widen-to-f32 (or zero-widen, when supported
190 /// natively) on the way in. Lets callers feed F16/BF16 weights
191 /// without a host-side cast.
192 pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
193 self.inner.set_param_typed(name, data, dtype);
194 }
195
196 /// Execute with typed inputs and return outputs in their declared
197 /// graph dtype, byte-encoded. Mirrors the wgpu / MLX zero-widen
198 /// semantics on f32-arena backends (CPU + Metal) by widening at
199 /// the boundary.
200 pub fn run_typed(
201 &mut self,
202 inputs: &[(&str, &[u8], rlx_ir::DType)],
203 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
204 self.inner.run_typed(inputs)
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use crate::*;
211
212 #[test]
213 #[cfg(feature = "cpu")]
214 fn end_to_end_session() {
215 let mut g = Graph::new("matmul_bias_gelu");
216 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
217 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
218 let b = g.param("b", Shape::new(&[3], DType::F32));
219 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
220 let add = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
221 let out = g.activation(op::Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
222 g.set_outputs(vec![out]);
223
224 // Compile
225 let session = Session::new(Device::Cpu);
226 let mut compiled = session.compile(g);
227
228 // Set weights
229 // w = identity-ish [4, 3]: first 3 rows are I, last row is 0
230 compiled.set_param(
231 "w",
232 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
233 );
234 compiled.set_param("b", &[0.5, -0.5, 0.0]);
235
236 // Run
237 let x_data = vec![
238 1.0, 0.0, 0.0, 0.0, // row 0: [1,0,0,0] @ w = [1,0,0] + bias = [1.5,-0.5,0]
239 0.0, 1.0, 0.0, 0.0, // row 1: [0,1,0,0] @ w = [0,1,0] + bias = [0.5, 0.5,0]
240 ];
241 let outputs = compiled.run(&[("x", &x_data)]);
242
243 assert_eq!(outputs.len(), 1);
244 let result = &outputs[0];
245 assert_eq!(result.len(), 6); // [2, 3]
246
247 // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
248 assert!(
249 (result[0] - 1.399).abs() < 0.01,
250 "gelu(1.5) = {}",
251 result[0]
252 );
253 assert!(
254 (result[1] - -0.154).abs() < 0.01,
255 "gelu(-0.5) = {}",
256 result[1]
257 );
258 assert!((result[2]).abs() < 0.01, "gelu(0) = {}", result[2]);
259
260 // gelu(0.5) ≈ 0.346, gelu(0.5) ≈ 0.346, gelu(0) = 0
261 assert!(
262 (result[3] - 0.346).abs() < 0.01,
263 "gelu(0.5) = {}",
264 result[3]
265 );
266 assert!(
267 (result[4] - 0.346).abs() < 0.01,
268 "gelu(0.5) = {}",
269 result[4]
270 );
271
272 // Run again with different input — zero allocation
273 let x2 = vec![0.0f32; 8];
274 let outputs2 = compiled.run(&[("x", &x2)]);
275 // All zeros input → gelu(bias) for each output
276 let r2 = &outputs2[0];
277 assert!((r2[0] - 0.346).abs() < 0.01, "gelu(0.5) = {}", r2[0]); // gelu(0+0.5)
278 }
279
280 #[test]
281 #[cfg(feature = "cpu")]
282 fn device_display() {
283 use crate::device_ext::is_available;
284 assert!(format!("{}", Device::Cpu).starts_with("CPU"));
285 assert!(is_available(Device::Cpu));
286 // Backend availability is feature-gated; only assert
287 // unavailable when the corresponding feature is off.
288 #[cfg(not(feature = "gpu"))]
289 assert!(!is_available(Device::Gpu));
290 #[cfg(not(feature = "cuda"))]
291 assert!(!is_available(Device::Cuda));
292 #[cfg(not(feature = "rocm"))]
293 assert!(!is_available(Device::Rocm));
294 #[cfg(not(feature = "tpu"))]
295 assert!(!is_available(Device::Tpu));
296 }
297}