Skip to main content

rlx_ir/hir/
mod.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! **HIR** — high-level IR.
17//!
18//! Block-oriented IR for model authors and external graph builders.
19//! HIR captures fusion-friendly patterns (SwiGLU FFN, linear layers,
20//! residual RMSNorm) as first-class ops and lowers to MIR via
21//! [`HirModule::lower_to_mir`].
22
23mod blocks;
24mod conv;
25mod fusion;
26mod graph_ext;
27mod lower;
28
29pub use blocks::lower_llama_decoder_block;
30pub use blocks::lower_qwen35_mtp_head;
31pub use fusion::FusionPolicy;
32pub use graph_ext::{HirGraphExt, HirMut};
33
34use crate::mir::MirModule;
35use crate::op::Activation;
36use crate::op::MaskKind;
37use crate::quant::QuantScheme;
38use crate::{Op, Shape};
39
40pub use lower::LowerError;
41
42/// Stable node identifier within a HIR module.
43#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
45pub struct HirNodeId(pub u32);
46
47impl std::fmt::Display for HirNodeId {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "h{}", self.0)
50    }
51}
52
53/// High-level operation — blocks and escape hatches.
54#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
55#[derive(Debug, Clone, PartialEq)]
56pub enum HirOp {
57    Input {
58        name: String,
59    },
60    Param {
61        name: String,
62    },
63    Constant {
64        data: Vec<u8>,
65    },
66
67    /// `matmul → add(bias)? → activation?`
68    /// Inputs: `[x, weight]` or `[x, weight, bias]`.
69    Linear {
70        activation: Option<Activation>,
71        has_bias: bool,
72    },
73
74    /// Emit [`Op::FusedMatMulBiasAct`] directly.
75    /// Inputs: `[x, weight, bias]`.
76    LinearFused {
77        activation: Option<Activation>,
78    },
79
80    /// Two matmuls sharing the same input (QKV / SwiGLU gate+up).
81    /// Inputs: `[x, w_first, w_second]`. `slot` selects which output.
82    SharedLinearPair {
83        slot: u8,
84    },
85
86    /// Full SwiGLU FFN.
87    /// Inputs: `[x, up_w, gate_w, down_w]`.
88    SwiGLU,
89
90    /// `add(x, residual)` then RMSNorm.
91    /// Inputs: `[x, residual, gamma, beta]`.
92    ResidualRmsNorm {
93        eps: f32,
94    },
95
96    /// Scaled dot-product attention.
97    /// Inputs: `[q, k, v, mask?]` — mask omitted when `mask == None`.
98    Attention {
99        num_heads: usize,
100        head_dim: usize,
101        mask: MaskKind,
102    },
103
104    /// Causal depthwise Conv1d on `[batch, seq, channels]` tensors.
105    /// Inputs: `[input, weight, left_pad]` — see [`conv::lower_depthwise_conv1d_causal`].
106    DepthwiseConv1dCausal {
107        kernel_size: usize,
108    },
109
110    /// Fused dequant + matmul. GGUF schemes take `[x, packed_w]`; legacy
111    /// Int8/NVFP4 schemes take `[x, w_q, scale, zp]`.
112    DequantMatMul {
113        scheme: QuantScheme,
114    },
115
116    /// Gated DeltaNet linear-attention scan (Qwen3.5 trunk).
117    /// Inputs: `[q, k, v, g, beta]` or with carry `[…, state]`.
118    GatedDeltaNet {
119        state_size: usize,
120        carry_state: bool,
121    },
122
123    /// Rotary position embedding. Inputs: `[x, cos, sin]`.
124    RoPE {
125        head_dim: usize,
126        n_rot: usize,
127    },
128
129    /// RMS normalization without residual. Inputs: `[x, gamma, beta]`.
130    RmsNorm {
131        eps: f32,
132    },
133
134    /// LLaMA-style pre-norm decoder block: attn (GQA) + SwiGLU FFN.
135    /// Inputs (causal): `[x, ln1_g, ln1_b, q_w, k_w, v_w, o_w, ln2_g, ln2_b,
136    /// gate_w, up_w, down_w, cos, sin]`. With `MaskKind::Custom` or `Bias`
137    /// append `mask`.
138    LlamaDecoderBlock {
139        num_heads: usize,
140        head_dim: usize,
141        num_kv_heads: usize,
142        eps: f32,
143        mask: MaskKind,
144    },
145
146    /// Qwen3.5 MTP draft head: hnorm∥enorm → eh_proj → full-attn → LM.
147    /// See [`blocks::lower_qwen35_mtp_head`] for the input layout.
148    Qwen35MtpHead {
149        num_heads: usize,
150        num_kv_heads: usize,
151        head_dim: usize,
152        n_rot: usize,
153        n_embd: usize,
154        n_ff: usize,
155        mtp_vocab: usize,
156        eps: f32,
157    },
158
159    /// Escape hatch — embed a single MIR op verbatim.
160    Mir(Op),
161}
162
163/// One node in a HIR module.
164#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
165#[derive(Debug, Clone)]
166pub struct HirNode {
167    pub id: HirNodeId,
168    pub op: HirOp,
169    pub inputs: Vec<HirNodeId>,
170    pub shape: Shape,
171    pub name: Option<String>,
172}
173
174/// High-level module — model builder output.
175#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
176#[derive(Debug, Clone)]
177pub struct HirModule {
178    pub name: String,
179    nodes: Vec<HirNode>,
180    pub outputs: Vec<HirNodeId>,
181    /// How block ops lower to MIR. Default: [`FusionPolicy::Direct`]
182    /// for new model code (fusion as a first-class citizen).
183    pub fusion_policy: FusionPolicy,
184}
185
186impl HirModule {
187    pub fn new(name: impl Into<String>) -> Self {
188        Self {
189            name: name.into(),
190            nodes: Vec::new(),
191            outputs: Vec::new(),
192            fusion_policy: FusionPolicy::Direct,
193        }
194    }
195
196    pub fn with_fusion_policy(mut self, policy: FusionPolicy) -> Self {
197        self.fusion_policy = policy;
198        self
199    }
200
201    pub fn len(&self) -> usize {
202        self.nodes.len()
203    }
204
205    pub fn is_empty(&self) -> bool {
206        self.nodes.is_empty()
207    }
208
209    pub fn nodes(&self) -> &[HirNode] {
210        &self.nodes
211    }
212
213    pub fn node(&self, id: HirNodeId) -> &HirNode {
214        &self.nodes[id.0 as usize]
215    }
216
217    pub fn node_mut(&mut self, id: HirNodeId) -> &mut HirNode {
218        &mut self.nodes[id.0 as usize]
219    }
220
221    /// Build a named block — sets `HirNode::name` on the returned node.
222    pub fn named(
223        &mut self,
224        name: impl Into<String>,
225        build: impl FnOnce(&mut Self) -> HirNodeId,
226    ) -> HirNodeId {
227        let id = build(self);
228        self.node_mut(id).name = Some(name.into());
229        id
230    }
231
232    fn push_block(
233        &mut self,
234        op: HirOp,
235        inputs: Vec<HirNodeId>,
236        shape: Shape,
237        name: Option<String>,
238    ) -> HirNodeId {
239        let name = name.or_else(|| default_hir_block_label(&op));
240        self.push(op, inputs, shape, name)
241    }
242
243    fn push(
244        &mut self,
245        op: HirOp,
246        inputs: Vec<HirNodeId>,
247        shape: Shape,
248        name: Option<String>,
249    ) -> HirNodeId {
250        let id = HirNodeId(self.nodes.len() as u32);
251        self.nodes.push(HirNode {
252            id,
253            op,
254            inputs,
255            shape,
256            name,
257        });
258        id
259    }
260
261    pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
262        self.push(HirOp::Input { name: name.into() }, vec![], shape, None)
263    }
264
265    /// `[batch, seq, hidden]` input with symbolic leading axes.
266    pub fn input_batch_seq(
267        &mut self,
268        name: impl Into<String>,
269        batch: u32,
270        seq: u32,
271        hidden: usize,
272        dtype: crate::DType,
273    ) -> HirNodeId {
274        self.input(name, Shape::batch_seq(batch, seq, hidden, dtype))
275    }
276
277    pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
278        self.push(HirOp::Param { name: name.into() }, vec![], shape, None)
279    }
280
281    pub fn linear(
282        &mut self,
283        x: HirNodeId,
284        weight: HirNodeId,
285        bias: Option<HirNodeId>,
286        activation: Option<Activation>,
287        out_shape: Shape,
288    ) -> HirNodeId {
289        let mut inputs = vec![x, weight];
290        if let Some(b) = bias {
291            inputs.push(b);
292        }
293        self.push_block(
294            HirOp::Linear {
295                activation,
296                has_bias: bias.is_some(),
297            },
298            inputs,
299            out_shape,
300            None,
301        )
302    }
303
304    /// Emit [`HirOp::LinearFused`] — fused matmul+bias+act at MIR level.
305    pub fn linear_fused(
306        &mut self,
307        x: HirNodeId,
308        weight: HirNodeId,
309        bias: HirNodeId,
310        activation: Option<Activation>,
311        out_shape: Shape,
312    ) -> HirNodeId {
313        self.push_block(
314            HirOp::LinearFused { activation },
315            vec![x, weight, bias],
316            out_shape,
317            None,
318        )
319    }
320
321    /// Two matmuls sharing `x`. Returns `(first, second)` in weight order.
322    pub fn shared_linear_pair(
323        &mut self,
324        x: HirNodeId,
325        w_first: HirNodeId,
326        w_second: HirNodeId,
327        out_shape: Shape,
328    ) -> (HirNodeId, HirNodeId) {
329        let inputs = vec![x, w_first, w_second];
330        let first = self.push_block(
331            HirOp::SharedLinearPair { slot: 0 },
332            inputs.clone(),
333            out_shape.clone(),
334            None,
335        );
336        let second = self.push_block(HirOp::SharedLinearPair { slot: 1 }, inputs, out_shape, None);
337        (first, second)
338    }
339
340    pub fn swiglu_ffn(
341        &mut self,
342        x: HirNodeId,
343        up_w: HirNodeId,
344        gate_w: HirNodeId,
345        down_w: HirNodeId,
346        out_shape: Shape,
347    ) -> HirNodeId {
348        self.push_block(
349            HirOp::SwiGLU,
350            vec![x, up_w, gate_w, down_w],
351            out_shape,
352            None,
353        )
354    }
355
356    pub fn residual_rms_norm(
357        &mut self,
358        x: HirNodeId,
359        residual: HirNodeId,
360        gamma: HirNodeId,
361        beta: HirNodeId,
362        eps: f32,
363        out_shape: Shape,
364    ) -> HirNodeId {
365        self.push_block(
366            HirOp::ResidualRmsNorm { eps },
367            vec![x, residual, gamma, beta],
368            out_shape,
369            None,
370        )
371    }
372
373    /// Scaled dot-product attention — see [`HirOp::Attention`].
374    pub fn attention(
375        &mut self,
376        q: HirNodeId,
377        k: HirNodeId,
378        v: HirNodeId,
379        mask: Option<HirNodeId>,
380        num_heads: usize,
381        head_dim: usize,
382        mask_kind: MaskKind,
383        out_shape: Shape,
384    ) -> HirNodeId {
385        let mut inputs = vec![q, k, v];
386        if let Some(m) = mask {
387            inputs.push(m);
388        }
389        self.push_block(
390            HirOp::Attention {
391                num_heads,
392                head_dim,
393                mask: mask_kind,
394            },
395            inputs,
396            out_shape,
397            None,
398        )
399    }
400
401    /// Causal depthwise Conv1d — Conformer / Wav2Vec2-BERT conv module.
402    ///
403    /// `input` and `left_pad` are `[B, S, C]` / `[B, K-1, C]`; `weight` is
404    /// `[C, 1, 1, K]` in grouped Conv2d layout.
405    pub fn depthwise_conv1d_causal(
406        &mut self,
407        input: HirNodeId,
408        weight: HirNodeId,
409        left_pad: HirNodeId,
410        kernel_size: usize,
411        out_shape: Shape,
412    ) -> HirNodeId {
413        self.push_block(
414            HirOp::DepthwiseConv1dCausal { kernel_size },
415            vec![input, weight, left_pad],
416            out_shape,
417            None,
418        )
419    }
420
421    /// Fused dequant + matmul — see [`HirOp::DequantMatMul`].
422    pub fn dequant_matmul(
423        &mut self,
424        x: HirNodeId,
425        w: HirNodeId,
426        scale: Option<HirNodeId>,
427        zp: Option<HirNodeId>,
428        scheme: QuantScheme,
429        out_shape: Shape,
430    ) -> HirNodeId {
431        let mut inputs = vec![x, w];
432        if !scheme.is_gguf() {
433            inputs.push(scale.expect("DequantMatMul: scale required for non-GGUF schemes"));
434            inputs.push(zp.expect("DequantMatMul: zp required for non-GGUF schemes"));
435        }
436        self.push_block(HirOp::DequantMatMul { scheme }, inputs, out_shape, None)
437    }
438
439    /// Gated DeltaNet without carry state (prefill / reset per batch).
440    pub fn gated_delta_net(
441        &mut self,
442        q: HirNodeId,
443        k: HirNodeId,
444        v: HirNodeId,
445        g: HirNodeId,
446        beta: HirNodeId,
447        state_size: usize,
448        out_shape: Shape,
449    ) -> HirNodeId {
450        self.push_block(
451            HirOp::GatedDeltaNet {
452                state_size,
453                carry_state: false,
454            },
455            vec![q, k, v, g, beta],
456            out_shape,
457            None,
458        )
459    }
460
461    /// Gated DeltaNet with decode carry — threads `state` in/out.
462    pub fn gated_delta_net_carry(
463        &mut self,
464        q: HirNodeId,
465        k: HirNodeId,
466        v: HirNodeId,
467        g: HirNodeId,
468        beta: HirNodeId,
469        state: HirNodeId,
470        state_size: usize,
471        out_shape: Shape,
472    ) -> HirNodeId {
473        self.push_block(
474            HirOp::GatedDeltaNet {
475                state_size,
476                carry_state: true,
477            },
478            vec![q, k, v, g, beta, state],
479            out_shape,
480            None,
481        )
482    }
483
484    /// Rotary position embedding.
485    pub fn rope(
486        &mut self,
487        x: HirNodeId,
488        cos: HirNodeId,
489        sin: HirNodeId,
490        head_dim: usize,
491        n_rot: usize,
492        out_shape: Shape,
493    ) -> HirNodeId {
494        self.push_block(
495            HirOp::RoPE { head_dim, n_rot },
496            vec![x, cos, sin],
497            out_shape,
498            None,
499        )
500    }
501
502    /// RMS normalization (no residual add).
503    pub fn rms_norm(
504        &mut self,
505        x: HirNodeId,
506        gamma: HirNodeId,
507        beta: HirNodeId,
508        eps: f32,
509        out_shape: Shape,
510    ) -> HirNodeId {
511        self.push_block(
512            HirOp::RmsNorm { eps },
513            vec![x, gamma, beta],
514            out_shape,
515            None,
516        )
517    }
518
519    /// LLaMA / LLaMA-3.2 decoder layer (pre-norm GQA + SwiGLU).
520    pub fn llama_decoder_block(
521        &mut self,
522        x: HirNodeId,
523        ln1_g: HirNodeId,
524        ln1_b: HirNodeId,
525        q_w: HirNodeId,
526        k_w: HirNodeId,
527        v_w: HirNodeId,
528        o_w: HirNodeId,
529        ln2_g: HirNodeId,
530        ln2_b: HirNodeId,
531        gate_w: HirNodeId,
532        up_w: HirNodeId,
533        down_w: HirNodeId,
534        cos: HirNodeId,
535        sin: HirNodeId,
536        mask: Option<HirNodeId>,
537        num_heads: usize,
538        head_dim: usize,
539        num_kv_heads: usize,
540        eps: f32,
541        mask_kind: MaskKind,
542        out_shape: Shape,
543    ) -> HirNodeId {
544        let mut ins = vec![
545            x, ln1_g, ln1_b, q_w, k_w, v_w, o_w, ln2_g, ln2_b, gate_w, up_w, down_w, cos, sin,
546        ];
547        if let Some(m) = mask {
548            ins.push(m);
549        }
550        self.push_block(
551            HirOp::LlamaDecoderBlock {
552                num_heads,
553                head_dim,
554                num_kv_heads,
555                eps,
556                mask: mask_kind,
557            },
558            ins,
559            out_shape,
560            Some("llama_decoder_block".into()),
561        )
562    }
563
564    /// Standard pre-norm transformer decoder block — alias for
565    /// [`Self::llama_decoder_block`] (LLaMA / GPT-style layers).
566    pub fn transformer_block(
567        &mut self,
568        x: HirNodeId,
569        ln1_g: HirNodeId,
570        ln1_b: HirNodeId,
571        q_w: HirNodeId,
572        k_w: HirNodeId,
573        v_w: HirNodeId,
574        o_w: HirNodeId,
575        ln2_g: HirNodeId,
576        ln2_b: HirNodeId,
577        gate_w: HirNodeId,
578        up_w: HirNodeId,
579        down_w: HirNodeId,
580        cos: HirNodeId,
581        sin: HirNodeId,
582        mask: Option<HirNodeId>,
583        num_heads: usize,
584        head_dim: usize,
585        num_kv_heads: usize,
586        eps: f32,
587        mask_kind: MaskKind,
588        out_shape: Shape,
589    ) -> HirNodeId {
590        let id = self.llama_decoder_block(
591            x,
592            ln1_g,
593            ln1_b,
594            q_w,
595            k_w,
596            v_w,
597            o_w,
598            ln2_g,
599            ln2_b,
600            gate_w,
601            up_w,
602            down_w,
603            cos,
604            sin,
605            mask,
606            num_heads,
607            head_dim,
608            num_kv_heads,
609            eps,
610            mask_kind,
611            out_shape,
612        );
613        self.node_mut(id).name = Some("transformer_block".into());
614        id
615    }
616
617    /// Qwen3.5 MTP draft head — see [`blocks::lower_qwen35_mtp_head`].
618    #[allow(clippy::too_many_arguments)]
619    pub fn qwen35_mtp_head(
620        &mut self,
621        h_pre_norm: HirNodeId,
622        input_ids: HirNodeId,
623        cos: HirNodeId,
624        sin: HirNodeId,
625        last_token_idx: HirNodeId,
626        embed_w: HirNodeId,
627        hnorm_w: HirNodeId,
628        hnorm_b: HirNodeId,
629        enorm_w: HirNodeId,
630        enorm_b: HirNodeId,
631        eh_w: HirNodeId,
632        fa_attn_norm_w: HirNodeId,
633        fa_attn_norm_b: HirNodeId,
634        fa_q_gate_w: HirNodeId,
635        fa_k_w: HirNodeId,
636        fa_v_w: HirNodeId,
637        fa_q_norm_w: HirNodeId,
638        fa_q_norm_b: HirNodeId,
639        fa_k_norm_w: HirNodeId,
640        fa_k_norm_b: HirNodeId,
641        fa_o_w: HirNodeId,
642        fa_post_norm_w: HirNodeId,
643        fa_post_norm_b: HirNodeId,
644        fa_gate_w: HirNodeId,
645        fa_up_w: HirNodeId,
646        fa_down_w: HirNodeId,
647        head_norm_w: HirNodeId,
648        head_norm_b: HirNodeId,
649        lm_head_w: HirNodeId,
650        num_heads: usize,
651        num_kv_heads: usize,
652        head_dim: usize,
653        n_rot: usize,
654        n_embd: usize,
655        n_ff: usize,
656        mtp_vocab: usize,
657        eps: f32,
658        out_shape: Shape,
659    ) -> HirNodeId {
660        self.push_block(
661            HirOp::Qwen35MtpHead {
662                num_heads,
663                num_kv_heads,
664                head_dim,
665                n_rot,
666                n_embd,
667                n_ff,
668                mtp_vocab,
669                eps,
670            },
671            vec![
672                h_pre_norm,
673                input_ids,
674                cos,
675                sin,
676                last_token_idx,
677                embed_w,
678                hnorm_w,
679                hnorm_b,
680                enorm_w,
681                enorm_b,
682                eh_w,
683                fa_attn_norm_w,
684                fa_attn_norm_b,
685                fa_q_gate_w,
686                fa_k_w,
687                fa_v_w,
688                fa_q_norm_w,
689                fa_q_norm_b,
690                fa_k_norm_w,
691                fa_k_norm_b,
692                fa_o_w,
693                fa_post_norm_w,
694                fa_post_norm_b,
695                fa_gate_w,
696                fa_up_w,
697                fa_down_w,
698                head_norm_w,
699                head_norm_b,
700                lm_head_w,
701            ],
702            out_shape,
703            Some("qwen35_mtp_head".into()),
704        )
705    }
706
707    /// Escape hatch — embed a single MIR [`Op`] verbatim.
708    pub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId {
709        self.push(HirOp::Mir(op), inputs, shape, None)
710    }
711
712    pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>) {
713        self.outputs = outputs;
714    }
715
716    /// Lower this module to MIR.
717    pub fn lower_to_mir(self) -> Result<MirModule, LowerError> {
718        lower::lower_module(self)
719    }
720
721    /// Lower with [`FusionPolicy::for_autodiff`] — primitive MIR chains
722    /// that need less unfuse work before `rlx_opt::prepare_graph_for_ad`.
723    pub fn lower_for_autodiff(self) -> Result<MirModule, LowerError> {
724        self.with_fusion_policy(FusionPolicy::for_autodiff())
725            .lower_to_mir()
726    }
727
728    /// Wrap an existing MIR [`Graph`] as a HIR module (`HirOp::Mir` per node).
729    /// Enables `Session::compile_hir` for legacy graph builders during migration.
730    pub fn wrap_mir_graph(graph: crate::Graph) -> Self {
731        use std::collections::HashMap;
732        let mut hir = Self::new(graph.name.clone()).with_fusion_policy(FusionPolicy::Direct);
733        let mut map: HashMap<crate::NodeId, HirNodeId> = HashMap::new();
734        for node in graph.nodes() {
735            let inputs: Vec<HirNodeId> = node.inputs.iter().map(|&id| map[&id]).collect();
736            let id = hir.mir(node.op.clone(), inputs, node.shape.clone());
737            map.insert(node.id, id);
738        }
739        let outputs: Vec<HirNodeId> = graph.outputs.iter().map(|&id| map[&id]).collect();
740        hir.set_outputs(outputs);
741        hir
742    }
743}
744
745pub(crate) fn default_hir_block_label(op: &HirOp) -> Option<String> {
746    Some(match op {
747        HirOp::Linear { .. } => "linear".into(),
748        HirOp::LinearFused { .. } => "linear_fused".into(),
749        HirOp::SharedLinearPair { slot } => return Some(format!("shared_linear_pair[{slot}]")),
750        HirOp::SwiGLU => "swiglu_ffn".into(),
751        HirOp::ResidualRmsNorm { .. } => "residual_rms_norm".into(),
752        HirOp::Attention { .. } => "attention".into(),
753        HirOp::DepthwiseConv1dCausal { .. } => "depthwise_conv1d_causal".into(),
754        HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
755        HirOp::GatedDeltaNet {
756            carry_state: true, ..
757        } => "gated_delta_net_carry".into(),
758        HirOp::GatedDeltaNet { .. } => "gated_delta_net".into(),
759        HirOp::RoPE { .. } => "rope".into(),
760        HirOp::RmsNorm { .. } => "rms_norm".into(),
761        HirOp::Mir(_) => "mir".into(),
762        HirOp::LlamaDecoderBlock { .. } => "llama_decoder_block".into(),
763        HirOp::Qwen35MtpHead { .. } => "qwen35_mtp_head".into(),
764        HirOp::Input { .. } | HirOp::Param { .. } | HirOp::Constant { .. } => return None,
765    })
766}
767
768impl std::fmt::Display for HirModule {
769    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
770        writeln!(f, "hir @{} {{", self.name)?;
771        for node in &self.nodes {
772            write!(f, "  {} = {:?}", node.id, node.op)?;
773            if !node.inputs.is_empty() {
774                write!(f, "(")?;
775                for (i, inp) in node.inputs.iter().enumerate() {
776                    if i > 0 {
777                        write!(f, ", ")?;
778                    }
779                    write!(f, "{inp}")?;
780                }
781                write!(f, ")")?;
782            }
783            writeln!(f, " : {}", node.shape)?;
784        }
785        if !self.outputs.is_empty() {
786            write!(f, "  return ")?;
787            for (i, o) in self.outputs.iter().enumerate() {
788                if i > 0 {
789                    write!(f, ", ")?;
790                }
791                write!(f, "{o}")?;
792            }
793            writeln!(f)?;
794        }
795        write!(f, "}}")
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802    use crate::DType;
803
804    fn f32_shape(d: &[usize]) -> Shape {
805        Shape::new(d, DType::F32)
806    }
807
808    #[test]
809    fn hir_depthwise_conv1d_causal_lowers_to_grouped_conv() {
810        use crate::Op;
811
812        let mut hir = HirModule::new("dw");
813        let x = hir.input("x", f32_shape(&[2, 8, 16]));
814        let w = hir.param("w", f32_shape(&[16, 1, 1, 3]));
815        let pad = hir.param("pad", f32_shape(&[2, 2, 16]));
816        let out = hir.depthwise_conv1d_causal(x, w, pad, 3, f32_shape(&[2, 8, 16]));
817        hir.outputs = vec![out];
818
819        let g = hir.lower_to_mir().expect("lower").into_graph();
820        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Conv { .. })));
821        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Concat { .. })));
822    }
823
824    #[test]
825    fn hir_swiglu_lowers_to_fusable_mir() {
826        use crate::Op;
827        use crate::hir::FusionPolicy;
828
829        let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
830        let x = hir.input("x", f32_shape(&[4, 768]));
831        let up_w = hir.param("up", f32_shape(&[768, 2048]));
832        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
833        let down_w = hir.param("down", f32_shape(&[2048, 768]));
834        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
835        hir.set_outputs(vec![out]);
836
837        let mir = hir.lower_to_mir().expect("lower");
838        let g = mir.into_graph();
839        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
840        assert_eq!(g.len(), 9);
841    }
842
843    #[test]
844    fn hir_gdn_dequant_rope_rms_lowers() {
845        use crate::Op;
846        use crate::quant::QuantScheme;
847
848        let mut hir = HirModule::new("qwen_block");
849        let q = hir.input("q", f32_shape(&[1, 4, 2, 8]));
850        let k = hir.param("k", f32_shape(&[1, 4, 2, 8]));
851        let v = hir.param("v", f32_shape(&[1, 4, 2, 8]));
852        let g_in = hir.param("g", f32_shape(&[1, 4, 2]));
853        let beta = hir.param("beta", f32_shape(&[1, 4, 2]));
854        let scan = hir.gated_delta_net(q, k, v, g_in, beta, 8, f32_shape(&[1, 4, 2, 8]));
855
856        let cos = hir.param("cos", f32_shape(&[1, 4, 8]));
857        let sin = hir.param("sin", f32_shape(&[1, 4, 8]));
858        let x = hir.input("x", f32_shape(&[1, 4, 8]));
859        let rotated = hir.rope(x, cos, sin, 8, 8, f32_shape(&[1, 4, 8]));
860
861        let gamma = hir.param("gamma", f32_shape(&[8]));
862        let beta_n = hir.param("beta_n", f32_shape(&[8]));
863        let normed = hir.rms_norm(rotated, gamma, beta_n, 1e-6, f32_shape(&[1, 4, 8]));
864
865        let x_in = hir.input("hidden", f32_shape(&[4, 128]));
866        let w = hir.param("w_q", f32_shape(&[1024]));
867        let proj = hir.dequant_matmul(
868            x_in,
869            w,
870            None,
871            None,
872            QuantScheme::GgufQ4K,
873            f32_shape(&[4, 128]),
874        );
875        hir.set_outputs(vec![scan, normed, proj]);
876
877        let g = hir.lower_to_mir().expect("lower").into_graph();
878        assert!(g.nodes().iter().any(|n| matches!(
879            n.op,
880            Op::GatedDeltaNet {
881                carry_state: false,
882                ..
883            }
884        )));
885        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::Rope { .. })));
886        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::RmsNorm { .. })));
887        assert!(g.nodes().iter().any(|n| matches!(
888            n.op,
889            Op::DequantMatMul {
890                scheme: QuantScheme::GgufQ4K
891            }
892        )));
893    }
894}