Skip to main content

rlx_ir/
binding_manifest.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Reflected binding layout from a specialized [`LirModule`].
5//!
6//! Host code uses this to fill weights and inputs without hand-maintaining parallel
7//! struct layouts (the shading-system “parameter block” pattern).
8
9use crate::lir::LirModule;
10use crate::{DType, NodeId, Shape};
11
12/// One named graph boundary with arena layout after buffer planning.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct IoBindingEntry {
15    pub name: String,
16    pub node: NodeId,
17    pub dtype: DType,
18    pub shape: Shape,
19    pub elem_count: usize,
20    pub byte_size: usize,
21    pub arena_offset: Option<usize>,
22    pub arena_size: Option<usize>,
23    pub is_view: bool,
24}
25
26/// Full I/O + parameter manifest for a compiled graph.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct BindingManifest {
29    pub graph_name: String,
30    pub arena_size: usize,
31    pub alignment: usize,
32    pub inputs: Vec<IoBindingEntry>,
33    pub params: Vec<IoBindingEntry>,
34    pub outputs: Vec<IoBindingEntry>,
35}
36
37impl BindingManifest {
38    pub fn from_lir(lir: &LirModule) -> Self {
39        let graph = lir.as_graph();
40        let plan = lir.plan();
41        let io = &plan.io;
42
43        let mut inputs = Vec::new();
44        for (name, id) in &io.inputs {
45            if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
46                inputs.push(e);
47            }
48        }
49
50        let mut params = Vec::new();
51        for (name, id) in &io.params {
52            if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
53                params.push(e);
54            }
55        }
56
57        let mut outputs = Vec::new();
58        for (i, id) in io.outputs.iter().enumerate() {
59            let name = format!("output{i}");
60            if let Some(e) = entry_for_node(graph, plan, name, *id) {
61                outputs.push(e);
62            }
63        }
64
65        Self {
66            graph_name: lir.name().to_string(),
67            arena_size: plan.arena_size,
68            alignment: plan.alignment,
69            inputs,
70            params,
71            outputs,
72        }
73    }
74
75    pub fn param_names(&self) -> impl Iterator<Item = &str> {
76        self.params.iter().map(|p| p.name.as_str())
77    }
78
79    pub fn input_names(&self) -> impl Iterator<Item = &str> {
80        self.inputs.iter().map(|p| p.name.as_str())
81    }
82
83    pub fn param_byte_size(&self, name: &str) -> Option<usize> {
84        self.params
85            .iter()
86            .find(|p| p.name == name)
87            .map(|p| p.byte_size)
88    }
89
90    pub fn total_param_bytes(&self) -> usize {
91        self.params.iter().map(|p| p.byte_size).sum()
92    }
93
94    /// Group parameters by dot-prefix (`layer0.attn` → block `layer0`).
95    pub fn weight_blocks(&self) -> Vec<WeightBlock> {
96        let mut blocks: std::collections::HashMap<String, Vec<IoBindingEntry>> =
97            std::collections::HashMap::new();
98        for p in &self.params {
99            let block = p.name.split('.').next().unwrap_or(&p.name).to_string();
100            blocks.entry(block).or_default().push(p.clone());
101        }
102        let mut out: Vec<WeightBlock> = blocks
103            .into_iter()
104            .map(|(prefix, params)| {
105                let byte_size = params.iter().map(|e| e.byte_size).sum();
106                WeightBlock {
107                    prefix,
108                    params,
109                    byte_size,
110                }
111            })
112            .collect();
113        out.sort_by(|a, b| a.prefix.cmp(&b.prefix));
114        out
115    }
116}
117
118/// Nested parameter block (Slang PerFrame / material grouping).
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct WeightBlock {
121    pub prefix: String,
122    pub params: Vec<IoBindingEntry>,
123    pub byte_size: usize,
124}
125
126fn entry_for_node(
127    graph: &crate::Graph,
128    plan: &crate::lir::LirBufferPlan,
129    name: String,
130    id: NodeId,
131) -> Option<IoBindingEntry> {
132    let node = graph.node(id);
133    let elem_count = node.shape.num_elements().unwrap_or(0);
134    let byte_size = elem_count * node.shape.dtype().size_bytes();
135    let (arena_offset, arena_size, is_view) = if let Some(alias) = plan.view_aliases.get(&id) {
136        let root_slot = plan.slot(alias.root)?;
137        (
138            Some(root_slot.offset + alias.byte_offset),
139            Some(byte_size),
140            true,
141        )
142    } else if let Some(slot) = plan.slot(id) {
143        (Some(slot.offset), Some(slot.size), false)
144    } else {
145        (None, None, false)
146    };
147    Some(IoBindingEntry {
148        name,
149        node: id,
150        dtype: node.shape.dtype(),
151        shape: node.shape.clone(),
152        elem_count,
153        byte_size,
154        arena_offset,
155        arena_size,
156        is_view,
157    })
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use crate::Graph;
164    use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
165
166    #[test]
167    fn manifest_lists_params_with_sizes() {
168        let mut g = Graph::new("t");
169        let x = g.input("x", Shape::new(&[2, 4], DType::F32));
170        let w = g.param("w", Shape::new(&[4, 3], DType::F32));
171        let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
172        g.set_outputs(vec![mm]);
173
174        let mut plan = LirBufferPlan {
175            io: LirIoManifest::collect(&g),
176            ..Default::default()
177        };
178        plan.assignments.insert(
179            x,
180            LirBufferSlot {
181                offset: 0,
182                size: 32,
183            },
184        );
185        plan.assignments.insert(
186            w,
187            LirBufferSlot {
188                offset: 32,
189                size: 48,
190            },
191        );
192        plan.assignments.insert(
193            mm,
194            LirBufferSlot {
195                offset: 80,
196                size: 24,
197            },
198        );
199        plan.arena_size = 104;
200
201        let lir = LirModule::new(crate::MirModule::from_graph(g), plan);
202        let m = BindingManifest::from_lir(&lir);
203        assert_eq!(m.params.len(), 1);
204        assert_eq!(m.params[0].name, "w");
205        assert_eq!(m.params[0].byte_size, 48);
206        assert_eq!(m.inputs[0].name, "x");
207    }
208}