#![allow(unsafe_code)]
use crate::{Document, Error, Extractor, Result};
use std::path::Path;
use std::time::Duration;
use windows::core::{RuntimeType, HSTRING};
use windows::Globalization::Language;
use windows::Graphics::Imaging::BitmapDecoder;
use windows::Media::Ocr::OcrEngine;
use windows::Storage::{FileAccessMode, StorageFile};
use windows_future::{AsyncStatus, IAsyncOperation};
#[derive(Default)]
pub struct WindowsOcrExtractor;
impl WindowsOcrExtractor {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Extractor for WindowsOcrExtractor {
fn extensions(&self) -> &[&'static str] {
&["png", "jpg", "jpeg", "tiff", "tif", "bmp", "gif"]
}
fn name(&self) -> &'static str {
"ocr-windows"
}
fn extract(&self, path: &Path) -> Result<Document> {
ensure_mta_initialized()?;
extract_with_windows_ocr(path)
}
}
fn block_on<T>(op: IAsyncOperation<T>) -> windows::core::Result<T>
where
T: RuntimeType + 'static,
{
while op.Status()? == AsyncStatus::Started {
std::thread::sleep(Duration::from_millis(1));
}
op.GetResults()
}
fn ensure_mta_initialized() -> Result<()> {
use windows::Win32::System::WinRT::{RoInitialize, RO_INIT_MULTITHREADED};
const S_FALSE: i32 = 1;
const RPC_E_CHANGED_MODE: i32 = 0x8001_0106u32 as i32;
let result = unsafe { RoInitialize(RO_INIT_MULTITHREADED) };
match result {
Ok(()) => Ok(()),
Err(e) if e.code().0 == S_FALSE => Ok(()),
Err(e) if e.code().0 == RPC_E_CHANGED_MODE => Err(Error::ParseError(
"Windows OCR needs an MTA thread; the calling thread is in STA mode \
(typically a UI/main thread). Dispatch to a worker thread \
(e.g. tokio::task::spawn_blocking, std::thread::spawn, or Tauri's \
tauri::async_runtime::spawn_blocking) and call extract() from there."
.into(),
)),
Err(e) => Err(Error::ParseError(format!(
"RoInitialize(MTA) failed: {e:?}"
))),
}
}
fn extract_with_windows_ocr(path: &Path) -> Result<Document> {
let absolute_path = path.canonicalize().map_err(|e| {
Error::ParseError(format!("could not canonicalize {}: {e}", path.display()))
})?;
let absolute_str = absolute_path.to_str().ok_or_else(|| {
Error::ParseError(format!(
"canonical path is not valid UTF-8: {}",
absolute_path.display()
))
})?;
let path_h = HSTRING::from(absolute_str);
let file_op = StorageFile::GetFileFromPathAsync(&path_h)
.map_err(|e| Error::ParseError(format!("GetFileFromPathAsync failed: {e:?}")))?;
let file = block_on(file_op)
.map_err(|e| Error::ParseError(format!("StorageFile open await failed: {e:?}")))?;
let stream_op = file
.OpenAsync(FileAccessMode::Read)
.map_err(|e| Error::ParseError(format!("StorageFile::OpenAsync failed: {e:?}")))?;
let stream = block_on(stream_op)
.map_err(|e| Error::ParseError(format!("stream open await failed: {e:?}")))?;
let decoder_op = BitmapDecoder::CreateAsync(&stream)
.map_err(|e| Error::ParseError(format!("BitmapDecoder::CreateAsync failed: {e:?}")))?;
let decoder = block_on(decoder_op)
.map_err(|e| Error::ParseError(format!("BitmapDecoder await failed: {e:?}")))?;
let bitmap_op = decoder
.GetSoftwareBitmapAsync()
.map_err(|e| Error::ParseError(format!("GetSoftwareBitmapAsync failed: {e:?}")))?;
let bitmap = block_on(bitmap_op)
.map_err(|e| Error::ParseError(format!("SoftwareBitmap await failed: {e:?}")))?;
let engine = match OcrEngine::TryCreateFromUserProfileLanguages() {
Ok(e) => e,
Err(_) => {
let en = Language::CreateLanguage(&HSTRING::from("en-US")).map_err(|e| {
Error::ParseError(format!("Language::CreateLanguage(en-US) failed: {e:?}"))
})?;
OcrEngine::TryCreateFromLanguage(&en).map_err(|e| {
Error::ParseError(format!(
"Windows OCR engine init failed — no installed language pack \
supports OCR. Install one via Settings → Time & Language → \
Language → Add a language → Optional features → OCR. \
Underlying error: {e:?}"
))
})?
}
};
let max_dim_u = OcrEngine::MaxImageDimension()
.map_err(|e| Error::ParseError(format!("MaxImageDimension query failed: {e:?}")))?;
let max_dim = i32::try_from(max_dim_u).unwrap_or(i32::MAX);
let w = bitmap
.PixelWidth()
.map_err(|e| Error::ParseError(format!("SoftwareBitmap::PixelWidth failed: {e:?}")))?;
let h = bitmap
.PixelHeight()
.map_err(|e| Error::ParseError(format!("SoftwareBitmap::PixelHeight failed: {e:?}")))?;
if w > max_dim || h > max_dim {
return Err(Error::ParseError(format!(
"image is {w}x{h}, exceeds Windows OCR max dimension of {max_dim}px. \
Downscale before passing in. (Auto-downscale is planned for a future release.)"
)));
}
let result_op = engine
.RecognizeAsync(&bitmap)
.map_err(|e| Error::ParseError(format!("OcrEngine::RecognizeAsync failed: {e:?}")))?;
let result = block_on(result_op)
.map_err(|e| Error::ParseError(format!("OCR result await failed: {e:?}")))?;
let lines = result
.Lines()
.map_err(|e| Error::ParseError(format!("OcrResult::Lines failed: {e:?}")))?;
let mut markdown = String::new();
for line in lines {
let text = line.Text().map(|h| h.to_string()).unwrap_or_default();
if !text.trim().is_empty() {
if !markdown.is_empty() {
markdown.push('\n');
}
markdown.push_str(&text);
}
}
Ok(Document {
markdown,
title: None,
metadata: std::collections::HashMap::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extensions_cover_common_image_formats() {
let ext = WindowsOcrExtractor.extensions();
for required in ["png", "jpg", "jpeg", "tiff", "bmp", "gif"] {
assert!(
ext.contains(&required),
"expected ocr-windows to handle .{required}, got {ext:?}"
);
}
}
#[test]
fn name_identifies_backend() {
assert_eq!(WindowsOcrExtractor.name(), "ocr-windows");
}
#[test]
#[ignore = "requires a real image file with text in tests/fixtures/ on a Windows host"]
fn extracts_text_from_a_real_image() {
let extractor = WindowsOcrExtractor::new();
let doc = extractor
.extract(std::path::Path::new("tests/fixtures/hello.png"))
.expect("extraction failed");
assert!(
!doc.markdown.is_empty(),
"expected non-empty markdown from hello.png"
);
assert!(
doc.markdown.to_lowercase().contains("hello"),
"expected 'hello' in OCR output: {:?}",
doc.markdown
);
}
#[test]
fn missing_file_returns_typed_error() {
let result =
WindowsOcrExtractor.extract(std::path::Path::new("C:\\nonexistent-image-here.png"));
assert!(matches!(result, Err(Error::ParseError(_))));
}
}