Skip to main content

rlx_runtime/
compiled.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Compiled graph — the hot-path execution object.
17
18use crate::backend::ExecutableGraph;
19use rlx_driver::Device;
20
21/// A compiled graph ready for execution.
22///
23/// Created by [`crate::Session::compile`]. Holds the fused + memory-planned
24/// graph and all pre-allocated execution state. Call
25/// [`CompiledGraph::run`] repeatedly with different inputs — zero
26/// allocation per call.
27pub struct CompiledGraph {
28    inner: Box<dyn ExecutableGraph>,
29    device: Device,
30}
31
32impl Clone for CompiledGraph {
33    /// Deep-clones the underlying executable via `ExecutableGraph::clone_box`.
34    /// Backends that don't support cloning will panic at this point.
35    fn clone(&self) -> Self {
36        Self {
37            inner: self.inner.clone_box(),
38            device: self.device,
39        }
40    }
41}
42
43impl CompiledGraph {
44    pub(crate) fn new(inner: Box<dyn ExecutableGraph>, device: Device) -> Self {
45        Self { inner, device }
46    }
47
48    /// Which device this graph runs on.
49    pub fn device(&self) -> Device {
50        self.device
51    }
52
53    /// Set a named parameter (model weight).
54    /// Call once per parameter after compilation.
55    pub fn set_param(&mut self, name: &str, data: &[f32]) {
56        self.inner.set_param(name, data);
57    }
58
59    /// Execute the graph with named inputs.
60    /// Returns one `Vec<f32>` per graph output (copies from arena).
61    pub fn run(&mut self, inputs: &[(&str, &[f32])]) -> Vec<Vec<f32>> {
62        self.inner.run(inputs)
63    }
64
65    /// Execute and return raw pointers to output data (zero-copy).
66    /// Data is valid until the next `run`/`run_raw` call.
67    ///
68    /// # Safety
69    /// The returned pointers point into the arena. Do not use after
70    /// the next call to run/run_raw (arena data will be overwritten).
71    pub fn run_raw(&mut self, inputs: &[(&str, &[f32])]) -> Vec<(*const f32, usize)> {
72        self.inner.run_raw(inputs)
73    }
74
75    /// Fastest execution: inputs by slot index (order matches graph input declaration).
76    /// Returns output (offset, len) pairs. Read data via `arena_ptr().add(offset)`.
77    /// Zero HashMap lookup, zero Vec allocation, zero name matching.
78    pub fn run_slots(&mut self, inputs: &[&[f32]]) -> &[(usize, usize)] {
79        self.inner.run_slots(inputs)
80    }
81
82    /// Arena pointer for reading output data after `run_slots`.
83    pub fn arena_ptr(&self) -> *const u8 {
84        self.inner.arena_ptr()
85    }
86
87    /// Bind a persistent buffer (KV-cache, optimizer state, etc.).
88    /// Stays alive across `run()` calls; the backend uses it as the
89    /// graph input with the matching name.
90    /// Returns true if the backend supports persistent handles.
91    pub fn bind_handle(&mut self, name: &str, data: &[f32]) -> bool {
92        self.inner.bind_handle(name, data)
93    }
94
95    /// Read the current contents of a persistent buffer.
96    pub fn read_handle(&self, name: &str) -> Option<Vec<f32>> {
97        self.inner.read_handle(name)
98    }
99
100    /// GPU-resident MLX input (no-op on non-MLX backends).
101    pub fn bind_gpu_handle(&mut self, name: &str, data: &[f32]) -> bool {
102        self.inner.bind_gpu_handle(name, data)
103    }
104
105    pub fn has_gpu_handle(&self, name: &str) -> bool {
106        self.inner.has_gpu_handle(name)
107    }
108
109    pub fn set_gpu_handle_feed(&mut self, handle_name: &str, output_index: usize) -> bool {
110        self.inner.set_gpu_handle_feed(handle_name, output_index)
111    }
112
113    pub fn read_gpu_handle(&self, name: &str) -> Option<Vec<f32>> {
114        self.inner.read_gpu_handle(name)
115    }
116
117    /// Run, refresh GPU handle from output, return that output vector.
118    pub fn run_feed_gpu_handle(
119        &mut self,
120        inputs: &[(&str, &[f32])],
121        handle_name: &str,
122        output_index: usize,
123    ) -> Option<Vec<f32>> {
124        self.inner
125            .run_feed_gpu_handle(inputs, handle_name, output_index)
126    }
127
128    /// Hint subsequent `run` calls to process only the first `actual`
129    /// rows along the bucket axis (out of `upper`, the compile extent).
130    /// Backends that support per-kernel active-extent dispatch honor
131    /// this; others ignore it. Pass `None` to clear.
132    ///
133    /// See `BucketedCompileCache::run_padded` for the canonical caller.
134    pub fn set_active_extent(&mut self, extent: Option<(usize, usize)>) {
135        self.inner.set_active_extent(extent);
136    }
137
138    /// TIDE merged MoE placement (`mask[expert]` device-resident if any layer has it).
139    pub fn set_moe_resident_experts(&mut self, mask: &[bool]) {
140        self.inner.set_moe_resident_experts(mask);
141    }
142
143    /// Per MoE layer placement (forward order). Preferred on CPU over merged mask.
144    pub fn set_moe_resident_experts_per_layer(&mut self, masks: &[&[bool]]) {
145        self.inner.set_moe_resident_experts_per_layer(masks);
146    }
147
148    /// Capture MoE router TopK on next forward (CPU). Returns false if unsupported.
149    pub fn enable_moe_topk_capture(&mut self, num_experts: usize) -> bool {
150        self.inner.enable_moe_topk_capture(num_experts)
151    }
152
153    /// Per-layer expert indices from the last forward (MoE router TopK order).
154    pub fn take_moe_topk_capture(&mut self) -> Option<Vec<Vec<u32>>> {
155        self.inner.take_moe_topk_capture()
156    }
157
158    /// GroupedMatMul GPU/CPU token accounting from the last forward (CPU).
159    pub fn take_moe_residency_stats(&mut self) -> Option<crate::MoeResidencyStats> {
160        self.inner.take_moe_residency_stats()
161    }
162
163    // ── Pipelined / async execution (Phase C) ─────────────────────────
164
165    /// Encode + commit a forward pass without waiting for the device.
166    ///
167    /// Outputs of intermediate calls are stomped — use `run_pipelined`
168    /// when you need each call's outputs back. Pair with `sync_pending`
169    /// to drain. CPU is synchronous, so this falls back to `run`.
170    pub fn commit_no_wait(&mut self, inputs: &[(&str, &[f32])]) {
171        self.inner.commit_no_wait(inputs);
172    }
173
174    /// Wait for every command queued by `commit_no_wait`. CPU is a no-op.
175    pub fn sync_pending(&mut self) {
176        self.inner.sync_pending();
177    }
178
179    /// Pipelined batch run. Issues one commit per input set, syncs once
180    /// at the end. On Metal, each commit gets its own output snapshot
181    /// (allocated + blit-copied), so subsequent commits stomping the
182    /// shared arena don't corrupt earlier runs' outputs.
183    /// Returns `out[run_idx][output_idx][element_idx]`.
184    pub fn run_pipelined(&mut self, input_sets: &[Vec<(&str, &[f32])>]) -> Vec<Vec<Vec<f32>>> {
185        self.inner.run_pipelined(input_sets)
186    }
187
188    /// Set a named parameter from raw bytes in the given dtype. The
189    /// backend handles the widen-to-f32 (or zero-widen, when supported
190    /// natively) on the way in. Lets callers feed F16/BF16 weights
191    /// without a host-side cast.
192    pub fn set_param_typed(&mut self, name: &str, data: &[u8], dtype: rlx_ir::DType) {
193        self.inner.set_param_typed(name, data, dtype);
194    }
195
196    /// Execute with typed inputs and return outputs in their declared
197    /// graph dtype, byte-encoded. Mirrors the wgpu / MLX zero-widen
198    /// semantics on f32-arena backends (CPU + Metal) by widening at
199    /// the boundary.
200    pub fn run_typed(
201        &mut self,
202        inputs: &[(&str, &[u8], rlx_ir::DType)],
203    ) -> Vec<(Vec<u8>, rlx_ir::DType)> {
204        self.inner.run_typed(inputs)
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use crate::*;
211
212    #[test]
213    #[cfg(feature = "cpu")]
214    fn end_to_end_session() {
215        let mut g = Graph::new("matmul_bias_gelu");
216        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
217        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
218        let b = g.param("b", Shape::new(&[3], DType::F32));
219        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
220        let add = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[2, 3], DType::F32));
221        let out = g.activation(op::Activation::Gelu, add, Shape::new(&[2, 3], DType::F32));
222        g.set_outputs(vec![out]);
223
224        // Compile
225        let session = Session::new(Device::Cpu);
226        let mut compiled = session.compile(g);
227
228        // Set weights
229        // w = identity-ish [4, 3]: first 3 rows are I, last row is 0
230        compiled.set_param(
231            "w",
232            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
233        );
234        compiled.set_param("b", &[0.5, -0.5, 0.0]);
235
236        // Run
237        let x_data = vec![
238            1.0, 0.0, 0.0, 0.0, // row 0: [1,0,0,0] @ w = [1,0,0] + bias = [1.5,-0.5,0]
239            0.0, 1.0, 0.0, 0.0, // row 1: [0,1,0,0] @ w = [0,1,0] + bias = [0.5, 0.5,0]
240        ];
241        let outputs = compiled.run(&[("x", &x_data)]);
242
243        assert_eq!(outputs.len(), 1);
244        let result = &outputs[0];
245        assert_eq!(result.len(), 6); // [2, 3]
246
247        // gelu(1.5) ≈ 1.399, gelu(-0.5) ≈ -0.154, gelu(0) = 0
248        assert!(
249            (result[0] - 1.399).abs() < 0.01,
250            "gelu(1.5) = {}",
251            result[0]
252        );
253        assert!(
254            (result[1] - -0.154).abs() < 0.01,
255            "gelu(-0.5) = {}",
256            result[1]
257        );
258        assert!((result[2]).abs() < 0.01, "gelu(0) = {}", result[2]);
259
260        // gelu(0.5) ≈ 0.346, gelu(0.5) ≈ 0.346, gelu(0) = 0
261        assert!(
262            (result[3] - 0.346).abs() < 0.01,
263            "gelu(0.5) = {}",
264            result[3]
265        );
266        assert!(
267            (result[4] - 0.346).abs() < 0.01,
268            "gelu(0.5) = {}",
269            result[4]
270        );
271
272        // Run again with different input — zero allocation
273        let x2 = vec![0.0f32; 8];
274        let outputs2 = compiled.run(&[("x", &x2)]);
275        // All zeros input → gelu(bias) for each output
276        let r2 = &outputs2[0];
277        assert!((r2[0] - 0.346).abs() < 0.01, "gelu(0.5) = {}", r2[0]); // gelu(0+0.5)
278    }
279
280    #[test]
281    #[cfg(feature = "cpu")]
282    fn device_display() {
283        use crate::device_ext::is_available;
284        assert!(format!("{}", Device::Cpu).starts_with("CPU"));
285        assert!(is_available(Device::Cpu));
286        // Backend availability is feature-gated; only assert
287        // unavailable when the corresponding feature is off.
288        #[cfg(not(feature = "gpu"))]
289        assert!(!is_available(Device::Gpu));
290        #[cfg(not(feature = "cuda"))]
291        assert!(!is_available(Device::Cuda));
292        #[cfg(not(feature = "rocm"))]
293        assert!(!is_available(Device::Rocm));
294        #[cfg(not(feature = "tpu"))]
295        assert!(!is_available(Device::Tpu));
296    }
297}