Skip to main content

rlx_runtime/
compiled.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled graph — the hot-path execution object.
17
18use crate::backend::ExecutableGraph;
19use rlx_driver::Device;
20
21/// A compiled graph ready for execution.
22///
23/// Created by [`crate::Session::compile`]. Holds the fused + memory-planned
24/// graph and all pre-allocated execution state. Call
25/// [`CompiledGraph::run`] repeatedly with different inputs — zero
26/// allocation per call.
27pub struct CompiledGraph {
28    inner: Box<dyn ExecutableGraph>,
29    device: Device,
30}
31
32impl Clone for CompiledGraph {
33    /// Deep-clones the underlying executable via `ExecutableGraph::clone_box`.
34    /// Backends that don't support cloning will panic at this point.
35    fn clone(&self) -> Self {
36        Self {
37            inner: self.inner.clone_box(),
38            device: self.device,
39        }
40    }
41}
42
43impl CompiledGraph {
44    pub(crate) fn new(inner: Box<dyn ExecutableGraph>, device: Device) -> Self {
45        Self { inner, device }
46    }
47
48    /// Which device this graph runs on.
49    pub fn device(&self) -> Device {
50        self.device
51    }
52
53    /// Set a named parameter (model weight).
54    /// Call once per parameter after compilation.
55    pub fn set_param(&mut self, name: &str, data: &[f32]) {
56        self.inner.set_param(name, data);
57    }
58
59    /// Execute the graph with named inputs.
60    /// Returns one `Vec<f32>` per graph output (copies from arena).
61    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
62        self.inner.run(inputs)
63    }
64
65    /// Run and read back only selected outputs (logits-only decode on MLX).
66    pub fn run_read_outputs(
67        &mut self,
68        inputs: &[(&str, &[f32])],
69        read_indices: Option<&[usize]>,
70    ) -> Vec<Vec<f32>> {
71        self.inner.run_read_outputs(inputs, read_indices)
72    }
73
74    /// Read one row from a row-major output tensor after a forward pass.
75    pub fn read_output_row(
76        &self,
77        out_idx: usize,
78        row: usize,
79        row_inner: usize,
80    ) -> Option<Vec<f32>> {
81        self.inner.read_output_row(out_idx, row, row_inner)
82    }
83
84    /// Execute and return raw pointers to output data (zero-copy).
85    /// Data is valid until the next `run`/`run_raw` call.
86    ///
87    /// # Safety
88    /// The returned pointers point into the arena. Do not use after
89    /// the next call to run/run_raw (arena data will be overwritten).
90    pub fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
91        self.inner.run_raw(inputs)
92    }
93
94    /// Fastest execution: inputs by slot index (order matches graph input declaration).
95    /// Returns output (offset, len) pairs. Read data via `arena_ptr().add(offset)`.
96    /// Zero HashMap lookup, zero Vec allocation, zero name matching.
97    pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
98        self.inner.run_slots(inputs)
99    }
100
101    /// Arena pointer for reading output data after `run_slots`.
102    pub fn arena_ptr(&self) -> *const u8 {
103        self.inner.arena_ptr()
104    }
105
106    /// Bind a persistent buffer (KV-cache, optimizer state, etc.).
107    /// Stays alive across `run()` calls; the backend uses it as the
108    /// graph input with the matching name.
109    /// Returns true if the backend supports persistent handles.
110    pub fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
111        self.inner.bind_handle(name, data)
112    }
113
114    /// Read the current contents of a persistent buffer.
115    pub fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
116        self.inner.read_handle(name)
117    }
118
119    /// GPU-resident MLX input (no-op on non-MLX backends).
120    pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
121        self.inner.bind_gpu_handle(name, data)
122    }
123
124    pub fn has_gpu_handle(&self, name: &str) -> bool {
125        self.inner.has_gpu_handle(name)
126    }
127
128    pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
129        self.inner.set_gpu_handle_feed(handle_name, output_index)
130    }
131
132    pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
133        self.inner.read_gpu_handle(name)
134    }
135
136    /// Run, refresh GPU handle from output, return that output vector.
137    pub fn run_feed_gpu_handle(
138        &mut self,
139        inputs: &[(&str, &[f32])],
140        handle_name: &str,
141        output_index: usize,
142    ) -> Option<Vec<f32>> {
143        self.inner
144            .run_feed_gpu_handle(inputs, handle_name, output_index)
145    }
146
147    /// Hint subsequent `run` calls to process only the first `actual`
148    /// rows along the bucket axis (out of `upper`, the compile extent).
149    /// Backends that support per-kernel active-extent dispatch honor
150    /// this; others ignore it. Pass `None` to clear.
151    ///
152    /// See `BucketedCompileCache::run_padded` for the canonical caller.
153    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
154        self.inner.set_active_extent(extent);
155    }
156
157    /// TIDE merged MoE placement (`mask[expert]` device-resident if any layer has it).
158    pub fn set_moe_resident_experts(&mut self, mask: &[bool]) {
159        self.inner.set_moe_resident_experts(mask);
160    }
161
162    /// Per MoE layer placement (forward order). Preferred on CPU over merged mask.
163    pub fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
164        self.inner.set_moe_resident_experts_per_layer(masks);
165    }
166
167    /// Capture MoE router TopK on next forward (CPU). Returns false if unsupported.
168    pub fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
169        self.inner.enable_moe_topk_capture(num_experts)
170    }
171
172    /// Per-layer expert indices from the last forward (MoE router TopK order).
173    pub fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
174        self.inner.take_moe_topk_capture()
175    }
176
177    /// GroupedMatMul GPU/CPU token accounting from the last forward (CPU).
178    pub fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
179        self.inner.take_moe_residency_stats()
180    }
181
182    // ── Pipelined / async execution (Phase C) ─────────────────────────
183
184    /// Encode + commit a forward pass without waiting for the device.
185    ///
186    /// Outputs of intermediate calls are stomped — use `run_pipelined`
187    /// when you need each call's outputs back. Pair with `sync_pending`
188    /// to drain. CPU is synchronous, so this falls back to `run`.
189    pub fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
190        self.inner.commit_no_wait(inputs);
191    }
192
193    /// Wait for every command queued by `commit_no_wait`. CPU is a no-op.
194    pub fn sync_pending(&mut self) {
195        self.inner.sync_pending();
196    }
197
198    /// Pipelined batch run. Issues one commit per input set, syncs once
199    /// at the end. On Metal, each commit gets its own output snapshot
200    /// (allocated + blit-copied), so subsequent commits stomping the
201    /// shared arena don't corrupt earlier runs' outputs.
202    /// Returns `out[run_idx][output_idx][element_idx]`.
203    pub fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
204        self.inner.run_pipelined(input_sets)
205    }
206
207    /// Set a named parameter from raw bytes in the given dtype. The
208    /// backend handles the widen-to-f32 (or zero-widen, when supported
209    /// natively) on the way in. Lets callers feed F16/BF16 weights
210    /// without a host-side cast.
211    pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
212        self.inner.set_param_typed(name, data, dtype);
213    }
214
215    /// Execute with typed inputs and return outputs in their declared
216    /// graph dtype, byte-encoded. Mirrors the wgpu / MLX zero-widen
217    /// semantics on f32-arena backends (CPU + Metal) by widening at
218    /// the boundary.
219    pub fn run_typed(
220        &mut self,
221        inputs: &[(&str, &[u8], rlx_ir::DType)],
222    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
223        self.inner.run_typed(inputs)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::*;
230
231    #[test]
232    #[cfg(feature = "cpu")]
233    fn end_to_end_session() {
234        let mut g = Graph::new("matmul_bias_gelu");
235        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
236        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
237        let b = g.param("b", Shape::new(&[3], DType::F32));
238        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
239        let add = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
240        let out = g.activation(op::Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
241        g.set_outputs(vec![out]);
242
243        // Compile
244        let session = Session::new(Device::Cpu);
245        let mut compiled = session.compile(g);
246
247        // Set weights
248        // w = identity-ish [4, 3]: first 3 rows are I, last row is 0
249        compiled.set_param(
250            "w",
251            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
252        );
253        compiled.set_param("b", &[0.5, -0.5, 0.0]);
254
255        // Run
256        let x_data = vec![
257            1.0, 0.0, 0.0, 0.0, // row 0: [1,0,0,0] @ w = [1,0,0] + bias = [1.5,-0.5,0]
258            0.0, 1.0, 0.0, 0.0, // row 1: [0,1,0,0] @ w = [0,1,0] + bias = [0.5, 0.5,0]
259        ];
260        let outputs = compiled.run(&[("x", &x_data)]);
261
262        assert_eq!(outputs.len(), 1);
263        let result = &outputs[0];
264        assert_eq!(result.len(), 6); // [2, 3]
265
266        // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
267        assert!(
268            (result[0] - 1.399).abs() < 0.01,
269            "gelu(1.5) = {}",
270            result[0]
271        );
272        assert!(
273            (result[1] - -0.154).abs() < 0.01,
274            "gelu(-0.5) = {}",
275            result[1]
276        );
277        assert!((result[2]).abs() < 0.01, "gelu(0) = {}", result[2]);
278
279        // gelu(0.5) ≈ 0.346, gelu(0.5) ≈ 0.346, gelu(0) = 0
280        assert!(
281            (result[3] - 0.346).abs() < 0.01,
282            "gelu(0.5) = {}",
283            result[3]
284        );
285        assert!(
286            (result[4] - 0.346).abs() < 0.01,
287            "gelu(0.5) = {}",
288            result[4]
289        );
290
291        // Run again with different input — zero allocation
292        let x2 = vec![0.0f32; 8];
293        let outputs2 = compiled.run(&[("x", &x2)]);
294        // All zeros input → gelu(bias) for each output
295        let r2 = &outputs2[0];
296        assert!((r2[0] - 0.346).abs() < 0.01, "gelu(0.5) = {}", r2[0]); // gelu(0+0.5)
297    }
298
299    #[test]
300    #[cfg(feature = "cpu")]
301    fn device_display() {
302        use crate::device_ext::is_available;
303        assert!(format!("{}", Device::Cpu).starts_with("CPU"));
304        assert!(is_available(Device::Cpu));
305        // Backend availability is feature-gated; only assert
306        // unavailable when the corresponding feature is off.
307        #[cfg(not(feature = "gpu"))]
308        assert!(!is_available(Device::Gpu));
309        #[cfg(not(feature = "cuda"))]
310        assert!(!is_available(Device::Cuda));
311        #[cfg(not(feature = "rocm"))]
312        assert!(!is_available(Device::Rocm));
313        #[cfg(not(feature = "tpu"))]
314        assert!(!is_available(Device::Tpu));
315    }
316}