Skip to main content

rlx_compile/
algebraic_simplify.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Algebraic simplification on graphs with constant leaves.
17//!
18//! Complements [`crate::const_fold::ConstantFolding`] by rewriting ops that have
19//! one constant operand and one dynamic operand (e.g. `mul(x, 0) → 0`).
20
21use rlx_fusion::pass::Pass;
22use rlx_ir::op::BinaryOp;
23use rlx_ir::{Graph, NodeId, Op};
24use std::collections::HashMap;
25
26fn decode_f32(data: &[u8]) -> Vec<f32> {
27    data.chunks_exact(4)
28        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
29        .collect()
30}
31
32fn encode_f32(data: &[f32]) -> Vec<u8> {
33    let mut bytes = Vec::with_capacity(data.len() * 4);
34    for &v in data {
35        bytes.extend_from_slice(&v.to_le_bytes());
36    }
37    bytes
38}
39
40fn constant_f32_values(graph: &Graph, id: NodeId) -> Option<Vec<f32>> {
41    match &graph.node(id).op {
42        Op::Constant { data } => Some(decode_f32(data)),
43        _ => None,
44    }
45}
46
47fn is_all_zero(v: &[f32]) -> bool {
48    v.iter().all(|&x| x == 0.0)
49}
50
51fn is_all_one(v: &[f32]) -> bool {
52    v.iter().all(|&x| x == 1.0)
53}
54
55fn zeros_like(graph: &mut Graph, shape: &rlx_ir::Shape) -> NodeId {
56    let n = shape.num_elements().unwrap_or(1);
57    graph.add_node(
58        Op::Constant {
59            data: encode_f32(&vec![0.0; n]),
60        },
61        vec![],
62        shape.clone(),
63    )
64}
65
66/// One pass of local binary simplification.
67pub fn algebraic_simplify(graph: &Graph) -> Graph {
68    let mut out = Graph::new(graph.name.clone());
69    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
70
71    for node in graph.nodes() {
72        let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
73        let simplified = if let Op::Binary(op) = &node.op {
74            if new_inputs.len() != 2 {
75                None
76            } else {
77                let (a, b) = (new_inputs[0], new_inputs[1]);
78                let a_const = constant_f32_values(&out, a);
79                let b_const = constant_f32_values(&out, b);
80                let out_elems = node.shape.num_elements().unwrap_or(0);
81                let const_matches = |c: &[f32]| c.len() == out_elems || c.len() == 1;
82                match (op, a_const.as_deref(), b_const.as_deref()) {
83                    (BinaryOp::Add, Some(c), None) if const_matches(c) && is_all_zero(c) => Some(b),
84                    (BinaryOp::Add, None, Some(c)) if const_matches(c) && is_all_zero(c) => Some(a),
85                    (BinaryOp::Sub, None, Some(c)) if const_matches(c) && is_all_zero(c) => Some(a),
86                    (BinaryOp::Mul, Some(c), None)
87                        if const_matches(c) && (is_all_zero(c) || is_all_one(c)) =>
88                    {
89                        if is_all_zero(c) {
90                            Some(zeros_like(&mut out, &node.shape))
91                        } else {
92                            Some(b)
93                        }
94                    }
95                    (BinaryOp::Mul, None, Some(c))
96                        if const_matches(c) && (is_all_zero(c) || is_all_one(c)) =>
97                    {
98                        if is_all_zero(c) {
99                            Some(zeros_like(&mut out, &node.shape))
100                        } else {
101                            Some(a)
102                        }
103                    }
104                    _ => None,
105                }
106            }
107        } else {
108            None
109        };
110
111        let new_id = if let Some(reuse_id) = simplified {
112            reuse_id
113        } else {
114            out.add_node(node.op.clone(), new_inputs, node.shape.clone())
115        };
116        id_map.insert(node.id, new_id);
117    }
118
119    let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|o| id_map[o]).collect();
120    out.set_outputs(new_outputs);
121    out
122}
123
124pub struct AlgebraicSimplify;
125
126impl Pass for AlgebraicSimplify {
127    fn name(&self) -> &str {
128        "algebraic_simplify"
129    }
130
131    fn run(&self, graph: Graph) -> Graph {
132        algebraic_simplify(&graph)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use rlx_ir::Shape;
140    use rlx_ir::op::BinaryOp;
141    use rlx_ir::*;
142
143    #[test]
144    fn mul_by_zero_scalar_zeros_output() {
145        let s = Shape::new(&[4], DType::F32);
146        let mut g = Graph::new("t");
147        let x = g.input("x", s.clone());
148        let z = g.add_node(
149            Op::Constant {
150                data: 0.0f32.to_le_bytes().to_vec(),
151            },
152            vec![],
153            Shape::new(&[1], DType::F32),
154        );
155        let y = g.binary(BinaryOp::Mul, x, z, s.clone());
156        g.set_outputs(vec![y]);
157
158        let out = algebraic_simplify(&g);
159        assert!(matches!(out.node(out.outputs[0]).op, Op::Constant { .. }));
160    }
161}