Skip to main content

rlx_ir/
phase.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Streaming inference phases attached to LIR nodes (plan #16 / #28).
17
18use std::collections::HashMap;
19
20use crate::{Graph, NodeId, Op};
21
22/// Where in a streaming forward pass a node belongs.
23#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum Phase {
26    /// One-time setup: embedding lookup, KV init, prompt prefill.
27    Prologue,
28    /// Per-token decode loop body.
29    SteadyState,
30    /// Final projection, sampling, detokenization.
31    Epilogue,
32}
33
34impl Phase {
35    pub fn order(self) -> u8 {
36        match self {
37            Self::Prologue => 0,
38            Self::SteadyState => 1,
39            Self::Epilogue => 2,
40        }
41    }
42}
43
44/// Per-node phase assignment for a streaming graph.
45#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
46#[derive(Debug, Clone, Default, PartialEq, Eq)]
47pub struct PhaseSchedule {
48    map: HashMap<NodeId, Phase>,
49}
50
51impl PhaseSchedule {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn set(&mut self, node: NodeId, phase: Phase) {
57        self.map.insert(node, phase);
58    }
59
60    pub fn get(&self, node: NodeId) -> Option<Phase> {
61        self.map.get(&node).copied()
62    }
63
64    pub fn iter(&self) -> impl Iterator<Item = (NodeId, Phase)> + '_ {
65        self.map.iter().map(|(&id, &p)| (id, p))
66    }
67
68    pub fn len(&self) -> usize {
69        self.map.len()
70    }
71
72    pub fn is_empty(&self) -> bool {
73        self.map.is_empty()
74    }
75
76    /// Nodes in a given phase, in schedule order when `schedule` is provided,
77    /// otherwise sorted by [`NodeId`].
78    pub fn nodes_in(&self, phase: Phase) -> Vec<NodeId> {
79        self.nodes_in_ordered(phase, None)
80    }
81
82    pub fn nodes_in_ordered(&self, phase: Phase, schedule: Option<&[NodeId]>) -> Vec<NodeId> {
83        if let Some(order) = schedule {
84            return order
85                .iter()
86                .copied()
87                .filter(|id| self.get(*id) == Some(phase))
88                .collect();
89        }
90        let mut v: Vec<NodeId> = self
91            .map
92            .iter()
93            .filter_map(|(&id, &p)| if p == phase { Some(id) } else { None })
94            .collect();
95        v.sort();
96        v
97    }
98}
99
100/// Heuristic phase classifier for optimized MIR graphs.
101pub fn derive_phases(graph: &Graph) -> PhaseSchedule {
102    let mut sched = PhaseSchedule::new();
103    let n = graph.len();
104    if n == 0 {
105        return sched;
106    }
107
108    let mut last_compute_step: Option<usize> = None;
109    let mut last_sample_step: Option<usize> = None;
110    for (step, node) in graph.nodes().iter().enumerate() {
111        match &node.op {
112            Op::Sample { .. } | Op::TopK { .. } | Op::RngNormal { .. } | Op::RngUniform { .. } => {
113                last_sample_step = Some(step);
114            }
115            Op::MatMul
116            | Op::FusedMatMulBiasAct { .. }
117            | Op::Attention { .. }
118            | Op::FusedAttentionBlock { .. }
119            | Op::FusedTransformerLayer { .. }
120            | Op::DotGeneral { .. }
121            | Op::GroupedMatMul
122            | Op::DequantGroupedMatMul { .. }
123            | Op::DequantMoEWeights { .. }
124            | Op::LoraMatMul { .. }
125            | Op::DequantMatMul { .. }
126            | Op::GatedDeltaNet { .. }
127            | Op::Lstm { .. }
128            | Op::Gru { .. }
129            | Op::Rnn { .. }
130            | Op::Mamba2 { .. } => {
131                last_compute_step = Some(step);
132            }
133            _ => {}
134        }
135    }
136
137    for (step, node) in graph.nodes().iter().enumerate() {
138        let phase = match &node.op {
139            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Phase::Prologue,
140            Op::Sample { .. } | Op::TopK { .. } | Op::RngNormal { .. } | Op::RngUniform { .. } => {
141                Phase::Epilogue
142            }
143            _ => {
144                if let Some(last) = last_sample_step {
145                    if step > last
146                        || (last_compute_step.is_some() && Some(step) > last_compute_step)
147                    {
148                        Phase::Epilogue
149                    } else {
150                        Phase::SteadyState
151                    }
152                } else if let Some(last) = last_compute_step {
153                    if step > last {
154                        Phase::Epilogue
155                    } else {
156                        Phase::SteadyState
157                    }
158                } else {
159                    Phase::SteadyState
160                }
161            }
162        };
163        sched.set(node.id, phase);
164    }
165    sched
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::{DType, Shape};
172
173    #[test]
174    fn derive_phases_classifies_typical_graph() {
175        let f = DType::F32;
176        let mut g = Graph::new("derive");
177        let x = g.input("x", Shape::new(&[1, 8], f));
178        let w = g.param("w", Shape::new(&[8, 4], f));
179        let mm = g.matmul(x, w, Shape::new(&[1, 4], f));
180        let s = g.sample(mm, 0, 1.0, 1.0, 0, Shape::new(&[1], f));
181        g.set_outputs(vec![s]);
182
183        let sched = derive_phases(&g);
184        assert_eq!(sched.get(x), Some(Phase::Prologue));
185        assert_eq!(sched.get(w), Some(Phase::Prologue));
186        assert_eq!(sched.get(mm), Some(Phase::SteadyState));
187        assert_eq!(sched.get(s), Some(Phase::Epilogue));
188    }
189
190    #[test]
191    fn phase_ordering_is_deterministic() {
192        assert!(Phase::Prologue.order() < Phase::SteadyState.order());
193        assert!(Phase::SteadyState.order() < Phase::Epilogue.order());
194    }
195}