pub struct HirModule {
pub name: String,
pub outputs: Vec<HirNodeId>,
pub fusion_policy: FusionPolicy,
/* private fields */
}Expand description
High-level module — model builder output.
Fields§
§name: String§outputs: Vec<HirNodeId>§fusion_policy: FusionPolicyHow block ops lower to MIR. Default: FusionPolicy::Direct
for new model code (fusion as a first-class citizen).
Implementations§
Source§impl HirModule
impl HirModule
pub fn new(name: impl Into<String>) -> HirModule
pub fn with_fusion_policy(self, policy: FusionPolicy) -> HirModule
pub fn len(&self) -> usize
pub fn is_empty(&self) -> bool
pub fn nodes(&self) -> &[HirNode]
pub fn node(&self, id: HirNodeId) -> &HirNode
pub fn node_mut(&mut self, id: HirNodeId) -> &mut HirNode
Sourcepub fn named(
&mut self,
name: impl Into<String>,
build: impl FnOnce(&mut HirModule) -> HirNodeId,
) -> HirNodeId
pub fn named( &mut self, name: impl Into<String>, build: impl FnOnce(&mut HirModule) -> HirNodeId, ) -> HirNodeId
Build a named block — sets HirNode::name on the returned node.
pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId
Sourcepub fn input_batch_seq(
&mut self,
name: impl Into<String>,
batch: u32,
seq: u32,
hidden: usize,
dtype: DType,
) -> HirNodeId
pub fn input_batch_seq( &mut self, name: impl Into<String>, batch: u32, seq: u32, hidden: usize, dtype: DType, ) -> HirNodeId
[batch, seq, hidden] input with symbolic leading axes.
pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId
pub fn linear( &mut self, x: HirNodeId, weight: HirNodeId, bias: Option<HirNodeId>, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId
Sourcepub fn linear_fused(
&mut self,
x: HirNodeId,
weight: HirNodeId,
bias: HirNodeId,
activation: Option<Activation>,
out_shape: Shape,
) -> HirNodeId
pub fn linear_fused( &mut self, x: HirNodeId, weight: HirNodeId, bias: HirNodeId, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId
Emit HirOp::LinearFused — fused matmul+bias+act at MIR level.
Two matmuls sharing x. Returns (first, second) in weight order.
pub fn swiglu_ffn( &mut self, x: HirNodeId, up_w: HirNodeId, gate_w: HirNodeId, down_w: HirNodeId, out_shape: Shape, ) -> HirNodeId
pub fn residual_rms_norm( &mut self, x: HirNodeId, residual: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId
Sourcepub fn attention(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn attention( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
Scaled dot-product attention — see HirOp::Attention.
Sourcepub fn depthwise_conv1d_causal(
&mut self,
input: HirNodeId,
weight: HirNodeId,
left_pad: HirNodeId,
kernel_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn depthwise_conv1d_causal( &mut self, input: HirNodeId, weight: HirNodeId, left_pad: HirNodeId, kernel_size: usize, out_shape: Shape, ) -> HirNodeId
Causal depthwise Conv1d — Conformer / Wav2Vec2-BERT conv module.
input and left_pad are [B, S, C] / [B, K-1, C]; weight is
[C, 1, 1, K] in grouped Conv2d layout.
Sourcepub fn dequant_matmul(
&mut self,
x: HirNodeId,
w: HirNodeId,
scale: Option<HirNodeId>,
zp: Option<HirNodeId>,
scheme: QuantScheme,
out_shape: Shape,
) -> HirNodeId
pub fn dequant_matmul( &mut self, x: HirNodeId, w: HirNodeId, scale: Option<HirNodeId>, zp: Option<HirNodeId>, scheme: QuantScheme, out_shape: Shape, ) -> HirNodeId
Fused dequant + matmul — see HirOp::DequantMatMul.
Sourcepub fn gated_delta_net(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
g: HirNodeId,
beta: HirNodeId,
state_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn gated_delta_net( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId
Gated DeltaNet without carry state (prefill / reset per batch).
Sourcepub fn gated_delta_net_carry(
&mut self,
q: HirNodeId,
k: HirNodeId,
v: HirNodeId,
g: HirNodeId,
beta: HirNodeId,
state: HirNodeId,
state_size: usize,
out_shape: Shape,
) -> HirNodeId
pub fn gated_delta_net_carry( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId
Gated DeltaNet with decode carry — threads state in/out.
Sourcepub fn rope(
&mut self,
x: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
head_dim: usize,
n_rot: usize,
out_shape: Shape,
) -> HirNodeId
pub fn rope( &mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, head_dim: usize, n_rot: usize, out_shape: Shape, ) -> HirNodeId
Rotary position embedding.
Sourcepub fn rms_norm(
&mut self,
x: HirNodeId,
gamma: HirNodeId,
beta: HirNodeId,
eps: f32,
out_shape: Shape,
) -> HirNodeId
pub fn rms_norm( &mut self, x: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId
RMS normalization (no residual add).
Sourcepub fn llama_decoder_block(
&mut self,
x: HirNodeId,
ln1_g: HirNodeId,
ln1_b: HirNodeId,
q_w: HirNodeId,
k_w: HirNodeId,
v_w: HirNodeId,
o_w: HirNodeId,
ln2_g: HirNodeId,
ln2_b: HirNodeId,
gate_w: HirNodeId,
up_w: HirNodeId,
down_w: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
eps: f32,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn llama_decoder_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
LLaMA / LLaMA-3.2 decoder layer (pre-norm GQA + SwiGLU).
Sourcepub fn transformer_block(
&mut self,
x: HirNodeId,
ln1_g: HirNodeId,
ln1_b: HirNodeId,
q_w: HirNodeId,
k_w: HirNodeId,
v_w: HirNodeId,
o_w: HirNodeId,
ln2_g: HirNodeId,
ln2_b: HirNodeId,
gate_w: HirNodeId,
up_w: HirNodeId,
down_w: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
mask: Option<HirNodeId>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
eps: f32,
mask_kind: MaskKind,
out_shape: Shape,
) -> HirNodeId
pub fn transformer_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId
Standard pre-norm transformer decoder block — alias for
Self::llama_decoder_block (LLaMA / GPT-style layers).
Sourcepub fn qwen35_mtp_head(
&mut self,
h_pre_norm: HirNodeId,
input_ids: HirNodeId,
cos: HirNodeId,
sin: HirNodeId,
last_token_idx: HirNodeId,
embed_w: HirNodeId,
hnorm_w: HirNodeId,
hnorm_b: HirNodeId,
enorm_w: HirNodeId,
enorm_b: HirNodeId,
eh_w: HirNodeId,
fa_attn_norm_w: HirNodeId,
fa_attn_norm_b: HirNodeId,
fa_q_gate_w: HirNodeId,
fa_k_w: HirNodeId,
fa_v_w: HirNodeId,
fa_q_norm_w: HirNodeId,
fa_q_norm_b: HirNodeId,
fa_k_norm_w: HirNodeId,
fa_k_norm_b: HirNodeId,
fa_o_w: HirNodeId,
fa_post_norm_w: HirNodeId,
fa_post_norm_b: HirNodeId,
fa_gate_w: HirNodeId,
fa_up_w: HirNodeId,
fa_down_w: HirNodeId,
head_norm_w: HirNodeId,
head_norm_b: HirNodeId,
lm_head_w: HirNodeId,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
n_rot: usize,
n_embd: usize,
n_ff: usize,
mtp_vocab: usize,
eps: f32,
out_shape: Shape,
) -> HirNodeId
pub fn qwen35_mtp_head( &mut self, h_pre_norm: HirNodeId, input_ids: HirNodeId, cos: HirNodeId, sin: HirNodeId, last_token_idx: HirNodeId, embed_w: HirNodeId, hnorm_w: HirNodeId, hnorm_b: HirNodeId, enorm_w: HirNodeId, enorm_b: HirNodeId, eh_w: HirNodeId, fa_attn_norm_w: HirNodeId, fa_attn_norm_b: HirNodeId, fa_q_gate_w: HirNodeId, fa_k_w: HirNodeId, fa_v_w: HirNodeId, fa_q_norm_w: HirNodeId, fa_q_norm_b: HirNodeId, fa_k_norm_w: HirNodeId, fa_k_norm_b: HirNodeId, fa_o_w: HirNodeId, fa_post_norm_w: HirNodeId, fa_post_norm_b: HirNodeId, fa_gate_w: HirNodeId, fa_up_w: HirNodeId, fa_down_w: HirNodeId, head_norm_w: HirNodeId, head_norm_b: HirNodeId, lm_head_w: HirNodeId, num_heads: usize, num_kv_heads: usize, head_dim: usize, n_rot: usize, n_embd: usize, n_ff: usize, mtp_vocab: usize, eps: f32, out_shape: Shape, ) -> HirNodeId
Qwen3.5 MTP draft head — see blocks::lower_qwen35_mtp_head.
Sourcepub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId
pub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId
Escape hatch — embed a single MIR Op verbatim.
pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>)
Sourcepub fn lower_to_mir(self) -> Result<MirModule, LowerError>
pub fn lower_to_mir(self) -> Result<MirModule, LowerError>
Lower this module to MIR.
Sourcepub fn lower_for_autodiff(self) -> Result<MirModule, LowerError>
pub fn lower_for_autodiff(self) -> Result<MirModule, LowerError>
Lower with FusionPolicy::for_autodiff — primitive MIR chains
that need less unfuse work before rlx_opt::prepare_graph_for_ad.
Sourcepub fn wrap_mir_graph(graph: Graph) -> HirModule
pub fn wrap_mir_graph(graph: Graph) -> HirModule
Wrap an existing MIR [Graph] as a HIR module (HirOp::Mir per node).
Enables Session::compile_hir for legacy graph builders during migration.
Trait Implementations§
Source§impl<'de> Deserialize<'de> for HirModule
impl<'de> Deserialize<'de> for HirModule
Source§fn deserialize<__D>(
__deserializer: __D,
) -> Result<HirModule, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(
__deserializer: __D,
) -> Result<HirModule, <__D as Deserializer<'de>>::Error>where
__D: Deserializer<'de>,
Source§impl From<HirModule> for GraphModule
impl From<HirModule> for GraphModule
Source§fn from(hir: HirModule) -> GraphModule
fn from(hir: HirModule) -> GraphModule
Source§impl Serialize for HirModule
impl Serialize for HirModule
Source§fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
fn serialize<__S>(
&self,
__serializer: __S,
) -> Result<<__S as Serializer>::Ok, <__S as Serializer>::Error>where
__S: Serializer,
Auto Trait Implementations§
impl Freeze for HirModule
impl RefUnwindSafe for HirModule
impl Send for HirModule
impl Sync for HirModule
impl Unpin for HirModule
impl UnsafeUnpin for HirModule
impl UnwindSafe for HirModule
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> DeserializeOwned for Twhere
T: for<'de> Deserialize<'de>,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more