Skip to main content

rlx_ir/ops/
blocks.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Fusion-friendly block builders — canonical subgraph shapes the
17//! optimizer passes in `rlx-opt` already recognize.
18//!
19//! Model authors should prefer these over hand-wiring `MatMul → Add →
20//! Activation` so fusion succeeds regardless of param declaration
21//! order elsewhere in the graph.
22
23use crate::infer::GraphExt;
24use crate::op::Activation;
25use crate::{Graph, NodeId, Op, Shape};
26
27impl Graph {
28    /// Dense linear layer: `matmul(input, weight)` with optional rank-1 bias.
29    pub fn linear_bias(&mut self, input: NodeId, weight: NodeId, bias: Option<NodeId>) -> NodeId {
30        let mm = self.mm(input, weight);
31        match bias {
32            Some(b) => self.add(mm, b),
33            None => mm,
34        }
35    }
36
37    /// Dense linear with optional bias and epilogue activation.
38    pub fn linear_bias_act(
39        &mut self,
40        input: NodeId,
41        weight: NodeId,
42        bias: Option<NodeId>,
43        activation: Option<Activation>,
44    ) -> NodeId {
45        let x = self.linear_bias(input, weight, bias);
46        activation.map_or(x, |act| self.activation_by_kind(x, act))
47    }
48
49    /// Emit `Op::FusedMatMulBiasAct` directly — deterministic fusion
50    /// without relying on the `FuseMatMulBiasAct` pass.
51    pub fn linear_fused(
52        &mut self,
53        input: NodeId,
54        weight: NodeId,
55        bias: NodeId,
56        activation: Option<Activation>,
57        out_shape: Shape,
58    ) -> NodeId {
59        self.fused_matmul_bias_act(input, weight, bias, activation, out_shape)
60    }
61
62    /// Two matmuls sharing the same input — canonical gate+up / QKV
63    /// pattern for `FuseSharedInputMatMul`.
64    ///
65    /// Returns `(first, second)` in declaration order. For SwiGLU,
66    /// pass **up** weight first and **gate** weight second so the
67    /// post-concat narrow layout matches `FuseSwiGLU` (up @ 0, gate @ N).
68    pub fn shared_matmul_pair(
69        &mut self,
70        input: NodeId,
71        w_first: NodeId,
72        w_second: NodeId,
73    ) -> (NodeId, NodeId) {
74        let first = self.mm(input, w_first);
75        let second = self.mm(input, w_second);
76        (first, second)
77    }
78
79    /// SwiGLU FFN block: shared-input gate+up → `silu(gate) * up` → down proj.
80    ///
81    /// Weight order matches `FuseSwiGLU`'s canonical narrow layout
82    /// (up projection first, gate projection second).
83    pub fn swiglu_ffn(
84        &mut self,
85        input: NodeId,
86        up_w: NodeId,
87        gate_w: NodeId,
88        down_w: NodeId,
89    ) -> NodeId {
90        let (up, gate) = self.shared_matmul_pair(input, up_w, gate_w);
91        let gate_silu = self.silu(gate);
92        let hidden = self.mul(up, gate_silu);
93        self.mm(hidden, down_w)
94    }
95
96    /// Fully fused SwiGLU FFN: concat weights → single matmul →
97    /// [`Op::FusedSwiGLU`] → down projection. Matches the rewrite
98    /// performed by [`FuseSwiGLUDualMatmul`](../../rlx-opt/src/fusion.rs)
99    /// without relying on the pass.
100    pub fn fused_swiglu_ffn(
101        &mut self,
102        input: NodeId,
103        up_w: NodeId,
104        gate_w: NodeId,
105        down_w: NodeId,
106        out_shape: Shape,
107    ) -> NodeId {
108        let wu_shape = self.shape(up_w);
109        let wg_shape = self.shape(gate_w);
110        let k = wu_shape.dim(0).unwrap_static();
111        let n_up = wu_shape.dim(1).unwrap_static();
112        let n_gate = wg_shape.dim(1).unwrap_static();
113        debug_assert_eq!(wu_shape.dim(0), wg_shape.dim(0));
114
115        let concat_shape = Shape::new(&[k, n_up + n_gate], wu_shape.dtype());
116        let concat_w = self.concat(vec![up_w, gate_w], 1, concat_shape);
117
118        let input_shape = self.shape(input);
119        let out_rank = input_shape.rank();
120        let dtype = input_shape.dtype();
121        let mut cat_dims: Vec<usize> = (0..out_rank)
122            .map(|i| input_shape.dim(i).unwrap_static())
123            .collect();
124        cat_dims[out_rank - 1] = n_up + n_gate;
125        let cat_shape = Shape::new(&cat_dims, dtype);
126        let cat_mm = self.matmul(input, concat_w, cat_shape);
127
128        let mut hidden_dims = cat_dims;
129        hidden_dims[out_rank - 1] = n_up;
130        let hidden_shape = Shape::new(&hidden_dims, dtype);
131        let hidden = self.add_node(
132            Op::FusedSwiGLU {
133                cast_to: None,
134                gate_first: false,
135            },
136            vec![cat_mm],
137            hidden_shape,
138        );
139
140        let _ = out_shape;
141        self.mm(hidden, down_w)
142    }
143
144    fn activation_by_kind(&mut self, x: NodeId, act: Activation) -> NodeId {
145        match act {
146            Activation::Gelu => self.gelu(x),
147            Activation::GeluApprox => self.gelu_approx(x),
148            Activation::Silu => self.silu(x),
149            Activation::Relu => self.relu(x),
150            Activation::Exp => self.exp(x),
151            Activation::Sqrt => self.sqrt(x),
152            Activation::Neg => self.neg(x),
153            Activation::Tanh => self.tanh(x),
154            Activation::Sigmoid => {
155                let s = self.shape(x).clone();
156                self.activation(Activation::Sigmoid, x, s)
157            }
158            Activation::Log => {
159                let s = self.shape(x).clone();
160                self.activation(Activation::Log, x, s)
161            }
162            _ => {
163                let s = self.shape(x).clone();
164                self.activation(act, x, s)
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::op::BinaryOp;
174    use crate::{DType, Op};
175
176    fn f32_shape(dims: &[usize]) -> Shape {
177        Shape::new(dims, DType::F32)
178    }
179
180    #[test]
181    fn linear_bias_act_emits_canonical_chain() {
182        let mut g = Graph::new("linear");
183        let x = g.input("x", f32_shape(&[4, 8]));
184        let w = g.param("w", f32_shape(&[8, 16]));
185        let b = g.param("b", f32_shape(&[16]));
186        let out = g.linear_bias_act(x, w, Some(b), Some(Activation::Silu));
187        g.set_outputs(vec![out]);
188
189        let act = g.node(out);
190        assert!(matches!(act.op, Op::Activation(Activation::Silu)));
191        let add = g.node(act.inputs[0]);
192        assert!(matches!(add.op, Op::Binary(BinaryOp::Add)));
193        let mm = g.node(add.inputs[0]);
194        assert!(matches!(mm.op, Op::MatMul));
195    }
196}