Skip to main content

rlx_nomic/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 NomicBERT encoder flow — native [`ModelFlow`] assembly.
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, ModelFlow, RopeTablesStage};
20use rlx_ir::{DType, Shape};
21
22use rlx_core::config::NomicBertConfig;
23use rlx_core::flow_util::WeightMapSource;
24use rlx_core::weight_map::WeightMap;
25
26#[derive(Debug, Clone)]
27pub struct NomicFlow<'a> {
28    cfg: &'a NomicBertConfig,
29    batch: usize,
30    seq: usize,
31    profile: CompileProfile,
32}
33
34impl<'a> NomicFlow<'a> {
35    pub fn new(cfg: &'a NomicBertConfig, batch: usize, seq: usize) -> Self {
36        Self {
37            cfg,
38            batch,
39            seq,
40            profile: CompileProfile::encoder(),
41        }
42    }
43
44    pub fn with_profile(mut self, profile: CompileProfile) -> Self {
45        self.profile = profile;
46        self
47    }
48
49    pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
50        let h = self.cfg.hidden_size;
51        let nh = self.cfg.num_attention_heads;
52        let dh = self.cfg.head_dim;
53        let eps = self.cfg.layer_norm_eps as f32;
54        let f = DType::F32;
55
56        let (cos_data, sin_data) = rope_tables(self.cfg);
57
58        let flow = ModelFlow::new("nomic_bert")
59            .with_profile(self.profile)
60            .input("input_ids", Shape::new(&[self.batch, self.seq], DType::F32))
61            .input("attention_mask", Shape::new(&[self.batch, self.seq], f))
62            .input(
63                "token_type_ids",
64                Shape::new(&[self.batch, self.seq], DType::F32),
65            )
66            .rope_tables(RopeTablesStage::param(
67                self.cfg.max_position_embeddings,
68                dh / 2,
69                cos_data,
70                sin_data,
71            ))
72            .embed("embeddings.word_embeddings.weight")
73            .gather_add("token_type_ids", "embeddings.token_type_embeddings.weight")
74            .layer_norm("emb_ln.weight", "emb_ln.bias", eps)
75            .repeat_nomic_layers(self.cfg.num_hidden_layers, h, nh, dh, eps)
76            .output("hidden_states");
77
78        flow.build(&mut WeightMapSource(weights))
79    }
80}
81
82fn rope_tables(cfg: &NomicBertConfig) -> (Vec<f32>, Vec<f32>) {
83    let dh = cfg.head_dim;
84    let half = dh / 2;
85    let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
86    let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
87    for pos in 0..cfg.max_position_embeddings {
88        for i in 0..half {
89            let freq = 1.0 / cfg.rotary_emb_base.powf((2 * i) as f64 / dh as f64);
90            let angle = pos as f64 * freq;
91            let (s, c) = angle.sin_cos();
92            cos_data[pos * half + i] = c as f32;
93            sin_data[pos * half + i] = s as f32;
94        }
95    }
96    (cos_data, sin_data)
97}
98
99pub fn build_nomic_built(
100    cfg: &NomicBertConfig,
101    weights: &mut WeightMap,
102    batch: usize,
103    seq: usize,
104) -> Result<BuiltModel> {
105    NomicFlow::new(cfg, batch, seq).build(weights)
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::collections::HashMap;
112
113    #[test]
114    fn nomic_flow_builds() {
115        let cfg = NomicBertConfig {
116            vocab_size: 32,
117            hidden_size: 16,
118            num_hidden_layers: 1,
119            num_attention_heads: 4,
120            intermediate_size: 32,
121            max_position_embeddings: 32,
122            type_vocab_size: 2,
123            layer_norm_eps: 1e-5,
124            head_dim: 4,
125            rotary_emb_base: 1000.0,
126        };
127        let h = cfg.hidden_size;
128        let int_dim = cfg.intermediate_size;
129        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
130        let z = |n: usize| vec![0.0f32; n];
131        t.insert(
132            "embeddings.word_embeddings.weight".into(),
133            (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
134        );
135        t.insert(
136            "embeddings.token_type_embeddings.weight".into(),
137            (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
138        );
139        t.insert("emb_ln.weight".into(), (z(h), vec![h]));
140        t.insert("emb_ln.bias".into(), (z(h), vec![h]));
141        let lp = "encoder.layers.0";
142        t.insert(
143            format!("{lp}.attn.Wqkv.weight"),
144            (z(h * 3 * h), vec![3 * h, h]),
145        );
146        t.insert(format!("{lp}.attn.out_proj.weight"), (z(h * h), vec![h, h]));
147        t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
148        t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
149        t.insert(
150            format!("{lp}.mlp.fc11.weight"),
151            (z(h * int_dim), vec![int_dim, h]),
152        );
153        t.insert(
154            format!("{lp}.mlp.fc12.weight"),
155            (z(h * int_dim), vec![int_dim, h]),
156        );
157        t.insert(
158            format!("{lp}.mlp.fc2.weight"),
159            (z(int_dim * h), vec![h, int_dim]),
160        );
161        t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
162        t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
163        let mut wm = WeightMap::from_tensors(t);
164        let built = NomicFlow::new(&cfg, 1, 4).build(&mut wm).unwrap();
165        assert!(built.into_hir().unwrap().len() > 10);
166    }
167}