Skip to main content

rlx_flow/
stage.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Flow stages — typed block assembly primitives.
17
18use std::sync::Arc;
19
20use anyhow::Result;
21
22use crate::blocks::{
23    AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
24    CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
25    GatherDecodeRopeStage, GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage,
26    GeluFfnStage, GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage,
27    LayerScaleStage, LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage,
28    LmHeadStage, LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage,
29    Qwen3DecoderStage, RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
30    RopeTablesStage, SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
31};
32use crate::context::FlowCtx;
33use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
34use crate::value::FlowValue;
35/// One stage in a model flow. Model authors compose these — not HIR ops.
36#[derive(Debug, Clone)]
37pub enum FlowStage {
38    /// Token embedding lookup.
39    Embed(EmbedStage),
40    /// Precomputed RoPE sin/cos tables as params.
41    RopeTables(RopeTablesStage),
42    /// Ensure a rank-1 zero vector exists for RMSNorm beta slots.
43    ZeroBeta { name: String, len: usize },
44    /// Bind decode inputs (RoPE slice, past K/V, mask) into flow state.
45    BindDecodeInputs(BindDecodeInputsStage),
46    /// Bind or synthesize vision attention mask (all-ones).
47    AttnMask(AttnMaskStage),
48    /// KV-cache decode layer (concat past K/V, causal/custom attention).
49    LlamaDecodeLayer(LlamaDecodeLayerStage),
50    /// LLaMA-style fused prefill decoder layer (GQA + RoPE + SwiGLU).
51    LlamaDecoder(LlamaDecoderStage),
52    /// Side-output K/V projections for a decoder layer (prefill cache export).
53    LlamaKvTap(LlamaKvTapStage),
54    /// Repeat an inner stage `count` times with a per-index name prefix.
55    Repeat(RepeatStage),
56    /// Named nested scope (fusion/debug labeling).
57    Named { name: String, inner: Arc<FlowStage> },
58    /// Run stages in order; side-effect stages may leave the main tensor unchanged.
59    Sequence(Vec<FlowStage>),
60    /// Final RMSNorm before LM head.
61    RmsNorm(RmsNormStage),
62    /// Gather last token along sequence axis (dynamic prefill).
63    GatherLastToken(GatherLastTokenStage),
64    /// Causal LM head matmul.
65    LmHead(LmHeadStage),
66    /// Matmul against a loaded weight (`LinearStage`).
67    Linear(LinearStage),
68    /// Save skip connection for residual add.
69    ResidualSave(ResidualSaveStage),
70    /// Add saved skip connection.
71    ResidualAdd(ResidualAddStage),
72    /// SwiGLU feed-forward (gate/up/down).
73    SwiGlu(SwiGluStage),
74    /// Prefill self-attention (QKV + RoPE + GQA + causal mask).
75    SelfAttnPrefill(SelfAttnPrefillStage),
76    /// Gated DeltaNet scan (inputs via [`FlowState::gdn`]).
77    GdnScan(GdnScanStage),
78    /// Store active flow into a named stream.
79    StoreStream(StoreStreamStage),
80    /// Load active flow from a named stream.
81    LoadStream(LoadStreamStage),
82    /// Dual-stream transform (img/txt, …).
83    DualStream(DualStreamStage),
84    /// Tier-2 custom subgraph (see `rlx_flow::escape`).
85    Custom(CustomStage),
86    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
87    BertEncoderLayer(BertEncoderLayerStage),
88    /// NomicBERT encoder layer (fused QKV + RoPE + padding-mask + SwiGLU).
89    NomicEncoderLayer(NomicEncoderLayerStage),
90    /// Qwen3 decoder layer (QK-norm + GQA + RoPE + SwiGLU).
91    Qwen3Decoder(Qwen3DecoderStage),
92    /// Qwen3 KV-cache decode layer (concat past K/V + QK-norm + GQA).
93    Qwen3DecodeLayer(Qwen3DecodeLayerStage),
94    /// ViT fused QKV self-attention with padding mask.
95    VitSelfAttn(VitSelfAttnStage),
96    /// DINOv2 LayerScale (gamma multiply).
97    LayerScale(LayerScaleStage),
98    /// NomicVision SwiGLU FFN with intermediate LayerNorm.
99    VisionSwiGluFfn(VisionSwiGluFfnStage),
100    /// CLS token pooling `[B, seq, H]` → `[B, H]`.
101    ClsTokenPool(ClsTokenPoolStage),
102    /// LayerNorm with gamma + beta.
103    LayerNorm(LayerNormStage),
104    /// GELU feed-forward (intermediate + output dense).
105    GeluFfn(GeluFfnStage),
106    /// Gather embedding table rows from a named side input.
107    GatherFromInput(GatherFromInputStage),
108    /// Add gather-from-side-input embedding to active hidden tensor.
109    GatherAdd(GatherAddStage),
110    /// Gemma embedding scale (`sqrt(hidden)`).
111    EmbedScale(EmbedScaleStage),
112    /// Gemma RMSNorm (`1 + weight`).
113    GemmaRmsNorm(GemmaRmsNormStage),
114    /// Gemma 2 logit softcap.
115    LogitSoftcap(LogitSoftcapStage),
116    /// Static decode RoPE tables (params in flow state).
117    DecodeRopeParams(DecodeRopeParamsStage),
118    /// Gather decode RoPE row from prefill-style param tables.
119    GatherDecodeRope(GatherDecodeRopeStage),
120    /// Gemma KV-cache decode layer.
121    GemmaDecodeLayer(GemmaDecodeLayerStage),
122    /// Gemma prefill K/V side export.
123    GemmaKvTap(GemmaKvTapStage),
124    /// Gemma GeGLU FFN.
125    GeGlu(GeGluStage),
126}
127
128impl FlowStage {
129    pub(crate) fn emit(
130        &self,
131        ctx: &mut FlowCtx<'_>,
132        input: Option<FlowValue>,
133    ) -> Result<Option<FlowValue>> {
134        match self {
135            FlowStage::Embed(s) => {
136                let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
137                s.emit(ctx, input)
138            }
139            FlowStage::RopeTables(s) => {
140                s.emit(ctx)?;
141                Ok(input)
142            }
143            FlowStage::ZeroBeta { name, len } => {
144                let id = ctx.synth_zeros(name, *len);
145                ctx.state.named.insert(name.clone(), id);
146                if ctx.state.zero_beta.is_none() {
147                    ctx.state.zero_beta = Some(id);
148                }
149                Ok(input)
150            }
151            FlowStage::BindDecodeInputs(s) => {
152                s.emit(ctx)?;
153                Ok(input)
154            }
155            FlowStage::AttnMask(s) => {
156                s.emit(ctx)?;
157                Ok(input)
158            }
159            FlowStage::LlamaDecodeLayer(s) => {
160                let input =
161                    input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
162                s.emit(ctx, input)
163            }
164            FlowStage::LlamaDecoder(s) => {
165                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
166                s.emit(ctx, input)
167            }
168            FlowStage::LlamaKvTap(s) => {
169                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
170                s.emit(ctx, input.clone())?;
171                Ok(Some(input))
172            }
173            FlowStage::Repeat(s) => s.emit(ctx, input),
174            FlowStage::Named { name, inner } => {
175                let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
176                let out = inner.emit(ctx, Some(input))?;
177                let value = out.expect("named inner stage produced no output");
178                ctx.hir().node_mut(value.id).name = Some(name.clone());
179                Ok(Some(value))
180            }
181            FlowStage::Sequence(stages) => {
182                let mut value = input;
183                for stage in stages {
184                    value = stage.emit(ctx, value)?;
185                }
186                Ok(value)
187            }
188            FlowStage::RmsNorm(s) => {
189                let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
190                s.emit(ctx, input)
191            }
192            FlowStage::GatherLastToken(s) => {
193                let input =
194                    input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
195                s.emit(ctx, input)
196            }
197            FlowStage::LmHead(s) => {
198                let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
199                s.emit(ctx, input)
200            }
201            FlowStage::Linear(s) => {
202                let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
203                s.emit(ctx, input)
204            }
205            FlowStage::ResidualSave(s) => {
206                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
207                s.emit(ctx, input)
208            }
209            FlowStage::ResidualAdd(s) => {
210                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
211                s.emit(ctx, input)
212            }
213            FlowStage::SwiGlu(s) => {
214                let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
215                s.emit(ctx, input)
216            }
217            FlowStage::SelfAttnPrefill(s) => {
218                let input =
219                    input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
220                s.emit(ctx, input)
221            }
222            FlowStage::GdnScan(s) => {
223                let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
224                s.emit(ctx, input)
225            }
226            FlowStage::StoreStream(s) => s.emit(ctx, input),
227            FlowStage::LoadStream(s) => s.emit(ctx, input),
228            FlowStage::DualStream(s) => s.emit(ctx, input),
229            FlowStage::Custom(s) => s.emit(ctx, input),
230            FlowStage::BertEncoderLayer(s) => {
231                let input =
232                    input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
233                s.emit(ctx, input)
234            }
235            FlowStage::NomicEncoderLayer(s) => {
236                let input =
237                    input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
238                s.emit(ctx, input)
239            }
240            FlowStage::Qwen3Decoder(s) => {
241                let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
242                s.emit(ctx, input)
243            }
244            FlowStage::Qwen3DecodeLayer(s) => {
245                let input =
246                    input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
247                s.emit(ctx, input)
248            }
249            FlowStage::VitSelfAttn(s) => {
250                let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
251                s.emit(ctx, input)
252            }
253            FlowStage::LayerScale(s) => {
254                let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
255                s.emit(ctx, input)
256            }
257            FlowStage::VisionSwiGluFfn(s) => {
258                let input =
259                    input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
260                s.emit(ctx, input)
261            }
262            FlowStage::ClsTokenPool(s) => {
263                let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
264                s.emit(ctx, input)
265            }
266            FlowStage::LayerNorm(s) => {
267                let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
268                s.emit(ctx, input)
269            }
270            FlowStage::GeluFfn(s) => {
271                let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
272                s.emit(ctx, input)
273            }
274            FlowStage::GatherFromInput(s) => {
275                let input =
276                    input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
277                s.emit(ctx, input)
278            }
279            FlowStage::GatherAdd(s) => {
280                let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
281                s.emit(ctx, input)
282            }
283            FlowStage::EmbedScale(s) => {
284                let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
285                s.emit(ctx, input)
286            }
287            FlowStage::GemmaRmsNorm(s) => {
288                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
289                s.emit(ctx, input)
290            }
291            FlowStage::LogitSoftcap(s) => {
292                let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
293                s.emit(ctx, input)
294            }
295            FlowStage::DecodeRopeParams(s) => {
296                s.emit(ctx)?;
297                Ok(input)
298            }
299            FlowStage::GatherDecodeRope(s) => {
300                s.emit(ctx)?;
301                Ok(input)
302            }
303            FlowStage::GemmaDecodeLayer(s) => {
304                let input =
305                    input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
306                s.emit(ctx, input)
307            }
308            FlowStage::GemmaKvTap(s) => {
309                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
310                s.emit(ctx, input)
311            }
312            FlowStage::GeGlu(s) => {
313                let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
314                s.emit(ctx, input)
315            }
316        }
317    }
318}