Skip to main content

rlx_ir/
phase.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Streaming inference phases attached to LIR nodes (plan #16 / #28).
17
18use std::collections::HashMap;
19
20use crate::{Graph, NodeId, Op};
21
22/// Where in a streaming forward pass a node belongs.
23#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum Phase {
26    /// One-time setup: embedding lookup, KV init, prompt prefill.
27    Prologue,
28    /// Per-token decode loop body.
29    SteadyState,
30    /// Final projection, sampling, detokenization.
31    Epilogue,
32}
33
34impl Phase {
35    pub fn order(self) -> u8 {
36        match self {
37            Self::Prologue => 0,
38            Self::SteadyState => 1,
39            Self::Epilogue => 2,
40        }
41    }
42}
43
44/// Per-node phase assignment for a streaming graph.
45#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
46#[derive(Debug, Clone, Default, PartialEq, Eq)]
47pub struct PhaseSchedule {
48    map: HashMap<NodeId, Phase>,
49}
50
51impl PhaseSchedule {
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    pub fn set(&mut self, node: NodeId, phase: Phase) {
57        self.map.insert(node, phase);
58    }
59
60    pub fn get(&self, node: NodeId) -> Option<Phase> {
61        self.map.get(&node).copied()
62    }
63
64    pub fn iter(&self) -> impl Iterator<Item = (NodeId, Phase)> + '_ {
65        self.map.iter().map(|(&id, &p)| (id, p))
66    }
67
68    pub fn len(&self) -> usize {
69        self.map.len()
70    }
71
72    pub fn is_empty(&self) -> bool {
73        self.map.is_empty()
74    }
75
76    /// Nodes in a given phase, in schedule order when `schedule` is provided,
77    /// otherwise sorted by [`NodeId`].
78    pub fn nodes_in(&self, phase: Phase) -> Vec<NodeId> {
79        self.nodes_in_ordered(phase, None)
80    }
81
82    pub fn nodes_in_ordered(&self, phase: Phase, schedule: Option<&[NodeId]>) -> Vec<NodeId> {
83        if let Some(order) = schedule {
84            return order
85                .iter()
86                .copied()
87                .filter(|id| self.get(*id) == Some(phase))
88                .collect();
89        }
90        let mut v: Vec<NodeId> = self
91            .map
92            .iter()
93            .filter_map(|(&id, &p)| if p == phase { Some(id) } else { None })
94            .collect();
95        v.sort();
96        v
97    }
98}
99
100/// Heuristic phase classifier for optimized MIR graphs.
101pub fn derive_phases(graph: &Graph) -> PhaseSchedule {
102    let mut sched = PhaseSchedule::new();
103    let n = graph.len();
104    if n == 0 {
105        return sched;
106    }
107
108    let mut last_compute_step: Option<usize> = None;
109    let mut last_sample_step: Option<usize> = None;
110    for (step, node) in graph.nodes().iter().enumerate() {
111        match &node.op {
112            Op::Sample { .. } | Op::TopK { .. } => {
113                last_sample_step = Some(step);
114            }
115            Op::MatMul
116            | Op::FusedMatMulBiasAct { .. }
117            | Op::Attention { .. }
118            | Op::FusedAttentionBlock { .. }
119            | Op::FusedTransformerLayer { .. }
120            | Op::DotGeneral { .. }
121            | Op::GroupedMatMul
122            | Op::DequantGroupedMatMul { .. }
123            | Op::DequantMoEWeights { .. }
124            | Op::LoraMatMul { .. }
125            | Op::DequantMatMul { .. }
126            | Op::GatedDeltaNet { .. } => {
127                last_compute_step = Some(step);
128            }
129            _ => {}
130        }
131    }
132
133    for (step, node) in graph.nodes().iter().enumerate() {
134        let phase = match &node.op {
135            Op::Input { .. } | Op::Param { .. } | Op::Constant { .. } => Phase::Prologue,
136            Op::Sample { .. } | Op::TopK { .. } => Phase::Epilogue,
137            _ => {
138                if let Some(last) = last_sample_step {
139                    if step > last
140                        || (last_compute_step.is_some() && Some(step) > last_compute_step)
141                    {
142                        Phase::Epilogue
143                    } else {
144                        Phase::SteadyState
145                    }
146                } else if let Some(last) = last_compute_step {
147                    if step > last {
148                        Phase::Epilogue
149                    } else {
150                        Phase::SteadyState
151                    }
152                } else {
153                    Phase::SteadyState
154                }
155            }
156        };
157        sched.set(node.id, phase);
158    }
159    sched
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::{DType, Shape};
166
167    #[test]
168    fn derive_phases_classifies_typical_graph() {
169        let f = DType::F32;
170        let mut g = Graph::new("derive");
171        let x = g.input("x", Shape::new(&[1, 8], f));
172        let w = g.param("w", Shape::new(&[8, 4], f));
173        let mm = g.matmul(x, w, Shape::new(&[1, 4], f));
174        let s = g.sample(mm, 0, 1.0, 1.0, 0, Shape::new(&[1], f));
175        g.set_outputs(vec![s]);
176
177        let sched = derive_phases(&g);
178        assert_eq!(sched.get(x), Some(Phase::Prologue));
179        assert_eq!(sched.get(w), Some(Phase::Prologue));
180        assert_eq!(sched.get(mm), Some(Phase::SteadyState));
181        assert_eq!(sched.get(s), Some(Phase::Epilogue));
182    }
183
184    #[test]
185    fn phase_ordering_is_deterministic() {
186        assert!(Phase::Prologue.order() < Phase::SteadyState.order());
187        assert!(Phase::SteadyState.order() < Phase::Epilogue.order());
188    }
189}