use std::collections::HashMap;
use serde_json::{json, Value};
use crate::{
registry::{ModelMetadata, ModelProcessorSpec, ModelRegistryError, RegistryResult},
types::{FieldLayout, Modality, PromptReplacement, TokenId},
vision::image_processor::PreprocessedImages,
};
pub(super) struct QwenVLVisionSpec;
impl QwenVLVisionSpec {
fn pad_token_id(metadata: &ModelMetadata) -> RegistryResult<TokenId> {
metadata
.config_u32(&["vision_token_id"])
.map(|v| v as TokenId)
.ok_or_else(|| ModelRegistryError::MissingConfigField {
field: "vision_token_id".to_string(),
})
}
fn start_token_id(metadata: &ModelMetadata) -> RegistryResult<TokenId> {
metadata
.config_u32(&["vision_start_token_id"])
.map(|v| v as TokenId)
.ok_or_else(|| ModelRegistryError::MissingConfigField {
field: "vision_start_token_id".to_string(),
})
}
}
impl ModelProcessorSpec for QwenVLVisionSpec {
fn name(&self) -> &'static str {
"qwen_vl"
}
fn matches(&self, metadata: &ModelMetadata) -> bool {
metadata.model_id.to_ascii_lowercase().contains("qwen")
&& metadata.model_id.to_ascii_lowercase().contains("vl")
}
fn placeholder_token(&self, _metadata: &ModelMetadata) -> RegistryResult<String> {
Ok("<image>".to_string())
}
fn placeholder_token_id(&self, metadata: &ModelMetadata) -> RegistryResult<TokenId> {
Self::pad_token_id(metadata)
}
fn modality_limits(
&self,
_metadata: &ModelMetadata,
) -> RegistryResult<HashMap<Modality, usize>> {
Ok(HashMap::from([(Modality::Image, 10)]))
}
fn processor_kwargs(&self, _metadata: &ModelMetadata) -> RegistryResult<Value> {
Ok(json!({}))
}
fn prompt_replacements(
&self,
metadata: &ModelMetadata,
preprocessed: &PreprocessedImages,
) -> RegistryResult<Vec<PromptReplacement>> {
let start_token_id = Self::start_token_id(metadata)?;
let pad_token_id = Self::pad_token_id(metadata)?;
let placeholder_token = self.placeholder_token(metadata)?;
Ok(preprocessed
.num_img_tokens
.iter()
.map(|&num_tokens| {
let mut tokens = Vec::with_capacity(num_tokens + 1);
tokens.push(start_token_id);
tokens.extend(std::iter::repeat_n(pad_token_id, num_tokens));
PromptReplacement::sequence(Modality::Image, &placeholder_token, tokens)
})
.collect())
}
fn field_layouts(&self) -> HashMap<String, FieldLayout> {
HashMap::from([
(
"pixel_values".to_string(),
FieldLayout::flat("patches_per_image"),
),
("image_grid_thw".to_string(), FieldLayout::Batched),
("patches_per_image".to_string(), FieldLayout::Batched),
])
}
fn keep_on_cpu_keys(&self) -> Vec<String> {
vec!["image_grid_thw".to_string()]
}
}
#[cfg(test)]
mod tests {
use serde_json::json;
use crate::{
registry::{test_helpers::*, ModelMetadata, ModelRegistry},
types::ImageSize,
};
#[test]
fn qwen_vision_uses_config_token_ids() {
let tokenizer = TestTokenizer::new(&[("<image>", 999)]);
let config = json!({
"model_type": "qwen2_vl",
"vision_start_token_id": 151652,
"vision_token_id": 151654,
"image_token_id": 151655,
"vision_config": {"patch_size": 14}
});
let metadata = ModelMetadata {
model_id: "Qwen2-VL-7B",
tokenizer: &tokenizer,
config: &config,
};
let registry = ModelRegistry::new();
let spec = registry.lookup(&metadata).expect("qwen spec");
let replacements = spec
.prompt_replacements(
&metadata,
&test_preprocessed_with_tokens(&[ImageSize::new(448, 448)], &[256]),
)
.unwrap();
assert_eq!(replacements[0].tokens.len(), 257);
assert_eq!(replacements[0].tokens[0], 151652);
assert_eq!(replacements[0].tokens[1], 151654);
}
}