Skip to main content

HirModule

Struct HirModule 

Source
pub struct HirModule {
    pub name: String,
    pub outputs: Vec<HirNodeId>,
    pub fusion_policy: FusionPolicy,
    /* private fields */
}
Expand description

High-level module — model builder output.

Fields§

§name: String§outputs: Vec<HirNodeId>§fusion_policy: FusionPolicy

How block ops lower to MIR. Default: FusionPolicy::Direct for new model code (fusion as a first-class citizen).

Implementations§

Source§

impl HirModule

Source

pub fn new(name: impl Into<String>) -> Self

Source

pub fn with_fusion_policy(self, policy: FusionPolicy) -> Self

Source

pub fn len(&self) -> usize

Source

pub fn is_empty(&self) -> bool

Source

pub fn nodes(&self) -> &[HirNode]

Source

pub fn node(&self, id: HirNodeId) -> &HirNode

Source

pub fn node_mut(&mut self, id: HirNodeId) -> &mut HirNode

Source

pub fn named( &mut self, name: impl Into<String>, build: impl FnOnce(&mut Self) -> HirNodeId, ) -> HirNodeId

Build a named block — sets HirNode::name on the returned node.

Source

pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId

Source

pub fn input_batch_seq( &mut self, name: impl Into<String>, batch: u32, seq: u32, hidden: usize, dtype: DType, ) -> HirNodeId

[batch, seq, hidden] input with symbolic leading axes.

Source

pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId

Source

pub fn linear( &mut self, x: HirNodeId, weight: HirNodeId, bias: Option<HirNodeId>, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId

Source

pub fn linear_fused( &mut self, x: HirNodeId, weight: HirNodeId, bias: HirNodeId, activation: Option<Activation>, out_shape: Shape, ) -> HirNodeId

Emit HirOp::LinearFused — fused matmul+bias+act at MIR level.

Source

pub fn shared_linear_pair( &mut self, x: HirNodeId, w_first: HirNodeId, w_second: HirNodeId, out_shape: Shape, ) -> (HirNodeId, HirNodeId)

Two matmuls sharing x. Returns (first, second) in weight order.

Source

pub fn swiglu_ffn( &mut self, x: HirNodeId, up_w: HirNodeId, gate_w: HirNodeId, down_w: HirNodeId, out_shape: Shape, ) -> HirNodeId

Source

pub fn residual_rms_norm( &mut self, x: HirNodeId, residual: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId

Source

pub fn attention( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId

Scaled dot-product attention — see HirOp::Attention.

Source

pub fn depthwise_conv1d_causal( &mut self, input: HirNodeId, weight: HirNodeId, left_pad: HirNodeId, kernel_size: usize, out_shape: Shape, ) -> HirNodeId

Causal depthwise Conv1d — Conformer / Wav2Vec2-BERT conv module.

input and left_pad are [B, S, C] / [B, K-1, C]; weight is [C, 1, 1, K] in grouped Conv2d layout.

Source

pub fn dequant_matmul( &mut self, x: HirNodeId, w: HirNodeId, scale: Option<HirNodeId>, zp: Option<HirNodeId>, scheme: QuantScheme, out_shape: Shape, ) -> HirNodeId

Fused dequant + matmul — see HirOp::DequantMatMul.

Source

pub fn gated_delta_net( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId

Gated DeltaNet without carry state (prefill / reset per batch).

Source

pub fn gated_delta_net_carry( &mut self, q: HirNodeId, k: HirNodeId, v: HirNodeId, g: HirNodeId, beta: HirNodeId, state: HirNodeId, state_size: usize, out_shape: Shape, ) -> HirNodeId

Gated DeltaNet with decode carry — threads state in/out.

Source

pub fn rope( &mut self, x: HirNodeId, cos: HirNodeId, sin: HirNodeId, head_dim: usize, n_rot: usize, out_shape: Shape, ) -> HirNodeId

Rotary position embedding.

Source

pub fn rms_norm( &mut self, x: HirNodeId, gamma: HirNodeId, beta: HirNodeId, eps: f32, out_shape: Shape, ) -> HirNodeId

RMS normalization (no residual add).

Source

pub fn llama_decoder_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId

LLaMA / LLaMA-3.2 decoder layer (pre-norm GQA + SwiGLU).

Source

pub fn transformer_block( &mut self, x: HirNodeId, ln1_g: HirNodeId, ln1_b: HirNodeId, q_w: HirNodeId, k_w: HirNodeId, v_w: HirNodeId, o_w: HirNodeId, ln2_g: HirNodeId, ln2_b: HirNodeId, gate_w: HirNodeId, up_w: HirNodeId, down_w: HirNodeId, cos: HirNodeId, sin: HirNodeId, mask: Option<HirNodeId>, num_heads: usize, head_dim: usize, num_kv_heads: usize, eps: f32, mask_kind: MaskKind, out_shape: Shape, ) -> HirNodeId

Standard pre-norm transformer decoder block — alias for Self::llama_decoder_block (LLaMA / GPT-style layers).

Source

pub fn qwen35_mtp_head( &mut self, h_pre_norm: HirNodeId, input_ids: HirNodeId, cos: HirNodeId, sin: HirNodeId, last_token_idx: HirNodeId, embed_w: HirNodeId, hnorm_w: HirNodeId, hnorm_b: HirNodeId, enorm_w: HirNodeId, enorm_b: HirNodeId, eh_w: HirNodeId, fa_attn_norm_w: HirNodeId, fa_attn_norm_b: HirNodeId, fa_q_gate_w: HirNodeId, fa_k_w: HirNodeId, fa_v_w: HirNodeId, fa_q_norm_w: HirNodeId, fa_q_norm_b: HirNodeId, fa_k_norm_w: HirNodeId, fa_k_norm_b: HirNodeId, fa_o_w: HirNodeId, fa_post_norm_w: HirNodeId, fa_post_norm_b: HirNodeId, fa_gate_w: HirNodeId, fa_up_w: HirNodeId, fa_down_w: HirNodeId, head_norm_w: HirNodeId, head_norm_b: HirNodeId, lm_head_w: HirNodeId, num_heads: usize, num_kv_heads: usize, head_dim: usize, n_rot: usize, n_embd: usize, n_ff: usize, mtp_vocab: usize, eps: f32, out_shape: Shape, ) -> HirNodeId

Qwen3.5 MTP draft head — see blocks::lower_qwen35_mtp_head.

Source

pub fn mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId

Escape hatch — embed a single MIR Op verbatim.

Source

pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>)

Source

pub fn lower_to_mir(self) -> Result<MirModule, LowerError>

Lower this module to MIR.

Source

pub fn lower_for_autodiff(self) -> Result<MirModule, LowerError>

Lower with FusionPolicy::for_autodiff — primitive MIR chains that need less unfuse work before rlx_opt::prepare_graph_for_ad.

Source

pub fn wrap_mir_graph(graph: Graph) -> Self

Wrap an existing MIR [Graph] as a HIR module (HirOp::Mir per node). Enables Session::compile_hir for legacy graph builders during migration.

Source§

impl HirModule

Source

pub fn inspect(&self) -> String

Text dump for inspection. Alias for inspect_hir.

Trait Implementations§

Source§

impl Clone for HirModule

Source§

fn clone(&self) -> HirModule

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for HirModule

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Display for HirModule

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl From<HirModule> for GraphModule

Source§

fn from(hir: HirModule) -> Self

Converts to this type from the input type.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.