1use std::sync::Arc;
19
20use anyhow::Result;
21
22use crate::blocks::{
23 AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
24 CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
25 GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage, GeluFfnStage,
26 GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage, LayerScaleStage,
27 LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
28 LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage, Qwen3DecoderStage,
29 RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage, RopeTablesStage,
30 SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
31};
32use crate::context::FlowCtx;
33use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
34use crate::value::FlowValue;
35#[derive(Debug, Clone)]
37pub enum FlowStage {
38 Embed(EmbedStage),
40 RopeTables(RopeTablesStage),
42 ZeroBeta { name: String, len: usize },
44 BindDecodeInputs(BindDecodeInputsStage),
46 AttnMask(AttnMaskStage),
48 LlamaDecodeLayer(LlamaDecodeLayerStage),
50 LlamaDecoder(LlamaDecoderStage),
52 LlamaKvTap(LlamaKvTapStage),
54 Repeat(RepeatStage),
56 Named { name: String, inner: Arc<FlowStage> },
58 Sequence(Vec<FlowStage>),
60 RmsNorm(RmsNormStage),
62 GatherLastToken(GatherLastTokenStage),
64 LmHead(LmHeadStage),
66 Linear(LinearStage),
68 ResidualSave(ResidualSaveStage),
70 ResidualAdd(ResidualAddStage),
72 SwiGlu(SwiGluStage),
74 SelfAttnPrefill(SelfAttnPrefillStage),
76 GdnScan(GdnScanStage),
78 StoreStream(StoreStreamStage),
80 LoadStream(LoadStreamStage),
82 DualStream(DualStreamStage),
84 Custom(CustomStage),
86 BertEncoderLayer(BertEncoderLayerStage),
88 NomicEncoderLayer(NomicEncoderLayerStage),
90 Qwen3Decoder(Qwen3DecoderStage),
92 Qwen3DecodeLayer(Qwen3DecodeLayerStage),
94 VitSelfAttn(VitSelfAttnStage),
96 LayerScale(LayerScaleStage),
98 VisionSwiGluFfn(VisionSwiGluFfnStage),
100 ClsTokenPool(ClsTokenPoolStage),
102 LayerNorm(LayerNormStage),
104 GeluFfn(GeluFfnStage),
106 GatherFromInput(GatherFromInputStage),
108 GatherAdd(GatherAddStage),
110 EmbedScale(EmbedScaleStage),
112 GemmaRmsNorm(GemmaRmsNormStage),
114 LogitSoftcap(LogitSoftcapStage),
116 DecodeRopeParams(DecodeRopeParamsStage),
118 GemmaDecodeLayer(GemmaDecodeLayerStage),
120 GemmaKvTap(GemmaKvTapStage),
122 GeGlu(GeGluStage),
124}
125
126impl FlowStage {
127 pub(crate) fn emit(
128 &self,
129 ctx: &mut FlowCtx<'_>,
130 input: Option<FlowValue>,
131 ) -> Result<Option<FlowValue>> {
132 match self {
133 FlowStage::Embed(s) => {
134 let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
135 s.emit(ctx, input)
136 }
137 FlowStage::RopeTables(s) => {
138 s.emit(ctx)?;
139 Ok(input)
140 }
141 FlowStage::ZeroBeta { name, len } => {
142 let id = ctx.synth_zeros(name, *len);
143 ctx.state.named.insert(name.clone(), id);
144 if ctx.state.zero_beta.is_none() {
145 ctx.state.zero_beta = Some(id);
146 }
147 Ok(input)
148 }
149 FlowStage::BindDecodeInputs(s) => {
150 s.emit(ctx)?;
151 Ok(input)
152 }
153 FlowStage::AttnMask(s) => {
154 s.emit(ctx)?;
155 Ok(input)
156 }
157 FlowStage::LlamaDecodeLayer(s) => {
158 let input =
159 input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
160 s.emit(ctx, input)
161 }
162 FlowStage::LlamaDecoder(s) => {
163 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
164 s.emit(ctx, input)
165 }
166 FlowStage::LlamaKvTap(s) => {
167 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
168 s.emit(ctx, input.clone())?;
169 Ok(Some(input))
170 }
171 FlowStage::Repeat(s) => s.emit(ctx, input),
172 FlowStage::Named { name, inner } => {
173 let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
174 let out = inner.emit(ctx, Some(input))?;
175 let value = out.expect("named inner stage produced no output");
176 ctx.hir().node_mut(value.id).name = Some(name.clone());
177 Ok(Some(value))
178 }
179 FlowStage::Sequence(stages) => {
180 let mut value = input;
181 for stage in stages {
182 value = stage.emit(ctx, value)?;
183 }
184 Ok(value)
185 }
186 FlowStage::RmsNorm(s) => {
187 let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
188 s.emit(ctx, input)
189 }
190 FlowStage::GatherLastToken(s) => {
191 let input =
192 input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
193 s.emit(ctx, input)
194 }
195 FlowStage::LmHead(s) => {
196 let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
197 s.emit(ctx, input)
198 }
199 FlowStage::Linear(s) => {
200 let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
201 s.emit(ctx, input)
202 }
203 FlowStage::ResidualSave(s) => {
204 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
205 s.emit(ctx, input)
206 }
207 FlowStage::ResidualAdd(s) => {
208 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
209 s.emit(ctx, input)
210 }
211 FlowStage::SwiGlu(s) => {
212 let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
213 s.emit(ctx, input)
214 }
215 FlowStage::SelfAttnPrefill(s) => {
216 let input =
217 input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
218 s.emit(ctx, input)
219 }
220 FlowStage::GdnScan(s) => {
221 let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
222 s.emit(ctx, input)
223 }
224 FlowStage::StoreStream(s) => s.emit(ctx, input),
225 FlowStage::LoadStream(s) => s.emit(ctx, input),
226 FlowStage::DualStream(s) => s.emit(ctx, input),
227 FlowStage::Custom(s) => s.emit(ctx, input),
228 FlowStage::BertEncoderLayer(s) => {
229 let input =
230 input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
231 s.emit(ctx, input)
232 }
233 FlowStage::NomicEncoderLayer(s) => {
234 let input =
235 input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
236 s.emit(ctx, input)
237 }
238 FlowStage::Qwen3Decoder(s) => {
239 let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
240 s.emit(ctx, input)
241 }
242 FlowStage::Qwen3DecodeLayer(s) => {
243 let input =
244 input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
245 s.emit(ctx, input)
246 }
247 FlowStage::VitSelfAttn(s) => {
248 let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
249 s.emit(ctx, input)
250 }
251 FlowStage::LayerScale(s) => {
252 let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
253 s.emit(ctx, input)
254 }
255 FlowStage::VisionSwiGluFfn(s) => {
256 let input =
257 input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
258 s.emit(ctx, input)
259 }
260 FlowStage::ClsTokenPool(s) => {
261 let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
262 s.emit(ctx, input)
263 }
264 FlowStage::LayerNorm(s) => {
265 let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
266 s.emit(ctx, input)
267 }
268 FlowStage::GeluFfn(s) => {
269 let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
270 s.emit(ctx, input)
271 }
272 FlowStage::GatherFromInput(s) => {
273 let input =
274 input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
275 s.emit(ctx, input)
276 }
277 FlowStage::GatherAdd(s) => {
278 let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
279 s.emit(ctx, input)
280 }
281 FlowStage::EmbedScale(s) => {
282 let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
283 s.emit(ctx, input)
284 }
285 FlowStage::GemmaRmsNorm(s) => {
286 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
287 s.emit(ctx, input)
288 }
289 FlowStage::LogitSoftcap(s) => {
290 let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
291 s.emit(ctx, input)
292 }
293 FlowStage::DecodeRopeParams(s) => {
294 s.emit(ctx)?;
295 Ok(input)
296 }
297 FlowStage::GemmaDecodeLayer(s) => {
298 let input =
299 input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
300 s.emit(ctx, input)
301 }
302 FlowStage::GemmaKvTap(s) => {
303 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
304 s.emit(ctx, input)
305 }
306 FlowStage::GeGlu(s) => {
307 let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
308 s.emit(ctx, input)
309 }
310 }
311 }
312}