1use std::ops::{Deref, DerefMut};
19
20use crate::hir::{FusionPolicy, HirModule, HirNodeId, LowerError};
21use crate::inspect::{inspect_hir, inspect_lir, inspect_mir};
22use crate::lir::LirModule;
23use crate::mir::MirModule;
24use crate::op::Activation;
25use crate::op::MaskKind;
26use crate::quant::QuantScheme;
27use crate::{Graph, NodeId, Op, Shape};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum GraphStage {
32 Hir,
33 Mir,
34 Lir,
35}
36
37#[derive(Debug, Clone)]
38enum Stage {
39 Hir(HirModule),
40 Mir(MirModule),
41 Lir(LirModule),
42}
43
44#[derive(Debug, Clone)]
46pub struct GraphModule {
47 stage: Stage,
48}
49
50impl GraphModule {
51 pub fn define(
52 name: impl Into<String>,
53 build: impl FnOnce(&mut HirModule) -> HirNodeId,
54 ) -> Self {
55 let mut hir = HirModule::new(name);
56 let out = build(&mut hir);
57 hir.set_outputs(vec![out]);
58 Self {
59 stage: Stage::Hir(hir),
60 }
61 }
62
63 pub fn hir(name: impl Into<String>) -> Self {
65 Self {
66 stage: Stage::Hir(HirModule::new(name)),
67 }
68 }
69
70 pub fn mir(name: impl Into<String>) -> Self {
72 Self {
73 stage: Stage::Mir(MirModule::new(name)),
74 }
75 }
76
77 pub fn from_hir(hir: HirModule) -> Self {
78 Self {
79 stage: Stage::Hir(hir),
80 }
81 }
82
83 pub fn from_graph(graph: Graph) -> Self {
84 Self {
85 stage: Stage::Mir(MirModule::from_graph(graph)),
86 }
87 }
88
89 pub fn from_mir(mir: MirModule) -> Self {
90 Self {
91 stage: Stage::Mir(mir),
92 }
93 }
94
95 pub fn from_lir(lir: LirModule) -> Self {
96 Self {
97 stage: Stage::Lir(lir),
98 }
99 }
100
101 pub fn block(
102 hir: &mut HirModule,
103 name: impl Into<String>,
104 build: impl FnOnce(&mut HirModule) -> HirNodeId,
105 ) -> HirNodeId {
106 hir.named(name, build)
107 }
108
109 pub fn fusion_policy(&self) -> Option<FusionPolicy> {
110 self.as_hir().map(|h| h.fusion_policy)
111 }
112
113 pub fn with_fusion_policy(mut self, policy: FusionPolicy) -> Self {
114 if let Stage::Hir(h) = &mut self.stage {
115 h.fusion_policy = policy;
116 } else {
117 panic!("GraphModule::with_fusion_policy requires HIR stage");
118 }
119 self
120 }
121
122 pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>) {
127 match &mut self.stage {
128 Stage::Hir(h) => h.set_outputs(outputs),
129 Stage::Mir(m) => m.set_outputs(outputs.into_iter().map(|h| NodeId(h.0)).collect()),
130 Stage::Lir(l) => l
131 .mir
132 .set_outputs(outputs.into_iter().map(|h| NodeId(h.0)).collect()),
133 }
134 }
135
136 pub fn set_hir_outputs(&mut self, outputs: Vec<HirNodeId>) {
137 self.set_outputs(outputs);
138 }
139
140 pub fn finish_hir(mut self, output: HirNodeId) -> Self {
142 self.set_hir_outputs(vec![output]);
143 self
144 }
145
146 fn hir_mut(&mut self) -> &mut HirModule {
147 self.as_hir_mut()
148 .expect("GraphModule: HIR builder methods require HIR stage — use GraphModule::hir() or Graph::define()")
149 }
150
151 pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
154 match &mut self.stage {
155 Stage::Hir(h) => h.input(name, shape),
156 Stage::Mir(m) => {
157 let id = m.as_graph_mut().input(name, shape);
158 HirNodeId(id.0)
159 }
160 Stage::Lir(l) => {
161 let id = l.mir.as_graph_mut().input(name, shape);
162 HirNodeId(id.0)
163 }
164 }
165 }
166
167 pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
168 match &mut self.stage {
169 Stage::Hir(h) => h.param(name, shape),
170 Stage::Mir(m) => {
171 let id = m.as_graph_mut().param(name, shape);
172 HirNodeId(id.0)
173 }
174 Stage::Lir(l) => {
175 let id = l.mir.as_graph_mut().param(name, shape);
176 HirNodeId(id.0)
177 }
178 }
179 }
180
181 pub fn linear(
182 &mut self,
183 x: HirNodeId,
184 weight: HirNodeId,
185 bias: Option<HirNodeId>,
186 activation: Option<Activation>,
187 out_shape: Shape,
188 ) -> HirNodeId {
189 self.hir_mut()
190 .linear(x, weight, bias, activation, out_shape)
191 }
192
193 pub fn linear_fused(
194 &mut self,
195 x: HirNodeId,
196 weight: HirNodeId,
197 bias: HirNodeId,
198 activation: Option<Activation>,
199 out_shape: Shape,
200 ) -> HirNodeId {
201 self.hir_mut()
202 .linear_fused(x, weight, bias, activation, out_shape)
203 }
204
205 pub fn shared_linear_pair(
206 &mut self,
207 x: HirNodeId,
208 w_first: HirNodeId,
209 w_second: HirNodeId,
210 out_shape: Shape,
211 ) -> (HirNodeId, HirNodeId) {
212 self.hir_mut()
213 .shared_linear_pair(x, w_first, w_second, out_shape)
214 }
215
216 pub fn swiglu_ffn(
217 &mut self,
218 x: HirNodeId,
219 up_w: HirNodeId,
220 gate_w: HirNodeId,
221 down_w: HirNodeId,
222 out_shape: Shape,
223 ) -> HirNodeId {
224 self.hir_mut()
225 .swiglu_ffn(x, up_w, gate_w, down_w, out_shape)
226 }
227
228 pub fn residual_rms_norm(
229 &mut self,
230 x: HirNodeId,
231 residual: HirNodeId,
232 gamma: HirNodeId,
233 beta: HirNodeId,
234 eps: f32,
235 out_shape: Shape,
236 ) -> HirNodeId {
237 self.hir_mut()
238 .residual_rms_norm(x, residual, gamma, beta, eps, out_shape)
239 }
240
241 pub fn attention(
242 &mut self,
243 q: HirNodeId,
244 k: HirNodeId,
245 v: HirNodeId,
246 mask: Option<HirNodeId>,
247 num_heads: usize,
248 head_dim: usize,
249 mask_kind: MaskKind,
250 out_shape: Shape,
251 ) -> HirNodeId {
252 self.hir_mut()
253 .attention(q, k, v, mask, num_heads, head_dim, mask_kind, out_shape)
254 }
255
256 pub fn depthwise_conv1d_causal(
257 &mut self,
258 input: HirNodeId,
259 weight: HirNodeId,
260 left_pad: HirNodeId,
261 kernel_size: usize,
262 out_shape: Shape,
263 ) -> HirNodeId {
264 self.hir_mut()
265 .depthwise_conv1d_causal(input, weight, left_pad, kernel_size, out_shape)
266 }
267
268 pub fn dequant_matmul(
269 &mut self,
270 x: HirNodeId,
271 w: HirNodeId,
272 scale: Option<HirNodeId>,
273 zp: Option<HirNodeId>,
274 scheme: QuantScheme,
275 out_shape: Shape,
276 ) -> HirNodeId {
277 self.hir_mut()
278 .dequant_matmul(x, w, scale, zp, scheme, out_shape)
279 }
280
281 pub fn gated_delta_net(
282 &mut self,
283 q: HirNodeId,
284 k: HirNodeId,
285 v: HirNodeId,
286 g: HirNodeId,
287 beta: HirNodeId,
288 state_size: usize,
289 out_shape: Shape,
290 ) -> HirNodeId {
291 self.hir_mut()
292 .gated_delta_net(q, k, v, g, beta, state_size, out_shape)
293 }
294
295 pub fn gated_delta_net_carry(
296 &mut self,
297 q: HirNodeId,
298 k: HirNodeId,
299 v: HirNodeId,
300 g: HirNodeId,
301 beta: HirNodeId,
302 state: HirNodeId,
303 state_size: usize,
304 out_shape: Shape,
305 ) -> HirNodeId {
306 self.hir_mut()
307 .gated_delta_net_carry(q, k, v, g, beta, state, state_size, out_shape)
308 }
309
310 pub fn rope(
311 &mut self,
312 x: HirNodeId,
313 cos: HirNodeId,
314 sin: HirNodeId,
315 head_dim: usize,
316 n_rot: usize,
317 out_shape: Shape,
318 ) -> HirNodeId {
319 self.hir_mut().rope(x, cos, sin, head_dim, n_rot, out_shape)
320 }
321
322 pub fn rms_norm(
323 &mut self,
324 x: HirNodeId,
325 gamma: HirNodeId,
326 beta: HirNodeId,
327 eps: f32,
328 out_shape: Shape,
329 ) -> HirNodeId {
330 self.hir_mut().rms_norm(x, gamma, beta, eps, out_shape)
331 }
332
333 pub fn hir_mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId {
334 self.hir_mut().mir(op, inputs, shape)
335 }
336
337 pub fn named(
338 &mut self,
339 name: impl Into<String>,
340 build: impl FnOnce(&mut HirModule) -> HirNodeId,
341 ) -> HirNodeId {
342 self.hir_mut().named(name, build)
343 }
344
345 pub fn stage(&self) -> GraphStage {
346 match &self.stage {
347 Stage::Hir(_) => GraphStage::Hir,
348 Stage::Mir(_) => GraphStage::Mir,
349 Stage::Lir(_) => GraphStage::Lir,
350 }
351 }
352
353 pub fn name(&self) -> &str {
354 match &self.stage {
355 Stage::Hir(h) => &h.name,
356 Stage::Mir(m) => m.name(),
357 Stage::Lir(l) => l.name(),
358 }
359 }
360
361 pub fn lower(self) -> Result<Self, LowerError> {
362 match self.stage {
363 Stage::Hir(hir) => Ok(Self {
364 stage: Stage::Mir(hir.lower_to_mir()?),
365 }),
366 other => Ok(Self { stage: other }),
367 }
368 }
369
370 pub fn into_hir(self) -> Option<HirModule> {
371 match self.stage {
372 Stage::Hir(h) => Some(h),
373 _ => None,
374 }
375 }
376
377 pub fn into_mir(self) -> Result<MirModule, LowerError> {
378 match self.stage {
379 Stage::Hir(hir) => hir.lower_to_mir(),
380 Stage::Mir(m) => Ok(m),
381 Stage::Lir(l) => Ok(l.mir),
382 }
383 }
384
385 pub fn into_lir(self) -> Option<LirModule> {
386 match self.stage {
387 Stage::Lir(l) => Some(l),
388 _ => None,
389 }
390 }
391
392 pub fn into_graph(self) -> Result<Graph, LowerError> {
393 Ok(self.into_mir()?.into_graph())
394 }
395
396 pub fn as_hir(&self) -> Option<&HirModule> {
397 match &self.stage {
398 Stage::Hir(h) => Some(h),
399 _ => None,
400 }
401 }
402
403 pub fn as_hir_mut(&mut self) -> Option<&mut HirModule> {
404 match &mut self.stage {
405 Stage::Hir(h) => Some(h),
406 _ => None,
407 }
408 }
409
410 pub fn as_mir(&self) -> Option<&MirModule> {
411 match &self.stage {
412 Stage::Mir(m) => Some(m),
413 Stage::Lir(l) => Some(&l.mir),
414 _ => None,
415 }
416 }
417
418 pub fn as_lir(&self) -> Option<&LirModule> {
419 match &self.stage {
420 Stage::Lir(l) => Some(l),
421 _ => None,
422 }
423 }
424
425 pub fn as_graph(&self) -> Option<&Graph> {
426 match &self.stage {
427 Stage::Mir(m) => Some(m.as_graph()),
428 Stage::Lir(l) => Some(l.as_graph()),
429 Stage::Hir(_) => None,
430 }
431 }
432
433 pub fn inspect(&self) -> String {
434 match &self.stage {
435 Stage::Hir(h) => inspect_hir(h),
436 Stage::Mir(m) => inspect_mir(m),
437 Stage::Lir(l) => inspect_lir(l),
438 }
439 }
440}
441
442impl Deref for GraphModule {
443 type Target = Graph;
444
445 fn deref(&self) -> &Graph {
446 self.as_graph()
447 .expect("GraphModule: HIR stage — call lower() before accessing MIR Graph")
448 }
449}
450
451impl DerefMut for GraphModule {
452 fn deref_mut(&mut self) -> &mut Graph {
453 match &mut self.stage {
454 Stage::Mir(m) => m.as_graph_mut(),
455 Stage::Lir(l) => l.mir.as_graph_mut(),
456 Stage::Hir(_) => panic!("GraphModule: HIR stage — use as_hir_mut() or lower() first"),
457 }
458 }
459}
460
461impl From<Graph> for GraphModule {
462 fn from(graph: Graph) -> Self {
463 Self::from_graph(graph)
464 }
465}
466
467impl TryFrom<GraphModule> for Graph {
468 type Error = LowerError;
469
470 fn try_from(module: GraphModule) -> Result<Self, LowerError> {
471 module.into_graph()
472 }
473}
474
475impl From<MirModule> for GraphModule {
476 fn from(mir: MirModule) -> Self {
477 Self::from_mir(mir)
478 }
479}
480
481impl From<HirModule> for GraphModule {
482 fn from(hir: HirModule) -> Self {
483 Self::from_hir(hir)
484 }
485}
486
487impl std::fmt::Display for GraphModule {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 match &self.stage {
490 Stage::Hir(h) => write!(f, "{h}"),
491 Stage::Mir(m) => write!(f, "{m}"),
492 Stage::Lir(l) => write!(f, "lir @{}", l.name()),
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::DType;
501 use crate::Graph;
502 use crate::Shape;
503
504 fn f32_shape(d: &[usize]) -> Shape {
505 Shape::new(d, DType::F32)
506 }
507
508 #[test]
509 fn define_lowers_to_mir_graph() {
510 let module = GraphModule::define("m", |m| {
511 let x = m.input("x", f32_shape(&[2, 8]));
512 let w = m.param("w", f32_shape(&[8, 8]));
513 m.linear(x, w, None, None, f32_shape(&[2, 8]))
514 });
515 assert_eq!(module.stage(), GraphStage::Hir);
516 let module = module.lower().expect("lower");
517 assert_eq!(module.stage(), GraphStage::Mir);
518 assert!(module.len() >= 3);
519 }
520
521 #[test]
522 fn mir_module_deref_builds_graph() {
523 let mut module = GraphModule::mir("raw");
524 let x = module.input("x", f32_shape(&[4]));
525 module.set_outputs(vec![x]);
526 assert_eq!(module.len(), 1);
527 }
528
529 #[test]
530 fn hir_module_block_builders_via_graph_module() {
531 use crate::quant::QuantScheme;
532
533 let mut module = GraphModule::hir("layer");
534 let x = module.input("x", f32_shape(&[2, 128]));
535 let w = module.param("w", f32_shape(&[128, 128]));
536 let y = module.dequant_matmul(x, w, None, None, QuantScheme::GgufQ4K, f32_shape(&[2, 128]));
537 module.set_outputs(vec![y]);
538 assert_eq!(module.stage(), GraphStage::Hir);
539
540 let module = module.lower().expect("lower");
541 assert_eq!(module.stage(), GraphStage::Mir);
542 assert!(module.len() >= 3);
543 }
544
545 #[test]
546 fn graph_hir_entry_matches_define() {
547 let via_graph = Graph::hir("m");
548 let via_define = Graph::define("m", |m| {
549 let x = m.input("x", f32_shape(&[4]));
550 m.rms_norm(x, x, x, 1e-5, f32_shape(&[4]))
551 });
552 assert_eq!(via_graph.stage(), GraphStage::Hir);
553 assert_eq!(via_define.stage(), GraphStage::Hir);
554 }
555}