rlx_runtime/compiled.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled graph — the hot-path execution object.
17
18use crate::backend::ExecutableGraph;
19use rlx_driver::Device;
20
21/// A compiled graph ready for execution.
22///
23/// Created by [`crate::Session::compile`]. Holds the fused + memory-planned
24/// graph and all pre-allocated execution state. Call
25/// [`CompiledGraph::run`] repeatedly with different inputs — zero
26/// allocation per call.
27pub struct CompiledGraph {
28 inner: Box<dyn ExecutableGraph>,
29 device: Device,
30}
31
32impl Clone for CompiledGraph {
33 /// Deep-clones the underlying executable via `ExecutableGraph::clone_box`.
34 /// Backends that don't support cloning will panic at this point.
35 fn clone(&self) -> Self {
36 Self {
37 inner: self.inner.clone_box(),
38 device: self.device,
39 }
40 }
41}
42
43impl CompiledGraph {
44 pub(crate) fn new(inner: Box<dyn ExecutableGraph>, device: Device) -> Self {
45 Self { inner, device }
46 }
47
48 /// Which device this graph runs on.
49 pub fn device(&self) -> Device {
50 self.device
51 }
52
53 /// Set a named parameter (model weight).
54 /// Call once per parameter after compilation.
55 pub fn set_param(&mut self, name: &str, data: &[f32]) {
56 self.inner.set_param(name, data);
57 }
58
59 /// Execute the graph with named inputs.
60 /// Returns one `Vec<f32>` per graph output (copies from arena).
61 pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
62 self.inner.run(inputs)
63 }
64
65 /// Run and read back only selected outputs (logits-only decode on MLX).
66 pub fn run_read_outputs(
67 &mut self,
68 inputs: &[(&str, &[f32])],
69 read_indices: Option<&[usize]>,
70 ) -> Vec<Vec<f32>> {
71 self.inner.run_read_outputs(inputs, read_indices)
72 }
73
74 /// Read one row from a row-major output tensor after a forward pass.
75 pub fn read_output_row(
76 &self,
77 out_idx: usize,
78 row: usize,
79 row_inner: usize,
80 ) -> Option<Vec<f32>> {
81 self.inner.read_output_row(out_idx, row, row_inner)
82 }
83
84 /// Execute and return raw pointers to output data (zero-copy).
85 /// Data is valid until the next `run`/`run_raw` call.
86 ///
87 /// # Safety
88 /// The returned pointers point into the arena. Do not use after
89 /// the next call to run/run_raw (arena data will be overwritten).
90 pub fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
91 self.inner.run_raw(inputs)
92 }
93
94 /// Fastest execution: inputs by slot index (order matches graph input declaration).
95 /// Returns output (offset, len) pairs. Read data via `arena_ptr().add(offset)`.
96 /// Zero HashMap lookup, zero Vec allocation, zero name matching.
97 pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
98 self.inner.run_slots(inputs)
99 }
100
101 /// Arena pointer for reading output data after `run_slots`.
102 pub fn arena_ptr(&self) -> *const u8 {
103 self.inner.arena_ptr()
104 }
105
106 /// Bind a persistent buffer (KV-cache, optimizer state, etc.).
107 /// Stays alive across `run()` calls; the backend uses it as the
108 /// graph input with the matching name.
109 /// Returns true if the backend supports persistent handles.
110 pub fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
111 self.inner.bind_handle(name, data)
112 }
113
114 /// Read the current contents of a persistent buffer.
115 pub fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
116 self.inner.read_handle(name)
117 }
118
119 /// GPU-resident MLX input (no-op on non-MLX backends).
120 pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
121 self.inner.bind_gpu_handle(name, data)
122 }
123
124 pub fn has_gpu_handle(&self, name: &str) -> bool {
125 self.inner.has_gpu_handle(name)
126 }
127
128 pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
129 self.inner.set_gpu_handle_feed(handle_name, output_index)
130 }
131
132 pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
133 self.inner.read_gpu_handle(name)
134 }
135
136 /// Run, refresh GPU handle from output, return that output vector.
137 pub fn run_feed_gpu_handle(
138 &mut self,
139 inputs: &[(&str, &[f32])],
140 handle_name: &str,
141 output_index: usize,
142 ) -> Option<Vec<f32>> {
143 self.inner
144 .run_feed_gpu_handle(inputs, handle_name, output_index)
145 }
146
147 /// Hint subsequent `run` calls to process only the first `actual`
148 /// rows along the bucket axis (out of `upper`, the compile extent).
149 /// Backends that support per-kernel active-extent dispatch honor
150 /// this; others ignore it. Pass `None` to clear.
151 ///
152 /// See `BucketedCompileCache::run_padded` for the canonical caller.
153 pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
154 self.inner.set_active_extent(extent);
155 }
156
157 /// TIDE merged MoE placement (`mask[expert]` device-resident if any layer has it).
158 pub fn set_moe_resident_experts(&mut self, mask: &[bool]) {
159 self.inner.set_moe_resident_experts(mask);
160 }
161
162 /// Per MoE layer placement (forward order). Preferred on CPU over merged mask.
163 pub fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
164 self.inner.set_moe_resident_experts_per_layer(masks);
165 }
166
167 /// Capture MoE router TopK on next forward (CPU). Returns false if unsupported.
168 pub fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
169 self.inner.enable_moe_topk_capture(num_experts)
170 }
171
172 /// Per-layer expert indices from the last forward (MoE router TopK order).
173 pub fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
174 self.inner.take_moe_topk_capture()
175 }
176
177 /// GroupedMatMul GPU/CPU token accounting from the last forward (CPU).
178 pub fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
179 self.inner.take_moe_residency_stats()
180 }
181
182 // ── Pipelined / async execution (Phase C) ─────────────────────────
183
184 /// Encode + commit a forward pass without waiting for the device.
185 ///
186 /// Outputs of intermediate calls are stomped — use `run_pipelined`
187 /// when you need each call's outputs back. Pair with `sync_pending`
188 /// to drain. CPU is synchronous, so this falls back to `run`.
189 pub fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
190 self.inner.commit_no_wait(inputs);
191 }
192
193 /// Wait for every command queued by `commit_no_wait`. CPU is a no-op.
194 pub fn sync_pending(&mut self) {
195 self.inner.sync_pending();
196 }
197
198 /// Pipelined batch run. Issues one commit per input set, syncs once
199 /// at the end. On Metal, each commit gets its own output snapshot
200 /// (allocated + blit-copied), so subsequent commits stomping the
201 /// shared arena don't corrupt earlier runs' outputs.
202 /// Returns `out[run_idx][output_idx][element_idx]`.
203 pub fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
204 self.inner.run_pipelined(input_sets)
205 }
206
207 /// Set a named parameter from raw bytes in the given dtype. The
208 /// backend handles the widen-to-f32 (or zero-widen, when supported
209 /// natively) on the way in. Lets callers feed F16/BF16 weights
210 /// without a host-side cast.
211 pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
212 self.inner.set_param_typed(name, data, dtype);
213 }
214
215 /// Execute with typed inputs and return outputs in their declared
216 /// graph dtype, byte-encoded. Mirrors the wgpu / MLX zero-widen
217 /// semantics on f32-arena backends (CPU + Metal) by widening at
218 /// the boundary.
219 pub fn run_typed(
220 &mut self,
221 inputs: &[(&str, &[u8], rlx_ir::DType)],
222 ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
223 self.inner.run_typed(inputs)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::*;
230
231 #[test]
232 #[cfg(feature = "cpu")]
233 fn end_to_end_session() {
234 let mut g = Graph::new("matmul_bias_gelu");
235 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
236 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
237 let b = g.param("b", Shape::new(&[3], DType::F32));
238 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
239 let add = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
240 let out = g.activation(op::Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
241 g.set_outputs(vec![out]);
242
243 // Compile
244 let session = Session::new(Device::Cpu);
245 let mut compiled = session.compile(g);
246
247 // Set weights
248 // w = identity-ish [4, 3]: first 3 rows are I, last row is 0
249 compiled.set_param(
250 "w",
251 &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
252 );
253 compiled.set_param("b", &[0.5, -0.5, 0.0]);
254
255 // Run
256 let x_data = vec![
257 1.0, 0.0, 0.0, 0.0, // row 0: [1,0,0,0] @ w = [1,0,0] + bias = [1.5,-0.5,0]
258 0.0, 1.0, 0.0, 0.0, // row 1: [0,1,0,0] @ w = [0,1,0] + bias = [0.5, 0.5,0]
259 ];
260 let outputs = compiled.run(&[("x", &x_data)]);
261
262 assert_eq!(outputs.len(), 1);
263 let result = &outputs[0];
264 assert_eq!(result.len(), 6); // [2, 3]
265
266 // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
267 assert!(
268 (result[0] - 1.399).abs() < 0.01,
269 "gelu(1.5) = {}",
270 result[0]
271 );
272 assert!(
273 (result[1] - -0.154).abs() < 0.01,
274 "gelu(-0.5) = {}",
275 result[1]
276 );
277 assert!((result[2]).abs() < 0.01, "gelu(0) = {}", result[2]);
278
279 // gelu(0.5) ≈ 0.346, gelu(0.5) ≈ 0.346, gelu(0) = 0
280 assert!(
281 (result[3] - 0.346).abs() < 0.01,
282 "gelu(0.5) = {}",
283 result[3]
284 );
285 assert!(
286 (result[4] - 0.346).abs() < 0.01,
287 "gelu(0.5) = {}",
288 result[4]
289 );
290
291 // Run again with different input — zero allocation
292 let x2 = vec![0.0f32; 8];
293 let outputs2 = compiled.run(&[("x", &x2)]);
294 // All zeros input → gelu(bias) for each output
295 let r2 = &outputs2[0];
296 assert!((r2[0] - 0.346).abs() < 0.01, "gelu(0.5) = {}", r2[0]); // gelu(0+0.5)
297 }
298
299 #[test]
300 #[cfg(feature = "cpu")]
301 fn device_display() {
302 use crate::device_ext::is_available;
303 assert!(format!("{}", Device::Cpu).starts_with("CPU"));
304 assert!(is_available(Device::Cpu));
305 // Backend availability is feature-gated; only assert
306 // unavailable when the corresponding feature is off.
307 #[cfg(not(feature = "gpu"))]
308 assert!(!is_available(Device::Gpu));
309 #[cfg(not(feature = "cuda"))]
310 assert!(!is_available(Device::Cuda));
311 #[cfg(not(feature = "rocm"))]
312 assert!(!is_available(Device::Rocm));
313 #[cfg(not(feature = "tpu"))]
314 assert!(!is_available(Device::Tpu));
315 }
316}