use std::sync::Arc;
use anyhow::Result;
use crate::blocks::{
AttnMaskStage, BertEncoderLayerStage, BindDecodeInputsStage, BlockStage, ClsTokenPoolStage,
CustomStage, EmbedStage, GatherAddStage, GatherFromInputStage, GatherLastTokenStage,
GdnScanStage, GeluFfnStage, LayerNormStage, LayerScaleStage, LinearStage,
LlamaDecodeLayerStage, LlamaDecoderStage, LlamaKvTapStage, LmHeadStage, NomicEncoderLayerStage,
Qwen3DecodeLayerStage, Qwen3DecoderStage, RepeatStage, ResidualAddStage, ResidualSaveStage,
RmsNormStage, RopeTablesStage, SelfAttnPrefillStage, SwiGluStage, VisionSwiGluFfnStage,
VitSelfAttnStage,
};
use crate::context::FlowCtx;
use crate::stream::{DualStreamStage, LoadStreamStage, StoreStreamStage};
use crate::value::FlowValue;
#[derive(Debug, Clone)]
pub enum FlowStage {
Embed(EmbedStage),
RopeTables(RopeTablesStage),
ZeroBeta { name: String, len: usize },
BindDecodeInputs(BindDecodeInputsStage),
AttnMask(AttnMaskStage),
LlamaDecodeLayer(LlamaDecodeLayerStage),
LlamaDecoder(LlamaDecoderStage),
LlamaKvTap(LlamaKvTapStage),
Repeat(RepeatStage),
Named { name: String, inner: Arc<FlowStage> },
Sequence(Vec<FlowStage>),
RmsNorm(RmsNormStage),
GatherLastToken(GatherLastTokenStage),
LmHead(LmHeadStage),
Linear(LinearStage),
ResidualSave(ResidualSaveStage),
ResidualAdd(ResidualAddStage),
SwiGlu(SwiGluStage),
SelfAttnPrefill(SelfAttnPrefillStage),
GdnScan(GdnScanStage),
StoreStream(StoreStreamStage),
LoadStream(LoadStreamStage),
DualStream(DualStreamStage),
Custom(CustomStage),
BertEncoderLayer(BertEncoderLayerStage),
NomicEncoderLayer(NomicEncoderLayerStage),
Qwen3Decoder(Qwen3DecoderStage),
Qwen3DecodeLayer(Qwen3DecodeLayerStage),
VitSelfAttn(VitSelfAttnStage),
LayerScale(LayerScaleStage),
VisionSwiGluFfn(VisionSwiGluFfnStage),
ClsTokenPool(ClsTokenPoolStage),
LayerNorm(LayerNormStage),
GeluFfn(GeluFfnStage),
GatherFromInput(GatherFromInputStage),
GatherAdd(GatherAddStage),
}
impl FlowStage {
pub(crate) fn emit(
&self,
ctx: &mut FlowCtx<'_>,
input: Option<FlowValue>,
) -> Result<Option<FlowValue>> {
match self {
FlowStage::Embed(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("Embed requires input"))?;
s.emit(ctx, input)
}
FlowStage::RopeTables(s) => {
s.emit(ctx)?;
Ok(input)
}
FlowStage::ZeroBeta { name, len } => {
let id = ctx.synth_zeros(name, *len);
ctx.state.named.insert(name.clone(), id);
if ctx.state.zero_beta.is_none() {
ctx.state.zero_beta = Some(id);
}
Ok(input)
}
FlowStage::BindDecodeInputs(s) => {
s.emit(ctx)?;
Ok(input)
}
FlowStage::AttnMask(s) => {
s.emit(ctx)?;
Ok(input)
}
FlowStage::LlamaDecodeLayer(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("LlamaDecodeLayer requires input"))?;
s.emit(ctx, input)
}
FlowStage::LlamaDecoder(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("LlamaDecoder requires input"))?;
s.emit(ctx, input)
}
FlowStage::LlamaKvTap(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("LlamaKvTap requires input"))?;
s.emit(ctx, input.clone())?;
Ok(Some(input))
}
FlowStage::Repeat(s) => s.emit(ctx, input),
FlowStage::Named { name, inner } => {
let input = input.ok_or_else(|| anyhow::anyhow!("Named block requires input"))?;
let out = inner.emit(ctx, Some(input))?;
let value = out.expect("named inner stage produced no output");
ctx.hir().node_mut(value.id).name = Some(name.clone());
Ok(Some(value))
}
FlowStage::Sequence(stages) => {
let mut value = input;
for stage in stages {
value = stage.emit(ctx, value)?;
}
Ok(value)
}
FlowStage::RmsNorm(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("RmsNorm requires input"))?;
s.emit(ctx, input)
}
FlowStage::GatherLastToken(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("GatherLastToken requires input"))?;
s.emit(ctx, input)
}
FlowStage::LmHead(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("LmHead requires input"))?;
s.emit(ctx, input)
}
FlowStage::Linear(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("Linear requires input"))?;
s.emit(ctx, input)
}
FlowStage::ResidualSave(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("ResidualSave requires input"))?;
s.emit(ctx, input)
}
FlowStage::ResidualAdd(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("ResidualAdd requires input"))?;
s.emit(ctx, input)
}
FlowStage::SwiGlu(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("SwiGlu requires input"))?;
s.emit(ctx, input)
}
FlowStage::SelfAttnPrefill(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("SelfAttnPrefill requires input"))?;
s.emit(ctx, input)
}
FlowStage::GdnScan(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("GdnScan requires input"))?;
s.emit(ctx, input)
}
FlowStage::StoreStream(s) => s.emit(ctx, input),
FlowStage::LoadStream(s) => s.emit(ctx, input),
FlowStage::DualStream(s) => s.emit(ctx, input),
FlowStage::Custom(s) => s.emit(ctx, input),
FlowStage::BertEncoderLayer(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("BertEncoderLayer requires input"))?;
s.emit(ctx, input)
}
FlowStage::NomicEncoderLayer(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires input"))?;
s.emit(ctx, input)
}
FlowStage::Qwen3Decoder(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("Qwen3Decoder requires input"))?;
s.emit(ctx, input)
}
FlowStage::Qwen3DecodeLayer(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("Qwen3DecodeLayer requires input"))?;
s.emit(ctx, input)
}
FlowStage::VitSelfAttn(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("VitSelfAttn requires input"))?;
s.emit(ctx, input)
}
FlowStage::LayerScale(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("LayerScale requires input"))?;
s.emit(ctx, input)
}
FlowStage::VisionSwiGluFfn(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("VisionSwiGluFfn requires input"))?;
s.emit(ctx, input)
}
FlowStage::ClsTokenPool(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("ClsTokenPool requires input"))?;
s.emit(ctx, input)
}
FlowStage::LayerNorm(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("LayerNorm requires input"))?;
s.emit(ctx, input)
}
FlowStage::GeluFfn(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("GeluFfn requires input"))?;
s.emit(ctx, input)
}
FlowStage::GatherFromInput(s) => {
let input =
input.ok_or_else(|| anyhow::anyhow!("GatherFromInput requires input"))?;
s.emit(ctx, input)
}
FlowStage::GatherAdd(s) => {
let input = input.ok_or_else(|| anyhow::anyhow!("GatherAdd requires input"))?;
s.emit(ctx, input)
}
}
}
}