1use std::sync::Arc;
19
20use crate::blocks::{
21 GatherAddStage, LayerNormStage, LinearStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
22 SelfAttnPrefillSpec, SelfAttnPrefillStage, SwiGluStage,
23};
24use crate::stage::FlowStage;
25
26#[derive(Debug, Clone, Default)]
28pub struct LayerStack {
29 name: Option<String>,
30 stages: Vec<FlowStage>,
31}
32
33impl LayerStack {
34 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn named(name: impl Into<String>) -> Self {
39 Self {
40 name: Some(name.into()),
41 stages: Vec::new(),
42 }
43 }
44
45 pub fn layer_norm(
46 mut self,
47 gamma_key: impl Into<String>,
48 beta_key: impl Into<String>,
49 eps: f32,
50 ) -> Self {
51 self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
52 gamma_key, beta_key, eps,
53 )));
54 self
55 }
56
57 pub fn gather_add(
58 mut self,
59 input_name: impl Into<String>,
60 weight_key: impl Into<String>,
61 ) -> Self {
62 self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
63 input_name, weight_key, 0,
64 )));
65 self
66 }
67
68 pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
69 self.stages
70 .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
71 self
72 }
73
74 pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
75 self.stages
76 .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
77 self
78 }
79
80 pub fn residual_save(mut self) -> Self {
81 self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
82 self
83 }
84
85 pub fn residual_add(mut self) -> Self {
86 self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
87 self
88 }
89
90 pub fn swiglu(
91 mut self,
92 gate_key: impl Into<String>,
93 up_key: impl Into<String>,
94 down_key: impl Into<String>,
95 ) -> Self {
96 self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
97 gate_key, up_key, down_key,
98 )));
99 self
100 }
101
102 pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
103 self.stages
104 .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
105 self
106 }
107
108 pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
109 self.stages
110 .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
111 self
112 }
113
114 pub fn stage(mut self, stage: FlowStage) -> Self {
115 self.stages.push(stage);
116 self
117 }
118
119 pub fn stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
120 self.stages.extend(stages);
121 self
122 }
123
124 pub fn build(self) -> FlowStage {
125 let inner = FlowStage::Sequence(self.stages);
126 match self.name {
127 Some(name) => FlowStage::Named {
128 name,
129 inner: Arc::new(inner),
130 },
131 None => inner,
132 }
133 }
134}