#![allow(unsafe_code)]
use crate::{Document, Error, Extractor, Result};
use objc2::rc::{autoreleasepool, Retained};
use objc2::AnyThread;
use objc2_app_kit::NSImage;
use objc2_foundation::{NSArray, NSString, NSURL};
use objc2_vision::{
VNImageRequestHandler, VNRecognizeTextRequest, VNRecognizedTextObservation, VNRequest,
};
use std::path::Path;
#[derive(Default)]
pub struct VisionOcrExtractor;
impl VisionOcrExtractor {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Extractor for VisionOcrExtractor {
fn extensions(&self) -> &[&'static str] {
&[
"png", "jpg", "jpeg", "tiff", "tif", "bmp", "gif", "heic", "heif",
]
}
fn name(&self) -> &'static str {
"vision-macos"
}
fn extract(&self, path: &Path) -> Result<Document> {
autoreleasepool(|_| extract_with_vision(path))
}
}
fn extract_with_vision(path: &Path) -> Result<Document> {
let path_str = path
.to_str()
.ok_or_else(|| Error::ParseError(format!("path is not valid UTF-8: {}", path.display())))?;
let url = NSURL::fileURLWithPath(&NSString::from_str(path_str));
let nsimage = NSImage::initWithContentsOfURL(NSImage::alloc(), &url).ok_or_else(|| {
Error::ParseError(format!(
"could not load image (unsupported format or corrupt): {path_str}"
))
})?;
let cg_image =
unsafe { nsimage.CGImageForProposedRect_context_hints(std::ptr::null_mut(), None, None) }
.ok_or_else(|| {
Error::ParseError(format!("NSImage→CGImage conversion failed: {path_str}"))
})?;
let request = {
let req = VNRecognizeTextRequest::new();
req.setRecognitionLevel(objc2_vision::VNRequestTextRecognitionLevel::Accurate);
req.setUsesLanguageCorrection(true);
req
};
let handler = unsafe {
let options = objc2_foundation::NSDictionary::<NSString, objc2::runtime::AnyObject>::new();
VNImageRequestHandler::initWithCGImage_options(
VNImageRequestHandler::alloc(),
&cg_image,
&options,
)
};
let request_as_vnrequest: Retained<VNRequest> = request.clone().into_super().into_super();
let requests: Retained<NSArray<VNRequest>> =
NSArray::from_retained_slice(&[request_as_vnrequest]);
handler
.performRequests_error(&requests)
.map_err(|e| Error::ParseError(format!("Vision performRequests failed: {e:?}")))?;
let observations = request.results().unwrap_or_else(NSArray::new);
let mut markdown = String::new();
for obs in &observations {
let text_obs = obs.downcast_ref::<VNRecognizedTextObservation>();
let Some(text_obs) = text_obs else { continue };
let candidates = text_obs.topCandidates(1);
let Some(top) = candidates.iter().next() else {
continue;
};
let line: String = top.string().to_string();
if !line.trim().is_empty() {
if !markdown.is_empty() {
markdown.push('\n');
}
markdown.push_str(&line);
}
}
Ok(Document {
markdown,
title: None,
metadata: std::collections::HashMap::new(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extensions_cover_common_image_formats() {
let ext = VisionOcrExtractor.extensions();
for required in ["png", "jpg", "jpeg", "tiff", "heic"] {
assert!(
ext.contains(&required),
"expected vision-macos to handle .{required}, got {ext:?}"
);
}
}
#[test]
fn name_identifies_backend() {
assert_eq!(VisionOcrExtractor.name(), "vision-macos");
}
#[test]
#[ignore = "requires a real image file with text in tests/fixtures/"]
fn extracts_text_from_a_real_image() {
let extractor = VisionOcrExtractor::new();
let doc = extractor
.extract(std::path::Path::new("tests/fixtures/hello.png"))
.expect("extraction failed");
assert!(
!doc.markdown.is_empty(),
"expected non-empty markdown from hello.png"
);
assert!(
doc.markdown.to_lowercase().contains("hello"),
"expected 'hello' in OCR output: {:?}",
doc.markdown
);
}
#[test]
fn missing_file_returns_typed_error() {
let result =
VisionOcrExtractor.extract(std::path::Path::new("/nonexistent-image-here.png"));
assert!(matches!(result, Err(Error::ParseError(_))));
}
}