1use crate::lir::LirModule;
10use crate::{DType, NodeId, Shape};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct IoBindingEntry {
15 pub name: String,
16 pub node: NodeId,
17 pub dtype: DType,
18 pub shape: Shape,
19 pub elem_count: usize,
20 pub byte_size: usize,
21 pub arena_offset: Option<usize>,
22 pub arena_size: Option<usize>,
23 pub is_view: bool,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct BindingManifest {
29 pub graph_name: String,
30 pub arena_size: usize,
31 pub alignment: usize,
32 pub inputs: Vec<IoBindingEntry>,
33 pub params: Vec<IoBindingEntry>,
34 pub outputs: Vec<IoBindingEntry>,
35}
36
37impl BindingManifest {
38 pub fn from_lir(lir: &LirModule) -> Self {
39 let graph = lir.as_graph();
40 let plan = lir.plan();
41 let io = &plan.io;
42
43 let mut inputs = Vec::new();
44 for (name, id) in &io.inputs {
45 if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
46 inputs.push(e);
47 }
48 }
49
50 let mut params = Vec::new();
51 for (name, id) in &io.params {
52 if let Some(e) = entry_for_node(graph, plan, name.clone(), *id) {
53 params.push(e);
54 }
55 }
56
57 let mut outputs = Vec::new();
58 for (i, id) in io.outputs.iter().enumerate() {
59 let name = format!("output{i}");
60 if let Some(e) = entry_for_node(graph, plan, name, *id) {
61 outputs.push(e);
62 }
63 }
64
65 Self {
66 graph_name: lir.name().to_string(),
67 arena_size: plan.arena_size,
68 alignment: plan.alignment,
69 inputs,
70 params,
71 outputs,
72 }
73 }
74
75 pub fn param_names(&self) -> impl Iterator<Item = &str> {
76 self.params.iter().map(|p| p.name.as_str())
77 }
78
79 pub fn input_names(&self) -> impl Iterator<Item = &str> {
80 self.inputs.iter().map(|p| p.name.as_str())
81 }
82
83 pub fn param_byte_size(&self, name: &str) -> Option<usize> {
84 self.params
85 .iter()
86 .find(|p| p.name == name)
87 .map(|p| p.byte_size)
88 }
89
90 pub fn total_param_bytes(&self) -> usize {
91 self.params.iter().map(|p| p.byte_size).sum()
92 }
93
94 pub fn weight_blocks(&self) -> Vec<WeightBlock> {
96 let mut blocks: std::collections::HashMap<String, Vec<IoBindingEntry>> =
97 std::collections::HashMap::new();
98 for p in &self.params {
99 let block = p.name.split('.').next().unwrap_or(&p.name).to_string();
100 blocks.entry(block).or_default().push(p.clone());
101 }
102 let mut out: Vec<WeightBlock> = blocks
103 .into_iter()
104 .map(|(prefix, params)| {
105 let byte_size = params.iter().map(|e| e.byte_size).sum();
106 WeightBlock {
107 prefix,
108 params,
109 byte_size,
110 }
111 })
112 .collect();
113 out.sort_by(|a, b| a.prefix.cmp(&b.prefix));
114 out
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct WeightBlock {
121 pub prefix: String,
122 pub params: Vec<IoBindingEntry>,
123 pub byte_size: usize,
124}
125
126fn entry_for_node(
127 graph: &crate::Graph,
128 plan: &crate::lir::LirBufferPlan,
129 name: String,
130 id: NodeId,
131) -> Option<IoBindingEntry> {
132 let node = graph.node(id);
133 let elem_count = node.shape.num_elements().unwrap_or(0);
134 let byte_size = elem_count * node.shape.dtype().size_bytes();
135 let (arena_offset, arena_size, is_view) = if let Some(alias) = plan.view_aliases.get(&id) {
136 let root_slot = plan.slot(alias.root)?;
137 (
138 Some(root_slot.offset + alias.byte_offset),
139 Some(byte_size),
140 true,
141 )
142 } else if let Some(slot) = plan.slot(id) {
143 (Some(slot.offset), Some(slot.size), false)
144 } else {
145 (None, None, false)
146 };
147 Some(IoBindingEntry {
148 name,
149 node: id,
150 dtype: node.shape.dtype(),
151 shape: node.shape.clone(),
152 elem_count,
153 byte_size,
154 arena_offset,
155 arena_size,
156 is_view,
157 })
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use crate::Graph;
164 use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
165
166 #[test]
167 fn manifest_lists_params_with_sizes() {
168 let mut g = Graph::new("t");
169 let x = g.input("x", Shape::new(&[2, 4], DType::F32));
170 let w = g.param("w", Shape::new(&[4, 3], DType::F32));
171 let mm = g.matmul(x, w, Shape::new(&[2, 3], DType::F32));
172 g.set_outputs(vec![mm]);
173
174 let mut plan = LirBufferPlan {
175 io: LirIoManifest::collect(&g),
176 ..Default::default()
177 };
178 plan.assignments.insert(
179 x,
180 LirBufferSlot {
181 offset: 0,
182 size: 32,
183 },
184 );
185 plan.assignments.insert(
186 w,
187 LirBufferSlot {
188 offset: 32,
189 size: 48,
190 },
191 );
192 plan.assignments.insert(
193 mm,
194 LirBufferSlot {
195 offset: 80,
196 size: 24,
197 },
198 );
199 plan.arena_size = 104;
200
201 let lir = LirModule::new(crate::MirModule::from_graph(g), plan);
202 let m = BindingManifest::from_lir(&lir);
203 assert_eq!(m.params.len(), 1);
204 assert_eq!(m.params[0].name, "w");
205 assert_eq!(m.params[0].byte_size, 48);
206 assert_eq!(m.inputs[0].name, "x");
207 }
208}