Skip to main content

rlx_compile/
const_fold.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Constant Folding — evaluate pure-input subgraphs at compile time.
17//!
18//! A node is foldable when all its inputs are foldable AND the op has
19//! a deterministic, pure evaluation (no I/O, no random). We evaluate
20//! such subgraphs once at compile time and replace them with `Op::Constant`.
21//!
22//! Examples that get folded:
23//! - `1.0 / sqrt(head_dim)` (attention scale factor)
24//! - `cast(known_param)` to a different dtype
25//! - small reshapes/expands of constants
26//!
27//! For a typical transformer, constant folding only catches scattered
28//! arithmetic — but it eliminates 10–50 redundant ops over a 12-layer
29//! model and shrinks the arena slightly.
30
31use rlx_fusion::pass::Pass;
32use rlx_ir::op::{Activation, BinaryOp};
33use rlx_ir::{Graph, NodeId, Op};
34use std::collections::{HashMap, HashSet};
35
36pub struct ConstantFolding;
37
38/// True if this op can be evaluated symbolically with no runtime state.
39fn is_pure(op: &Op) -> bool {
40    matches!(
41        op,
42        Op::Activation(_)
43            | Op::Binary(_)
44            | Op::Compare(_)
45            | Op::Reshape { .. }
46            | Op::Expand { .. }
47            | Op::Cast { .. }
48    )
49}
50
51/// True if the node's inputs are all known constants (Param, Constant, or
52/// previously-folded result).
53fn is_foldable(node_id: NodeId, graph: &Graph, folded: &HashSet<NodeId>) -> bool {
54    let node = graph.node(node_id);
55    if !is_pure(&node.op) {
56        return false;
57    }
58    node.inputs.iter().all(|i| folded.contains(i))
59}
60
61/// Evaluate a foldable node given precomputed input values.
62/// Returns a flat f32 buffer of the result, or None if not supported.
63fn evaluate(node: &rlx_ir::Node, inputs: &[&Vec<f32>]) -> Option<Vec<f32>> {
64    let total = node.shape.num_elements()?;
65    let mut out = vec![0f32; total];
66
67    match &node.op {
68        Op::Activation(act) => {
69            let x = inputs[0];
70            for (i, &v) in x.iter().enumerate() {
71                out[i] = match act {
72                    Activation::Gelu | Activation::GeluApprox => {
73                        v * 0.5 * (1.0 + (v * std::f32::consts::FRAC_1_SQRT_2).tanh())
74                    }
75                    Activation::Silu => v / (1.0 + (-v).exp()),
76                    Activation::Relu => v.max(0.0),
77                    Activation::Sigmoid => 1.0 / (1.0 + (-v).exp()),
78                    Activation::Tanh => v.tanh(),
79                    Activation::Exp => v.exp(),
80                    Activation::Log => v.ln(),
81                    Activation::Sqrt => v.sqrt(),
82                    Activation::Rsqrt => 1.0 / v.sqrt(),
83                    Activation::Neg => -v,
84                    Activation::Abs => v.abs(),
85                    Activation::Round => v.round(),
86                    Activation::Sin => v.sin(),
87                    Activation::Cos => v.cos(),
88                    Activation::Tan => v.tan(),
89                    Activation::Atan => v.atan(),
90                };
91            }
92            Some(out)
93        }
94        Op::Binary(op) => {
95            let lhs = inputs[0];
96            let rhs = inputs[1];
97            // Naive: support same-shape only. Broadcast handled later.
98            if lhs.len() != total || rhs.len() != total {
99                return None;
100            }
101            for i in 0..total {
102                out[i] = match op {
103                    BinaryOp::Add => lhs[i] + rhs[i],
104                    BinaryOp::Sub => lhs[i] - rhs[i],
105                    BinaryOp::Mul => lhs[i] * rhs[i],
106                    BinaryOp::Div => lhs[i] / rhs[i],
107                    BinaryOp::Max => lhs[i].max(rhs[i]),
108                    BinaryOp::Min => lhs[i].min(rhs[i]),
109                    BinaryOp::Pow => lhs[i].powf(rhs[i]),
110                };
111            }
112            Some(out)
113        }
114        Op::Reshape { .. } | Op::Expand { .. } | Op::Cast { .. } => {
115            // Same data, just reshape/cast. For now: copy through as f32.
116            let src = inputs[0];
117            if src.len() == total {
118                Some(src.clone())
119            } else if src.len() == 1 {
120                Some(vec![src[0]; total])
121            } else {
122                None
123            }
124        }
125        _ => None,
126    }
127}
128
129/// Encode an f32 buffer as raw bytes for `Op::Constant`.
130fn encode_constant(data: &[f32]) -> Vec<u8> {
131    let mut bytes = Vec::with_capacity(data.len() * 4);
132    for &v in data {
133        bytes.extend_from_slice(&v.to_le_bytes());
134    }
135    bytes
136}
137
138impl Pass for ConstantFolding {
139    fn name(&self) -> &str {
140        "constant_folding"
141    }
142
143    fn run(&self, graph: Graph) -> Graph {
144        // Walk in topological order, tracking which nodes are foldable
145        // and accumulating their evaluated values.
146        let mut folded: HashSet<NodeId> = HashSet::new();
147        let mut values: HashMap<NodeId, Vec<f32>> = HashMap::new();
148
149        for node in graph.nodes() {
150            // Constant nodes are trivially foldable (we already have the data).
151            if let Op::Constant { data } = &node.op {
152                folded.insert(node.id);
153                let f32s: Vec<f32> = data
154                    .chunks_exact(4)
155                    .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
156                    .collect();
157                values.insert(node.id, f32s);
158                continue;
159            }
160            // Inputs/Params are NOT foldable (their values are runtime).
161            if matches!(node.op, Op::Input { .. } | Op::Param { .. }) {
162                continue;
163            }
164            // Try to fold pure ops with all-constant inputs.
165            if is_foldable(node.id, &graph, &folded) {
166                let inputs: Vec<&Vec<f32>> = node.inputs.iter().map(|i| &values[i]).collect();
167                if let Some(result) = evaluate(node, &inputs) {
168                    folded.insert(node.id);
169                    values.insert(node.id, result);
170                }
171            }
172        }
173
174        // Rebuild: replace folded nodes with Op::Constant, rewire others.
175        let mut new_graph = Graph::new(&graph.name);
176        let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
177        for node in graph.nodes() {
178            // Foldable downstream nodes get replaced with Constant unless
179            // they're terminal Constants/Params themselves.
180            if folded.contains(&node.id)
181                && !matches!(
182                    node.op,
183                    Op::Constant { .. } | Op::Param { .. } | Op::Input { .. }
184                )
185            {
186                let bytes = encode_constant(&values[&node.id]);
187                let new_id =
188                    new_graph.add_node(Op::Constant { data: bytes }, vec![], node.shape.clone());
189                id_map.insert(node.id, new_id);
190                continue;
191            }
192            // Otherwise copy the node, remapping inputs.
193            let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
194            let new_id = new_graph.add_node(node.op.clone(), new_inputs, node.shape.clone());
195            id_map.insert(node.id, new_id);
196        }
197        let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|i| id_map[i]).collect();
198        new_graph.set_outputs(new_outputs);
199        new_graph
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use rlx_ir::*;
207
208    #[test]
209    fn folds_constant_arithmetic() {
210        // const(2.0) + const(3.0) → const(5.0)
211        let mut g = Graph::new("test");
212        let a = g.add_node(
213            Op::Constant {
214                data: 2.0f32.to_le_bytes().to_vec(),
215            },
216            vec![],
217            Shape::new(&[1], DType::F32),
218        );
219        let b = g.add_node(
220            Op::Constant {
221                data: 3.0f32.to_le_bytes().to_vec(),
222            },
223            vec![],
224            Shape::new(&[1], DType::F32),
225        );
226        let sum = g.binary(op::BinaryOp::Add, a, b, Shape::new(&[1], DType::F32));
227        g.set_outputs(vec![sum]);
228
229        let folded = ConstantFolding.run(g);
230        // After folding, the Add node should be a Constant with value 5.0
231        let out_node = folded.node(folded.outputs[0]);
232        if let Op::Constant { data } = &out_node.op {
233            let v = f32::from_le_bytes([data[0], data[1], data[2], data[3]]);
234            assert!((v - 5.0).abs() < 1e-6);
235        } else {
236            panic!("expected folded Constant, got {:?}", out_node.op);
237        }
238    }
239
240    #[test]
241    fn does_not_fold_input_dependent() {
242        let mut g = Graph::new("test");
243        let x = g.input("x", Shape::new(&[4], DType::F32));
244        let c = g.add_node(
245            Op::Constant {
246                data: vec![0u8; 16],
247            },
248            vec![],
249            Shape::new(&[4], DType::F32),
250        );
251        let sum = g.binary(op::BinaryOp::Add, x, c, Shape::new(&[4], DType::F32));
252        g.set_outputs(vec![sum]);
253
254        let folded = ConstantFolding.run(g);
255        // x + c is input-dependent; should NOT be folded.
256        assert!(matches!(folded.node(folded.outputs[0]).op, Op::Binary(_)));
257    }
258}