Skip to main content

rlx_flow/
stage.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Flow stages — typed block assembly primitives.
5
6use std::sync::Arc;
7
8use anyhow::Result;
9
10use crate::blocks::{
11    AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
12    CustomStage, EmbedStage, GatherAddStage, GatherFromInputStage, GatherLastTokenStage,
13    GdnScanStage, GeluFfnStage, LayerNormStage, LayerScaleStage, LinearStage,
14    LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage, NomicEncoderLayerStage,
15    Qwen3DecodeLayerStage, Qwen3DecoderStage, RepeatStage, ResidualAddStage, ResidualSaveStage,
16    RmsNormStage, RopeTablesStage, SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage,
17    VitSelfAttnStage,
18};
19use crate::context::FlowCtx;
20use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
21use crate::value::FlowValue;
22/// One stage in a model flow. Model authors compose these — not HIR ops.
23#[derive(Debug, Clone)]
24pub enum FlowStage {
25    /// Token embedding lookup.
26    Embed(EmbedStage),
27    /// Precomputed RoPE sin/cos tables as params.
28    RopeTables(RopeTablesStage),
29    /// Ensure a rank-1 zero vector exists for RMSNorm beta slots.
30    ZeroBeta { name: String, len: usize },
31    /// Bind decode inputs (RoPE slice, past K/V, mask) into flow state.
32    BindDecodeInputs(BindDecodeInputsStage),
33    /// Bind or synthesize vision attention mask (all-ones).
34    AttnMask(AttnMaskStage),
35    /// KV-cache decode layer (concat past K/V, causal/custom attention).
36    LlamaDecodeLayer(LlamaDecodeLayerStage),
37    /// LLaMA-style fused prefill decoder layer (GQA + RoPE + SwiGLU).
38    LlamaDecoder(LlamaDecoderStage),
39    /// Side-output K/V projections for a decoder layer (prefill cache export).
40    LlamaKvTap(LlamaKvTapStage),
41    /// Repeat an inner stage `count` times with a per-index name prefix.
42    Repeat(RepeatStage),
43    /// Named nested scope (fusion/debug labeling).
44    Named { name: String, inner: Arc<FlowStage> },
45    /// Run stages in order; side-effect stages may leave the main tensor unchanged.
46    Sequence(Vec<FlowStage>),
47    /// Final RMSNorm before LM head.
48    RmsNorm(RmsNormStage),
49    /// Gather last token along sequence axis (dynamic prefill).
50    GatherLastToken(GatherLastTokenStage),
51    /// Causal LM head matmul.
52    LmHead(LmHeadStage),
53    /// Matmul against a loaded weight (`LinearStage`).
54    Linear(LinearStage),
55    /// Save skip connection for residual add.
56    ResidualSave(ResidualSaveStage),
57    /// Add saved skip connection.
58    ResidualAdd(ResidualAddStage),
59    /// SwiGLU feed-forward (gate/up/down).
60    SwiGlu(SwiGluStage),
61    /// Prefill self-attention (QKV + RoPE + GQA + causal mask).
62    SelfAttnPrefill(SelfAttnPrefillStage),
63    /// Gated DeltaNet scan (inputs via [`FlowState::gdn`]).
64    GdnScan(GdnScanStage),
65    /// Store active flow into a named stream.
66    StoreStream(StoreStreamStage),
67    /// Load active flow from a named stream.
68    LoadStream(LoadStreamStage),
69    /// Dual-stream transform (img/txt, …).
70    DualStream(DualStreamStage),
71    /// Tier-2 custom subgraph (see `rlx_flow::escape`).
72    Custom(CustomStage),
73    /// BERT-style encoder layer (fused QKV + padding-mask attention + GELU FFN).
74    BertEncoderLayer(BertEncoderLayerStage),
75    /// NomicBERT encoder layer (fused QKV + RoPE + padding-mask + SwiGLU).
76    NomicEncoderLayer(NomicEncoderLayerStage),
77    /// Qwen3 decoder layer (QK-norm + GQA + RoPE + SwiGLU).
78    Qwen3Decoder(Qwen3DecoderStage),
79    /// Qwen3 KV-cache decode layer (concat past K/V + QK-norm + GQA).
80    Qwen3DecodeLayer(Qwen3DecodeLayerStage),
81    /// ViT fused QKV self-attention with padding mask.
82    VitSelfAttn(VitSelfAttnStage),
83    /// DINOv2 LayerScale (gamma multiply).
84    LayerScale(LayerScaleStage),
85    /// NomicVision SwiGLU FFN with intermediate LayerNorm.
86    VisionSwiGluFfn(VisionSwiGluFfnStage),
87    /// CLS token pooling `[B, seq, H]` → `[B, H]`.
88    ClsTokenPool(ClsTokenPoolStage),
89    /// LayerNorm with gamma + beta.
90    LayerNorm(LayerNormStage),
91    /// GELU feed-forward (intermediate + output dense).
92    GeluFfn(GeluFfnStage),
93    /// Gather embedding table rows from a named side input.
94    GatherFromInput(GatherFromInputStage),
95    /// Add gather-from-side-input embedding to active hidden tensor.
96    GatherAdd(GatherAddStage),
97}
98
99impl FlowStage {
100    pub(crate) fn emit(
101        &self,
102        ctx: &mut FlowCtx<'_>,
103        input: Option<FlowValue>,
104    ) -> Result<Option<FlowValue>> {
105        match self {
106            FlowStage::Embed(s) => {
107                let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
108                s.emit(ctx, input)
109            }
110            FlowStage::RopeTables(s) => {
111                s.emit(ctx)?;
112                Ok(input)
113            }
114            FlowStage::ZeroBeta { name, len } => {
115                let id = ctx.synth_zeros(name, *len);
116                ctx.state.named.insert(name.clone(), id);
117                if ctx.state.zero_beta.is_none() {
118                    ctx.state.zero_beta = Some(id);
119                }
120                Ok(input)
121            }
122            FlowStage::BindDecodeInputs(s) => {
123                s.emit(ctx)?;
124                Ok(input)
125            }
126            FlowStage::AttnMask(s) => {
127                s.emit(ctx)?;
128                Ok(input)
129            }
130            FlowStage::LlamaDecodeLayer(s) => {
131                let input =
132                    input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
133                s.emit(ctx, input)
134            }
135            FlowStage::LlamaDecoder(s) => {
136                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
137                s.emit(ctx, input)
138            }
139            FlowStage::LlamaKvTap(s) => {
140                let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
141                s.emit(ctx, input.clone())?;
142                Ok(Some(input))
143            }
144            FlowStage::Repeat(s) => s.emit(ctx, input),
145            FlowStage::Named { name, inner } => {
146                let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
147                let out = inner.emit(ctx, Some(input))?;
148                let value = out.expect("named inner stage produced no output");
149                ctx.hir().node_mut(value.id).name = Some(name.clone());
150                Ok(Some(value))
151            }
152            FlowStage::Sequence(stages) => {
153                let mut value = input;
154                for stage in stages {
155                    value = stage.emit(ctx, value)?;
156                }
157                Ok(value)
158            }
159            FlowStage::RmsNorm(s) => {
160                let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
161                s.emit(ctx, input)
162            }
163            FlowStage::GatherLastToken(s) => {
164                let input =
165                    input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
166                s.emit(ctx, input)
167            }
168            FlowStage::LmHead(s) => {
169                let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
170                s.emit(ctx, input)
171            }
172            FlowStage::Linear(s) => {
173                let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
174                s.emit(ctx, input)
175            }
176            FlowStage::ResidualSave(s) => {
177                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
178                s.emit(ctx, input)
179            }
180            FlowStage::ResidualAdd(s) => {
181                let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
182                s.emit(ctx, input)
183            }
184            FlowStage::SwiGlu(s) => {
185                let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
186                s.emit(ctx, input)
187            }
188            FlowStage::SelfAttnPrefill(s) => {
189                let input =
190                    input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
191                s.emit(ctx, input)
192            }
193            FlowStage::GdnScan(s) => {
194                let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
195                s.emit(ctx, input)
196            }
197            FlowStage::StoreStream(s) => s.emit(ctx, input),
198            FlowStage::LoadStream(s) => s.emit(ctx, input),
199            FlowStage::DualStream(s) => s.emit(ctx, input),
200            FlowStage::Custom(s) => s.emit(ctx, input),
201            FlowStage::BertEncoderLayer(s) => {
202                let input =
203                    input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
204                s.emit(ctx, input)
205            }
206            FlowStage::NomicEncoderLayer(s) => {
207                let input =
208                    input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
209                s.emit(ctx, input)
210            }
211            FlowStage::Qwen3Decoder(s) => {
212                let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
213                s.emit(ctx, input)
214            }
215            FlowStage::Qwen3DecodeLayer(s) => {
216                let input =
217                    input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
218                s.emit(ctx, input)
219            }
220            FlowStage::VitSelfAttn(s) => {
221                let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
222                s.emit(ctx, input)
223            }
224            FlowStage::LayerScale(s) => {
225                let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
226                s.emit(ctx, input)
227            }
228            FlowStage::VisionSwiGluFfn(s) => {
229                let input =
230                    input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
231                s.emit(ctx, input)
232            }
233            FlowStage::ClsTokenPool(s) => {
234                let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
235                s.emit(ctx, input)
236            }
237            FlowStage::LayerNorm(s) => {
238                let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
239                s.emit(ctx, input)
240            }
241            FlowStage::GeluFfn(s) => {
242                let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
243                s.emit(ctx, input)
244            }
245            FlowStage::GatherFromInput(s) => {
246                let input =
247                    input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
248                s.emit(ctx, input)
249            }
250            FlowStage::GatherAdd(s) => {
251                let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
252                s.emit(ctx, input)
253            }
254        }
255    }
256}