Skip to main content

rlx_flow/blocks/
nomic_layer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::Result;
17use rlx_ir::HirGraphExt;
18use rlx_ir::hir::HirMut;
19
20use super::BlockStage;
21use crate::context::FlowCtx;
22use crate::value::FlowValue;
23
24#[derive(Debug, Clone)]
25pub struct NomicEncoderLayerSpec {
26    pub layer_prefix: String,
27    pub hidden_size: usize,
28    pub num_heads: usize,
29    pub head_dim: usize,
30    pub eps: f32,
31    pub attention_mask_input: String,
32}
33
34impl NomicEncoderLayerSpec {
35    pub fn hf(
36        layer_prefix: impl Into<String>,
37        hidden_size: usize,
38        num_heads: usize,
39        head_dim: usize,
40        eps: f32,
41    ) -> Self {
42        Self {
43            layer_prefix: layer_prefix.into(),
44            hidden_size,
45            num_heads,
46            head_dim,
47            eps,
48            attention_mask_input: "attention_mask".into(),
49        }
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct NomicEncoderLayerStage {
55    pub spec: NomicEncoderLayerSpec,
56}
57
58impl NomicEncoderLayerStage {
59    pub fn new(spec: NomicEncoderLayerSpec) -> Self {
60        Self { spec }
61    }
62}
63
64impl BlockStage for NomicEncoderLayerStage {
65    fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
66        let spec = &self.spec;
67        let h = spec.hidden_size;
68        let nh = spec.num_heads;
69        let dh = spec.head_dim;
70        let lp = &spec.layer_prefix;
71
72        let cos = ctx
73            .state
74            .rope_cos
75            .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
76        let sin = ctx
77            .state
78            .rope_sin
79            .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
80        let mask_id = ctx
81            .state
82            .inputs
83            .get(&spec.attention_mask_input)
84            .map(|(id, _)| *id)
85            .ok_or_else(|| {
86                anyhow::anyhow!(
87                    "NomicEncoderLayer requires input `{}`",
88                    spec.attention_mask_input
89                )
90            })?;
91
92        let qkv_w = ctx.load_param(&format!("{lp}.attn.Wqkv.weight"), true)?;
93        let out_w = ctx.load_param(&format!("{lp}.attn.out_proj.weight"), true)?;
94        let ln1_g = ctx.load_param(&format!("{lp}.norm1.weight"), false)?;
95        let ln1_b = ctx.load_param(&format!("{lp}.norm1.bias"), false)?;
96        let fc11_w = ctx.load_param(&format!("{lp}.mlp.fc11.weight"), true)?;
97        let fc12_w = ctx.load_param(&format!("{lp}.mlp.fc12.weight"), true)?;
98        let fc2_w = ctx.load_param(&format!("{lp}.mlp.fc2.weight"), true)?;
99        let ln2_g = ctx.load_param(&format!("{lp}.norm2.weight"), false)?;
100        let ln2_b = ctx.load_param(&format!("{lp}.norm2.bias"), false)?;
101
102        let mut gb = HirMut::new(ctx.hir());
103        let skip = input.id;
104
105        let qkv = gb.mm(skip, qkv_w);
106        let last_ax = gb.shape(qkv).rank() - 1;
107        let q = gb.narrow_(qkv, last_ax, 0, h);
108        let k = gb.narrow_(qkv, last_ax, h, h);
109        let v = gb.narrow_(qkv, last_ax, 2 * h, h);
110        let q_rope = gb.rope(q, cos, sin, dh);
111        let k_rope = gb.rope(k, cos, sin, dh);
112        let attn = gb.attention_(q_rope, k_rope, v, mask_id, nh, dh);
113
114        let attn_out = gb.mm(attn, out_w);
115        let res1 = gb.add(attn_out, skip);
116        let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
117
118        let up = gb.mm(normed1, fc11_w);
119        let gate_mm = gb.mm(normed1, fc12_w);
120        let gate = gb.silu(gate_mm);
121        let swiglu = gb.mul(up, gate);
122        let ffn_out = gb.mm(swiglu, fc2_w);
123
124        let res2 = gb.add(ffn_out, normed1);
125        let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
126
127        Ok(Some(ctx.wrap(out, input.shape.clone())))
128    }
129}