Skip to main content

rlx_compile/
param_specialize.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Replace selected `Op::Param` nodes with `Op::Constant` before compile-time opts.
17//!
18//! Deploy graphs (e.g. pruned ternary FFT) often fix gate masks and zero twiddles
19//! at specialization time while still building the graph with `Graph::param`.
20//! Baking those values here lets constant folding and DCE remove dead paths.
21
22use rlx_fusion::pass::Pass;
23use rlx_ir::{Graph, NodeId, Op};
24use std::collections::HashMap;
25
26fn encode_f32(data: &[f32]) -> Vec<u8> {
27    let mut bytes = Vec::with_capacity(data.len() * 4);
28    for &v in data {
29        bytes.extend_from_slice(&v.to_le_bytes());
30    }
31    bytes
32}
33
34/// Substitute listed params with constants. Unlisted params are unchanged.
35pub fn specialize_params(graph: &Graph, bindings: &HashMap<String, Vec<f32>>) -> Graph {
36    if bindings.is_empty() {
37        return graph.clone();
38    }
39    let mut out = Graph::new(graph.name.clone());
40    let mut id_map: HashMap<NodeId, NodeId> = HashMap::new();
41
42    for node in graph.nodes() {
43        let new_id = match &node.op {
44            Op::Param { name } => {
45                if let Some(values) = bindings.get(name) {
46                    let expected = node.shape.num_elements().unwrap_or(values.len());
47                    assert_eq!(
48                        values.len(),
49                        expected,
50                        "param '{name}' binding len {} != shape elements {expected}",
51                        values.len()
52                    );
53                    out.add_node(
54                        Op::Constant {
55                            data: encode_f32(values),
56                        },
57                        vec![],
58                        node.shape.clone(),
59                    )
60                } else {
61                    out.add_node(node.op.clone(), vec![], node.shape.clone())
62                }
63            }
64            _ => {
65                let new_inputs: Vec<NodeId> = node.inputs.iter().map(|i| id_map[i]).collect();
66                out.add_node(node.op.clone(), new_inputs, node.shape.clone())
67            }
68        };
69        id_map.insert(node.id, new_id);
70    }
71
72    let new_outputs: Vec<NodeId> = graph.outputs.iter().map(|o| id_map[o]).collect();
73    out.set_outputs(new_outputs);
74    out
75}
76
77/// Pass wrapper for the fusion pipeline / runtime preprocess hook.
78pub struct SpecializeParams {
79    pub bindings: HashMap<String, Vec<f32>>,
80}
81
82impl SpecializeParams {
83    pub fn new(bindings: HashMap<String, Vec<f32>>) -> Self {
84        Self { bindings }
85    }
86}
87
88impl Pass for SpecializeParams {
89    fn name(&self) -> &str {
90        "specialize_params"
91    }
92
93    fn run(&self, graph: Graph) -> Graph {
94        specialize_params(&graph, &self.bindings)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use rlx_ir::Shape;
102    use rlx_ir::op::BinaryOp;
103    use rlx_ir::*;
104
105    #[test]
106    fn replaces_bound_param_with_constant() {
107        let s = Shape::new(&[2], DType::F32);
108        let mut g = Graph::new("t");
109        let x = g.input("x", s.clone());
110        let w = g.param("w", s.clone());
111        let y = g.binary(BinaryOp::Mul, x, w, s.clone());
112        g.set_outputs(vec![y]);
113
114        let mut bindings = HashMap::new();
115        bindings.insert("w".into(), vec![0.0, 1.0]);
116        let out = specialize_params(&g, &bindings);
117        let w_node = out.node(out.nodes()[1].id);
118        assert!(matches!(w_node.op, Op::Constant { .. }));
119    }
120}