Skip to main content

rlx_vision/
vision.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NomicVision graph builder — delegates to [`crate::flow::NomicVisionFlow`].
17
18use anyhow::Result;
19use rlx_core::config::NomicVisionConfig;
20use rlx_core::weight_map::WeightMap;
21use rlx_ir::Graph;
22use std::collections::HashMap;
23
24/// Build a NomicVision encoder IR graph via native [`ModelFlow`].
25pub fn build_vision_graph_sized(
26    cfg: &NomicVisionConfig,
27    weights: &mut WeightMap,
28    batch: usize,
29) -> Result<(Graph, HashMap<String, Vec<f32>>, VisionPreprocessWeights)> {
30    let built = crate::flow::build_nomic_vision_built(cfg, weights, batch)?;
31    let (graph, params) = rlx_core::flow_util::graph_from_built(built.model)?;
32    Ok((graph, params, built.preprocess))
33}
34
35/// Preprocessing weights extracted from safetensors for the caller to
36/// assemble the "hidden" input before graph execution.
37pub struct VisionPreprocessWeights {
38    /// Patch projection weight [patch_dim, H] (pre-transposed for sgemm)
39    pub proj_w: Vec<f32>,
40    /// Number of columns in proj_w (= hidden_size)
41    pub proj_w_cols: usize,
42    /// Patch projection bias \[H\]
43    pub proj_b: Vec<f32>,
44    /// CLS token \[H\] (or [1, 1, H] flattened)
45    pub cls_token: Vec<f32>,
46    /// Position embeddings [1+np, H] (or [1, 1+np, H] flattened)
47    pub pos_embed: Vec<f32>,
48}