Skip to main content

rlx_ir/
binding_manifest.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Reflected binding layout from a specialized [`LirModule`].
17//!
18//! Host code uses this to fill weights and inputs without hand-maintaining parallel
19//! struct layouts (the shading-system “parameter block” pattern).
20
21use crate::lir::LirModule;
22use crate::{DType, NodeId, Shape};
23
24/// One named graph boundary with arena layout after buffer planning.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct IoBindingEntry {
27    pub name: String,
28    pub node: NodeId,
29    pub dtype: DType,
30    pub shape: Shape,
31    pub elem_count: usize,
32    pub byte_size: usize,
33    pub arena_offset: Option<usize>,
34    pub arena_size: Option<usize>,
35    pub is_view: bool,
36}
37
38/// Full I/O + parameter manifest for a compiled graph.
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct BindingManifest {
41    pub graph_name: String,
42    pub arena_size: usize,
43    pub alignment: usize,
44    pub inputs: Vec<IoBindingEntry>,
45    pub params: Vec<IoBindingEntry>,
46    pub outputs: Vec<IoBindingEntry>,
47}
48
49impl BindingManifest {
50    pub fn from_lir(lir: &LirModule) -> Self {
51        let graph = lir.as_graph();
52        let plan = lir.plan();
53        let io = &plan.io;
54
55        let mut inputs = Vec::new();
56        for (name, id) in &io.inputs {
57            if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
58                inputs.push(e);
59            }
60        }
61
62        let mut params = Vec::new();
63        for (name, id) in &io.params {
64            if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
65                params.push(e);
66            }
67        }
68
69        let mut outputs = Vec::new();
70        for (i, id) in io.outputs.iter().enumerate() {
71            let name = format!("output{i}");
72            if let Some(e) = entry_for_node(graph, plan, name, *id) {
73                outputs.push(e);
74            }
75        }
76
77        Self {
78            graph_name: lir.name().to_string(),
79            arena_size: plan.arena_size,
80            alignment: plan.alignment,
81            inputs,
82            params,
83            outputs,
84        }
85    }
86
87    pub fn param_names(&self) -> impl Iterator<Item = &str> {
88        self.params.iter().map(|p| p.name.as_str())
89    }
90
91    pub fn input_names(&self) -> impl Iterator<Item = &str> {
92        self.inputs.iter().map(|p| p.name.as_str())
93    }
94
95    pub fn param_byte_size(&self, name: &str) -> Option<usize> {
96        self.params
97            .iter()
98            .find(|p| p.name == name)
99            .map(|p| p.byte_size)
100    }
101
102    pub fn total_param_bytes(&self) -> usize {
103        self.params.iter().map(|p| p.byte_size).sum()
104    }
105
106    /// Group parameters by dot-prefix (`layer0.attn` → block `layer0`).
107    pub fn weight_blocks(&self) -> Vec<WeightBlock> {
108        let mut blocks: std::collections::HashMap<String, Vec<IoBindingEntry>> =
109            std::collections::HashMap::new();
110        for p in &self.params {
111            let block = p.name.split('.').next().unwrap_or(&p.name).to_string();
112            blocks.entry(block).or_default().push(p.clone());
113        }
114        let mut out: Vec<WeightBlock> = blocks
115            .into_iter()
116            .map(|(prefix, params)| {
117                let byte_size = params.iter().map(|e| e.byte_size).sum();
118                WeightBlock {
119                    prefix,
120                    params,
121                    byte_size,
122                }
123            })
124            .collect();
125        out.sort_by(|a, b| a.prefix.cmp(&b.prefix));
126        out
127    }
128}
129
130/// Nested parameter block (Slang PerFrame / material grouping).
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct WeightBlock {
133    pub prefix: String,
134    pub params: Vec<IoBindingEntry>,
135    pub byte_size: usize,
136}
137
138fn entry_for_node(
139    graph: &crate::Graph,
140    plan: &crate::lir::LirBufferPlan,
141    name: String,
142    id: NodeId,
143) -> Option<IoBindingEntry> {
144    let node = graph.node(id);
145    let elem_count = node.shape.num_elements().unwrap_or(0);
146    let byte_size = elem_count * node.shape.dtype().size_bytes();
147    let (arena_offset, arena_size, is_view) = if let Some(alias) = plan.view_aliases.get(&id) {
148        let root_slot = plan.slot(alias.root)?;
149        (
150            Some(root_slot.offset + alias.byte_offset),
151            Some(byte_size),
152            true,
153        )
154    } else if let Some(slot) = plan.slot(id) {
155        (Some(slot.offset), Some(slot.size), false)
156    } else {
157        (None, None, false)
158    };
159    Some(IoBindingEntry {
160        name,
161        node: id,
162        dtype: node.shape.dtype(),
163        shape: node.shape.clone(),
164        elem_count,
165        byte_size,
166        arena_offset,
167        arena_size,
168        is_view,
169    })
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::Graph;
176    use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
177
178    #[test]
179    fn manifest_lists_params_with_sizes() {
180        let mut g = Graph::new("t");
181        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
182        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
183        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
184        g.set_outputs(vec![mm]);
185
186        let mut plan = LirBufferPlan {
187            io: LirIoManifest::collect(&g),
188            ..Default::default()
189        };
190        plan.assignments.insert(
191            x,
192            LirBufferSlot {
193                offset: 0,
194                size: 32,
195            },
196        );
197        plan.assignments.insert(
198            w,
199            LirBufferSlot {
200                offset: 32,
201                size: 48,
202            },
203        );
204        plan.assignments.insert(
205            mm,
206            LirBufferSlot {
207                offset: 80,
208                size: 24,
209            },
210        );
211        plan.arena_size = 104;
212
213        let lir = LirModule::new(crate::MirModule::from_graph(g), plan);
214        let m = BindingManifest::from_lir(&lir);
215        assert_eq!(m.params.len(), 1);
216        assert_eq!(m.params[0].name, "w");
217        assert_eq!(m.params[0].byte_size, 48);
218        assert_eq!(m.inputs[0].name, "x");
219    }
220}