1use crate::lir::LirModule;
22use crate::{DType, NodeId, Shape};
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct IoBindingEntry {
27 pub name: String,
28 pub node: NodeId,
29 pub dtype: DType,
30 pub shape: Shape,
31 pub elem_count: usize,
32 pub byte_size: usize,
33 pub arena_offset: Option<usize>,
34 pub arena_size: Option<usize>,
35 pub is_view: bool,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub struct BindingManifest {
41 pub graph_name: String,
42 pub arena_size: usize,
43 pub alignment: usize,
44 pub inputs: Vec<IoBindingEntry>,
45 pub params: Vec<IoBindingEntry>,
46 pub outputs: Vec<IoBindingEntry>,
47}
48
49impl BindingManifest {
50 pub fn from_lir(lir: &LirModule) -> Self {
51 let graph = lir.as_graph();
52 let plan = lir.plan();
53 let io = &plan.io;
54
55 let mut inputs = Vec::new();
56 for (name, id) in &io.inputs {
57 if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
58 inputs.push(e);
59 }
60 }
61
62 let mut params = Vec::new();
63 for (name, id) in &io.params {
64 if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
65 params.push(e);
66 }
67 }
68
69 let mut outputs = Vec::new();
70 for (i, id) in io.outputs.iter().enumerate() {
71 let name = format!("output{i}");
72 if let Some(e) = entry_for_node(graph, plan, name, *id) {
73 outputs.push(e);
74 }
75 }
76
77 Self {
78 graph_name: lir.name().to_string(),
79 arena_size: plan.arena_size,
80 alignment: plan.alignment,
81 inputs,
82 params,
83 outputs,
84 }
85 }
86
87 pub fn param_names(&self) -> impl Iterator<Item = &str> {
88 self.params.iter().map(|p| p.name.as_str())
89 }
90
91 pub fn input_names(&self) -> impl Iterator<Item = &str> {
92 self.inputs.iter().map(|p| p.name.as_str())
93 }
94
95 pub fn param_byte_size(&self, name: &str) -> Option<usize> {
96 self.params
97 .iter()
98 .find(|p| p.name == name)
99 .map(|p| p.byte_size)
100 }
101
102 pub fn total_param_bytes(&self) -> usize {
103 self.params.iter().map(|p| p.byte_size).sum()
104 }
105
106 pub fn weight_blocks(&self) -> Vec<WeightBlock> {
108 let mut blocks: std::collections::HashMap<String, Vec<IoBindingEntry>> =
109 std::collections::HashMap::new();
110 for p in &self.params {
111 let block = p.name.split('.').next().unwrap_or(&p.name).to_string();
112 blocks.entry(block).or_default().push(p.clone());
113 }
114 let mut out: Vec<WeightBlock> = blocks
115 .into_iter()
116 .map(|(prefix, params)| {
117 let byte_size = params.iter().map(|e| e.byte_size).sum();
118 WeightBlock {
119 prefix,
120 params,
121 byte_size,
122 }
123 })
124 .collect();
125 out.sort_by(|a, b| a.prefix.cmp(&b.prefix));
126 out
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct WeightBlock {
133 pub prefix: String,
134 pub params: Vec<IoBindingEntry>,
135 pub byte_size: usize,
136}
137
138fn entry_for_node(
139 graph: &crate::Graph,
140 plan: &crate::lir::LirBufferPlan,
141 name: String,
142 id: NodeId,
143) -> Option<IoBindingEntry> {
144 let node = graph.node(id);
145 let elem_count = node.shape.num_elements().unwrap_or(0);
146 let byte_size = elem_count * node.shape.dtype().size_bytes();
147 let (arena_offset, arena_size, is_view) = if let Some(alias) = plan.view_aliases.get(&id) {
148 let root_slot = plan.slot(alias.root)?;
149 (
150 Some(root_slot.offset + alias.byte_offset),
151 Some(byte_size),
152 true,
153 )
154 } else if let Some(slot) = plan.slot(id) {
155 (Some(slot.offset), Some(slot.size), false)
156 } else {
157 (None, None, false)
158 };
159 Some(IoBindingEntry {
160 name,
161 node: id,
162 dtype: node.shape.dtype(),
163 shape: node.shape.clone(),
164 elem_count,
165 byte_size,
166 arena_offset,
167 arena_size,
168 is_view,
169 })
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::Graph;
176 use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
177
178 #[test]
179 fn manifest_lists_params_with_sizes() {
180 let mut g = Graph::new("t");
181 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
182 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
183 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
184 g.set_outputs(vec![mm]);
185
186 let mut plan = LirBufferPlan {
187 io: LirIoManifest::collect(&g),
188 ..Default::default()
189 };
190 plan.assignments.insert(
191 x,
192 LirBufferSlot {
193 offset: 0,
194 size: 32,
195 },
196 );
197 plan.assignments.insert(
198 w,
199 LirBufferSlot {
200 offset: 32,
201 size: 48,
202 },
203 );
204 plan.assignments.insert(
205 mm,
206 LirBufferSlot {
207 offset: 80,
208 size: 24,
209 },
210 );
211 plan.arena_size = 104;
212
213 let lir = LirModule::new(crate::MirModule::from_graph(g), plan);
214 let m = BindingManifest::from_lir(&lir);
215 assert_eq!(m.params.len(), 1);
216 assert_eq!(m.params[0].name, "w");
217 assert_eq!(m.params[0].byte_size, 48);
218 assert_eq!(m.inputs[0].name, "x");
219 }
220}