Skip to main content

rlx_flow/
layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fluent per-layer composer — stack small blocks without IR/Graph imports.
17
18use std::sync::Arc;
19
20use crate::blocks::{
21    GatherAddStage, LayerNormStage, LinearStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
22    SelfAttnPrefillSpec, SelfAttnPrefillStage, SwiGluStage,
23};
24use crate::stage::FlowStage;
25
26/// Stack transformer sub-blocks into one named layer stage.
27#[derive(Debug, Clone, Default)]
28pub struct LayerStack {
29    name: Option<String>,
30    stages: Vec<FlowStage>,
31}
32
33impl LayerStack {
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    pub fn named(name: impl Into<String>) -> Self {
39        Self {
40            name: Some(name.into()),
41            stages: Vec::new(),
42        }
43    }
44
45    pub fn layer_norm(
46        mut self,
47        gamma_key: impl Into<String>,
48        beta_key: impl Into<String>,
49        eps: f32,
50    ) -> Self {
51        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
52            gamma_key, beta_key, eps,
53        )));
54        self
55    }
56
57    pub fn gather_add(
58        mut self,
59        input_name: impl Into<String>,
60        weight_key: impl Into<String>,
61    ) -> Self {
62        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
63            input_name, weight_key, 0,
64        )));
65        self
66    }
67
68    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
69        self.stages
70            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
71        self
72    }
73
74    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
75        self.stages
76            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
77        self
78    }
79
80    pub fn residual_save(mut self) -> Self {
81        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
82        self
83    }
84
85    pub fn residual_add(mut self) -> Self {
86        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
87        self
88    }
89
90    pub fn swiglu(
91        mut self,
92        gate_key: impl Into<String>,
93        up_key: impl Into<String>,
94        down_key: impl Into<String>,
95    ) -> Self {
96        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
97            gate_key, up_key, down_key,
98        )));
99        self
100    }
101
102    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
103        self.stages
104            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
105        self
106    }
107
108    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
109        self.stages
110            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
111        self
112    }
113
114    pub fn stage(mut self, stage: FlowStage) -> Self {
115        self.stages.push(stage);
116        self
117    }
118
119    pub fn stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
120        self.stages.extend(stages);
121        self
122    }
123
124    pub fn build(self) -> FlowStage {
125        let inner = FlowStage::Sequence(self.stages);
126        match self.name {
127            Some(name) => FlowStage::Named {
128                name,
129                inner: Arc::new(inner),
130            },
131            None => inner,
132        }
133    }
134}