1use std::sync::Arc;
7
8use anyhow::Result;
9
10use crate::blocks::{
11 AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
12 CustomStage, EmbedStage, GatherAddStage, GatherFromInputStage, GatherLastTokenStage,
13 GdnScanStage, GeluFfnStage, LayerNormStage, LayerScaleStage, LinearStage,
14 LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage, NomicEncoderLayerStage,
15 Qwen3DecodeLayerStage, Qwen3DecoderStage, RepeatStage, ResidualAddStage, ResidualSaveStage,
16 RmsNormStage, RopeTablesStage, SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage,
17 VitSelfAttnStage,
18};
19use crate::context::FlowCtx;
20use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
21use crate::value::FlowValue;
22#[derive(Debug, Clone)]
24pub enum FlowStage {
25 Embed(EmbedStage),
27 RopeTables(RopeTablesStage),
29 ZeroBeta { name: String, len: usize },
31 BindDecodeInputs(BindDecodeInputsStage),
33 AttnMask(AttnMaskStage),
35 LlamaDecodeLayer(LlamaDecodeLayerStage),
37 LlamaDecoder(LlamaDecoderStage),
39 LlamaKvTap(LlamaKvTapStage),
41 Repeat(RepeatStage),
43 Named { name: String, inner: Arc<FlowStage> },
45 Sequence(Vec<FlowStage>),
47 RmsNorm(RmsNormStage),
49 GatherLastToken(GatherLastTokenStage),
51 LmHead(LmHeadStage),
53 Linear(LinearStage),
55 ResidualSave(ResidualSaveStage),
57 ResidualAdd(ResidualAddStage),
59 SwiGlu(SwiGluStage),
61 SelfAttnPrefill(SelfAttnPrefillStage),
63 GdnScan(GdnScanStage),
65 StoreStream(StoreStreamStage),
67 LoadStream(LoadStreamStage),
69 DualStream(DualStreamStage),
71 Custom(CustomStage),
73 BertEncoderLayer(BertEncoderLayerStage),
75 NomicEncoderLayer(NomicEncoderLayerStage),
77 Qwen3Decoder(Qwen3DecoderStage),
79 Qwen3DecodeLayer(Qwen3DecodeLayerStage),
81 VitSelfAttn(VitSelfAttnStage),
83 LayerScale(LayerScaleStage),
85 VisionSwiGluFfn(VisionSwiGluFfnStage),
87 ClsTokenPool(ClsTokenPoolStage),
89 LayerNorm(LayerNormStage),
91 GeluFfn(GeluFfnStage),
93 GatherFromInput(GatherFromInputStage),
95 GatherAdd(GatherAddStage),
97}
98
99impl FlowStage {
100 pub(crate) fn emit(
101 &self,
102 ctx: &mut FlowCtx<'_>,
103 input: Option<FlowValue>,
104 ) -> Result<Option<FlowValue>> {
105 match self {
106 FlowStage::Embed(s) => {
107 let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
108 s.emit(ctx, input)
109 }
110 FlowStage::RopeTables(s) => {
111 s.emit(ctx)?;
112 Ok(input)
113 }
114 FlowStage::ZeroBeta { name, len } => {
115 let id = ctx.synth_zeros(name, *len);
116 ctx.state.named.insert(name.clone(), id);
117 if ctx.state.zero_beta.is_none() {
118 ctx.state.zero_beta = Some(id);
119 }
120 Ok(input)
121 }
122 FlowStage::BindDecodeInputs(s) => {
123 s.emit(ctx)?;
124 Ok(input)
125 }
126 FlowStage::AttnMask(s) => {
127 s.emit(ctx)?;
128 Ok(input)
129 }
130 FlowStage::LlamaDecodeLayer(s) => {
131 let input =
132 input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
133 s.emit(ctx, input)
134 }
135 FlowStage::LlamaDecoder(s) => {
136 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
137 s.emit(ctx, input)
138 }
139 FlowStage::LlamaKvTap(s) => {
140 let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
141 s.emit(ctx, input.clone())?;
142 Ok(Some(input))
143 }
144 FlowStage::Repeat(s) => s.emit(ctx, input),
145 FlowStage::Named { name, inner } => {
146 let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
147 let out = inner.emit(ctx, Some(input))?;
148 let value = out.expect("named inner stage produced no output");
149 ctx.hir().node_mut(value.id).name = Some(name.clone());
150 Ok(Some(value))
151 }
152 FlowStage::Sequence(stages) => {
153 let mut value = input;
154 for stage in stages {
155 value = stage.emit(ctx, value)?;
156 }
157 Ok(value)
158 }
159 FlowStage::RmsNorm(s) => {
160 let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
161 s.emit(ctx, input)
162 }
163 FlowStage::GatherLastToken(s) => {
164 let input =
165 input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
166 s.emit(ctx, input)
167 }
168 FlowStage::LmHead(s) => {
169 let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
170 s.emit(ctx, input)
171 }
172 FlowStage::Linear(s) => {
173 let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
174 s.emit(ctx, input)
175 }
176 FlowStage::ResidualSave(s) => {
177 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
178 s.emit(ctx, input)
179 }
180 FlowStage::ResidualAdd(s) => {
181 let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
182 s.emit(ctx, input)
183 }
184 FlowStage::SwiGlu(s) => {
185 let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
186 s.emit(ctx, input)
187 }
188 FlowStage::SelfAttnPrefill(s) => {
189 let input =
190 input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
191 s.emit(ctx, input)
192 }
193 FlowStage::GdnScan(s) => {
194 let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
195 s.emit(ctx, input)
196 }
197 FlowStage::StoreStream(s) => s.emit(ctx, input),
198 FlowStage::LoadStream(s) => s.emit(ctx, input),
199 FlowStage::DualStream(s) => s.emit(ctx, input),
200 FlowStage::Custom(s) => s.emit(ctx, input),
201 FlowStage::BertEncoderLayer(s) => {
202 let input =
203 input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
204 s.emit(ctx, input)
205 }
206 FlowStage::NomicEncoderLayer(s) => {
207 let input =
208 input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
209 s.emit(ctx, input)
210 }
211 FlowStage::Qwen3Decoder(s) => {
212 let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
213 s.emit(ctx, input)
214 }
215 FlowStage::Qwen3DecodeLayer(s) => {
216 let input =
217 input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
218 s.emit(ctx, input)
219 }
220 FlowStage::VitSelfAttn(s) => {
221 let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
222 s.emit(ctx, input)
223 }
224 FlowStage::LayerScale(s) => {
225 let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
226 s.emit(ctx, input)
227 }
228 FlowStage::VisionSwiGluFfn(s) => {
229 let input =
230 input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
231 s.emit(ctx, input)
232 }
233 FlowStage::ClsTokenPool(s) => {
234 let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
235 s.emit(ctx, input)
236 }
237 FlowStage::LayerNorm(s) => {
238 let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
239 s.emit(ctx, input)
240 }
241 FlowStage::GeluFfn(s) => {
242 let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
243 s.emit(ctx, input)
244 }
245 FlowStage::GatherFromInput(s) => {
246 let input =
247 input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
248 s.emit(ctx, input)
249 }
250 FlowStage::GatherAdd(s) => {
251 let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
252 s.emit(ctx, input)
253 }
254 }
255 }
256}