1use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use crate::vision::VisionPreprocessWeights;
23use rlx_core::config::NomicVisionConfig;
24use rlx_core::flow_util::WeightMapSource;
25use rlx_core::weight_map::WeightMap;
26
27#[derive(Debug, Clone)]
28pub struct NomicVisionFlow<'a> {
29 cfg: &'a NomicVisionConfig,
30 batch: usize,
31}
32
33impl<'a> NomicVisionFlow<'a> {
34 pub fn new(cfg: &'a NomicVisionConfig, batch: usize) -> Self {
35 Self { cfg, batch }
36 }
37
38 pub fn build(self, weights: &mut WeightMap) -> Result<NomicVisionBuilt> {
39 build_nomic_vision_built(self.cfg, weights, self.batch)
40 }
41}
42
43pub struct NomicVisionBuilt {
44 pub model: BuiltModel,
45 pub preprocess: VisionPreprocessWeights,
46}
47
48pub fn build_nomic_vision_built(
49 cfg: &NomicVisionConfig,
50 weights: &mut WeightMap,
51 batch: usize,
52) -> Result<NomicVisionBuilt> {
53 let preprocess = extract_vision_preprocess(weights)?;
54 let final_ln = resolve_final_norm_prefix(weights);
55
56 let h = cfg.hidden_size;
57 let nh = cfg.num_attention_heads;
58 let eps = cfg.layer_norm_eps() as f32;
59 let ps = cfg.patch_size;
60 let np = (cfg.img_size / ps) * (cfg.img_size / ps);
61 let seq = np + 1;
62 let f = DType::F32;
63
64 let model = ModelFlow::new("nomic_vision")
65 .with_profile(CompileProfile::encoder())
66 .input("hidden", Shape::new(&[batch, seq, h], f))
67 .attn_mask_ones(batch, seq)
68 .repeat_vision_layers(cfg.num_hidden_layers, h, nh, eps)
69 .layer_norm(
70 format!("{final_ln}.weight"),
71 format!("{final_ln}.bias"),
72 eps,
73 )
74 .cls_token_pool(batch, h)
75 .output("cls")
76 .build(&mut WeightMapSource(weights))?;
77
78 Ok(NomicVisionBuilt { model, preprocess })
79}
80
81fn extract_vision_preprocess(weights: &mut WeightMap) -> Result<VisionPreprocessWeights> {
82 let (proj_w_data, proj_w_shape) = weights.take_transposed("embeddings.proj.weight")?;
83 let (proj_b_data, _) = weights.take("embeddings.proj.bias")?;
84 let (cls_token_data, _) = weights.take("embeddings.cls_token")?;
85 let (pos_embed_data, _) = weights.take("embeddings.pos_embed")?;
86 Ok(VisionPreprocessWeights {
87 proj_w: proj_w_data,
88 proj_w_cols: proj_w_shape.last().copied().unwrap_or(0),
89 proj_b: proj_b_data,
90 cls_token: cls_token_data,
91 pos_embed: pos_embed_data,
92 })
93}
94
95fn resolve_final_norm_prefix(weights: &WeightMap) -> &'static str {
96 if weights.has("norm.weight") {
97 "norm"
98 } else if weights.has("selector.norm1.weight") {
99 "selector.norm1"
100 } else {
101 "encoder.norm"
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use std::collections::HashMap;
109
110 fn tiny_cfg() -> NomicVisionConfig {
111 NomicVisionConfig {
112 hidden_size: 16,
113 num_hidden_layers: 1,
114 num_attention_heads: 4,
115 n_inner: 32,
116 img_size: 32,
117 patch_size: 16,
118 layer_norm_epsilon: 1e-5,
119 }
120 }
121
122 fn synth_weights(cfg: &NomicVisionConfig) -> WeightMap {
123 let h = cfg.hidden_size;
124 let int_dim = cfg.intermediate_size();
125 let ps = cfg.patch_size;
126 let patch_dim = 3 * ps * ps;
127 let np = (cfg.img_size / ps) * (cfg.img_size / ps);
128 let seq = np + 1;
129 let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
130 let z = |n: usize| vec![0.0f32; n];
131 t.insert(
132 "embeddings.proj.weight".into(),
133 (z(patch_dim * h), vec![h, patch_dim]),
134 );
135 t.insert("embeddings.proj.bias".into(), (z(h), vec![h]));
136 t.insert("embeddings.cls_token".into(), (z(h), vec![1, 1, h]));
137 t.insert("embeddings.pos_embed".into(), (z(seq * h), vec![1, seq, h]));
138 let lp = "layers.0";
139 t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
140 t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
141 t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
142 t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
143 t.insert(
144 format!("{lp}.attn.Wqkv.weight"),
145 (z(3 * h * h), vec![3 * h, h]),
146 );
147 t.insert(format!("{lp}.attn.Wqkv.bias"), (z(3 * h), vec![3 * h]));
148 t.insert(format!("{lp}.attn.out_proj.weight"), (z(h * h), vec![h, h]));
149 t.insert(format!("{lp}.attn.out_proj.bias"), (z(h), vec![h]));
150 t.insert(
151 format!("{lp}.mlp.fc11.weight"),
152 (z(int_dim * h), vec![int_dim, h]),
153 );
154 t.insert(format!("{lp}.mlp.fc11.bias"), (z(int_dim), vec![int_dim]));
155 t.insert(
156 format!("{lp}.mlp.fc12.weight"),
157 (z(int_dim * h), vec![int_dim, h]),
158 );
159 t.insert(format!("{lp}.mlp.fc12.bias"), (z(int_dim), vec![int_dim]));
160 t.insert(
161 format!("{lp}.mlp.fc2.weight"),
162 (z(h * int_dim), vec![h, int_dim]),
163 );
164 t.insert(format!("{lp}.mlp.fc2.bias"), (z(h), vec![h]));
165 t.insert(format!("{lp}.mlp.norm.weight"), (z(int_dim), vec![int_dim]));
166 t.insert(format!("{lp}.mlp.norm.bias"), (z(int_dim), vec![int_dim]));
167 t.insert("norm.weight".into(), (z(h), vec![h]));
168 t.insert("norm.bias".into(), (z(h), vec![h]));
169 WeightMap::from_tensors(t)
170 }
171
172 #[test]
173 fn vision_flow_builds() {
174 let cfg = tiny_cfg();
175 let mut wm = synth_weights(&cfg);
176 let built = NomicVisionFlow::new(&cfg, 1).build(&mut wm).unwrap();
177 assert_eq!(
178 *built.model.primary_shape(),
179 Shape::new(&[1, cfg.hidden_size], DType::F32)
180 );
181 }
182}