rlx_flow/blocks/
nomic_layer.rs1use anyhow::Result;
5use rlx_ir::HirGraphExt;
6use rlx_ir::hir::HirMut;
7
8use super::BlockStage;
9use crate::context::FlowCtx;
10use crate::value::FlowValue;
11
12#[derive(Debug, Clone)]
13pub struct NomicEncoderLayerSpec {
14 pub layer_prefix: String,
15 pub hidden_size: usize,
16 pub num_heads: usize,
17 pub head_dim: usize,
18 pub eps: f32,
19 pub attention_mask_input: String,
20}
21
22impl NomicEncoderLayerSpec {
23 pub fn hf(
24 layer_prefix: impl Into<String>,
25 hidden_size: usize,
26 num_heads: usize,
27 head_dim: usize,
28 eps: f32,
29 ) -> Self {
30 Self {
31 layer_prefix: layer_prefix.into(),
32 hidden_size,
33 num_heads,
34 head_dim,
35 eps,
36 attention_mask_input: "attention_mask".into(),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct NomicEncoderLayerStage {
43 pub spec: NomicEncoderLayerSpec,
44}
45
46impl NomicEncoderLayerStage {
47 pub fn new(spec: NomicEncoderLayerSpec) -> Self {
48 Self { spec }
49 }
50}
51
52impl BlockStage for NomicEncoderLayerStage {
53 fn emit(&self, ctx: &mut FlowCtx<'_>, input: FlowValue) -> Result<Option<FlowValue>> {
54 let spec = &self.spec;
55 let h = spec.hidden_size;
56 let nh = spec.num_heads;
57 let dh = spec.head_dim;
58 let lp = &spec.layer_prefix;
59
60 let cos = ctx
61 .state
62 .rope_cos
63 .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
64 let sin = ctx
65 .state
66 .rope_sin
67 .ok_or_else(|| anyhow::anyhow!("NomicEncoderLayer requires RopeTables"))?;
68 let mask_id = ctx
69 .state
70 .inputs
71 .get(&spec.attention_mask_input)
72 .map(|(id, _)| *id)
73 .ok_or_else(|| {
74 anyhow::anyhow!(
75 "NomicEncoderLayer requires input `{}`",
76 spec.attention_mask_input
77 )
78 })?;
79
80 let qkv_w = ctx.load_param(&format!("{lp}.attn.Wqkv.weight"), true)?;
81 let out_w = ctx.load_param(&format!("{lp}.attn.out_proj.weight"), true)?;
82 let ln1_g = ctx.load_param(&format!("{lp}.norm1.weight"), false)?;
83 let ln1_b = ctx.load_param(&format!("{lp}.norm1.bias"), false)?;
84 let fc11_w = ctx.load_param(&format!("{lp}.mlp.fc11.weight"), true)?;
85 let fc12_w = ctx.load_param(&format!("{lp}.mlp.fc12.weight"), true)?;
86 let fc2_w = ctx.load_param(&format!("{lp}.mlp.fc2.weight"), true)?;
87 let ln2_g = ctx.load_param(&format!("{lp}.norm2.weight"), false)?;
88 let ln2_b = ctx.load_param(&format!("{lp}.norm2.bias"), false)?;
89
90 let mut gb = HirMut::new(ctx.hir());
91 let skip = input.id;
92
93 let qkv = gb.mm(skip, qkv_w);
94 let last_ax = gb.shape(qkv).rank() - 1;
95 let q = gb.narrow_(qkv, last_ax, 0, h);
96 let k = gb.narrow_(qkv, last_ax, h, h);
97 let v = gb.narrow_(qkv, last_ax, 2 * h, h);
98 let q_rope = gb.rope(q, cos, sin, dh);
99 let k_rope = gb.rope(k, cos, sin, dh);
100 let attn = gb.attention_(q_rope, k_rope, v, mask_id, nh, dh);
101
102 let attn_out = gb.mm(attn, out_w);
103 let res1 = gb.add(attn_out, skip);
104 let normed1 = gb.ln(res1, ln1_g, ln1_b, spec.eps);
105
106 let up = gb.mm(normed1, fc11_w);
107 let gate_mm = gb.mm(normed1, fc12_w);
108 let gate = gb.silu(gate_mm);
109 let swiglu = gb.mul(up, gate);
110 let ffn_out = gb.mm(swiglu, fc2_w);
111
112 let res2 = gb.add(ffn_out, normed1);
113 let out = gb.ln(res2, ln2_g, ln2_b, spec.eps);
114
115 Ok(Some(ctx.wrap(out, input.shape.clone())))
116 }
117}