Skip to main content

rlx_flow/
stage.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Flow stages — typed block assembly primitives.
17
18use std::sync::Arc;
19
20use anyhow::Result;
21
22use crate::blocks::{
23    AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
24    CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
25    GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage, GeluFfnStage,
26    GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage, LayerScaleStage,
27    LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
28    LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage, Qwen3DecoderStage,
29    RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage, RopeTablesStage,
30    SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
31};
32use crate::context::FlowCtx;
33use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
34use crate::value::FlowValue;
35/// One stage in a model flow. Model authors compose these — not HIR ops.
36#[derive(Debug, Clone)]
37pub enum FlowStage {
38    /// Token embedding lookup.
39    Embed(EmbedStage),
40    /// Precomputed RoPE sin/cos tables as params.
41    RopeTables(RopeTablesStage),
42    /// Ensure a rank-1 zero vector exists for RMSNorm beta slots.
43    ZeroBeta { name: String, len: usize },
44    /// Bind decode inputs (RoPE slice, past K/V, mask) into flow state.
45    BindDecodeInputs(BindDecodeInputsStage),
46    /// Bind or synthesize vision attention mask (all-ones).
47    AttnMask(AttnMaskStage),
48    /// KV-cache decode layer (concat past K/V, causal/custom attention).
49    LlamaDecodeLayer(LlamaDecodeLayerStage),
50    /// LLaMA-style fused prefill decoder layer (GQA + RoPE + SwiGLU).
51    LlamaDecoder(LlamaDecoderStage),
52    /// Side-output K/V projections for a decoder layer (prefill cache export).
53    LlamaKvTap(LlamaKvTapStage),
54    /// Repeat an inner stage `count` times with a per-index name prefix.
55    Repeat(RepeatStage),
56    /// Named nested scope (fusion/debug labeling).
57    Named { name: String, inner: Arc<FlowStage> },
58    /// Run stages in order; side-effect stages may leave the main tensor unchanged.
59    Sequence(Vec<FlowStage>),
60    /// Final RMSNorm before LM head.
61    RmsNorm(RmsNormStage),
62    /// Gather last token along sequence axis (dynamic prefill).
63    GatherLastToken(GatherLastTokenStage),
64    /// Causal LM head matmul.
65    LmHead(LmHeadStage),
66    /// Matmul against a loaded weight (`LinearStage`).
67    Linear(LinearStage),
68    /// Save skip connection for residual add.
69    ResidualSave(ResidualSaveStage),
70    /// Add saved skip connection.
71    ResidualAdd(ResidualAddStage),
72    /// SwiGLU feed-forward (gate/up/down).
73    SwiGlu(SwiGluStage),
74    /// Prefill self-attention (QKV + RoPE + GQA + causal mask).
75    SelfAttnPrefill(SelfAttnPrefillStage),
76    /// Gated DeltaNet scan (inputs via [`FlowState::gdn`]).
77    GdnScan(GdnScanStage),
78    /// Store active flow into a named stream.
79    StoreStream(StoreStreamStage),
80    /// Load active flow from a named stream.
81    LoadStream(LoadStreamStage),
82    /// Dual-stream transform (img/txt, …).
83    DualStream(DualStreamStage),
84    /// Tier-2 custom subgraph (see `rlx_flow::escape`).
85    Custom(CustomStage),
86    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
87    BertEncoderLayer(BertEncoderLayerStage),
88    /// NomicBERT encoder layer (fused QKV + RoPE + padding-mask + SwiGLU).
89    NomicEncoderLayer(NomicEncoderLayerStage),
90    /// Qwen3 decoder layer (QK-norm + GQA + RoPE + SwiGLU).
91    Qwen3Decoder(Qwen3DecoderStage),
92    /// Qwen3 KV-cache decode layer (concat past K/V + QK-norm + GQA).
93    Qwen3DecodeLayer(Qwen3DecodeLayerStage),
94    /// ViT fused QKV self-attention with padding mask.
95    VitSelfAttn(VitSelfAttnStage),
96    /// DINOv2 LayerScale (gamma multiply).
97    LayerScale(LayerScaleStage),
98    /// NomicVision SwiGLU FFN with intermediate LayerNorm.
99    VisionSwiGluFfn(VisionSwiGluFfnStage),
100    /// CLS token pooling `[B, seq, H]` → `[B, H]`.
101    ClsTokenPool(ClsTokenPoolStage),
102    /// LayerNorm with gamma + beta.
103    LayerNorm(LayerNormStage),
104    /// GELU feed-forward (intermediate + output dense).
105    GeluFfn(GeluFfnStage),
106    /// Gather embedding table rows from a named side input.
107    GatherFromInput(GatherFromInputStage),
108    /// Add gather-from-side-input embedding to active hidden tensor.
109    GatherAdd(GatherAddStage),
110    /// Gemma embedding scale (`sqrt(hidden)`).
111    EmbedScale(EmbedScaleStage),
112    /// Gemma RMSNorm (`1 + weight`).
113    GemmaRmsNorm(GemmaRmsNormStage),
114    /// Gemma 2 logit softcap.
115    LogitSoftcap(LogitSoftcapStage),
116    /// Static decode RoPE tables (params in flow state).
117    DecodeRopeParams(DecodeRopeParamsStage),
118    /// Gemma KV-cache decode layer.
119    GemmaDecodeLayer(GemmaDecodeLayerStage),
120    /// Gemma prefill K/V side export.
121    GemmaKvTap(GemmaKvTapStage),
122    /// Gemma GeGLU FFN.
123    GeGlu(GeGluStage),
124}
125
126impl FlowStage {
127    pub(crate) fn emit(
128        &self,
129        ctx: &mut FlowCtx<'_>,
130        input: Option<FlowValue>,
131    ) -> Result<Option<FlowValue>> {
132        match self {
133            FlowStage::Embed(s) => {
134                let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
135                s.emit(ctx, input)
136            }
137            FlowStage::RopeTables(s) => {
138                s.emit(ctx)?;
139                Ok(input)
140            }
141            FlowStage::ZeroBeta { name, len } => {
142                let id = ctx.synth_zeros(name, *len);
143                ctx.state.named.insert(name.clone(), id);
144                if ctx.state.zero_beta.is_none() {
145                    ctx.state.zero_beta = Some(id);
146                }
147                Ok(input)
148            }
149            FlowStage::BindDecodeInputs(s) => {
150                s.emit(ctx)?;
151                Ok(input)
152            }
153            FlowStage::AttnMask(s) => {
154                s.emit(ctx)?;
155                Ok(input)
156            }
157            FlowStage::LlamaDecodeLayer(s) => {
158                let input =
159                    input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
160                s.emit(ctx, input)
161            }
162            FlowStage::LlamaDecoder(s) => {
163                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
164                s.emit(ctx, input)
165            }
166            FlowStage::LlamaKvTap(s) => {
167                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
168                s.emit(ctx, input.clone())?;
169                Ok(Some(input))
170            }
171            FlowStage::Repeat(s) => s.emit(ctx, input),
172            FlowStage::Named { name, inner } => {
173                let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
174                let out = inner.emit(ctx, Some(input))?;
175                let value = out.expect("named inner stage produced no output");
176                ctx.hir().node_mut(value.id).name = Some(name.clone());
177                Ok(Some(value))
178            }
179            FlowStage::Sequence(stages) => {
180                let mut value = input;
181                for stage in stages {
182                    value = stage.emit(ctx, value)?;
183                }
184                Ok(value)
185            }
186            FlowStage::RmsNorm(s) => {
187                let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
188                s.emit(ctx, input)
189            }
190            FlowStage::GatherLastToken(s) => {
191                let input =
192                    input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
193                s.emit(ctx, input)
194            }
195            FlowStage::LmHead(s) => {
196                let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
197                s.emit(ctx, input)
198            }
199            FlowStage::Linear(s) => {
200                let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
201                s.emit(ctx, input)
202            }
203            FlowStage::ResidualSave(s) => {
204                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
205                s.emit(ctx, input)
206            }
207            FlowStage::ResidualAdd(s) => {
208                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
209                s.emit(ctx, input)
210            }
211            FlowStage::SwiGlu(s) => {
212                let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
213                s.emit(ctx, input)
214            }
215            FlowStage::SelfAttnPrefill(s) => {
216                let input =
217                    input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
218                s.emit(ctx, input)
219            }
220            FlowStage::GdnScan(s) => {
221                let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
222                s.emit(ctx, input)
223            }
224            FlowStage::StoreStream(s) => s.emit(ctx, input),
225            FlowStage::LoadStream(s) => s.emit(ctx, input),
226            FlowStage::DualStream(s) => s.emit(ctx, input),
227            FlowStage::Custom(s) => s.emit(ctx, input),
228            FlowStage::BertEncoderLayer(s) => {
229                let input =
230                    input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
231                s.emit(ctx, input)
232            }
233            FlowStage::NomicEncoderLayer(s) => {
234                let input =
235                    input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
236                s.emit(ctx, input)
237            }
238            FlowStage::Qwen3Decoder(s) => {
239                let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
240                s.emit(ctx, input)
241            }
242            FlowStage::Qwen3DecodeLayer(s) => {
243                let input =
244                    input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
245                s.emit(ctx, input)
246            }
247            FlowStage::VitSelfAttn(s) => {
248                let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
249                s.emit(ctx, input)
250            }
251            FlowStage::LayerScale(s) => {
252                let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
253                s.emit(ctx, input)
254            }
255            FlowStage::VisionSwiGluFfn(s) => {
256                let input =
257                    input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
258                s.emit(ctx, input)
259            }
260            FlowStage::ClsTokenPool(s) => {
261                let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
262                s.emit(ctx, input)
263            }
264            FlowStage::LayerNorm(s) => {
265                let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
266                s.emit(ctx, input)
267            }
268            FlowStage::GeluFfn(s) => {
269                let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
270                s.emit(ctx, input)
271            }
272            FlowStage::GatherFromInput(s) => {
273                let input =
274                    input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
275                s.emit(ctx, input)
276            }
277            FlowStage::GatherAdd(s) => {
278                let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
279                s.emit(ctx, input)
280            }
281            FlowStage::EmbedScale(s) => {
282                let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
283                s.emit(ctx, input)
284            }
285            FlowStage::GemmaRmsNorm(s) => {
286                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
287                s.emit(ctx, input)
288            }
289            FlowStage::LogitSoftcap(s) => {
290                let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
291                s.emit(ctx, input)
292            }
293            FlowStage::DecodeRopeParams(s) => {
294                s.emit(ctx)?;
295                Ok(input)
296            }
297            FlowStage::GemmaDecodeLayer(s) => {
298                let input =
299                    input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
300                s.emit(ctx, input)
301            }
302            FlowStage::GemmaKvTap(s) => {
303                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
304                s.emit(ctx, input)
305            }
306            FlowStage::GeGlu(s) => {
307                let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
308                s.emit(ctx, input)
309            }
310        }
311    }
312}