use objc2::AnyThread;
use objc2::rc::Retained;
use objc2::runtime::AnyObject;
use objc2_foundation::{NSArray, NSData, NSDictionary, NSString, NSUInteger};
use objc2_vision::{VNImageRequestHandler, VNRecognizeTextRequest, VNRequestTextRecognitionLevel};
use async_trait::async_trait;
use super::{OcrEngine, OcrError, OcrRegion, OcrResult};
const SUPPORTED_LANGUAGES: &[&str] = &[
"en", "fr", "it", "de", "es", "pt", "zh-Hans", "zh-Hant", "ja", "ko", "ru", "uk", "th", "vi",
"ar",
];
pub struct AppleVisionEngine;
impl AppleVisionEngine {
pub fn new() -> Self {
Self
}
}
impl Default for AppleVisionEngine {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl OcrEngine for AppleVisionEngine {
fn name(&self) -> &'static str {
"apple_vision"
}
fn supported_languages(&self) -> &'static [&'static str] {
SUPPORTED_LANGUAGES
}
fn is_available(&self) -> bool {
true
}
async fn ocr_image(&self, image_bytes: &[u8]) -> Result<OcrResult, OcrError> {
let bytes = image_bytes.to_vec();
tokio::task::spawn_blocking(move || run_vision_ocr(&bytes))
.await
.map_err(|e| OcrError::Framework(format!("spawn_blocking join error: {e}")))?
}
}
fn run_vision_ocr(image_bytes: &[u8]) -> Result<OcrResult, OcrError> {
let ns_data: Retained<NSData> = NSData::from_vec(image_bytes.to_vec());
let options: Retained<NSDictionary<objc2_vision::VNImageOption, AnyObject>> =
NSDictionary::new();
let request = build_text_request();
let handler = VNImageRequestHandler::initWithData_options(
VNImageRequestHandler::alloc(),
&ns_data,
&options,
);
let req_array: Retained<NSArray<VNRecognizeTextRequest>> =
NSArray::from_retained_slice(std::slice::from_ref(&request));
let vn_req_array: &NSArray<objc2_vision::VNRequest> = unsafe {
&*std::ptr::from_ref::<NSArray<VNRecognizeTextRequest>>(req_array.as_ref())
.cast::<NSArray<objc2_vision::VNRequest>>()
};
if let Err(err) = handler.performRequests_error(vn_req_array) {
let msg = err.localizedDescription().to_string();
return Err(OcrError::Framework(msg));
}
Ok(extract_results(&request))
}
fn build_text_request() -> Retained<VNRecognizeTextRequest> {
let request = VNRecognizeTextRequest::new();
request.setRecognitionLevel(VNRequestTextRecognitionLevel::Accurate);
request.setUsesLanguageCorrection(true);
request.setAutomaticallyDetectsLanguage(true);
request
}
fn extract_results(request: &VNRecognizeTextRequest) -> OcrResult {
let observations = request.results();
let observations = match observations {
Some(obs) if !obs.is_empty() => obs,
_ => {
return OcrResult {
text: String::new(),
language: None,
confidence: 0.0,
regions: vec![],
};
}
};
let mut regions = Vec::with_capacity(observations.len());
let mut total_confidence = 0.0_f32;
for obs in &observations {
let candidates = obs.topCandidates(1 as NSUInteger);
let Some(candidate) = candidates.iter().next() else {
continue;
};
let ns_text: Retained<NSString> = candidate.string();
let text = ns_text.to_string();
let confidence = candidate.confidence();
let bbox = unsafe { obs.boundingBox() };
#[allow(clippy::cast_possible_truncation)]
let region = OcrRegion {
text,
bounding_box: [
bbox.origin.x as f32,
1.0 - (bbox.origin.y as f32) - (bbox.size.height as f32),
bbox.size.width as f32,
bbox.size.height as f32,
],
confidence,
};
total_confidence += confidence;
regions.push(region);
}
#[allow(clippy::cast_precision_loss)]
let avg_confidence = if regions.is_empty() {
0.0
} else {
total_confidence / regions.len() as f32
};
let full_text = regions
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join("\n");
OcrResult {
text: full_text,
language: None, confidence: avg_confidence,
regions,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn name_is_apple_vision() {
let engine = AppleVisionEngine::new();
assert_eq!(engine.name(), "apple_vision");
}
#[test]
fn is_available_on_macos() {
let engine = AppleVisionEngine::new();
assert!(engine.is_available());
}
#[test]
fn supported_languages_correct_set() {
let engine = AppleVisionEngine::new();
let langs = engine.supported_languages();
for required in &["en", "ja", "zh-Hans", "ko", "ar", "ru"] {
assert!(langs.contains(required), "missing language: {required}");
}
assert!(!langs.contains(&"fi"), "fi must not be in list");
assert!(!langs.contains(&"sv"), "sv must not be in list");
}
#[test]
fn ocr_region_bounding_box_is_four_floats() {
let region = OcrRegion {
text: "test".to_string(),
bounding_box: [0.1, 0.2, 0.5, 0.3],
confidence: 0.9,
};
assert_eq!(region.bounding_box.len(), 4);
assert!((region.bounding_box[0] - 0.1).abs() < f32::EPSILON);
assert!((region.bounding_box[2] - 0.5).abs() < f32::EPSILON);
}
}