1mod blocks;
24mod conv;
25mod fusion;
26mod graph_ext;
27mod lower;
28
29pub use blocks::lower_llama_decoder_block;
30pub use blocks::lower_qwen35_mtp_head;
31pub use fusion::FusionPolicy;
32pub use graph_ext::{HirGraphExt, HirMut};
33
34use crate::mir::MirModule;
35use crate::op::Activation;
36use crate::op::MaskKind;
37use crate::quant::QuantScheme;
38use crate::{Op, Shape};
39
40pub use lower::LowerError;
41
42#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
45pub struct HirNodeId(pub u32);
46
47impl std::fmt::Display for HirNodeId {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "h{}", self.0)
50 }
51}
52
53#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
55#[derive(Debug, Clone, PartialEq)]
56pub enum HirOp {
57 Input {
58 name: String,
59 },
60 Param {
61 name: String,
62 },
63 Constant {
64 data: Vec<u8>,
65 },
66
67 Linear {
70 activation: Option<Activation>,
71 has_bias: bool,
72 },
73
74 LinearFused {
77 activation: Option<Activation>,
78 },
79
80 SharedLinearPair {
83 slot: u8,
84 },
85
86 SwiGLU,
89
90 ResidualRmsNorm {
93 eps: f32,
94 },
95
96 Attention {
99 num_heads: usize,
100 head_dim: usize,
101 mask: MaskKind,
102 },
103
104 DepthwiseConv1dCausal {
107 kernel_size: usize,
108 },
109
110 DequantMatMul {
113 scheme: QuantScheme,
114 },
115
116 GatedDeltaNet {
119 state_size: usize,
120 carry_state: bool,
121 },
122
123 RoPE {
125 head_dim: usize,
126 n_rot: usize,
127 },
128
129 RmsNorm {
131 eps: f32,
132 },
133
134 LlamaDecoderBlock {
139 num_heads: usize,
140 head_dim: usize,
141 num_kv_heads: usize,
142 eps: f32,
143 mask: MaskKind,
144 },
145
146 Qwen35MtpHead {
149 num_heads: usize,
150 num_kv_heads: usize,
151 head_dim: usize,
152 n_rot: usize,
153 n_embd: usize,
154 n_ff: usize,
155 mtp_vocab: usize,
156 eps: f32,
157 },
158
159 Mir(Op),
161}
162
163#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
165#[derive(Debug, Clone)]
166pub struct HirNode {
167 pub id: HirNodeId,
168 pub op: HirOp,
169 pub inputs: Vec<HirNodeId>,
170 pub shape: Shape,
171 pub name: Option<String>,
172}
173
174#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
176#[derive(Debug, Clone)]
177pub struct HirModule {
178 pub name: String,
179 nodes: Vec<HirNode>,
180 pub outputs: Vec<HirNodeId>,
181 pub fusion_policy: FusionPolicy,
184}
185
186impl HirModule {
187 pub fn new(name: impl Into<String>) -> Self {
188 Self {
189 name: name.into(),
190 nodes: Vec::new(),
191 outputs: Vec::new(),
192 fusion_policy: FusionPolicy::Direct,
193 }
194 }
195
196 pub fn with_fusion_policy(mut self, policy: FusionPolicy) -> Self {
197 self.fusion_policy = policy;
198 self
199 }
200
201 pub fn len(&self) -> usize {
202 self.nodes.len()
203 }
204
205 pub fn is_empty(&self) -> bool {
206 self.nodes.is_empty()
207 }
208
209 pub fn nodes(&self) -> &[HirNode] {
210 &self.nodes
211 }
212
213 pub fn node(&self, id: HirNodeId) -> &HirNode {
214 &self.nodes[id.0 as usize]
215 }
216
217 pub fn node_mut(&mut self, id: HirNodeId) -> &mut HirNode {
218 &mut self.nodes[id.0 as usize]
219 }
220
221 pub fn named(
223 &mut self,
224 name: impl Into<String>,
225 build: impl FnOnce(&mut Self) -> HirNodeId,
226 ) -> HirNodeId {
227 let id = build(self);
228 self.node_mut(id).name = Some(name.into());
229 id
230 }
231
232 fn push_block(
233 &mut self,
234 op: HirOp,
235 inputs: Vec<HirNodeId>,
236 shape: Shape,
237 name: Option<String>,
238 ) -> HirNodeId {
239 let name = name.or_else(|| default_hir_block_label(&op));
240 self.push(op, inputs, shape, name)
241 }
242
243 fn push(
244 &mut self,
245 op: HirOp,
246 inputs: Vec<HirNodeId>,
247 shape: Shape,
248 name: Option<String>,
249 ) -> HirNodeId {
250 let id = HirNodeId(self.nodes.len() as u32);
251 self.nodes.push(HirNode {
252 id,
253 op,
254 inputs,
255 shape,
256 name,
257 });
258 id
259 }
260
261 pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
262 self.push(HirOp::Input { name: name.into() }, vec![], shape, None)
263 }
264
265 pub fn input_batch_seq(
267 &mut self,
268 name: impl Into<String>,
269 batch: u32,
270 seq: u32,
271 hidden: usize,
272 dtype: crate::DType,
273 ) -> HirNodeId {
274 self.input(name, Shape::batch_seq(batch, seq, hidden, dtype))
275 }
276
277 pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
278 self.push(HirOp::Param { name: name.into() }, vec![], shape, None)
279 }
280
281 pub fn linear(
282 &mut self,
283 x: HirNodeId,
284 weight: HirNodeId,
285 bias: Option<HirNodeId>,
286 activation: Option<Activation>,
287 out_shape: Shape,
288 ) -> HirNodeId {
289 let mut inputs = vec![x, weight];
290 if let Some(b) = bias {
291 inputs.push(b);
292 }
293 self.push_block(
294 HirOp::Linear {
295 activation,
296 has_bias: bias.is_some(),
297 },
298 inputs,
299 out_shape,
300 None,
301 )
302 }
303
304 pub fn linear_fused(
306 &mut self,
307 x: HirNodeId,
308 weight: HirNodeId,
309 bias: HirNodeId,
310 activation: Option<Activation>,
311 out_shape: Shape,
312 ) -> HirNodeId {
313 self.push_block(
314 HirOp::LinearFused { activation },
315 vec![x, weight, bias],
316 out_shape,
317 None,
318 )
319 }
320
321 pub fn shared_linear_pair(
323 &mut self,
324 x: HirNodeId,
325 w_first: HirNodeId,
326 w_second: HirNodeId,
327 out_shape: Shape,
328 ) -> (HirNodeId, HirNodeId) {
329 let inputs = vec![x, w_first, w_second];
330 let first = self.push_block(
331 HirOp::SharedLinearPair { slot: 0 },
332 inputs.clone(),
333 out_shape.clone(),
334 None,
335 );
336 let second = self.push_block(HirOp::SharedLinearPair { slot: 1 }, inputs, out_shape, None);
337 (first, second)
338 }
339
340 pub fn swiglu_ffn(
341 &mut self,
342 x: HirNodeId,
343 up_w: HirNodeId,
344 gate_w: HirNodeId,
345 down_w: HirNodeId,
346 out_shape: Shape,
347 ) -> HirNodeId {
348 self.push_block(
349 HirOp::SwiGLU,
350 vec![x, up_w, gate_w, down_w],
351 out_shape,
352 None,
353 )
354 }
355
356 pub fn residual_rms_norm(
357 &mut self,
358 x: HirNodeId,
359 residual: HirNodeId,
360 gamma: HirNodeId,
361 beta: HirNodeId,
362 eps: f32,
363 out_shape: Shape,
364 ) -> HirNodeId {
365 self.push_block(
366 HirOp::ResidualRmsNorm { eps },
367 vec![x, residual, gamma, beta],
368 out_shape,
369 None,
370 )
371 }
372
373 pub fn attention(
375 &mut self,
376 q: HirNodeId,
377 k: HirNodeId,
378 v: HirNodeId,
379 mask: Option<HirNodeId>,
380 num_heads: usize,
381 head_dim: usize,
382 mask_kind: MaskKind,
383 out_shape: Shape,
384 ) -> HirNodeId {
385 let mut inputs = vec![q, k, v];
386 if let Some(m) = mask {
387 inputs.push(m);
388 }
389 self.push_block(
390 HirOp::Attention {
391 num_heads,
392 head_dim,
393 mask: mask_kind,
394 },
395 inputs,
396 out_shape,
397 None,
398 )
399 }
400
401 pub fn depthwise_conv1d_causal(
406 &mut self,
407 input: HirNodeId,
408 weight: HirNodeId,
409 left_pad: HirNodeId,
410 kernel_size: usize,
411 out_shape: Shape,
412 ) -> HirNodeId {
413 self.push_block(
414 HirOp::DepthwiseConv1dCausal { kernel_size },
415 vec![input, weight, left_pad],
416 out_shape,
417 None,
418 )
419 }
420
421 pub fn dequant_matmul(
423 &mut self,
424 x: HirNodeId,
425 w: HirNodeId,
426 scale: Option<HirNodeId>,
427 zp: Option<HirNodeId>,
428 scheme: QuantScheme,
429 out_shape: Shape,
430 ) -> HirNodeId {
431 let mut inputs = vec![x, w];
432 if !scheme.is_gguf() {
433 inputs.push(scale.expect("DequantMatMul: scale required for non-GGUF schemes"));
434 inputs.push(zp.expect("DequantMatMul: zp required for non-GGUF schemes"));
435 }
436 self.push_block(HirOp::DequantMatMul { scheme }, inputs, out_shape, None)
437 }
438
439 pub fn gated_delta_net(
441 &mut self,
442 q: HirNodeId,
443 k: HirNodeId,
444 v: HirNodeId,
445 g: HirNodeId,
446 beta: HirNodeId,
447 state_size: usize,
448 out_shape: Shape,
449 ) -> HirNodeId {
450 self.push_block(
451 HirOp::GatedDeltaNet {
452 state_size,
453 carry_state: false,
454 },
455 vec![q, k, v, g, beta],
456 out_shape,
457 None,
458 )
459 }
460
461 pub fn gated_delta_net_carry(
463 &mut self,
464 q: HirNodeId,
465 k: HirNodeId,
466 v: HirNodeId,
467 g: HirNodeId,
468 beta: HirNodeId,
469 state: HirNodeId,
470 state_size: usize,
471 out_shape: Shape,
472 ) -> HirNodeId {
473 self.push_block(
474 HirOp::GatedDeltaNet {
475 state_size,
476 carry_state: true,
477 },
478 vec![q, k, v, g, beta, state],
479 out_shape,
480 None,
481 )
482 }
483
484 pub fn rope(
486 &mut self,
487 x: HirNodeId,
488 cos: HirNodeId,
489 sin: HirNodeId,
490 head_dim: usize,
491 n_rot: usize,
492 out_shape: Shape,
493 ) -> HirNodeId {
494 self.push_block(
495 HirOp::RoPE { head_dim, n_rot },
496 vec![x, cos, sin],
497 out_shape,
498 None,
499 )
500 }
501
502 pub fn rms_norm(
504 &mut self,
505 x: HirNodeId,
506 gamma: HirNodeId,
507 beta: HirNodeId,
508 eps: f32,
509 out_shape: Shape,
510 ) -> HirNodeId {
511 self.push_block(
512 HirOp::RmsNorm { eps },
513 vec![x, gamma, beta],
514 out_shape,
515 None,
516 )
517 }
518
519 pub fn llama_decoder_block(
521 &mut self,
522 x: HirNodeId,
523 ln1_g: HirNodeId,
524 ln1_b: HirNodeId,
525 q_w: HirNodeId,
526 k_w: HirNodeId,
527 v_w: HirNodeId,
528 o_w: HirNodeId,
529 ln2_g: HirNodeId,
530 ln2_b: HirNodeId,
531 gate_w: HirNodeId,
532 up_w: HirNodeId,
533 down_w: HirNodeId,
534 cos: HirNodeId,
535 sin: HirNodeId,
536 mask: Option<HirNodeId>,
537 num_heads: usize,
538 head_dim: usize,
539 num_kv_heads: usize,
540 eps: f32,
541 mask_kind: MaskKind,
542 out_shape: Shape,
543 ) -> HirNodeId {
544 let mut ins = vec![
545 x, ln1_g, ln1_b, q_w, k_w, v_w, o_w, ln2_g, ln2_b, gate_w, up_w, down_w, cos, sin,
546 ];
547 if let Some(m) = mask {
548 ins.push(m);
549 }
550 self.push_block(
551 HirOp::LlamaDecoderBlock {
552 num_heads,
553 head_dim,
554 num_kv_heads,
555 eps,
556 mask: mask_kind,
557 },
558 ins,
559 out_shape,
560 Some("llama_decoder_block".into()),
561 )
562 }
563
564 pub fn transformer_block(
567 &mut self,
568 x: HirNodeId,
569 ln1_g: HirNodeId,
570 ln1_b: HirNodeId,
571 q_w: HirNodeId,
572 k_w: HirNodeId,
573 v_w: HirNodeId,
574 o_w: HirNodeId,
575 ln2_g: HirNodeId,
576 ln2_b: HirNodeId,
577 gate_w: HirNodeId,
578 up_w: HirNodeId,
579 down_w: HirNodeId,
580 cos: HirNodeId,
581 sin: HirNodeId,
582 mask: Option<HirNodeId>,
583 num_heads: usize,
584 head_dim: usize,
585 num_kv_heads: usize,
586 eps: f32,
587 mask_kind: MaskKind,
588 out_shape: Shape,
589 ) -> HirNodeId {
590 let id = self.llama_decoder_block(
591 x,
592 ln1_g,
593 ln1_b,
594 q_w,
595 k_w,
596 v_w,
597 o_w,
598 ln2_g,
599 ln2_b,
600 gate_w,
601 up_w,
602 down_w,
603 cos,
604 sin,
605 mask,
606 num_heads,
607 head_dim,
608 num_kv_heads,
609 eps,
610 mask_kind,
611 out_shape,
612 );
613 self.node_mut(id).name = Some("transformer_block".into());
614 id
615 }
616
617 #[allow(clippy::too_many_arguments)]
619 pub fn qwen35_mtp_head(
620 &mut self,
621 h_pre_norm: HirNodeId,
622 input_ids: HirNodeId,
623 cos: HirNodeId,
624 sin: HirNodeId,
625 last_token_idx: HirNodeId,
626 embed_w: HirNodeId,
627 hnorm_w: HirNodeId,
628 hnorm_b: HirNodeId,
629 enorm_w: HirNodeId,
630 enorm_b: HirNodeId,
631 eh_w: HirNodeId,
632 fa_attn_norm_w: HirNodeId,
633 fa_attn_norm_b: HirNodeId,
634 fa_q_gate_w: HirNodeId,
635 fa_k_w: HirNodeId,
636 fa_v_w: HirNodeId,
637 fa_q_norm_w: HirNodeId,
638 fa_q_norm_b: HirNodeId,
639 fa_k_norm_w: HirNodeId,
640 fa_k_norm_b: HirNodeId,
641 fa_o_w: HirNodeId,
642 fa_post_norm_w: HirNodeId,
643 fa_post_norm_b: HirNodeId,
644 fa_gate_w: HirNodeId,
645 fa_up_w: HirNodeId,
646 fa_down_w: HirNodeId,
647 head_norm_w: HirNodeId,
648 head_norm_b: HirNodeId,
649 lm_head_w: HirNodeId,
650 num_heads: usize,
651 num_kv_heads: usize,
652 head_dim: usize,
653 n_rot: usize,
654 n_embd: usize,
655 n_ff: usize,
656 mtp_vocab: usize,
657 eps: f32,
658 out_shape: Shape,
659 ) -> HirNodeId {
660 self.push_block(
661 HirOp::Qwen35MtpHead {
662 num_heads,
663 num_kv_heads,
664 head_dim,
665 n_rot,
666 n_embd,
667 n_ff,
668 mtp_vocab,
669 eps,
670 },
671 vec![
672 h_pre_norm,
673 input_ids,
674 cos,
675 sin,
676 last_token_idx,
677 embed_w,
678 hnorm_w,
679 hnorm_b,
680 enorm_w,
681 enorm_b,
682 eh_w,
683 fa_attn_norm_w,
684 fa_attn_norm_b,
685 fa_q_gate_w,
686 fa_k_w,
687 fa_v_w,
688 fa_q_norm_w,
689 fa_q_norm_b,
690 fa_k_norm_w,
691 fa_k_norm_b,
692 fa_o_w,
693 fa_post_norm_w,
694 fa_post_norm_b,
695 fa_gate_w,
696 fa_up_w,
697 fa_down_w,
698 head_norm_w,
699 head_norm_b,
700 lm_head_w,
701 ],
702 out_shape,
703 Some("qwen35_mtp_head".into()),
704 )
705 }
706
707 pub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId {
709 self.push(HirOp::Mir(op), inputs, shape, None)
710 }
711
712 pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>) {
713 self.outputs = outputs;
714 }
715
716 pub fn lower_to_mir(self) -> Result<MirModule, LowerError> {
718 lower::lower_module(self)
719 }
720
721 pub fn lower_for_autodiff(self) -> Result<MirModule, LowerError> {
724 self.with_fusion_policy(FusionPolicy::for_autodiff())
725 .lower_to_mir()
726 }
727
728 pub fn wrap_mir_graph(graph: crate::Graph) -> Self {
731 use std::collections::HashMap;
732 let mut hir = Self::new(graph.name.clone()).with_fusion_policy(FusionPolicy::Direct);
733 let mut map: HashMap<crate::NodeId, HirNodeId> = HashMap::new();
734 for node in graph.nodes() {
735 let inputs: Vec<HirNodeId> = node.inputs.iter().map(|&id| map[&id]).collect();
736 let id = hir.mir(node.op.clone(), inputs, node.shape.clone());
737 map.insert(node.id, id);
738 }
739 let outputs: Vec<HirNodeId> = graph.outputs.iter().map(|&id| map[&id]).collect();
740 hir.set_outputs(outputs);
741 hir
742 }
743}
744
745pub(crate) fn default_hir_block_label(op: &HirOp) -> Option<String> {
746 Some(match op {
747 HirOp::Linear { .. } => "linear".into(),
748 HirOp::LinearFused { .. } => "linear_fused".into(),
749 HirOp::SharedLinearPair { slot } => return Some(format!("shared_linear_pair[{slot}]")),
750 HirOp::SwiGLU => "swiglu_ffn".into(),
751 HirOp::ResidualRmsNorm { .. } => "residual_rms_norm".into(),
752 HirOp::Attention { .. } => "attention".into(),
753 HirOp::DepthwiseConv1dCausal { .. } => "depthwise_conv1d_causal".into(),
754 HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
755 HirOp::GatedDeltaNet {
756 carry_state: true, ..
757 } => "gated_delta_net_carry".into(),
758 HirOp::GatedDeltaNet { .. } => "gated_delta_net".into(),
759 HirOp::RoPE { .. } => "rope".into(),
760 HirOp::RmsNorm { .. } => "rms_norm".into(),
761 HirOp::Mir(_) => "mir".into(),
762 HirOp::LlamaDecoderBlock { .. } => "llama_decoder_block".into(),
763 HirOp::Qwen35MtpHead { .. } => "qwen35_mtp_head".into(),
764 HirOp::Input { .. } | HirOp::Param { .. } | HirOp::Constant { .. } => return None,
765 })
766}
767
768impl std::fmt::Display for HirModule {
769 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
770 writeln!(f, "hir @{} {{", self.name)?;
771 for node in &self.nodes {
772 write!(f, " {} = {:?}", node.id, node.op)?;
773 if !node.inputs.is_empty() {
774 write!(f, "(")?;
775 for (i, inp) in node.inputs.iter().enumerate() {
776 if i > 0 {
777 write!(f, ", ")?;
778 }
779 write!(f, "{inp}")?;
780 }
781 write!(f, ")")?;
782 }
783 writeln!(f, " : {}", node.shape)?;
784 }
785 if !self.outputs.is_empty() {
786 write!(f, " return ")?;
787 for (i, o) in self.outputs.iter().enumerate() {
788 if i > 0 {
789 write!(f, ", ")?;
790 }
791 write!(f, "{o}")?;
792 }
793 writeln!(f)?;
794 }
795 write!(f, "}}")
796 }
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802 use crate::DType;
803
804 fn f32_shape(d: &[usize]) -> Shape {
805 Shape::new(d, DType::F32)
806 }
807
808 #[test]
809 fn hir_depthwise_conv1d_causal_lowers_to_grouped_conv() {
810 use crate::Op;
811
812 let mut hir = HirModule::new("dw");
813 let x = hir.input("x", f32_shape(&[2, 8, 16]));
814 let w = hir.param("w", f32_shape(&[16, 1, 1, 3]));
815 let pad = hir.param("pad", f32_shape(&[2, 2, 16]));
816 let out = hir.depthwise_conv1d_causal(x, w, pad, 3, f32_shape(&[2, 8, 16]));
817 hir.outputs = vec![out];
818
819 let g = hir.lower_to_mir().expect("lower").into_graph();
820 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Conv { .. })));
821 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Concat { .. })));
822 }
823
824 #[test]
825 fn hir_swiglu_lowers_to_fusable_mir() {
826 use crate::Op;
827 use crate::hir::FusionPolicy;
828
829 let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
830 let x = hir.input("x", f32_shape(&[4, 768]));
831 let up_w = hir.param("up", f32_shape(&[768, 2048]));
832 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
833 let down_w = hir.param("down", f32_shape(&[2048, 768]));
834 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
835 hir.set_outputs(vec![out]);
836
837 let mir = hir.lower_to_mir().expect("lower");
838 let g = mir.into_graph();
839 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
840 assert_eq!(g.len(), 9);
841 }
842
843 #[test]
844 fn hir_gdn_dequant_rope_rms_lowers() {
845 use crate::Op;
846 use crate::quant::QuantScheme;
847
848 let mut hir = HirModule::new("qwen_block");
849 let q = hir.input("q", f32_shape(&[1, 4, 2, 8]));
850 let k = hir.param("k", f32_shape(&[1, 4, 2, 8]));
851 let v = hir.param("v", f32_shape(&[1, 4, 2, 8]));
852 let g_in = hir.param("g", f32_shape(&[1, 4, 2]));
853 let beta = hir.param("beta", f32_shape(&[1, 4, 2]));
854 let scan = hir.gated_delta_net(q, k, v, g_in, beta, 8, f32_shape(&[1, 4, 2, 8]));
855
856 let cos = hir.param("cos", f32_shape(&[1, 4, 8]));
857 let sin = hir.param("sin", f32_shape(&[1, 4, 8]));
858 let x = hir.input("x", f32_shape(&[1, 4, 8]));
859 let rotated = hir.rope(x, cos, sin, 8, 8, f32_shape(&[1, 4, 8]));
860
861 let gamma = hir.param("gamma", f32_shape(&[8]));
862 let beta_n = hir.param("beta_n", f32_shape(&[8]));
863 let normed = hir.rms_norm(rotated, gamma, beta_n, 1e-6, f32_shape(&[1, 4, 8]));
864
865 let x_in = hir.input("hidden", f32_shape(&[4, 128]));
866 let w = hir.param("w_q", f32_shape(&[1024]));
867 let proj = hir.dequant_matmul(
868 x_in,
869 w,
870 None,
871 None,
872 QuantScheme::GgufQ4K,
873 f32_shape(&[4, 128]),
874 );
875 hir.set_outputs(vec![scan, normed, proj]);
876
877 let g = hir.lower_to_mir().expect("lower").into_graph();
878 assert!(g.nodes().iter().any(|n| matches!(
879 n.op,
880 Op::GatedDeltaNet {
881 carry_state: false,
882 ..
883 }
884 )));
885 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Rope { .. })));
886 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::RmsNorm { .. })));
887 assert!(g.nodes().iter().any(|n| matches!(
888 n.op,
889 Op::DequantMatMul {
890 scheme: QuantScheme::GgufQ4K
891 }
892 )));
893 }
894}