1use std::path::Path;
19
20use anyhow::Result;
21use rlx_flow::CompileProfile;
22use rlx_runtime::CompiledGraph;
23use rlx_runtime::Device;
24
25use rlx_core::config::NomicVisionConfig;
26use rlx_core::weight_map::WeightMap;
27use rlx_vision::vision::{VisionPreprocessWeights, build_vision_graph_sized};
28
29pub fn assemble_vision_hidden(
31 pixel_values: &[f32],
32 batch: usize,
33 img: usize,
34 ps: usize,
35 h: usize,
36 preprocess: &VisionPreprocessWeights,
37) -> Vec<f32> {
38 let np = (img / ps) * (img / ps);
39 let seq = np + 1;
40 let patch_dim = 3 * ps * ps;
41 let patches_per_row = img / ps;
42 let pw = preprocess;
43
44 let mut patches = vec![0f32; batch * np * patch_dim];
45 for bi in 0..batch {
46 for py in 0..patches_per_row {
47 for px in 0..patches_per_row {
48 let pi = bi * np + py * patches_per_row + px;
49 let dst = &mut patches[pi * patch_dim..(pi + 1) * patch_dim];
50 let mut di = 0;
51 for c in 0..3usize {
52 for dy in 0..ps {
53 for dx in 0..ps {
54 let y = py * ps + dy;
55 let x = px * ps + dx;
56 dst[di] =
57 pixel_values[bi * 3 * img * img + c * img * img + y * img + x];
58 di += 1;
59 }
60 }
61 }
62 }
63 }
64 }
65
66 let m = batch * np;
67 let k = patch_dim;
68 let n = h;
69 let mut projected = vec![0f32; m * n];
70 rlx_cpu::blas::sgemm_bias(&patches, &pw.proj_w, &pw.proj_b, &mut projected, m, k, n);
71
72 let mut hidden = vec![0f32; batch * seq * h];
73 let cls = &pw.cls_token[..h.min(pw.cls_token.len())];
74 let pos = &pw.pos_embed;
75 for bi in 0..batch {
76 let base = bi * seq * h;
77 hidden[base..base + h].copy_from_slice(cls);
78 let proj_start = bi * np * h;
79 hidden[base + h..base + (np + 1) * h]
80 .copy_from_slice(&projected[proj_start..proj_start + np * h]);
81 let pos_len = (seq * h).min(pos.len());
82 for i in 0..pos_len {
83 hidden[base + i] += pos[i];
84 }
85 }
86 hidden
87}
88
89pub struct RlxVisionModel {
91 compiled: CompiledGraph,
92 config: NomicVisionConfig,
93 preprocess: VisionPreprocessWeights,
94 #[allow(dead_code)]
95 compiled_batch: usize,
96}
97
98impl RlxVisionModel {
99 pub fn load_sized(config_path: &Path, weights_path: &str, batch: usize) -> Result<Self> {
100 Self::load_sized_on(config_path, weights_path, batch, Device::Cpu)
101 }
102
103 pub fn load_sized_on(
104 config_path: &Path,
105 weights_path: &str,
106 batch: usize,
107 device: Device,
108 ) -> Result<Self> {
109 let config = NomicVisionConfig::from_file(config_path)?;
110 let mut wm = WeightMap::from_file(weights_path)?;
111 let (graph, params, preprocess) = build_vision_graph_sized(&config, &mut wm, batch)?;
112 let mut compiled = rlx_core::flow_bridge::compile_graph_with_profile(
113 device,
114 graph,
115 &CompileProfile::encoder(),
116 )?;
117 for (name, data) in ¶ms {
118 compiled.set_param(name, data);
119 }
120 Ok(Self {
121 compiled,
122 config,
123 preprocess,
124 compiled_batch: batch,
125 })
126 }
127
128 pub fn forward(&mut self, pixel_values: &[f32], batch: usize) -> Vec<f32> {
130 let hidden = assemble_vision_hidden(
131 pixel_values,
132 batch,
133 self.config.img_size,
134 self.config.patch_size,
135 self.config.hidden_size,
136 &self.preprocess,
137 );
138 self.compiled
139 .run(&[("hidden", &hidden)])
140 .into_iter()
141 .next()
142 .unwrap_or_default()
143 }
144
145 pub fn forward_all(&mut self, pixel_values: &[f32], batch: usize) -> Vec<Vec<f32>> {
146 let hidden = assemble_vision_hidden(
147 pixel_values,
148 batch,
149 self.config.img_size,
150 self.config.patch_size,
151 self.config.hidden_size,
152 &self.preprocess,
153 );
154 self.compiled.run(&[("hidden", &hidden)])
155 }
156
157 pub fn forward_slots(&mut self, hidden: &[f32]) -> (*const f32, usize) {
158 let slots = self.compiled.run_slots(&[hidden]);
159 if slots.is_empty() {
160 return (std::ptr::null(), 0);
161 }
162 let (off, len) = slots[0];
163 unsafe {
164 let ptr = self.compiled.arena_ptr().add(off) as *const f32;
165 (ptr, len)
166 }
167 }
168
169 pub fn hidden_size(&self) -> usize {
170 self.config.hidden_size
171 }
172
173 pub fn img_size(&self) -> usize {
174 self.config.img_size
175 }
176
177 pub fn patch_size(&self) -> usize {
178 self.config.patch_size
179 }
180
181 pub fn num_patches(&self) -> usize {
182 (self.config.img_size / self.config.patch_size).pow(2)
183 }
184
185 pub fn preprocess_weights(&self) -> &VisionPreprocessWeights {
186 &self.preprocess
187 }
188}