1use std::collections::HashMap;
24use std::collections::hash_map::DefaultHasher;
25use std::hash::{Hash, Hasher};
26
27use crate::mir::MirModule;
28use crate::phase::{Phase, PhaseSchedule};
29use crate::{Graph, NodeId, Op};
30
31#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct LirBufferSlot {
35 pub offset: usize,
36 pub size: usize,
37}
38
39#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct LirViewAlias {
43 pub root: NodeId,
44 pub byte_offset: usize,
45}
46
47#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Default, PartialEq, Eq)]
50pub struct LirIoManifest {
51 pub inputs: Vec<(String, NodeId)>,
52 pub params: Vec<(String, NodeId)>,
53 pub outputs: Vec<NodeId>,
54}
55
56impl LirIoManifest {
57 pub fn collect(graph: &Graph) -> Self {
58 let mut inputs = Vec::new();
59 let mut params = Vec::new();
60 for node in graph.nodes() {
61 match &node.op {
62 Op::Input { name } => inputs.push((name.clone(), node.id)),
63 Op::Param { name } => params.push((name.clone(), node.id)),
64 _ => {}
65 }
66 }
67 Self {
68 inputs,
69 params,
70 outputs: graph.outputs.clone(),
71 }
72 }
73}
74
75#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct LirBufferPlan {
79 pub arena_size: usize,
80 pub assignments: HashMap<NodeId, LirBufferSlot>,
81 pub schedule: Vec<NodeId>,
83 pub view_aliases: HashMap<NodeId, LirViewAlias>,
85 pub phases: PhaseSchedule,
87 pub io: LirIoManifest,
89 pub alignment: usize,
91 pub dynamic_symbols: Vec<u32>,
93}
94
95impl Default for LirBufferPlan {
96 fn default() -> Self {
97 Self {
98 arena_size: 0,
99 assignments: HashMap::new(),
100 schedule: Vec::new(),
101 view_aliases: HashMap::new(),
102 phases: PhaseSchedule::new(),
103 io: LirIoManifest::default(),
104 alignment: 64,
105 dynamic_symbols: Vec::new(),
106 }
107 }
108}
109
110impl LirBufferPlan {
111 pub fn total_unshared_bytes(&self) -> usize {
112 self.assignments.values().map(|s| s.size).sum()
113 }
114
115 pub fn bytes_saved(&self) -> usize {
116 self.total_unshared_bytes().saturating_sub(self.arena_size)
117 }
118
119 pub fn slot(&self, id: NodeId) -> Option<&LirBufferSlot> {
120 self.assignments.get(&id)
121 }
122
123 pub fn is_view(&self, id: NodeId) -> bool {
124 self.view_aliases.contains_key(&id)
125 }
126
127 pub fn phase_of(&self, id: NodeId) -> Option<Phase> {
128 self.phases.get(id)
129 }
130
131 pub fn nodes_in_phase(&self, phase: Phase) -> Vec<NodeId> {
132 self.phases.nodes_in_ordered(phase, Some(&self.schedule))
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub struct LirFingerprint(pub u64);
139
140impl LirFingerprint {
141 pub fn of(module: &LirModule) -> Self {
142 let mut h = DefaultHasher::new();
143 module.mir.name().hash(&mut h);
144 module.mir.len().hash(&mut h);
145 for node in module.mir.as_graph().nodes() {
146 node.id.0.hash(&mut h);
147 format!("{}", node.op).hash(&mut h);
148 node.shape.hash(&mut h);
149 node.inputs.len().hash(&mut h);
150 for inp in &node.inputs {
151 inp.0.hash(&mut h);
152 }
153 }
154 for out in module.mir.as_graph().outputs.iter() {
155 out.0.hash(&mut h);
156 }
157 module.buffers.arena_size.hash(&mut h);
158 module.buffers.schedule.len().hash(&mut h);
159 module.buffers.alignment.hash(&mut h);
160 module.buffers.view_aliases.len().hash(&mut h);
161 Self(h.finish())
162 }
163}
164
165#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
167#[derive(Debug, Clone, PartialEq)]
168pub struct LirModule {
169 pub mir: MirModule,
170 pub buffers: LirBufferPlan,
171}
172
173impl LirModule {
174 pub fn new(mir: MirModule, buffers: LirBufferPlan) -> Self {
175 Self { mir, buffers }
176 }
177
178 pub fn name(&self) -> &str {
179 self.mir.name()
180 }
181
182 pub fn arena_size(&self) -> usize {
183 self.buffers.arena_size
184 }
185
186 pub fn fingerprint(&self) -> LirFingerprint {
187 LirFingerprint::of(self)
188 }
189
190 pub fn plan(&self) -> &LirBufferPlan {
191 &self.buffers
192 }
193
194 pub fn into_graph(self) -> Graph {
196 self.mir.into_graph()
197 }
198
199 pub fn as_graph(&self) -> &Graph {
200 self.mir.as_graph()
201 }
202
203 pub fn has_dynamic_dims(&self) -> bool {
204 crate::dynamic::has_dynamic_dims(self.as_graph())
205 }
206
207 pub fn is_fully_static(&self) -> bool {
208 !self.has_dynamic_dims()
209 }
210
211 pub fn dynamic_symbols(&self) -> &[u32] {
212 &self.buffers.dynamic_symbols
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::{DType, Shape};
220
221 fn f32_shape(d: &[usize]) -> Shape {
222 Shape::new(d, DType::F32)
223 }
224
225 #[test]
226 fn io_manifest_collects_boundaries() {
227 let mut g = Graph::new("m");
228 let x = g.input("x", f32_shape(&[4]));
229 let w = g.param("w", f32_shape(&[4, 4]));
230 let y = g.matmul(x, w, f32_shape(&[4, 4]));
231 g.set_outputs(vec![y]);
232
233 let io = LirIoManifest::collect(&g);
234 assert_eq!(io.inputs, vec![("x".into(), x)]);
235 assert_eq!(io.params, vec![("w".into(), w)]);
236 assert_eq!(io.outputs, vec![y]);
237 }
238
239 #[test]
240 fn fingerprint_is_stable() {
241 let mut g = Graph::new("m");
242 let x = g.input("x", f32_shape(&[2]));
243 g.set_outputs(vec![x]);
244 let mir = MirModule::from_graph(g);
245 let plan = LirBufferPlan {
246 arena_size: 8,
247 assignments: [(x, LirBufferSlot { offset: 0, size: 8 })]
248 .into_iter()
249 .collect(),
250 schedule: vec![x],
251 io: LirIoManifest {
252 inputs: vec![("x".into(), x)],
253 ..Default::default()
254 },
255 ..Default::default()
256 };
257 let lir = LirModule::new(mir, plan);
258 assert_eq!(lir.fingerprint(), lir.fingerprint());
259 }
260}