Skip to main content

rlx_flow/
layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Fluent per-layer composer — stack small blocks without IR/Graph imports.
5
6use std::sync::Arc;
7
8use crate::blocks::{
9    GatherAddStage, LayerNormStage, LinearStage, ResidualAddStage, ResidualSaveStage, RmsNormStage,
10    SelfAttnPrefillSpec, SelfAttnPrefillStage, SwiGluStage,
11};
12use crate::stage::FlowStage;
13
14/// Stack transformer sub-blocks into one named layer stage.
15#[derive(Debug, Clone, Default)]
16pub struct LayerStack {
17    name: Option<String>,
18    stages: Vec<FlowStage>,
19}
20
21impl LayerStack {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn named(name: impl Into<String>) -> Self {
27        Self {
28            name: Some(name.into()),
29            stages: Vec::new(),
30        }
31    }
32
33    pub fn layer_norm(
34        mut self,
35        gamma_key: impl Into<String>,
36        beta_key: impl Into<String>,
37        eps: f32,
38    ) -> Self {
39        self.stages.push(FlowStage::LayerNorm(LayerNormStage::new(
40            gamma_key, beta_key, eps,
41        )));
42        self
43    }
44
45    pub fn gather_add(
46        mut self,
47        input_name: impl Into<String>,
48        weight_key: impl Into<String>,
49    ) -> Self {
50        self.stages.push(FlowStage::GatherAdd(GatherAddStage::new(
51            input_name, weight_key, 0,
52        )));
53        self
54    }
55
56    pub fn rms_norm(mut self, weight_key: impl Into<String>, eps: f32) -> Self {
57        self.stages
58            .push(FlowStage::RmsNorm(RmsNormStage::new(weight_key, eps)));
59        self
60    }
61
62    pub fn linear(mut self, weight_key: impl Into<String>, transpose: bool) -> Self {
63        self.stages
64            .push(FlowStage::Linear(LinearStage::new(weight_key, transpose)));
65        self
66    }
67
68    pub fn residual_save(mut self) -> Self {
69        self.stages.push(FlowStage::ResidualSave(ResidualSaveStage));
70        self
71    }
72
73    pub fn residual_add(mut self) -> Self {
74        self.stages.push(FlowStage::ResidualAdd(ResidualAddStage));
75        self
76    }
77
78    pub fn swiglu(
79        mut self,
80        gate_key: impl Into<String>,
81        up_key: impl Into<String>,
82        down_key: impl Into<String>,
83    ) -> Self {
84        self.stages.push(FlowStage::SwiGlu(SwiGluStage::new(
85            gate_key, up_key, down_key,
86        )));
87        self
88    }
89
90    pub fn swiglu_hf_mlp(mut self, prefix: impl Into<String>) -> Self {
91        self.stages
92            .push(FlowStage::SwiGlu(SwiGluStage::hf_mlp(prefix)));
93        self
94    }
95
96    pub fn self_attn_prefill(mut self, spec: SelfAttnPrefillSpec) -> Self {
97        self.stages
98            .push(FlowStage::SelfAttnPrefill(SelfAttnPrefillStage::new(spec)));
99        self
100    }
101
102    pub fn stage(mut self, stage: FlowStage) -> Self {
103        self.stages.push(stage);
104        self
105    }
106
107    pub fn stages(mut self, stages: impl IntoIterator<Item = FlowStage>) -> Self {
108        self.stages.extend(stages);
109        self
110    }
111
112    pub fn build(self) -> FlowStage {
113        let inner = FlowStage::Sequence(self.stages);
114        match self.name {
115            Some(name) => FlowStage::Named {
116                name,
117                inner: Arc::new(inner),
118            },
119            None => inner,
120        }
121    }
122}