Skip to main content

rlx_qwen35/vision/
encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Qwen3.5 VLM vision encoder — CPU encode + multimodal prompt helpers.
17
18use super::config::MmProjConfig;
19use super::flow::build_qwen35_vision_built;
20use super::preprocess::{build_vision_positions, preprocess_rgb};
21use super::weights::MmProjWeights;
22use anyhow::{Context, Result};
23use rlx_core::flow_util::compile_built;
24use rlx_core::weight_loader::GgufLoader;
25use rlx_runtime::Device;
26use std::path::{Path, PathBuf};
27
28/// Vision tower forward output.
29#[derive(Debug, Clone)]
30pub struct VisionEncodeOutput {
31    pub embeddings: Vec<f32>,
32    pub grid_x: usize,
33    pub grid_y: usize,
34    pub n_tokens: usize,
35}
36
37/// CPU vision encoder wrapping a compiled mmproj graph.
38pub struct Qwen35VisionEncoder {
39    cfg: MmProjConfig,
40    weights: MmProjWeights,
41    params: std::collections::HashMap<String, Vec<f32>>,
42    graph_key: (usize, usize),
43    compiled: rlx_runtime::CompiledGraph,
44}
45
46impl Qwen35VisionEncoder {
47    /// Load mmproj GGUF from disk and compile for the given image size.
48    pub fn from_mmproj(path: impl AsRef<Path>, img_w: usize, img_h: usize) -> Result<Self> {
49        let path = path.as_ref();
50        let path_str = path.to_str().context("mmproj path utf8")?;
51        let mut loader = GgufLoader::from_file(path_str)?;
52        let cfg = MmProjConfig::from_gguf(loader.file())?;
53        let weights = MmProjWeights::from_loader(&cfg, &mut loader)?;
54        Self::from_parts(cfg, weights, img_w, img_h)
55    }
56
57    /// Build from already-loaded config + weights (tests).
58    pub fn from_parts(
59        cfg: MmProjConfig,
60        weights: MmProjWeights,
61        img_w: usize,
62        img_h: usize,
63    ) -> Result<Self> {
64        let built = build_qwen35_vision_built(&cfg, &weights, img_w, img_h)?;
65        let params = built.params().clone();
66        let compiled = compile_built(built, Device::Cpu)?;
67        Ok(Self {
68            graph_key: (img_w, img_h),
69            cfg,
70            weights,
71            params,
72            compiled,
73        })
74    }
75
76    pub fn config(&self) -> &MmProjConfig {
77        &self.cfg
78    }
79
80    /// Encode an RGB u8 buffer. Recompiles when smart-resize changes dimensions.
81    pub fn encode_rgb(&mut self, rgb: &[u8], w: usize, h: usize) -> Result<VisionEncodeOutput> {
82        let (nchw, tw, th) = preprocess_rgb(rgb, w, h, &self.cfg);
83        self.ensure_compiled(tw, th)?;
84        let (gx, gy) = self.cfg.output_grid(tw, th);
85        let n_tokens = gx * gy;
86        let proj = self.cfg.llm_hidden_size;
87
88        let _positions = build_vision_positions(tw, th, &self.cfg);
89
90        let outs = self.compiled.run(&[("image", &nchw)]);
91        let emb = outs
92            .into_iter()
93            .next()
94            .context("vision graph produced no outputs")?;
95
96        anyhow::ensure!(
97            emb.len() == n_tokens * proj,
98            "vision output len {} != n_tokens*proj {}*{}",
99            emb.len(),
100            n_tokens,
101            proj
102        );
103
104        Ok(VisionEncodeOutput {
105            embeddings: emb,
106            grid_x: gx,
107            grid_y: gy,
108            n_tokens,
109        })
110    }
111
112    fn ensure_compiled(&mut self, img_w: usize, img_h: usize) -> Result<()> {
113        if self.graph_key == (img_w, img_h) {
114            return Ok(());
115        }
116        let built = build_qwen35_vision_built(&self.cfg, &self.weights, img_w, img_h)?;
117        self.params = built.params().clone();
118        self.compiled = compile_built(built, Device::Cpu)?;
119        self.graph_key = (img_w, img_h);
120        Ok(())
121    }
122}
123
124/// Convenience: load encoder from path string.
125pub fn load_vision_encoder(
126    mmproj_path: &str,
127    img_w: usize,
128    img_h: usize,
129) -> Result<Qwen35VisionEncoder> {
130    Qwen35VisionEncoder::from_mmproj(PathBuf::from(mmproj_path), img_w, img_h)
131}
132
133#[cfg(feature = "qwen35-vlm")]
134pub fn encode_image_file(
135    encoder: &mut Qwen35VisionEncoder,
136    path: &str,
137) -> Result<VisionEncodeOutput> {
138    let (rgb, w, h) = super::preprocess::load_rgb_image(path)?;
139    encoder.encode_rgb(&rgb, w, h)
140}