rlx_compile/
algebraic_simplify.rs1use rlx_fusion::pass::Pass;
22use rlx_ir::op::BinaryOp;
23use rlx_ir::{Graph, NodeId, Op};
24use std::collections::HashMap;
25
26fn decode_f32(data: &[u8]) -> Vec<f32> {
27 data.chunks_exact(4)
28 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
29 .collect()
30}
31
32fn encode_f32(data: &[f32]) -> Vec<u8> {
33 let mut bytes = Vec::with_capacity(data.len() * 4);
34 for &v in data {
35 bytes.extend_from_slice(&v.to_le_bytes());
36 }
37 bytes
38}
39
40fn constant_f32_values(graph: &Graph, id: NodeId) -> Option<Vec<f32>> {
41 match &graph.node(id).op {
42 Op::Constant { data } => Some(decode_f32(data)),
43 _ => None,
44 }
45}
46
47fn is_all_zero(v: &[f32]) -> bool {
48 v.iter().all(|&x| x == 0.0)
49}
50
51fn is_all_one(v: &[f32]) -> bool {
52 v.iter().all(|&x| x == 1.0)
53}
54
55fn zeros_like(graph: &mut Graph, shape: &rlx_ir::Shape) -> NodeId {
56 let n = shape.num_elements().unwrap_or(1);
57 graph.add_node(
58 Op::Constant {
59 data: encode_f32(&vec![0.0; n]),
60 },
61 vec![],
62 shape.clone(),
63 )
64}
65
66pub fn algebraic_simplify(graph: &Graph) -> Graph {
68 let mut out = Graph::new(graph.name.clone());
69 let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
70
71 for node in graph.nodes() {
72 let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
73 let simplified = if let Op::Binary(op) = &node.op {
74 if new_inputs.len() != 2 {
75 None
76 } else {
77 let (a, b) = (new_inputs[0], new_inputs[1]);
78 let a_const = constant_f32_values(&out, a);
79 let b_const = constant_f32_values(&out, b);
80 let out_elems = node.shape.num_elements().unwrap_or(0);
81 let const_matches = |c: &[f32]| c.len() == out_elems || c.len() == 1;
82 match (op, a_const.as_deref(), b_const.as_deref()) {
83 (BinaryOp::Add, Some(c), None) if const_matches(c) && is_all_zero(c) => Some(b),
84 (BinaryOp::Add, None, Some(c)) if const_matches(c) && is_all_zero(c) => Some(a),
85 (BinaryOp::Sub, None, Some(c)) if const_matches(c) && is_all_zero(c) => Some(a),
86 (BinaryOp::Mul, Some(c), None)
87 if const_matches(c) && (is_all_zero(c) || is_all_one(c)) =>
88 {
89 if is_all_zero(c) {
90 Some(zeros_like(&mut out, &node.shape))
91 } else {
92 Some(b)
93 }
94 }
95 (BinaryOp::Mul, None, Some(c))
96 if const_matches(c) && (is_all_zero(c) || is_all_one(c)) =>
97 {
98 if is_all_zero(c) {
99 Some(zeros_like(&mut out, &node.shape))
100 } else {
101 Some(a)
102 }
103 }
104 _ => None,
105 }
106 }
107 } else {
108 None
109 };
110
111 let new_id = if let Some(reuse_id) = simplified {
112 reuse_id
113 } else {
114 out.add_node(node.op.clone(), new_inputs, node.shape.clone())
115 };
116 id_map.insert(node.id, new_id);
117 }
118
119 let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|o| id_map[o]).collect();
120 out.set_outputs(new_outputs);
121 out
122}
123
124pub struct AlgebraicSimplify;
125
126impl Pass for AlgebraicSimplify {
127 fn name(&self) -> &str {
128 "algebraic_simplify"
129 }
130
131 fn run(&self, graph: Graph) -> Graph {
132 algebraic_simplify(&graph)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use rlx_ir::Shape;
140 use rlx_ir::op::BinaryOp;
141 use rlx_ir::*;
142
143 #[test]
144 fn mul_by_zero_scalar_zeros_output() {
145 let s = Shape::new(&[4], DType::F32);
146 let mut g = Graph::new("t");
147 let x = g.input("x", s.clone());
148 let z = g.add_node(
149 Op::Constant {
150 data: 0.0f32.to_le_bytes().to_vec(),
151 },
152 vec![],
153 Shape::new(&[1], DType::F32),
154 );
155 let y = g.binary(BinaryOp::Mul, x, z, s.clone());
156 g.set_outputs(vec![y]);
157
158 let out = algebraic_simplify(&g);
159 assert!(matches!(out.node(out.outputs[0]).op, Op::Constant { .. }));
160 }
161}