1use std::sync::Arc;
19
20use anyhow::Result;
21
22use crate::blocks::{
23 AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
24 CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
25 GatherDecodeRopeStage, GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage,
26 GeluFfnStage, GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage,
27 LayerScaleStage, LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage,
28 LmHeadStage, LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage,
29 Qwen3DecoderStage, RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
30 RopeTablesStage, SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
31};
32use crate::context::FlowCtx;
33use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
34use crate::value::FlowValue;
35#[derive(Debug, Clone)]
37pub enum FlowStage {
38 Embed(EmbedStage),
40 RopeTables(RopeTablesStage),
42 ZeroBeta { name: String, len: usize },
44 BindDecodeInputs(BindDecodeInputsStage),
46 AttnMask(AttnMaskStage),
48 LlamaDecodeLayer(LlamaDecodeLayerStage),
50 LlamaDecoder(LlamaDecoderStage),
52 LlamaKvTap(LlamaKvTapStage),
54 Repeat(RepeatStage),
56 Named { name: String, inner: Arc<FlowStage> },
58 Sequence(Vec<FlowStage>),
60 RmsNorm(RmsNormStage),
62 GatherLastToken(GatherLastTokenStage),
64 LmHead(LmHeadStage),
66 Linear(LinearStage),
68 ResidualSave(ResidualSaveStage),
70 ResidualAdd(ResidualAddStage),
72 SwiGlu(SwiGluStage),
74 SelfAttnPrefill(SelfAttnPrefillStage),
76 GdnScan(GdnScanStage),
78 StoreStream(StoreStreamStage),
80 LoadStream(LoadStreamStage),
82 DualStream(DualStreamStage),
84 Custom(CustomStage),
86 BertEncoderLayer(BertEncoderLayerStage),
88 NomicEncoderLayer(NomicEncoderLayerStage),
90 Qwen3Decoder(Qwen3DecoderStage),
92 Qwen3DecodeLayer(Qwen3DecodeLayerStage),
94 VitSelfAttn(VitSelfAttnStage),
96 LayerScale(LayerScaleStage),
98 VisionSwiGluFfn(VisionSwiGluFfnStage),
100 ClsTokenPool(ClsTokenPoolStage),
102 LayerNorm(LayerNormStage),
104 GeluFfn(GeluFfnStage),
106 GatherFromInput(GatherFromInputStage),
108 GatherAdd(GatherAddStage),
110 EmbedScale(EmbedScaleStage),
112 GemmaRmsNorm(GemmaRmsNormStage),
114 LogitSoftcap(LogitSoftcapStage),
116 DecodeRopeParams(DecodeRopeParamsStage),
118 GatherDecodeRope(GatherDecodeRopeStage),
120 GemmaDecodeLayer(GemmaDecodeLayerStage),
122 GemmaKvTap(GemmaKvTapStage),
124 GeGlu(GeGluStage),
126}
127
128impl FlowStage {
129 pub(crate) fn emit(
130 &self,
131 ctx: &mut FlowCtx<'_>,
132 input: Option<FlowValue>,
133 ) -> Result<Option<FlowValue>> {
134 match self {
135 FlowStage::Embed(s) => {
136 let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
137 s.emit(ctx, input)
138 }
139 FlowStage::RopeTables(s) => {
140 s.emit(ctx)?;
141 Ok(input)
142 }
143 FlowStage::ZeroBeta { name, len } => {
144 let id = ctx.synth_zeros(name, *len);
145 ctx.state.named.insert(name.clone(), id);
146 if ctx.state.zero_beta.is_none() {
147 ctx.state.zero_beta = Some(id);
148 }
149 Ok(input)
150 }
151 FlowStage::BindDecodeInputs(s) => {
152 s.emit(ctx)?;
153 Ok(input)
154 }
155 FlowStage::AttnMask(s) => {
156 s.emit(ctx)?;
157 Ok(input)
158 }
159 FlowStage::LlamaDecodeLayer(s) => {
160 let input =
161 input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
162 s.emit(ctx, input)
163 }
164 FlowStage::LlamaDecoder(s) => {
165 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
166 s.emit(ctx, input)
167 }
168 FlowStage::LlamaKvTap(s) => {
169 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
170 s.emit(ctx, input.clone())?;
171 Ok(Some(input))
172 }
173 FlowStage::Repeat(s) => s.emit(ctx, input),
174 FlowStage::Named { name, inner } => {
175 let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
176 let out = inner.emit(ctx, Some(input))?;
177 let value = out.expect("named inner stage produced no output");
178 ctx.hir().node_mut(value.id).name = Some(name.clone());
179 Ok(Some(value))
180 }
181 FlowStage::Sequence(stages) => {
182 let mut value = input;
183 for stage in stages {
184 value = stage.emit(ctx, value)?;
185 }
186 Ok(value)
187 }
188 FlowStage::RmsNorm(s) => {
189 let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
190 s.emit(ctx, input)
191 }
192 FlowStage::GatherLastToken(s) => {
193 let input =
194 input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
195 s.emit(ctx, input)
196 }
197 FlowStage::LmHead(s) => {
198 let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
199 s.emit(ctx, input)
200 }
201 FlowStage::Linear(s) => {
202 let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
203 s.emit(ctx, input)
204 }
205 FlowStage::ResidualSave(s) => {
206 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
207 s.emit(ctx, input)
208 }
209 FlowStage::ResidualAdd(s) => {
210 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
211 s.emit(ctx, input)
212 }
213 FlowStage::SwiGlu(s) => {
214 let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
215 s.emit(ctx, input)
216 }
217 FlowStage::SelfAttnPrefill(s) => {
218 let input =
219 input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
220 s.emit(ctx, input)
221 }
222 FlowStage::GdnScan(s) => {
223 let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
224 s.emit(ctx, input)
225 }
226 FlowStage::StoreStream(s) => s.emit(ctx, input),
227 FlowStage::LoadStream(s) => s.emit(ctx, input),
228 FlowStage::DualStream(s) => s.emit(ctx, input),
229 FlowStage::Custom(s) => s.emit(ctx, input),
230 FlowStage::BertEncoderLayer(s) => {
231 let input =
232 input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
233 s.emit(ctx, input)
234 }
235 FlowStage::NomicEncoderLayer(s) => {
236 let input =
237 input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
238 s.emit(ctx, input)
239 }
240 FlowStage::Qwen3Decoder(s) => {
241 let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
242 s.emit(ctx, input)
243 }
244 FlowStage::Qwen3DecodeLayer(s) => {
245 let input =
246 input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
247 s.emit(ctx, input)
248 }
249 FlowStage::VitSelfAttn(s) => {
250 let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
251 s.emit(ctx, input)
252 }
253 FlowStage::LayerScale(s) => {
254 let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
255 s.emit(ctx, input)
256 }
257 FlowStage::VisionSwiGluFfn(s) => {
258 let input =
259 input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
260 s.emit(ctx, input)
261 }
262 FlowStage::ClsTokenPool(s) => {
263 let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
264 s.emit(ctx, input)
265 }
266 FlowStage::LayerNorm(s) => {
267 let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
268 s.emit(ctx, input)
269 }
270 FlowStage::GeluFfn(s) => {
271 let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
272 s.emit(ctx, input)
273 }
274 FlowStage::GatherFromInput(s) => {
275 let input =
276 input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
277 s.emit(ctx, input)
278 }
279 FlowStage::GatherAdd(s) => {
280 let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
281 s.emit(ctx, input)
282 }
283 FlowStage::EmbedScale(s) => {
284 let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
285 s.emit(ctx, input)
286 }
287 FlowStage::GemmaRmsNorm(s) => {
288 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
289 s.emit(ctx, input)
290 }
291 FlowStage::LogitSoftcap(s) => {
292 let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
293 s.emit(ctx, input)
294 }
295 FlowStage::DecodeRopeParams(s) => {
296 s.emit(ctx)?;
297 Ok(input)
298 }
299 FlowStage::GatherDecodeRope(s) => {
300 s.emit(ctx)?;
301 Ok(input)
302 }
303 FlowStage::GemmaDecodeLayer(s) => {
304 let input =
305 input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
306 s.emit(ctx, input)
307 }
308 FlowStage::GemmaKvTap(s) => {
309 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
310 s.emit(ctx, input)
311 }
312 FlowStage::GeGlu(s) => {
313 let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
314 s.emit(ctx, input)
315 }
316 }
317 }
318}