Skip to main content

rlx_flow/blocks/
nomic_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7
8use super::BlockStage;
9use crate::context::FlowCtx;
10use crate::value::FlowValue;
11
12#[derive(Debug, Clone)]
13pub struct NomicEncoderLayerSpec {
14    pub layer_prefix: String,
15    pub hidden_size: usize,
16    pub num_heads: usize,
17    pub head_dim: usize,
18    pub eps: f32,
19    pub attention_mask_input: String,
20}
21
22impl NomicEncoderLayerSpec {
23    pub fn hf(
24        layer_prefix: impl Into<String>,
25        hidden_size: usize,
26        num_heads: usize,
27        head_dim: usize,
28        eps: f32,
29    ) -> Self {
30        Self {
31            layer_prefix: layer_prefix.into(),
32            hidden_size,
33            num_heads,
34            head_dim,
35            eps,
36            attention_mask_input: "attention_mask".into(),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct NomicEncoderLayerStage {
43    pub spec: NomicEncoderLayerSpec,
44}
45
46impl NomicEncoderLayerStage {
47    pub fn new(spec: NomicEncoderLayerSpec) -> Self {
48        Self { spec }
49    }
50}
51
52impl BlockStage for NomicEncoderLayerStage {
53    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
54        let spec = &self.spec;
55        let h = spec.hidden_size;
56        let nh = spec.num_heads;
57        let dh = spec.head_dim;
58        let lp = &spec.layer_prefix;
59
60        let cos = ctx
61            .state
62            .rope_cos
63            .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
64        let sin = ctx
65            .state
66            .rope_sin
67            .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
68        let mask_id = ctx
69            .state
70            .inputs
71            .get(&spec.attention_mask_input)
72            .map(|(id, _)| *id)
73            .ok_or_else(|| {
74                anyhow::anyhow!(
75                    "NomicEncoderLayer requires input `{}`",
76                    spec.attention_mask_input
77                )
78            })?;
79
80        let qkv_w = ctx.load_param(&format!("{lp}.attn.Wqkv.weight"), true)?;
81        let out_w = ctx.load_param(&format!("{lp}.attn.out_proj.weight"), true)?;
82        let ln1_g = ctx.load_param(&format!("{lp}.norm1.weight"), false)?;
83        let ln1_b = ctx.load_param(&format!("{lp}.norm1.bias"), false)?;
84        let fc11_w = ctx.load_param(&format!("{lp}.mlp.fc11.weight"), true)?;
85        let fc12_w = ctx.load_param(&format!("{lp}.mlp.fc12.weight"), true)?;
86        let fc2_w = ctx.load_param(&format!("{lp}.mlp.fc2.weight"), true)?;
87        let ln2_g = ctx.load_param(&format!("{lp}.norm2.weight"), false)?;
88        let ln2_b = ctx.load_param(&format!("{lp}.norm2.bias"), false)?;
89
90        let mut gb = HirMut::new(ctx.hir());
91        let skip = input.id;
92
93        let qkv = gb.mm(skip, qkv_w);
94        let last_ax = gb.shape(qkv).rank() - 1;
95        let q = gb.narrow_(qkv, last_ax, 0, h);
96        let k = gb.narrow_(qkv, last_ax, h, h);
97        let v = gb.narrow_(qkv, last_ax, 2 * h, h);
98        let q_rope = gb.rope(q, cos, sin, dh);
99        let k_rope = gb.rope(k, cos, sin, dh);
100        let attn = gb.attention_(q_rope, k_rope, v, mask_id, nh, dh);
101
102        let attn_out = gb.mm(attn, out_w);
103        let res1 = gb.add(attn_out, skip);
104        let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
105
106        let up = gb.mm(normed1, fc11_w);
107        let gate_mm = gb.mm(normed1, fc12_w);
108        let gate = gb.silu(gate_mm);
109        let swiglu = gb.mul(up, gate);
110        let ffn_out = gb.mm(swiglu, fc2_w);
111
112        let res2 = gb.add(ffn_out, normed1);
113        let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
114
115        Ok(Some(ctx.wrap(out, input.shape.clone())))
116    }
117}