1use std::sync::Arc;
7
8use crate::blocks::{
9 GatherAddStage, LayerNormStage, LinearStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
10 SelfAttnPrefillSpec, SelfAttnPrefillStage, SwiGluStage,
11};
12use crate::stage::FlowStage;
13
14#[derive(Debug, Clone, Default)]
16pub struct LayerStack {
17 name: Option<String>,
18 stages: Vec<FlowStage>,
19}
20
21impl LayerStack {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn named(name: impl Into<String>) -> Self {
27 Self {
28 name: Some(name.into()),
29 stages: Vec::new(),
30 }
31 }
32
33 pub fn layer_norm(
34 mut self,
35 gamma_key: impl Into<String>,
36 beta_key: impl Into<String>,
37 eps: f32,
38 ) -> Self {
39 self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
40 gamma_key, beta_key, eps,
41 )));
42 self
43 }
44
45 pub fn gather_add(
46 mut self,
47 input_name: impl Into<String>,
48 weight_key: impl Into<String>,
49 ) -> Self {
50 self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
51 input_name, weight_key, 0,
52 )));
53 self
54 }
55
56 pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
57 self.stages
58 .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
59 self
60 }
61
62 pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
63 self.stages
64 .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
65 self
66 }
67
68 pub fn residual_save(mut self) -> Self {
69 self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
70 self
71 }
72
73 pub fn residual_add(mut self) -> Self {
74 self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
75 self
76 }
77
78 pub fn swiglu(
79 mut self,
80 gate_key: impl Into<String>,
81 up_key: impl Into<String>,
82 down_key: impl Into<String>,
83 ) -> Self {
84 self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
85 gate_key, up_key, down_key,
86 )));
87 self
88 }
89
90 pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
91 self.stages
92 .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
93 self
94 }
95
96 pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
97 self.stages
98 .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
99 self
100 }
101
102 pub fn stage(mut self, stage: FlowStage) -> Self {
103 self.stages.push(stage);
104 self
105 }
106
107 pub fn stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
108 self.stages.extend(stages);
109 self
110 }
111
112 pub fn build(self) -> FlowStage {
113 let inner = FlowStage::Sequence(self.stages);
114 match self.name {
115 Some(name) => FlowStage::Named {
116 name,
117 inner: Arc::new(inner),
118 },
119 None => inner,
120 }
121 }
122}