pub struct HirModule {
pub name: String,
pub outputs: Vec<HirNodeId>,
pub fusion_policy: FusionPolicy,
/* private fields */
}Expand description
High-level module — model builder output.
Fields§
§name: String§outputs: Vec<HirNodeId>§fusion_policy: FusionPolicyHow block ops lower to MIR. Default: FusionPolicy::Direct
for new model code (fusion as a first-class citizen).
Implementations§
Source§impl HirModule
impl HirModule
pub fn new(name: impl Into<String>) -> Self
pub fn with_fusion_policy(self, policy: FusionPolicy) -> Self
pub fn len(&self) -> usize
pub fn is_empty(&self) -> bool
pub fn nodes(&self) -> &[HirNode]
pub fn node(&self, id: HirNodeId) -> &HirNode
pub fn node_mut(&mut self, id: HirNodeId) -> &mut HirNode
Sourcepub fn named(
&mut self,
name: impl Into<String>,
build: impl FnOnce(&mut Self) -> HirNodeId,
) -> HirNodeId
pub fn named( &mut self, name: impl Into<String>, build: impl FnOnce(&mut Self) -> HirNodeId, ) -> HirNodeId
Build a named block — sets HirNode::name on the returned node.
pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId
Sourcepub fn input_batch_seq(
&mut self,
name: impl Into<String>,
batch: u32,
seq: u32,
hidden: usize,
dtype: DType,
) -> HirNodeId
pub fn input_batch_seq( &mut self, name: impl Into<String>, batch: u32, seq: u32, hidden: usize, dtype: DType, ) -> HirNodeId
[batch, seq, hidden] input with symbolic leading axes.
pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId
pub fn linear( &mut self, x: HirNodeId, weight: HirNodeId, bias: Option<HirNodeId>, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId
Sourcepub fn linear_fused(
&mut self,
x: HirNodeId,
weight: HirNodeId,
bias: HirNodeId,
activation: Option<Activation>,
out_shape: Shape,
) -> HirNodeId
pub fn linear_fused( &mut self, x: HirNodeId, weight: HirNodeId, bias: HirNodeId, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId
Emit HirOp::LinearFused — fused matmul+bias+act at MIR level.
Two matmuls sharing x. Returns (first, second) in weight order.
pub fn swiglu_ffn( &mut self, x: HirNodeId, up_w: HirNodeId, gate_w: HirNodeId, down_w: HirNodeId, out_shape: Shape, ) -> HirNodeId
pub fn residual_rms_norm( &mut self, x: HirNodeId, residual: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId
Sourcepub fn attention(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn attention( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
Scaled dot-product attention — see HirOp::Attention.
Sourcepub fn depthwise_conv1d_causal(
&mut self,
input: HirNodeId,
weight: HirNodeId,
left_pad: HirNodeId,
kernel_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn depthwise_conv1d_causal( &mut self, input: HirNodeId, weight: HirNodeId, left_pad: HirNodeId, kernel_size: usize, out_shape: Shape, ) -> HirNodeId
Causal depthwise Conv1d — Conformer / Wav2Vec2-BERT conv module.
input and left_pad are [B, S, C] / [B, K-1, C]; weight is
[C, 1, 1, K] in grouped Conv2d layout.
Sourcepub fn dequant_matmul(
&mut self,
x: HirNodeId,
w: HirNodeId,
scale: Option<HirNodeId>,
zp: Option<HirNodeId>,
scheme: QuantScheme,
out_shape: Shape,
) -> HirNodeId
pub fn dequant_matmul( &mut self, x: HirNodeId, w: HirNodeId, scale: Option<HirNodeId>, zp: Option<HirNodeId>, scheme: QuantScheme, out_shape: Shape, ) -> HirNodeId
Fused dequant + matmul — see HirOp::DequantMatMul.
Sourcepub fn gated_delta_net(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
g: HirNodeId,
beta: HirNodeId,
state_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn gated_delta_net( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId
Gated DeltaNet without carry state (prefill / reset per batch).
Sourcepub fn gated_delta_net_carry(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
g: HirNodeId,
beta: HirNodeId,
state: HirNodeId,
state_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn gated_delta_net_carry( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId
Gated DeltaNet with decode carry — threads state in/out.
Sourcepub fn rope(
&mut self,
x: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
head_dim: usize,
n_rot: usize,
out_shape: Shape,
) -> HirNodeId
pub fn rope( &mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, head_dim: usize, n_rot: usize, out_shape: Shape, ) -> HirNodeId
Rotary position embedding.
Sourcepub fn rms_norm(
&mut self,
x: HirNodeId,
gamma: HirNodeId,
beta: HirNodeId,
eps: f32,
out_shape: Shape,
) -> HirNodeId
pub fn rms_norm( &mut self, x: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId
RMS normalization (no residual add).
Sourcepub fn llama_decoder_block(
&mut self,
x: HirNodeId,
ln1_g: HirNodeId,
ln1_b: HirNodeId,
q_w: HirNodeId,
k_w: HirNodeId,
v_w: HirNodeId,
o_w: HirNodeId,
ln2_g: HirNodeId,
ln2_b: HirNodeId,
gate_w: HirNodeId,
up_w: HirNodeId,
down_w: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
eps: f32,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn llama_decoder_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
LLaMA / LLaMA-3.2 decoder layer (pre-norm GQA + SwiGLU).
Sourcepub fn transformer_block(
&mut self,
x: HirNodeId,
ln1_g: HirNodeId,
ln1_b: HirNodeId,
q_w: HirNodeId,
k_w: HirNodeId,
v_w: HirNodeId,
o_w: HirNodeId,
ln2_g: HirNodeId,
ln2_b: HirNodeId,
gate_w: HirNodeId,
up_w: HirNodeId,
down_w: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
eps: f32,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn transformer_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
Standard pre-norm transformer decoder block — alias for
Self::llama_decoder_block (LLaMA / GPT-style layers).
Sourcepub fn qwen35_mtp_head(
&mut self,
h_pre_norm: HirNodeId,
input_ids: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
last_token_idx: HirNodeId,
embed_w: HirNodeId,
hnorm_w: HirNodeId,
hnorm_b: HirNodeId,
enorm_w: HirNodeId,
enorm_b: HirNodeId,
eh_w: HirNodeId,
fa_attn_norm_w: HirNodeId,
fa_attn_norm_b: HirNodeId,
fa_q_gate_w: HirNodeId,
fa_k_w: HirNodeId,
fa_v_w: HirNodeId,
fa_q_norm_w: HirNodeId,
fa_q_norm_b: HirNodeId,
fa_k_norm_w: HirNodeId,
fa_k_norm_b: HirNodeId,
fa_o_w: HirNodeId,
fa_post_norm_w: HirNodeId,
fa_post_norm_b: HirNodeId,
fa_gate_w: HirNodeId,
fa_up_w: HirNodeId,
fa_down_w: HirNodeId,
head_norm_w: HirNodeId,
head_norm_b: HirNodeId,
lm_head_w: HirNodeId,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
n_rot: usize,
n_embd: usize,
n_ff: usize,
mtp_vocab: usize,
eps: f32,
out_shape: Shape,
) -> HirNodeId
pub fn qwen35_mtp_head( &mut self, h_pre_norm: HirNodeId, input_ids: HirNodeId, cos: HirNodeId, sin: HirNodeId, last_token_idx: HirNodeId, embed_w: HirNodeId, hnorm_w: HirNodeId, hnorm_b: HirNodeId, enorm_w: HirNodeId, enorm_b: HirNodeId, eh_w: HirNodeId, fa_attn_norm_w: HirNodeId, fa_attn_norm_b: HirNodeId, fa_q_gate_w: HirNodeId, fa_k_w: HirNodeId, fa_v_w: HirNodeId, fa_q_norm_w: HirNodeId, fa_q_norm_b: HirNodeId, fa_k_norm_w: HirNodeId, fa_k_norm_b: HirNodeId, fa_o_w: HirNodeId, fa_post_norm_w: HirNodeId, fa_post_norm_b: HirNodeId, fa_gate_w: HirNodeId, fa_up_w: HirNodeId, fa_down_w: HirNodeId, head_norm_w: HirNodeId, head_norm_b: HirNodeId, lm_head_w: HirNodeId, num_heads: usize, num_kv_heads: usize, head_dim: usize, n_rot: usize, n_embd: usize, n_ff: usize, mtp_vocab: usize, eps: f32, out_shape: Shape, ) -> HirNodeId
Qwen3.5 MTP draft head — see blocks::lower_qwen35_mtp_head.
Sourcepub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId
pub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId
Escape hatch — embed a single MIR Op verbatim.
pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>)
Sourcepub fn lower_to_mir(self) -> Result<MirModule, LowerError>
pub fn lower_to_mir(self) -> Result<MirModule, LowerError>
Lower this module to MIR.
Sourcepub fn lower_for_autodiff(self) -> Result<MirModule, LowerError>
pub fn lower_for_autodiff(self) -> Result<MirModule, LowerError>
Lower with FusionPolicy::for_autodiff — primitive MIR chains
that need less unfuse work before rlx_opt::prepare_graph_for_ad.
Sourcepub fn wrap_mir_graph(graph: Graph) -> Self
pub fn wrap_mir_graph(graph: Graph) -> Self
Wrap an existing MIR [Graph] as a HIR module (HirOp::Mir per node).
Enables Session::compile_hir for legacy graph builders during migration.