rlx_qwen35/vision/
encoder.rs1use super::config::MmProjConfig;
19use super::flow::build_qwen35_vision_built;
20use super::preprocess::{build_vision_positions, preprocess_rgb};
21use super::weights::MmProjWeights;
22use anyhow::{Context, Result};
23use rlx_core::flow_util::compile_built;
24use rlx_core::weight_loader::GgufLoader;
25use rlx_runtime::Device;
26use std::path::{Path, PathBuf};
27
28#[derive(Debug, Clone)]
30pub struct VisionEncodeOutput {
31 pub embeddings: Vec<f32>,
32 pub grid_x: usize,
33 pub grid_y: usize,
34 pub n_tokens: usize,
35}
36
37pub struct Qwen35VisionEncoder {
39 cfg: MmProjConfig,
40 weights: MmProjWeights,
41 params: std::collections::HashMap<String, Vec<f32>>,
42 graph_key: (usize, usize),
43 compiled: rlx_runtime::CompiledGraph,
44}
45
46impl Qwen35VisionEncoder {
47 pub fn from_mmproj(path: impl AsRef<Path>, img_w: usize, img_h: usize) -> Result<Self> {
49 let path = path.as_ref();
50 let path_str = path.to_str().context("mmproj path utf8")?;
51 let mut loader = GgufLoader::from_file(path_str)?;
52 let cfg = MmProjConfig::from_gguf(loader.file())?;
53 let weights = MmProjWeights::from_loader(&cfg, &mut loader)?;
54 Self::from_parts(cfg, weights, img_w, img_h)
55 }
56
57 pub fn from_parts(
59 cfg: MmProjConfig,
60 weights: MmProjWeights,
61 img_w: usize,
62 img_h: usize,
63 ) -> Result<Self> {
64 let built = build_qwen35_vision_built(&cfg, &weights, img_w, img_h)?;
65 let params = built.params().clone();
66 let compiled = compile_built(built, Device::Cpu)?;
67 Ok(Self {
68 graph_key: (img_w, img_h),
69 cfg,
70 weights,
71 params,
72 compiled,
73 })
74 }
75
76 pub fn config(&self) -> &MmProjConfig {
77 &self.cfg
78 }
79
80 pub fn encode_rgb(&mut self, rgb: &[u8], w: usize, h: usize) -> Result<VisionEncodeOutput> {
82 let (nchw, tw, th) = preprocess_rgb(rgb, w, h, &self.cfg);
83 self.ensure_compiled(tw, th)?;
84 let (gx, gy) = self.cfg.output_grid(tw, th);
85 let n_tokens = gx * gy;
86 let proj = self.cfg.llm_hidden_size;
87
88 let _positions = build_vision_positions(tw, th, &self.cfg);
89
90 let outs = self.compiled.run(&[("image", &nchw)]);
91 let emb = outs
92 .into_iter()
93 .next()
94 .context("vision graph produced no outputs")?;
95
96 anyhow::ensure!(
97 emb.len() == n_tokens * proj,
98 "vision output len {} != n_tokens*proj {}*{}",
99 emb.len(),
100 n_tokens,
101 proj
102 );
103
104 Ok(VisionEncodeOutput {
105 embeddings: emb,
106 grid_x: gx,
107 grid_y: gy,
108 n_tokens,
109 })
110 }
111
112 fn ensure_compiled(&mut self, img_w: usize, img_h: usize) -> Result<()> {
113 if self.graph_key == (img_w, img_h) {
114 return Ok(());
115 }
116 let built = build_qwen35_vision_built(&self.cfg, &self.weights, img_w, img_h)?;
117 self.params = built.params().clone();
118 self.compiled = compile_built(built, Device::Cpu)?;
119 self.graph_key = (img_w, img_h);
120 Ok(())
121 }
122}
123
124pub fn load_vision_encoder(
126 mmproj_path: &str,
127 img_w: usize,
128 img_h: usize,
129) -> Result<Qwen35VisionEncoder> {
130 Qwen35VisionEncoder::from_mmproj(PathBuf::from(mmproj_path), img_w, img_h)
131}
132
133#[cfg(feature = "qwen35-vlm")]
134pub fn encode_image_file(
135 encoder: &mut Qwen35VisionEncoder,
136 path: &str,
137) -> Result<VisionEncodeOutput> {
138 let (rgb, w, h) = super::preprocess::load_rgb_image(path)?;
139 encoder.encode_rgb(&rgb, w, h)
140}