use std::borrow::Cow;
use async_trait::async_trait;
use base64::Engine;
use liter_llm::{ChatCompletionRequest, ContentPart, ImageUrl, LlmClient, Message, UserContent, UserMessage};
use crate::core::config::LlmConfig;
use crate::plugins::{OcrBackend, OcrBackendType, Plugin};
pub struct VlmOcrBackend;
impl Plugin for VlmOcrBackend {
fn name(&self) -> &str {
"vlm"
}
fn version(&self) -> String {
env!("CARGO_PKG_VERSION").to_string()
}
fn initialize(&self) -> crate::Result<()> {
Ok(())
}
fn shutdown(&self) -> crate::Result<()> {
Ok(())
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl OcrBackend for VlmOcrBackend {
async fn process_image(
&self,
image_bytes: &[u8],
config: &crate::OcrConfig,
) -> crate::Result<crate::ExtractionResult> {
let vlm_config = config
.vlm_config
.as_ref()
.ok_or_else(|| crate::KreuzbergError::validation("VLM OCR requires vlm_config to be set"))?;
let mime = infer::get(image_bytes).map(|t| t.mime_type()).unwrap_or("image/png");
let (text, usage) = vlm_ocr(
image_bytes,
mime,
&config.language,
vlm_config,
config.vlm_prompt.as_deref(),
)
.await?;
Ok(crate::ExtractionResult {
content: text,
mime_type: Cow::Borrowed("text/plain"),
llm_usage: usage.map(|u| vec![u]),
..Default::default()
})
}
fn supports_language(&self, _lang: &str) -> bool {
true
}
fn backend_type(&self) -> OcrBackendType {
OcrBackendType::Custom
}
}
pub async fn vlm_ocr(
image_bytes: &[u8],
image_mime_type: &str,
language: &str,
config: &LlmConfig,
vlm_prompt: Option<&str>,
) -> crate::Result<(String, Option<crate::types::LlmUsage>)> {
let client = super::client::create_client(config)?;
let b64 = base64::engine::general_purpose::STANDARD.encode(image_bytes);
let data_url = format!("data:{image_mime_type};base64,{b64}");
let template = vlm_prompt.unwrap_or(super::prompts::VLM_OCR_TEMPLATE);
let ctx = minijinja::context! { language => language };
let prompt = super::prompts::render_template(template, &ctx)?;
let message = Message::User(UserMessage {
content: UserContent::Parts(vec![
ContentPart::Text { text: prompt },
ContentPart::ImageUrl {
image_url: ImageUrl {
url: data_url,
detail: None,
},
},
]),
name: None,
});
let mut request = ChatCompletionRequest::default();
request.model = config.model.clone();
request.messages = vec![message];
request.temperature = config.temperature;
request.max_tokens = config.max_tokens;
let response = client.chat(request).await.map_err(|e| {
crate::KreuzbergError::ocr(format!(
"VLM OCR request failed: model={}, language={}, image_size={}KB: {e}",
config.model,
language,
image_bytes.len() / 1024
))
})?;
let usage = super::usage::extract_usage_from_chat(&response, "vlm_ocr");
let text = response
.choices
.first()
.and_then(|choice| choice.message.content.as_deref())
.ok_or_else(|| crate::KreuzbergError::ocr(format!("VLM OCR returned no content (model={})", config.model)))?
.to_string();
Ok((text, usage))
}
#[cfg(test)]
mod tests {
fn render_ocr_prompt(language: &str) -> String {
let ctx = minijinja::context! { language => language };
super::super::prompts::render_template(super::super::prompts::VLM_OCR_TEMPLATE, &ctx).unwrap()
}
#[test]
fn test_vlm_ocr_prompt_non_english_includes_language() {
let prompt = render_ocr_prompt("deu");
assert!(prompt.contains("language: deu"));
}
#[test]
fn test_vlm_ocr_prompt_english_no_language_hint() {
let prompt = render_ocr_prompt("eng");
assert!(!prompt.contains("language:"));
}
#[test]
fn test_vlm_ocr_prompt_en_no_language_hint() {
let prompt = render_ocr_prompt("en");
assert!(!prompt.contains("language:"));
}
#[test]
fn test_vlm_prompt_custom_template_is_used_issue_760() {
let custom_prompt = "Extract all text from this document image. \
Preserve formatting and use latex for mathematical formulas.";
let ctx = minijinja::context! { language => "eng" };
let prompt = super::super::prompts::render_template(custom_prompt, &ctx).unwrap();
assert!(prompt.contains("latex"), "custom prompt must be used; got: {prompt}");
assert!(
prompt.contains("Preserve formatting"),
"custom prompt must be used; got: {prompt}"
);
assert!(
!prompt.contains("Extract all visible text"),
"default template must NOT be used when custom prompt is set; got: {prompt}"
);
}
#[test]
fn test_vlm_prompt_none_falls_back_to_default() {
let ctx = minijinja::context! { language => "eng" };
let prompt = super::super::prompts::render_template(super::super::prompts::VLM_OCR_TEMPLATE, &ctx).unwrap();
assert!(
prompt.contains("Extract all visible text"),
"default template must be used when vlm_prompt is None; got: {prompt}"
);
}
}