1use std::sync::Arc;
7
8use anyhow::Result;
9
10use crate::blocks::{
11 AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
12 CustomStage, DecodeRopeParamsStage, EmbedScaleStage, EmbedStage, GatherAddStage,
13 GatherFromInputStage, GatherLastTokenStage, GdnScanStage, GeGluStage, GeluFfnStage,
14 GemmaDecodeLayerStage, GemmaKvTapStage, GemmaRmsNormStage, LayerNormStage, LayerScaleStage,
15 LinearStage, LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage,
16 LogitSoftcapStage, NomicEncoderLayerStage, Qwen3DecodeLayerStage, Qwen3DecoderStage,
17 RepeatStage, ResidualAddStage, ResidualSaveStage, RmsNormStage, RopeTablesStage,
18 SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage, VitSelfAttnStage,
19};
20use crate::context::FlowCtx;
21use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
22use crate::value::FlowValue;
23#[derive(Debug, Clone)]
25pub enum FlowStage {
26 Embed(EmbedStage),
28 RopeTables(RopeTablesStage),
30 ZeroBeta { name: String, len: usize },
32 BindDecodeInputs(BindDecodeInputsStage),
34 AttnMask(AttnMaskStage),
36 LlamaDecodeLayer(LlamaDecodeLayerStage),
38 LlamaDecoder(LlamaDecoderStage),
40 LlamaKvTap(LlamaKvTapStage),
42 Repeat(RepeatStage),
44 Named { name: String, inner: Arc<FlowStage> },
46 Sequence(Vec<FlowStage>),
48 RmsNorm(RmsNormStage),
50 GatherLastToken(GatherLastTokenStage),
52 LmHead(LmHeadStage),
54 Linear(LinearStage),
56 ResidualSave(ResidualSaveStage),
58 ResidualAdd(ResidualAddStage),
60 SwiGlu(SwiGluStage),
62 SelfAttnPrefill(SelfAttnPrefillStage),
64 GdnScan(GdnScanStage),
66 StoreStream(StoreStreamStage),
68 LoadStream(LoadStreamStage),
70 DualStream(DualStreamStage),
72 Custom(CustomStage),
74 BertEncoderLayer(BertEncoderLayerStage),
76 NomicEncoderLayer(NomicEncoderLayerStage),
78 Qwen3Decoder(Qwen3DecoderStage),
80 Qwen3DecodeLayer(Qwen3DecodeLayerStage),
82 VitSelfAttn(VitSelfAttnStage),
84 LayerScale(LayerScaleStage),
86 VisionSwiGluFfn(VisionSwiGluFfnStage),
88 ClsTokenPool(ClsTokenPoolStage),
90 LayerNorm(LayerNormStage),
92 GeluFfn(GeluFfnStage),
94 GatherFromInput(GatherFromInputStage),
96 GatherAdd(GatherAddStage),
98 EmbedScale(EmbedScaleStage),
100 GemmaRmsNorm(GemmaRmsNormStage),
102 LogitSoftcap(LogitSoftcapStage),
104 DecodeRopeParams(DecodeRopeParamsStage),
106 GemmaDecodeLayer(GemmaDecodeLayerStage),
108 GemmaKvTap(GemmaKvTapStage),
110 GeGlu(GeGluStage),
112}
113
114impl FlowStage {
115 pub(crate) fn emit(
116 &self,
117 ctx: &mut FlowCtx<'_>,
118 input: Option<FlowValue>,
119 ) -> Result<Option<FlowValue>> {
120 match self {
121 FlowStage::Embed(s) => {
122 let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
123 s.emit(ctx, input)
124 }
125 FlowStage::RopeTables(s) => {
126 s.emit(ctx)?;
127 Ok(input)
128 }
129 FlowStage::ZeroBeta { name, len } => {
130 let id = ctx.synth_zeros(name, *len);
131 ctx.state.named.insert(name.clone(), id);
132 if ctx.state.zero_beta.is_none() {
133 ctx.state.zero_beta = Some(id);
134 }
135 Ok(input)
136 }
137 FlowStage::BindDecodeInputs(s) => {
138 s.emit(ctx)?;
139 Ok(input)
140 }
141 FlowStage::AttnMask(s) => {
142 s.emit(ctx)?;
143 Ok(input)
144 }
145 FlowStage::LlamaDecodeLayer(s) => {
146 let input =
147 input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
148 s.emit(ctx, input)
149 }
150 FlowStage::LlamaDecoder(s) => {
151 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
152 s.emit(ctx, input)
153 }
154 FlowStage::LlamaKvTap(s) => {
155 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
156 s.emit(ctx, input.clone())?;
157 Ok(Some(input))
158 }
159 FlowStage::Repeat(s) => s.emit(ctx, input),
160 FlowStage::Named { name, inner } => {
161 let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
162 let out = inner.emit(ctx, Some(input))?;
163 let value = out.expect("named inner stage produced no output");
164 ctx.hir().node_mut(value.id).name = Some(name.clone());
165 Ok(Some(value))
166 }
167 FlowStage::Sequence(stages) => {
168 let mut value = input;
169 for stage in stages {
170 value = stage.emit(ctx, value)?;
171 }
172 Ok(value)
173 }
174 FlowStage::RmsNorm(s) => {
175 let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
176 s.emit(ctx, input)
177 }
178 FlowStage::GatherLastToken(s) => {
179 let input =
180 input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
181 s.emit(ctx, input)
182 }
183 FlowStage::LmHead(s) => {
184 let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
185 s.emit(ctx, input)
186 }
187 FlowStage::Linear(s) => {
188 let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
189 s.emit(ctx, input)
190 }
191 FlowStage::ResidualSave(s) => {
192 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
193 s.emit(ctx, input)
194 }
195 FlowStage::ResidualAdd(s) => {
196 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
197 s.emit(ctx, input)
198 }
199 FlowStage::SwiGlu(s) => {
200 let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
201 s.emit(ctx, input)
202 }
203 FlowStage::SelfAttnPrefill(s) => {
204 let input =
205 input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
206 s.emit(ctx, input)
207 }
208 FlowStage::GdnScan(s) => {
209 let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
210 s.emit(ctx, input)
211 }
212 FlowStage::StoreStream(s) => s.emit(ctx, input),
213 FlowStage::LoadStream(s) => s.emit(ctx, input),
214 FlowStage::DualStream(s) => s.emit(ctx, input),
215 FlowStage::Custom(s) => s.emit(ctx, input),
216 FlowStage::BertEncoderLayer(s) => {
217 let input =
218 input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
219 s.emit(ctx, input)
220 }
221 FlowStage::NomicEncoderLayer(s) => {
222 let input =
223 input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
224 s.emit(ctx, input)
225 }
226 FlowStage::Qwen3Decoder(s) => {
227 let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
228 s.emit(ctx, input)
229 }
230 FlowStage::Qwen3DecodeLayer(s) => {
231 let input =
232 input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
233 s.emit(ctx, input)
234 }
235 FlowStage::VitSelfAttn(s) => {
236 let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
237 s.emit(ctx, input)
238 }
239 FlowStage::LayerScale(s) => {
240 let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
241 s.emit(ctx, input)
242 }
243 FlowStage::VisionSwiGluFfn(s) => {
244 let input =
245 input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
246 s.emit(ctx, input)
247 }
248 FlowStage::ClsTokenPool(s) => {
249 let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
250 s.emit(ctx, input)
251 }
252 FlowStage::LayerNorm(s) => {
253 let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
254 s.emit(ctx, input)
255 }
256 FlowStage::GeluFfn(s) => {
257 let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
258 s.emit(ctx, input)
259 }
260 FlowStage::GatherFromInput(s) => {
261 let input =
262 input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
263 s.emit(ctx, input)
264 }
265 FlowStage::GatherAdd(s) => {
266 let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
267 s.emit(ctx, input)
268 }
269 FlowStage::EmbedScale(s) => {
270 let input = input.ok_or_else(|| anyhow::anyhow!("EmbedScale requires input"))?;
271 s.emit(ctx, input)
272 }
273 FlowStage::GemmaRmsNorm(s) => {
274 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaRmsNorm requires input"))?;
275 s.emit(ctx, input)
276 }
277 FlowStage::LogitSoftcap(s) => {
278 let input = input.ok_or_else(|| anyhow::anyhow!("LogitSoftcap requires input"))?;
279 s.emit(ctx, input)
280 }
281 FlowStage::DecodeRopeParams(s) => {
282 s.emit(ctx)?;
283 Ok(input)
284 }
285 FlowStage::GemmaDecodeLayer(s) => {
286 let input =
287 input.ok_or_else(|| anyhow::anyhow!("GemmaDecodeLayer requires input"))?;
288 s.emit(ctx, input)
289 }
290 FlowStage::GemmaKvTap(s) => {
291 let input = input.ok_or_else(|| anyhow::anyhow!("GemmaKvTap requires input"))?;
292 s.emit(ctx, input)
293 }
294 FlowStage::GeGlu(s) => {
295 let input = input.ok_or_else(|| anyhow::anyhow!("GeGlu requires input"))?;
296 s.emit(ctx, input)
297 }
298 }
299 }
300}