Skip to main content

rlx_vision/
flow.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tier-0 NomicVision encoder flow — native [`ModelFlow`] ViT assembly.
17
18use anyhow::Result;
19use rlx_flow::{BuiltModel, CompileProfile, ModelFlow};
20use rlx_ir::{DType, Shape};
21
22use crate::vision::VisionPreprocessWeights;
23use rlx_core::config::NomicVisionConfig;
24use rlx_core::flow_util::WeightMapSource;
25use rlx_core::weight_map::WeightMap;
26
27#[derive(Debug, Clone)]
28pub struct NomicVisionFlow<'a> {
29    cfg: &'a NomicVisionConfig,
30    batch: usize,
31}
32
33impl<'a> NomicVisionFlow<'a> {
34    pub fn new(cfg: &'a NomicVisionConfig, batch: usize) -> Self {
35        Self { cfg, batch }
36    }
37
38    pub fn build(self, weights: &mut WeightMap) -> Result<NomicVisionBuilt> {
39        build_nomic_vision_built(self.cfg, weights, self.batch)
40    }
41}
42
43pub struct NomicVisionBuilt {
44    pub model: BuiltModel,
45    pub preprocess: VisionPreprocessWeights,
46}
47
48pub fn build_nomic_vision_built(
49    cfg: &NomicVisionConfig,
50    weights: &mut WeightMap,
51    batch: usize,
52) -> Result<NomicVisionBuilt> {
53    let preprocess = extract_vision_preprocess(weights)?;
54    let final_ln = resolve_final_norm_prefix(weights);
55
56    let h = cfg.hidden_size;
57    let nh = cfg.num_attention_heads;
58    let eps = cfg.layer_norm_eps() as f32;
59    let ps = cfg.patch_size;
60    let np = (cfg.img_size / ps) * (cfg.img_size / ps);
61    let seq = np + 1;
62    let f = DType::F32;
63
64    let model = ModelFlow::new("nomic_vision")
65        .with_profile(CompileProfile::encoder())
66        .input("hidden", Shape::new(&[batch, seq, h], f))
67        .attn_mask_ones(batch, seq)
68        .repeat_vision_layers(cfg.num_hidden_layers, h, nh, eps)
69        .layer_norm(
70            format!("{final_ln}.weight"),
71            format!("{final_ln}.bias"),
72            eps,
73        )
74        .cls_token_pool(batch, h)
75        .output("cls")
76        .build(&mut WeightMapSource(weights))?;
77
78    Ok(NomicVisionBuilt { model, preprocess })
79}
80
81fn extract_vision_preprocess(weights: &mut WeightMap) -> Result<VisionPreprocessWeights> {
82    let (proj_w_data, proj_w_shape) = weights.take_transposed("embeddings.proj.weight")?;
83    let (proj_b_data, _) = weights.take("embeddings.proj.bias")?;
84    let (cls_token_data, _) = weights.take("embeddings.cls_token")?;
85    let (pos_embed_data, _) = weights.take("embeddings.pos_embed")?;
86    Ok(VisionPreprocessWeights {
87        proj_w: proj_w_data,
88        proj_w_cols: proj_w_shape.last().copied().unwrap_or(0),
89        proj_b: proj_b_data,
90        cls_token: cls_token_data,
91        pos_embed: pos_embed_data,
92    })
93}
94
95fn resolve_final_norm_prefix(weights: &WeightMap) -> &'static str {
96    if weights.has("norm.weight") {
97        "norm"
98    } else if weights.has("selector.norm1.weight") {
99        "selector.norm1"
100    } else {
101        "encoder.norm"
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use std::collections::HashMap;
109
110    fn tiny_cfg() -> NomicVisionConfig {
111        NomicVisionConfig {
112            hidden_size: 16,
113            num_hidden_layers: 1,
114            num_attention_heads: 4,
115            n_inner: 32,
116            img_size: 32,
117            patch_size: 16,
118            layer_norm_epsilon: 1e-5,
119        }
120    }
121
122    fn synth_weights(cfg: &NomicVisionConfig) -> WeightMap {
123        let h = cfg.hidden_size;
124        let int_dim = cfg.intermediate_size();
125        let ps = cfg.patch_size;
126        let patch_dim = 3 * ps * ps;
127        let np = (cfg.img_size / ps) * (cfg.img_size / ps);
128        let seq = np + 1;
129        let mut t: HashMap<String, (Vec<f32>, Vec<usize>)> = HashMap::new();
130        let z = |n: usize| vec![0.0f32; n];
131        t.insert(
132            "embeddings.proj.weight".into(),
133            (z(patch_dim * h), vec![h, patch_dim]),
134        );
135        t.insert("embeddings.proj.bias".into(), (z(h), vec![h]));
136        t.insert("embeddings.cls_token".into(), (z(h), vec![1, 1, h]));
137        t.insert("embeddings.pos_embed".into(), (z(seq * h), vec![1, seq, h]));
138        let lp = "layers.0";
139        t.insert(format!("{lp}.norm1.weight"), (z(h), vec![h]));
140        t.insert(format!("{lp}.norm1.bias"), (z(h), vec![h]));
141        t.insert(format!("{lp}.norm2.weight"), (z(h), vec![h]));
142        t.insert(format!("{lp}.norm2.bias"), (z(h), vec![h]));
143        t.insert(
144            format!("{lp}.attn.Wqkv.weight"),
145            (z(3 * h * h), vec![3 * h, h]),
146        );
147        t.insert(format!("{lp}.attn.Wqkv.bias"), (z(3 * h), vec![3 * h]));
148        t.insert(format!("{lp}.attn.out_proj.weight"), (z(h * h), vec![h, h]));
149        t.insert(format!("{lp}.attn.out_proj.bias"), (z(h), vec![h]));
150        t.insert(
151            format!("{lp}.mlp.fc11.weight"),
152            (z(int_dim * h), vec![int_dim, h]),
153        );
154        t.insert(format!("{lp}.mlp.fc11.bias"), (z(int_dim), vec![int_dim]));
155        t.insert(
156            format!("{lp}.mlp.fc12.weight"),
157            (z(int_dim * h), vec![int_dim, h]),
158        );
159        t.insert(format!("{lp}.mlp.fc12.bias"), (z(int_dim), vec![int_dim]));
160        t.insert(
161            format!("{lp}.mlp.fc2.weight"),
162            (z(h * int_dim), vec![h, int_dim]),
163        );
164        t.insert(format!("{lp}.mlp.fc2.bias"), (z(h), vec![h]));
165        t.insert(format!("{lp}.mlp.norm.weight"), (z(int_dim), vec![int_dim]));
166        t.insert(format!("{lp}.mlp.norm.bias"), (z(int_dim), vec![int_dim]));
167        t.insert("norm.weight".into(), (z(h), vec![h]));
168        t.insert("norm.bias".into(), (z(h), vec![h]));
169        WeightMap::from_tensors(t)
170    }
171
172    #[test]
173    fn vision_flow_builds() {
174        let cfg = tiny_cfg();
175        let mut wm = synth_weights(&cfg);
176        let built = NomicVisionFlow::new(&cfg, 1).build(&mut wm).unwrap();
177        assert_eq!(
178            *built.model.primary_shape(),
179            Shape::new(&[1, cfg.hidden_size], DType::F32)
180        );
181    }
182}