rlx_runtime/subgraph.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Sub-graph execution helper.
17//!
18//! Op::If branches and Op::While body/cond are sub-graphs nested inside
19//! a parent graph. To execute them, the backend recursively compiles
20//! and runs the inner graph with bound inputs.
21//!
22//! Strategy: compile the sub-graph lazily on first encounter, cache the
23//! `ExecutableGraph` for repeated invocations (loops). A future
24//! optimization: hoist the compile to the parent's compile-time once
25//! we have a stable IR for sub-graphs.
26
27use crate::CompileOptions;
28use crate::backend::{Backend, ExecutableGraph};
29use rlx_ir::Graph;
30use std::collections::HashMap;
31
32/// Lazily-compiled sub-graph cache.
33/// Keyed by sub-graph name (caller must ensure names are unique within
34/// the parent graph). Backend-agnostic: stores boxed ExecutableGraphs.
35pub struct SubgraphCache {
36 cache: HashMap<String, Box<dyn ExecutableGraph>>,
37 options: CompileOptions,
38}
39
40impl SubgraphCache {
41 pub fn new(options: CompileOptions) -> Self {
42 Self {
43 cache: HashMap::new(),
44 options,
45 }
46 }
47
48 /// Compile a sub-graph if not cached, return mutable executable handle.
49 pub fn get_or_compile<'a>(
50 &'a mut self,
51 backend: &dyn Backend,
52 graph: &Graph,
53 ) -> &'a mut Box<dyn ExecutableGraph> {
54 let key = graph.name.clone();
55 self.cache
56 .entry(key)
57 .or_insert_with(|| backend.compile(graph.clone(), &self.options))
58 }
59
60 /// Run a sub-graph with named inputs, returning its outputs.
61 pub fn run(
62 &mut self,
63 backend: &dyn Backend,
64 graph: &Graph,
65 inputs: &[(&str, &[f32])],
66 ) -> Vec<Vec<f32>> {
67 let exe = self.get_or_compile(backend, graph);
68 exe.run(inputs)
69 }
70}
71
72/// Helper: evaluate an Op::If by running one of two sub-graphs.
73pub fn run_if(
74 cache: &mut SubgraphCache,
75 backend: &dyn Backend,
76 predicate: f32,
77 then_branch: &Graph,
78 else_branch: &Graph,
79 inputs: &[(&str, &[f32])],
80) -> Vec<Vec<f32>> {
81 let chosen = if predicate != 0.0 {
82 then_branch
83 } else {
84 else_branch
85 };
86 cache.run(backend, chosen, inputs)
87}
88
89/// Helper: evaluate an Op::While by repeatedly running cond + body.
90/// `loop_carried` are the values flowing through iterations.
91pub fn run_while(
92 cache: &mut SubgraphCache,
93 backend: &dyn Backend,
94 cond: &Graph,
95 body: &Graph,
96 initial: Vec<Vec<f32>>,
97 input_names: &[&str],
98 max_iterations: Option<usize>,
99) -> Vec<Vec<f32>> {
100 let mut state = initial;
101 let limit = max_iterations.unwrap_or(usize::MAX);
102 for _ in 0..limit {
103 // Build named-input slice for cond + body
104 let bindings: Vec<(&str, &[f32])> = input_names
105 .iter()
106 .zip(state.iter())
107 .map(|(n, v)| (*n, v.as_slice()))
108 .collect();
109 let cond_out = cache.run(backend, cond, &bindings);
110 // Cond is a scalar bool: stop if it's zero / false
111 if cond_out
112 .first()
113 .map(|v| v.first().copied().unwrap_or(0.0))
114 .unwrap_or(0.0)
115 == 0.0
116 {
117 break;
118 }
119 state = cache.run(backend, body, &bindings);
120 }
121 state
122}