Skip to main content

rlx_ir/
lir.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! **LIR** — low-level IR.
17//!
18//! Optimized MIR plus a concrete execution plan: arena layout, topo
19//! schedule, view aliases, streaming phases, and an I/O manifest.
20//! Backends lower LIR to device thunks/kernels without re-running the
21//! optimizer or memory planner when the embedded plan is still valid.
22
23use std::collections::HashMap;
24use std::collections::hash_map::DefaultHasher;
25use std::hash::{Hash, Hasher};
26
27use crate::mir::MirModule;
28use crate::phase::{Phase, PhaseSchedule};
29use crate::{Graph, NodeId, Op};
30
31/// A buffer slot in the arena.
32#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct LirBufferSlot {
35    pub offset: usize,
36    pub size: usize,
37}
38
39/// A view node that aliases part of a root buffer (no separate allocation).
40#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct LirViewAlias {
43    pub root: NodeId,
44    pub byte_offset: usize,
45}
46
47/// Named graph boundaries — stable handles for runtime I/O wiring.
48#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct LirIoManifest {
51    pub inputs: Vec<(String, NodeId)>,
52    pub params: Vec<(String, NodeId)>,
53    pub outputs: Vec<NodeId>,
54}
55
56impl LirIoManifest {
57    pub fn collect(graph: &Graph) -> Self {
58        let mut inputs = Vec::new();
59        let mut params = Vec::new();
60        for node in graph.nodes() {
61            match &node.op {
62                Op::Input { name } => inputs.push((name.clone(), node.id)),
63                Op::Param { name } => params.push((name.clone(), node.id)),
64                _ => {}
65            }
66        }
67        Self {
68            inputs,
69            params,
70            outputs: graph.outputs.clone(),
71        }
72    }
73}
74
75/// Liveness-aware buffer assignment + execution metadata.
76#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct LirBufferPlan {
79    pub arena_size: usize,
80    pub assignments: HashMap<NodeId, LirBufferSlot>,
81    /// Topological execution order (node ids).
82    pub schedule: Vec<NodeId>,
83    /// Pure-view nodes (`Reshape`, identity `Cast`, axis-0 `Narrow`).
84    pub view_aliases: HashMap<NodeId, LirViewAlias>,
85    /// Streaming inference phase per node.
86    pub phases: PhaseSchedule,
87    /// Input / param / output node manifest.
88    pub io: LirIoManifest,
89    /// Arena alignment used when planning (bytes).
90    pub alignment: usize,
91    /// Dynamic symbols referenced by the graph at plan time.
92    pub dynamic_symbols: Vec<u32>,
93}
94
95impl Default for LirBufferPlan {
96    fn default() -> Self {
97        Self {
98            arena_size: 0,
99            assignments: HashMap::new(),
100            schedule: Vec::new(),
101            view_aliases: HashMap::new(),
102            phases: PhaseSchedule::new(),
103            io: LirIoManifest::default(),
104            alignment: 64,
105            dynamic_symbols: Vec::new(),
106        }
107    }
108}
109
110impl LirBufferPlan {
111    pub fn total_unshared_bytes(&self) -> usize {
112        self.assignments.values().map(|s| s.size).sum()
113    }
114
115    pub fn bytes_saved(&self) -> usize {
116        self.total_unshared_bytes().saturating_sub(self.arena_size)
117    }
118
119    pub fn slot(&self, id: NodeId) -> Option<&LirBufferSlot> {
120        self.assignments.get(&id)
121    }
122
123    pub fn is_view(&self, id: NodeId) -> bool {
124        self.view_aliases.contains_key(&id)
125    }
126
127    pub fn phase_of(&self, id: NodeId) -> Option<Phase> {
128        self.phases.get(id)
129    }
130
131    pub fn nodes_in_phase(&self, phase: Phase) -> Vec<NodeId> {
132        self.phases.nodes_in_ordered(phase, Some(&self.schedule))
133    }
134}
135
136/// Stable compile fingerprint for AOT cache keys.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub struct LirFingerprint(pub u64);
139
140impl LirFingerprint {
141    pub fn of(module: &LirModule) -> Self {
142        let mut h = DefaultHasher::new();
143        module.mir.name().hash(&mut h);
144        module.mir.len().hash(&mut h);
145        for node in module.mir.as_graph().nodes() {
146            node.id.0.hash(&mut h);
147            format!("{}", node.op).hash(&mut h);
148            node.shape.hash(&mut h);
149            node.inputs.len().hash(&mut h);
150            for inp in &node.inputs {
151                inp.0.hash(&mut h);
152            }
153        }
154        for out in module.mir.as_graph().outputs.iter() {
155            out.0.hash(&mut h);
156        }
157        module.buffers.arena_size.hash(&mut h);
158        module.buffers.schedule.len().hash(&mut h);
159        module.buffers.alignment.hash(&mut h);
160        module.buffers.view_aliases.len().hash(&mut h);
161        Self(h.finish())
162    }
163}
164
165/// Low-level module — backend compile input after optimization.
166#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
167#[derive(Debug, Clone, PartialEq)]
168pub struct LirModule {
169    pub mir: MirModule,
170    pub buffers: LirBufferPlan,
171}
172
173impl LirModule {
174    pub fn new(mir: MirModule, buffers: LirBufferPlan) -> Self {
175        Self { mir, buffers }
176    }
177
178    pub fn name(&self) -> &str {
179        self.mir.name()
180    }
181
182    pub fn arena_size(&self) -> usize {
183        self.buffers.arena_size
184    }
185
186    pub fn fingerprint(&self) -> LirFingerprint {
187        LirFingerprint::of(self)
188    }
189
190    pub fn plan(&self) -> &LirBufferPlan {
191        &self.buffers
192    }
193
194    /// Extract the optimized MIR graph for legacy backend entry points.
195    pub fn into_graph(self) -> Graph {
196        self.mir.into_graph()
197    }
198
199    pub fn as_graph(&self) -> &Graph {
200        self.mir.as_graph()
201    }
202
203    pub fn has_dynamic_dims(&self) -> bool {
204        crate::dynamic::has_dynamic_dims(self.as_graph())
205    }
206
207    pub fn is_fully_static(&self) -> bool {
208        !self.has_dynamic_dims()
209    }
210
211    pub fn dynamic_symbols(&self) -> &[u32] {
212        &self.buffers.dynamic_symbols
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::{DType, Shape};
220
221    fn f32_shape(d: &[usize]) -> Shape {
222        Shape::new(d, DType::F32)
223    }
224
225    #[test]
226    fn io_manifest_collects_boundaries() {
227        let mut g = Graph::new("m");
228        let x = g.input("x", f32_shape(&[4]));
229        let w = g.param("w", f32_shape(&[4, 4]));
230        let y = g.matmul(x, w, f32_shape(&[4, 4]));
231        g.set_outputs(vec![y]);
232
233        let io = LirIoManifest::collect(&g);
234        assert_eq!(io.inputs, vec![("x".into(), x)]);
235        assert_eq!(io.params, vec![("w".into(), w)]);
236        assert_eq!(io.outputs, vec![y]);
237    }
238
239    #[test]
240    fn fingerprint_is_stable() {
241        let mut g = Graph::new("m");
242        let x = g.input("x", f32_shape(&[2]));
243        g.set_outputs(vec![x]);
244        let mir = MirModule::from_graph(g);
245        let plan = LirBufferPlan {
246            arena_size: 8,
247            assignments: [(x, LirBufferSlot { offset: 0, size: 8 })]
248                .into_iter()
249                .collect(),
250            schedule: vec![x],
251            io: LirIoManifest {
252                inputs: vec![("x".into(), x)],
253                ..Default::default()
254            },
255            ..Default::default()
256        };
257        let lir = LirModule::new(mir, plan);
258        assert_eq!(lir.fingerprint(), lir.fingerprint());
259    }
260}