1use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, ModelFlow, RopeTablesStage};
20use rlx_ir::{DType, Shape};
21
22use rlx_core::config::NomicBertConfig;
23use rlx_core::flow_util::WeightMapSource;
24use rlx_core::weight_map::WeightMap;
25
26#[derive(Debug, Clone)]
27pub struct NomicFlow<'a> {
28 cfg: &'a NomicBertConfig,
29 batch: usize,
30 seq: usize,
31 profile: CompileProfile,
32}
33
34impl<'a> NomicFlow<'a> {
35 pub fn new(cfg: &'a NomicBertConfig, batch: usize, seq: usize) -> Self {
36 Self {
37 cfg,
38 batch,
39 seq,
40 profile: CompileProfile::encoder(),
41 }
42 }
43
44 pub fn with_profile(mut self, profile: CompileProfile) -> Self {
45 self.profile = profile;
46 self
47 }
48
49 pub fn build(self, weights: &mut WeightMap) -> Result<BuiltModel> {
50 let h = self.cfg.hidden_size;
51 let nh = self.cfg.num_attention_heads;
52 let dh = self.cfg.head_dim;
53 let eps = self.cfg.layer_norm_eps as f32;
54 let f = DType::F32;
55
56 let (cos_data, sin_data) = rope_tables(self.cfg);
57
58 let flow = ModelFlow::new("nomic_bert")
59 .with_profile(self.profile)
60 .input("input_ids", Shape::new(&[self.batch, self.seq], DType::F32))
61 .input("attention_mask", Shape::new(&[self.batch, self.seq], f))
62 .input(
63 "token_type_ids",
64 Shape::new(&[self.batch, self.seq], DType::F32),
65 )
66 .rope_tables(RopeTablesStage::param(
67 self.cfg.max_position_embeddings,
68 dh / 2,
69 cos_data,
70 sin_data,
71 ))
72 .embed("embeddings.word_embeddings.weight")
73 .gather_add("token_type_ids", "embeddings.token_type_embeddings.weight")
74 .layer_norm("emb_ln.weight", "emb_ln.bias", eps)
75 .repeat_nomic_layers(self.cfg.num_hidden_layers, h, nh, dh, eps)
76 .output("hidden_states");
77
78 flow.build(&mut WeightMapSource(weights))
79 }
80}
81
82fn rope_tables(cfg: &NomicBertConfig) -> (Vec<f32>, Vec<f32>) {
83 let dh = cfg.head_dim;
84 let half = dh / 2;
85 let mut cos_data = vec![0f32; cfg.max_position_embeddings * half];
86 let mut sin_data = vec![0f32; cfg.max_position_embeddings * half];
87 for pos in 0..cfg.max_position_embeddings {
88 for i in 0..half {
89 let freq = 1.0 / cfg.rotary_emb_base.powf((2 * i) as f64 / dh as f64);
90 let angle = pos as f64 * freq;
91 let (s, c) = angle.sin_cos();
92 cos_data[pos * half + i] = c as f32;
93 sin_data[pos * half + i] = s as f32;
94 }
95 }
96 (cos_data, sin_data)
97}
98
99pub fn build_nomic_built(
100 cfg: &NomicBertConfig,
101 weights: &mut WeightMap,
102 batch: usize,
103 seq: usize,
104) -> Result<BuiltModel> {
105 NomicFlow::new(cfg, batch, seq).build(weights)
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::collections::HashMap;
112
113 #[test]
114 fn nomic_flow_builds() {
115 let cfg = NomicBertConfig {
116 vocab_size: 32,
117 hidden_size: 16,
118 num_hidden_layers: 1,
119 num_attention_heads: 4,
120 intermediate_size: 32,
121 max_position_embeddings: 32,
122 type_vocab_size: 2,
123 layer_norm_eps: 1e-5,
124 head_dim: 4,
125 rotary_emb_base: 1000.0,
126 };
127 let h = cfg.hidden_size;
128 let int_dim = cfg.intermediate_size;
129 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
130 let z = |n: usize| vec![0.0f32; n];
131 t.insert(
132 "embeddings.word_embeddings.weight".into(),
133 (z(cfg.vocab_size * h), vec![cfg.vocab_size, h]),
134 );
135 t.insert(
136 "embeddings.token_type_embeddings.weight".into(),
137 (z(cfg.type_vocab_size * h), vec![cfg.type_vocab_size, h]),
138 );
139 t.insert("emb_ln.weight".into(), (z(h), vec![h]));
140 t.insert("emb_ln.bias".into(), (z(h), vec![h]));
141 let lp = "encoder.layers.0";
142 t.insert(
143 format!("{lp}.attn.Wqkv.weight"),
144 (z(h * 3 * h), vec![3 * h, h]),
145 );
146 t.insert(format!("{lp}.attn.out_proj.weight"), (z(h * h), vec![h, h]));
147 t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
148 t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
149 t.insert(
150 format!("{lp}.mlp.fc11.weight"),
151 (z(h * int_dim), vec![int_dim, h]),
152 );
153 t.insert(
154 format!("{lp}.mlp.fc12.weight"),
155 (z(h * int_dim), vec![int_dim, h]),
156 );
157 t.insert(
158 format!("{lp}.mlp.fc2.weight"),
159 (z(int_dim * h), vec![h, int_dim]),
160 );
161 t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
162 t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
163 let mut wm = WeightMap::from_tensors(t);
164 let built = NomicFlow::new(&cfg, 1, 4).build(&mut wm).unwrap();
165 assert!(built.into_hir().unwrap().len() > 10);
166 }
167}