Skip to main content

rlx_flow/
stage.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Flow stages — typed block assembly primitives.
5
6use std::sync::Arc;
7
8use anyhow::Result;
9
10use crate::blocks::{
11    AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
12    CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
13    GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage, GeluFfnStage,
14    GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage, LayerScaleStage,
15    LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
16    LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage, Qwen3DecoderStage,
17    RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage, RopeTablesStage,
18    SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
19};
20use crate::context::FlowCtx;
21use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
22use crate::value::FlowValue;
23/// One stage in a model flow. Model authors compose these — not HIR ops.
24#[derive(Debug, Clone)]
25pub enum FlowStage {
26    /// Token embedding lookup.
27    Embed(EmbedStage),
28    /// Precomputed RoPE sin/cos tables as params.
29    RopeTables(RopeTablesStage),
30    /// Ensure a rank-1 zero vector exists for RMSNorm beta slots.
31    ZeroBeta { name: String, len: usize },
32    /// Bind decode inputs (RoPE slice, past K/V, mask) into flow state.
33    BindDecodeInputs(BindDecodeInputsStage),
34    /// Bind or synthesize vision attention mask (all-ones).
35    AttnMask(AttnMaskStage),
36    /// KV-cache decode layer (concat past K/V, causal/custom attention).
37    LlamaDecodeLayer(LlamaDecodeLayerStage),
38    /// LLaMA-style fused prefill decoder layer (GQA + RoPE + SwiGLU).
39    LlamaDecoder(LlamaDecoderStage),
40    /// Side-output K/V projections for a decoder layer (prefill cache export).
41    LlamaKvTap(LlamaKvTapStage),
42    /// Repeat an inner stage `count` times with a per-index name prefix.
43    Repeat(RepeatStage),
44    /// Named nested scope (fusion/debug labeling).
45    Named { name: String, inner: Arc<FlowStage> },
46    /// Run stages in order; side-effect stages may leave the main tensor unchanged.
47    Sequence(Vec<FlowStage>),
48    /// Final RMSNorm before LM head.
49    RmsNorm(RmsNormStage),
50    /// Gather last token along sequence axis (dynamic prefill).
51    GatherLastToken(GatherLastTokenStage),
52    /// Causal LM head matmul.
53    LmHead(LmHeadStage),
54    /// Matmul against a loaded weight (`LinearStage`).
55    Linear(LinearStage),
56    /// Save skip connection for residual add.
57    ResidualSave(ResidualSaveStage),
58    /// Add saved skip connection.
59    ResidualAdd(ResidualAddStage),
60    /// SwiGLU feed-forward (gate/up/down).
61    SwiGlu(SwiGluStage),
62    /// Prefill self-attention (QKV + RoPE + GQA + causal mask).
63    SelfAttnPrefill(SelfAttnPrefillStage),
64    /// Gated DeltaNet scan (inputs via [`FlowState::gdn`]).
65    GdnScan(GdnScanStage),
66    /// Store active flow into a named stream.
67    StoreStream(StoreStreamStage),
68    /// Load active flow from a named stream.
69    LoadStream(LoadStreamStage),
70    /// Dual-stream transform (img/txt, …).
71    DualStream(DualStreamStage),
72    /// Tier-2 custom subgraph (see `rlx_flow::escape`).
73    Custom(CustomStage),
74    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
75    BertEncoderLayer(BertEncoderLayerStage),
76    /// NomicBERT encoder layer (fused QKV + RoPE + padding-mask + SwiGLU).
77    NomicEncoderLayer(NomicEncoderLayerStage),
78    /// Qwen3 decoder layer (QK-norm + GQA + RoPE + SwiGLU).
79    Qwen3Decoder(Qwen3DecoderStage),
80    /// Qwen3 KV-cache decode layer (concat past K/V + QK-norm + GQA).
81    Qwen3DecodeLayer(Qwen3DecodeLayerStage),
82    /// ViT fused QKV self-attention with padding mask.
83    VitSelfAttn(VitSelfAttnStage),
84    /// DINOv2 LayerScale (gamma multiply).
85    LayerScale(LayerScaleStage),
86    /// NomicVision SwiGLU FFN with intermediate LayerNorm.
87    VisionSwiGluFfn(VisionSwiGluFfnStage),
88    /// CLS token pooling `[B, seq, H]` → `[B, H]`.
89    ClsTokenPool(ClsTokenPoolStage),
90    /// LayerNorm with gamma + beta.
91    LayerNorm(LayerNormStage),
92    /// GELU feed-forward (intermediate + output dense).
93    GeluFfn(GeluFfnStage),
94    /// Gather embedding table rows from a named side input.
95    GatherFromInput(GatherFromInputStage),
96    /// Add gather-from-side-input embedding to active hidden tensor.
97    GatherAdd(GatherAddStage),
98    /// Gemma embedding scale (`sqrt(hidden)`).
99    EmbedScale(EmbedScaleStage),
100    /// Gemma RMSNorm (`1 + weight`).
101    GemmaRmsNorm(GemmaRmsNormStage),
102    /// Gemma 2 logit softcap.
103    LogitSoftcap(LogitSoftcapStage),
104    /// Static decode RoPE tables (params in flow state).
105    DecodeRopeParams(DecodeRopeParamsStage),
106    /// Gemma KV-cache decode layer.
107    GemmaDecodeLayer(GemmaDecodeLayerStage),
108    /// Gemma prefill K/V side export.
109    GemmaKvTap(GemmaKvTapStage),
110    /// Gemma GeGLU FFN.
111    GeGlu(GeGluStage),
112}
113
114impl FlowStage {
115    pub(crate) fn emit(
116        &self,
117        ctx: &mut FlowCtx<'_>,
118        input: Option<FlowValue>,
119    ) -> Result<Option<FlowValue>> {
120        match self {
121            FlowStage::Embed(s) => {
122                let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
123                s.emit(ctx, input)
124            }
125            FlowStage::RopeTables(s) => {
126                s.emit(ctx)?;
127                Ok(input)
128            }
129            FlowStage::ZeroBeta { name, len } => {
130                let id = ctx.synth_zeros(name, *len);
131                ctx.state.named.insert(name.clone(), id);
132                if ctx.state.zero_beta.is_none() {
133                    ctx.state.zero_beta = Some(id);
134                }
135                Ok(input)
136            }
137            FlowStage::BindDecodeInputs(s) => {
138                s.emit(ctx)?;
139                Ok(input)
140            }
141            FlowStage::AttnMask(s) => {
142                s.emit(ctx)?;
143                Ok(input)
144            }
145            FlowStage::LlamaDecodeLayer(s) => {
146                let input =
147                    input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
148                s.emit(ctx, input)
149            }
150            FlowStage::LlamaDecoder(s) => {
151                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
152                s.emit(ctx, input)
153            }
154            FlowStage::LlamaKvTap(s) => {
155                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
156                s.emit(ctx, input.clone())?;
157                Ok(Some(input))
158            }
159            FlowStage::Repeat(s) => s.emit(ctx, input),
160            FlowStage::Named { name, inner } => {
161                let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
162                let out = inner.emit(ctx, Some(input))?;
163                let value = out.expect("named inner stage produced no output");
164                ctx.hir().node_mut(value.id).name = Some(name.clone());
165                Ok(Some(value))
166            }
167            FlowStage::Sequence(stages) => {
168                let mut value = input;
169                for stage in stages {
170                    value = stage.emit(ctx, value)?;
171                }
172                Ok(value)
173            }
174            FlowStage::RmsNorm(s) => {
175                let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
176                s.emit(ctx, input)
177            }
178            FlowStage::GatherLastToken(s) => {
179                let input =
180                    input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
181                s.emit(ctx, input)
182            }
183            FlowStage::LmHead(s) => {
184                let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
185                s.emit(ctx, input)
186            }
187            FlowStage::Linear(s) => {
188                let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
189                s.emit(ctx, input)
190            }
191            FlowStage::ResidualSave(s) => {
192                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
193                s.emit(ctx, input)
194            }
195            FlowStage::ResidualAdd(s) => {
196                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
197                s.emit(ctx, input)
198            }
199            FlowStage::SwiGlu(s) => {
200                let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
201                s.emit(ctx, input)
202            }
203            FlowStage::SelfAttnPrefill(s) => {
204                let input =
205                    input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
206                s.emit(ctx, input)
207            }
208            FlowStage::GdnScan(s) => {
209                let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
210                s.emit(ctx, input)
211            }
212            FlowStage::StoreStream(s) => s.emit(ctx, input),
213            FlowStage::LoadStream(s) => s.emit(ctx, input),
214            FlowStage::DualStream(s) => s.emit(ctx, input),
215            FlowStage::Custom(s) => s.emit(ctx, input),
216            FlowStage::BertEncoderLayer(s) => {
217                let input =
218                    input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
219                s.emit(ctx, input)
220            }
221            FlowStage::NomicEncoderLayer(s) => {
222                let input =
223                    input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
224                s.emit(ctx, input)
225            }
226            FlowStage::Qwen3Decoder(s) => {
227                let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
228                s.emit(ctx, input)
229            }
230            FlowStage::Qwen3DecodeLayer(s) => {
231                let input =
232                    input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
233                s.emit(ctx, input)
234            }
235            FlowStage::VitSelfAttn(s) => {
236                let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
237                s.emit(ctx, input)
238            }
239            FlowStage::LayerScale(s) => {
240                let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
241                s.emit(ctx, input)
242            }
243            FlowStage::VisionSwiGluFfn(s) => {
244                let input =
245                    input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
246                s.emit(ctx, input)
247            }
248            FlowStage::ClsTokenPool(s) => {
249                let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
250                s.emit(ctx, input)
251            }
252            FlowStage::LayerNorm(s) => {
253                let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
254                s.emit(ctx, input)
255            }
256            FlowStage::GeluFfn(s) => {
257                let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
258                s.emit(ctx, input)
259            }
260            FlowStage::GatherFromInput(s) => {
261                let input =
262                    input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
263                s.emit(ctx, input)
264            }
265            FlowStage::GatherAdd(s) => {
266                let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
267                s.emit(ctx, input)
268            }
269            FlowStage::EmbedScale(s) => {
270                let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
271                s.emit(ctx, input)
272            }
273            FlowStage::GemmaRmsNorm(s) => {
274                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
275                s.emit(ctx, input)
276            }
277            FlowStage::LogitSoftcap(s) => {
278                let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
279                s.emit(ctx, input)
280            }
281            FlowStage::DecodeRopeParams(s) => {
282                s.emit(ctx)?;
283                Ok(input)
284            }
285            FlowStage::GemmaDecodeLayer(s) => {
286                let input =
287                    input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
288                s.emit(ctx, input)
289            }
290            FlowStage::GemmaKvTap(s) => {
291                let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
292                s.emit(ctx, input)
293            }
294            FlowStage::GeGlu(s) => {
295                let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
296                s.emit(ctx, input)
297            }
298        }
299    }
300}