Skip to main content

rlx_embed/
vision.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! RLX-compiled NomicVision encoder for image embeddings.
17
18use std::path::Path;
19
20use anyhow::Result;
21use rlx_flow::CompileProfile;
22use rlx_runtime::CompiledGraph;
23use rlx_runtime::Device;
24
25use rlx_core::config::NomicVisionConfig;
26use rlx_core::weight_map::WeightMap;
27use rlx_vision::vision::{VisionPreprocessWeights, build_vision_graph_sized};
28
29/// Assemble encoder input `[batch, seq, hidden]` from NCHW pixels + preprocess weights.
30pub fn assemble_vision_hidden(
31    pixel_values: &[f32],
32    batch: usize,
33    img: usize,
34    ps: usize,
35    h: usize,
36    preprocess: &VisionPreprocessWeights,
37) -> Vec<f32> {
38    let np = (img / ps) * (img / ps);
39    let seq = np + 1;
40    let patch_dim = 3 * ps * ps;
41    let patches_per_row = img / ps;
42    let pw = preprocess;
43
44    let mut patches = vec![0f32; batch * np * patch_dim];
45    for bi in 0..batch {
46        for py in 0..patches_per_row {
47            for px in 0..patches_per_row {
48                let pi = bi * np + py * patches_per_row + px;
49                let dst = &mut patches[pi * patch_dim..(pi + 1) * patch_dim];
50                let mut di = 0;
51                for c in 0..3usize {
52                    for dy in 0..ps {
53                        for dx in 0..ps {
54                            let y = py * ps + dy;
55                            let x = px * ps + dx;
56                            dst[di] =
57                                pixel_values[bi * 3 * img * img + c * img * img + y * img + x];
58                            di += 1;
59                        }
60                    }
61                }
62            }
63        }
64    }
65
66    let m = batch * np;
67    let k = patch_dim;
68    let n = h;
69    let mut projected = vec![0f32; m * n];
70    rlx_cpu::blas::sgemm_bias(&patches, &pw.proj_w, &pw.proj_b, &mut projected, m, k, n);
71
72    let mut hidden = vec![0f32; batch * seq * h];
73    let cls = &pw.cls_token[..h.min(pw.cls_token.len())];
74    let pos = &pw.pos_embed;
75    for bi in 0..batch {
76        let base = bi * seq * h;
77        hidden[base..base + h].copy_from_slice(cls);
78        let proj_start = bi * np * h;
79        hidden[base + h..base + (np + 1) * h]
80            .copy_from_slice(&projected[proj_start..proj_start + np * h]);
81        let pos_len = (seq * h).min(pos.len());
82        for i in 0..pos_len {
83            hidden[base + i] += pos[i];
84        }
85    }
86    hidden
87}
88
89/// RLX-compiled NomicVision encoder (patch preprocess host-side, trunk on RLX).
90pub struct RlxVisionModel {
91    compiled: CompiledGraph,
92    config: NomicVisionConfig,
93    preprocess: VisionPreprocessWeights,
94    #[allow(dead_code)]
95    compiled_batch: usize,
96}
97
98impl RlxVisionModel {
99    pub fn load_sized(config_path: &Path, weights_path: &str, batch: usize) -> Result<Self> {
100        Self::load_sized_on(config_path, weights_path, batch, Device::Cpu)
101    }
102
103    pub fn load_sized_on(
104        config_path: &Path,
105        weights_path: &str,
106        batch: usize,
107        device: Device,
108    ) -> Result<Self> {
109        let config = NomicVisionConfig::from_file(config_path)?;
110        let mut wm = WeightMap::from_file(weights_path)?;
111        let (graph, params, preprocess) = build_vision_graph_sized(&config, &mut wm, batch)?;
112        let mut compiled = rlx_core::flow_bridge::compile_graph_with_profile(
113            device,
114            graph,
115            &CompileProfile::encoder(),
116        )?;
117        for (name, data) in &params {
118            compiled.set_param(name, data);
119        }
120        Ok(Self {
121            compiled,
122            config,
123            preprocess,
124            compiled_batch: batch,
125        })
126    }
127
128    /// Forward: `pixel_values` `[batch, 3, img, img]` row-major NCHW → CLS `[batch, hidden]`.
129    pub fn forward(&mut self, pixel_values: &[f32], batch: usize) -> Vec<f32> {
130        let hidden = assemble_vision_hidden(
131            pixel_values,
132            batch,
133            self.config.img_size,
134            self.config.patch_size,
135            self.config.hidden_size,
136            &self.preprocess,
137        );
138        self.compiled
139            .run(&[("hidden", &hidden)])
140            .into_iter()
141            .next()
142            .unwrap_or_default()
143    }
144
145    pub fn forward_all(&mut self, pixel_values: &[f32], batch: usize) -> Vec<Vec<f32>> {
146        let hidden = assemble_vision_hidden(
147            pixel_values,
148            batch,
149            self.config.img_size,
150            self.config.patch_size,
151            self.config.hidden_size,
152            &self.preprocess,
153        );
154        self.compiled.run(&[("hidden", &hidden)])
155    }
156
157    pub fn forward_slots(&mut self, hidden: &[f32]) -> (*const f32, usize) {
158        let slots = self.compiled.run_slots(&[hidden]);
159        if slots.is_empty() {
160            return (std::ptr::null(), 0);
161        }
162        let (off, len) = slots[0];
163        unsafe {
164            let ptr = self.compiled.arena_ptr().add(off) as *const f32;
165            (ptr, len)
166        }
167    }
168
169    pub fn hidden_size(&self) -> usize {
170        self.config.hidden_size
171    }
172
173    pub fn img_size(&self) -> usize {
174        self.config.img_size
175    }
176
177    pub fn patch_size(&self) -> usize {
178        self.config.patch_size
179    }
180
181    pub fn num_patches(&self) -> usize {
182        (self.config.img_size / self.config.patch_size).pow(2)
183    }
184
185    pub fn preprocess_weights(&self) -> &VisionPreprocessWeights {
186        &self.preprocess
187    }
188}