1use std::fmt;
30
31#[cfg(any(
32 feature = "ocr-mistral",
33 feature = "ocr-google",
34 feature = "ocr-aws",
35 feature = "ocr-azure"
36))]
37use std::io::Cursor;
38
39#[cfg(feature = "ocr-azure")]
40use std::time::Duration;
41
42#[cfg(any(feature = "ocr-mistral", feature = "ocr-google", feature = "ocr-aws"))]
43use base64::Engine as _;
44#[cfg(feature = "ocr-aws")]
45use hmac::{Hmac, Mac};
46#[cfg(any(
47 feature = "ocr-mistral",
48 feature = "ocr-google",
49 feature = "ocr-aws",
50 feature = "ocr-azure"
51))]
52use image::{DynamicImage, ImageFormat, RgbImage};
53#[cfg(feature = "ocr-google")]
54use pkcs8::DecodePrivateKey as _;
55#[cfg(feature = "ocr-google")]
56use rsa::{
57 pkcs1v15::SigningKey,
58 signature::{SignatureEncoding as _, Signer as _},
59 RsaPrivateKey,
60};
61#[cfg(any(
62 feature = "ocr-mistral",
63 feature = "ocr-google",
64 feature = "ocr-aws",
65 feature = "ocr-azure"
66))]
67use serde::{Deserialize, Serialize};
68#[cfg(feature = "ocr-aws")]
69use sha2::Digest as _;
70#[cfg(any(feature = "ocr-google", feature = "ocr-aws"))]
71use sha2::Sha256;
72#[cfg(feature = "ocr-aws")]
73use time::macros::format_description;
74#[cfg(any(feature = "ocr-google", feature = "ocr-aws"))]
75use time::OffsetDateTime;
76
77#[derive(Debug, Clone)]
81pub struct OcrResult {
82 pub text: String,
84 pub words: Vec<OcrWord>,
86 pub confidence: f32,
88}
89
90#[derive(Debug, Clone)]
92pub struct OcrWord {
93 pub text: String,
95 pub bbox: [f32; 4],
97 pub confidence: f32,
99}
100
101#[derive(Debug)]
103pub enum OcrError {
104 NoEngine,
106 ImageError(String),
108 RecognitionFailed(String),
110}
111
112impl fmt::Display for OcrError {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match self {
115 OcrError::NoEngine => write!(f, "no OCR engine available"),
116 OcrError::ImageError(e) => write!(f, "image error: {e}"),
117 OcrError::RecognitionFailed(e) => write!(f, "recognition failed: {e}"),
118 }
119 }
120}
121
122impl std::error::Error for OcrError {}
123
124pub trait OcrBackend: Send + Sync {
130 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError>;
137
138 fn name(&self) -> &str;
140}
141
142#[cfg(any(
145 feature = "ocr-mistral",
146 feature = "ocr-google",
147 feature = "ocr-aws",
148 feature = "ocr-azure"
149))]
150fn blocking_http_client() -> reqwest::blocking::Client {
151 reqwest::blocking::Client::builder()
152 .timeout(std::time::Duration::from_secs(60))
153 .build()
154 .expect("building reqwest blocking client")
155}
156
157#[cfg(any(
158 feature = "ocr-mistral",
159 feature = "ocr-google",
160 feature = "ocr-aws",
161 feature = "ocr-azure"
162))]
163fn expected_rgb_len(width: u32, height: u32) -> Result<usize, OcrError> {
164 (width as usize)
165 .checked_mul(height as usize)
166 .and_then(|px| px.checked_mul(3))
167 .ok_or_else(|| OcrError::ImageError("image dimensions overflowed".into()))
168}
169
170#[cfg(any(
171 feature = "ocr-mistral",
172 feature = "ocr-google",
173 feature = "ocr-aws",
174 feature = "ocr-azure"
175))]
176fn rgb_png_bytes(image_data: &[u8], width: u32, height: u32) -> Result<Vec<u8>, OcrError> {
177 let expected_len = expected_rgb_len(width, height)?;
178 if image_data.len() != expected_len {
179 return Err(OcrError::ImageError(format!(
180 "expected {expected_len} RGB bytes, got {}",
181 image_data.len()
182 )));
183 }
184
185 let image = RgbImage::from_raw(width, height, image_data.to_vec()).ok_or_else(|| {
186 OcrError::ImageError("failed to create RGB image from raw OCR buffer".into())
187 })?;
188
189 let mut cursor = Cursor::new(Vec::new());
190 DynamicImage::ImageRgb8(image)
191 .write_to(&mut cursor, ImageFormat::Png)
192 .map_err(|e| OcrError::ImageError(format!("encode PNG for OCR request: {e}")))?;
193
194 Ok(cursor.into_inner())
195}
196
197#[cfg(any(feature = "ocr-mistral", feature = "ocr-google", feature = "ocr-aws"))]
198fn rgb_base64_png(image_data: &[u8], width: u32, height: u32) -> Result<String, OcrError> {
199 let png_bytes = rgb_png_bytes(image_data, width, height)?;
200 Ok(base64::engine::general_purpose::STANDARD.encode(png_bytes))
201}
202
203#[cfg(feature = "ocr-mistral")]
204fn rgb_data_url(image_data: &[u8], width: u32, height: u32) -> Result<String, OcrError> {
205 let encoded = rgb_base64_png(image_data, width, height)?;
206 Ok(format!("data:image/png;base64,{encoded}"))
207}
208
209#[cfg(any(feature = "ocr-google", feature = "ocr-azure"))]
210fn bbox_from_points(points: &[(f32, f32)]) -> [f32; 4] {
211 if points.is_empty() {
212 return [0.0, 0.0, 0.0, 0.0];
213 }
214
215 let mut min_x = f32::INFINITY;
216 let mut min_y = f32::INFINITY;
217 let mut max_x = f32::NEG_INFINITY;
218 let mut max_y = f32::NEG_INFINITY;
219
220 for &(x, y) in points {
221 min_x = min_x.min(x);
222 min_y = min_y.min(y);
223 max_x = max_x.max(x);
224 max_y = max_y.max(y);
225 }
226
227 [min_x, min_y, max_x - min_x, max_y - min_y]
228}
229
230#[cfg(any(feature = "ocr-google", feature = "ocr-aws", feature = "ocr-azure"))]
231fn confidence_from_words(words: &[OcrWord], fallback: f32) -> f32 {
232 if words.is_empty() {
233 fallback
234 } else {
235 words.iter().map(|word| word.confidence).sum::<f32>() / words.len() as f32
236 }
237}
238
239#[cfg(any(feature = "ocr-google", feature = "ocr-aws", feature = "ocr-azure"))]
240fn home_dir() -> Option<std::path::PathBuf> {
241 std::env::var_os("HOME")
242 .map(std::path::PathBuf::from)
243 .or_else(|| std::env::var_os("USERPROFILE").map(std::path::PathBuf::from))
244}
245
246#[cfg(feature = "ocr-mistral")]
249const MISTRAL_OCR_ENDPOINT: &str = "https://api.mistral.ai/v1/ocr";
250#[cfg(feature = "ocr-mistral")]
251const MISTRAL_OCR_MODEL: &str = "mistral-ocr-latest";
252
253#[cfg(feature = "ocr-mistral")]
255pub struct MistralOcrBackend {
256 api_key: String,
257 client: reqwest::blocking::Client,
258}
259
260#[cfg(feature = "ocr-mistral")]
261impl MistralOcrBackend {
262 pub fn new(api_key: &str) -> Self {
264 Self {
265 api_key: api_key.to_string(),
266 client: blocking_http_client(),
267 }
268 }
269
270 pub fn from_env() -> Result<Self, OcrError> {
272 let key = std::env::var("MISTRAL_API_KEY").map_err(|_| OcrError::NoEngine)?;
273 Ok(Self::new(&key))
274 }
275}
276
277#[cfg(feature = "ocr-mistral")]
278impl OcrBackend for MistralOcrBackend {
279 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
280 let image_url = rgb_data_url(image_data, width, height)?;
281 let model =
282 std::env::var("MISTRAL_OCR_MODEL").unwrap_or_else(|_| MISTRAL_OCR_MODEL.to_string());
283 let body = MistralOcrRequest {
284 model: &model,
285 document: MistralOcrDocument {
286 kind: "image_url",
287 image_url: &image_url,
288 },
289 };
290 let body_json = serde_json::to_vec(&body)
291 .map_err(|e| OcrError::RecognitionFailed(format!("serialize request: {e}")))?;
292
293 let response = self
294 .client
295 .post(MISTRAL_OCR_ENDPOINT)
296 .header(
297 reqwest::header::AUTHORIZATION,
298 format!("Bearer {}", self.api_key),
299 )
300 .header(reqwest::header::CONTENT_TYPE, "application/json")
301 .body(body_json)
302 .send()
303 .map_err(|e| OcrError::RecognitionFailed(format!("Mistral OCR request failed: {e}")))?;
304
305 let status = response.status();
306 let response_text = response
307 .text()
308 .map_err(|e| OcrError::RecognitionFailed(format!("read Mistral OCR response: {e}")))?;
309
310 if !status.is_success() {
311 return Err(OcrError::RecognitionFailed(format!(
312 "Mistral OCR returned {status}: {response_text}"
313 )));
314 }
315
316 let parsed: MistralOcrResponse = serde_json::from_str(&response_text)
317 .map_err(|e| OcrError::RecognitionFailed(format!("parse Mistral OCR response: {e}")))?;
318 let text = mistral_markdown_text(&parsed);
319 let confidence = if text.trim().is_empty() { 0.0 } else { 0.95 };
320
321 Ok(OcrResult {
322 text,
323 words: Vec::new(),
324 confidence,
325 })
326 }
327
328 fn name(&self) -> &str {
329 "mistral"
330 }
331}
332
333#[cfg(feature = "ocr-mistral")]
334#[derive(Serialize)]
335struct MistralOcrRequest<'a> {
336 model: &'a str,
337 document: MistralOcrDocument<'a>,
338}
339
340#[cfg(feature = "ocr-mistral")]
341#[derive(Serialize)]
342struct MistralOcrDocument<'a> {
343 #[serde(rename = "type")]
344 kind: &'a str,
345 image_url: &'a str,
346}
347
348#[cfg(feature = "ocr-mistral")]
349#[derive(Deserialize)]
350struct MistralOcrResponse {
351 #[serde(default)]
352 pages: Vec<MistralOcrPage>,
353}
354
355#[cfg(feature = "ocr-mistral")]
356#[derive(Deserialize)]
357struct MistralOcrPage {
358 markdown: Option<String>,
359}
360
361#[cfg(feature = "ocr-mistral")]
362fn mistral_markdown_text(response: &MistralOcrResponse) -> String {
363 response
364 .pages
365 .iter()
366 .filter_map(|page| page.markdown.as_deref())
367 .map(str::trim)
368 .filter(|markdown| !markdown.is_empty())
369 .collect::<Vec<_>>()
370 .join("\n\n")
371}
372
373#[cfg(feature = "ocr-google")]
376const GOOGLE_VISION_ENDPOINT: &str = "https://vision.googleapis.com/v1/images:annotate";
377#[cfg(feature = "ocr-google")]
378const GOOGLE_TOKEN_ENDPOINT: &str = "https://oauth2.googleapis.com/token";
379#[cfg(feature = "ocr-google")]
380const GOOGLE_OAUTH_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
381
382#[cfg(feature = "ocr-google")]
389pub struct GoogleVisionBackend {
390 api_key: Option<String>,
391 service_account_json: Option<String>,
392 client: reqwest::blocking::Client,
393 endpoint: String,
394}
395
396#[cfg(feature = "ocr-google")]
397impl GoogleVisionBackend {
398 pub fn from_api_key(key: &str) -> Self {
400 Self::with_auth(
401 Some(key.to_string()),
402 None,
403 GOOGLE_VISION_ENDPOINT.to_string(),
404 )
405 }
406
407 pub fn from_service_account(json_path: &str) -> Result<Self, OcrError> {
412 let json = std::fs::read_to_string(json_path).map_err(|e| {
413 OcrError::RecognitionFailed(format!(
414 "read Google service-account JSON from {json_path}: {e}"
415 ))
416 })?;
417 let credentials = parse_google_credentials(&json)?;
418 match credentials.kind.as_deref() {
419 Some("service_account") => Ok(Self::with_auth(
420 None,
421 Some(json),
422 GOOGLE_VISION_ENDPOINT.to_string(),
423 )),
424 Some(other) => Err(OcrError::RecognitionFailed(format!(
425 "expected Google service-account JSON, found credentials type {other}"
426 ))),
427 None => Err(OcrError::RecognitionFailed(
428 "Google service-account JSON is missing the `type` field".into(),
429 )),
430 }
431 }
432
433 pub fn from_env() -> Result<Self, OcrError> {
441 if let Ok(key) = std::env::var("GOOGLE_VISION_API_KEY") {
442 if !key.trim().is_empty() {
443 return Ok(Self::from_api_key(&key));
444 }
445 }
446
447 if let Ok(path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
448 return Self::from_credentials_file(&path);
449 }
450
451 if std::env::var_os("GOOGLE_CLOUD_PROJECT").is_some() {
452 if let Some(path) = google_application_default_credentials_path() {
453 if path.is_file() {
454 return Self::from_credentials_path(path);
455 }
456 }
457 }
458
459 Err(OcrError::NoEngine)
460 }
461
462 fn with_auth(
463 api_key: Option<String>,
464 service_account_json: Option<String>,
465 endpoint: String,
466 ) -> Self {
467 Self {
468 api_key,
469 service_account_json,
470 client: blocking_http_client(),
471 endpoint,
472 }
473 }
474
475 fn from_credentials_file(path: &str) -> Result<Self, OcrError> {
476 Self::from_credentials_path(std::path::PathBuf::from(path))
477 }
478
479 fn from_credentials_path(path: impl AsRef<std::path::Path>) -> Result<Self, OcrError> {
480 let path = path.as_ref();
481 let json = std::fs::read_to_string(path).map_err(|e| {
482 OcrError::RecognitionFailed(format!(
483 "read Google application credentials from {}: {e}",
484 path.display()
485 ))
486 })?;
487 let credentials = parse_google_credentials(&json)?;
488 match credentials.kind.as_deref() {
489 Some("service_account") | Some("authorized_user") => Ok(Self::with_auth(
490 None,
491 Some(json),
492 GOOGLE_VISION_ENDPOINT.to_string(),
493 )),
494 Some(other) => Err(OcrError::RecognitionFailed(format!(
495 "unsupported Google credentials type {other}"
496 ))),
497 None if credentials.refresh_token.is_some() => Ok(Self::with_auth(
498 None,
499 Some(json),
500 GOOGLE_VISION_ENDPOINT.to_string(),
501 )),
502 None => Err(OcrError::RecognitionFailed(
503 "Google credentials JSON is missing both `type` and refresh-token fields".into(),
504 )),
505 }
506 }
507
508 #[cfg(test)]
509 fn with_endpoint_api_key(key: &str, endpoint: &str) -> Self {
510 Self::with_auth(Some(key.to_string()), None, endpoint.to_string())
511 }
512
513 #[cfg(test)]
514 fn with_endpoint_credentials(credentials_json: &str, endpoint: &str) -> Self {
515 Self::with_auth(
516 None,
517 Some(credentials_json.to_string()),
518 endpoint.to_string(),
519 )
520 }
521}
522
523#[cfg(feature = "ocr-google")]
524impl OcrBackend for GoogleVisionBackend {
525 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
526 let encoded = rgb_base64_png(image_data, width, height)?;
527 let request = GoogleVisionAnnotateEnvelopeRequest {
528 requests: vec![GoogleVisionAnnotateRequest {
529 image: GoogleVisionAnnotateImage { content: &encoded },
530 features: vec![GoogleVisionAnnotateFeature {
531 kind: "TEXT_DETECTION",
532 }],
533 }],
534 };
535 let body_json = serde_json::to_vec(&request).map_err(|e| {
536 OcrError::RecognitionFailed(format!("serialize Google Vision request: {e}"))
537 })?;
538
539 let mut http = self
540 .client
541 .post(&self.endpoint)
542 .header(reqwest::header::CONTENT_TYPE, "application/json");
543
544 if let Some(api_key) = &self.api_key {
545 http = http.query(&[("key", api_key.as_str())]);
546 } else if let Some(credentials_json) = &self.service_account_json {
547 let access_token = google_access_token(&self.client, credentials_json)?;
548 http = http.header(
549 reqwest::header::AUTHORIZATION,
550 format!("Bearer {access_token}"),
551 );
552 } else {
553 return Err(OcrError::NoEngine);
554 }
555
556 let response = http.body(body_json).send().map_err(|e| {
557 OcrError::RecognitionFailed(format!("Google Vision OCR request failed: {e}"))
558 })?;
559
560 let status = response.status();
561 let response_text = response.text().map_err(|e| {
562 OcrError::RecognitionFailed(format!("read Google Vision OCR response: {e}"))
563 })?;
564
565 if !status.is_success() {
566 return Err(OcrError::RecognitionFailed(format!(
567 "Google Vision returned {status}: {response_text}"
568 )));
569 }
570
571 let parsed: GoogleVisionAnnotateEnvelopeResponse = serde_json::from_str(&response_text)
572 .map_err(|e| {
573 OcrError::RecognitionFailed(format!("parse Google Vision OCR response: {e}"))
574 })?;
575 let response = parsed.responses.into_iter().next().ok_or_else(|| {
576 OcrError::RecognitionFailed("Google Vision response did not contain any entries".into())
577 })?;
578
579 if let Some(error) = response.error {
580 let message = error
581 .message
582 .unwrap_or_else(|| "unknown Google Vision error".into());
583 return Err(OcrError::RecognitionFailed(format!(
584 "Google Vision OCR error: {message}"
585 )));
586 }
587
588 let text = response
589 .full_text_annotation
590 .as_ref()
591 .and_then(|annotation| annotation.text.clone())
592 .or_else(|| {
593 response
594 .text_annotations
595 .first()
596 .map(|annotation| annotation.description.clone())
597 })
598 .unwrap_or_default();
599
600 let words = response
601 .text_annotations
602 .into_iter()
603 .enumerate()
604 .filter_map(|(index, annotation)| {
605 if index == 0 {
606 return None;
607 }
608
609 let text = annotation.description.trim().to_string();
610 if text.is_empty() {
611 return None;
612 }
613
614 let bbox = annotation
615 .bounding_poly
616 .map(|poly| {
617 let points = poly
618 .vertices
619 .into_iter()
620 .map(|vertex| (vertex.x.unwrap_or(0.0), vertex.y.unwrap_or(0.0)))
621 .collect::<Vec<_>>();
622 bbox_from_points(&points)
623 })
624 .unwrap_or([0.0, 0.0, 0.0, 0.0]);
625
626 Some(OcrWord {
627 text,
628 bbox,
629 confidence: 1.0,
630 })
631 })
632 .collect::<Vec<_>>();
633
634 let confidence = if text.trim().is_empty() {
635 0.0
636 } else {
637 confidence_from_words(&words, 1.0)
638 };
639
640 Ok(OcrResult {
641 text,
642 words,
643 confidence,
644 })
645 }
646
647 fn name(&self) -> &str {
648 "google-vision"
649 }
650}
651
652#[cfg(feature = "ocr-google")]
653#[derive(Serialize)]
654struct GoogleVisionAnnotateEnvelopeRequest<'a> {
655 requests: Vec<GoogleVisionAnnotateRequest<'a>>,
656}
657
658#[cfg(feature = "ocr-google")]
659#[derive(Serialize)]
660struct GoogleVisionAnnotateRequest<'a> {
661 image: GoogleVisionAnnotateImage<'a>,
662 features: Vec<GoogleVisionAnnotateFeature<'a>>,
663}
664
665#[cfg(feature = "ocr-google")]
666#[derive(Serialize)]
667struct GoogleVisionAnnotateImage<'a> {
668 content: &'a str,
669}
670
671#[cfg(feature = "ocr-google")]
672#[derive(Serialize)]
673struct GoogleVisionAnnotateFeature<'a> {
674 #[serde(rename = "type")]
675 kind: &'a str,
676}
677
678#[cfg(feature = "ocr-google")]
679#[derive(Deserialize)]
680struct GoogleVisionAnnotateEnvelopeResponse {
681 #[serde(default)]
682 responses: Vec<GoogleVisionAnnotateResponse>,
683}
684
685#[cfg(feature = "ocr-google")]
686#[derive(Deserialize)]
687struct GoogleVisionAnnotateResponse {
688 #[serde(rename = "fullTextAnnotation")]
689 full_text_annotation: Option<GoogleVisionFullTextAnnotation>,
690 #[serde(rename = "textAnnotations", default)]
691 text_annotations: Vec<GoogleVisionTextAnnotation>,
692 error: Option<GoogleVisionError>,
693}
694
695#[cfg(feature = "ocr-google")]
696#[derive(Deserialize)]
697struct GoogleVisionFullTextAnnotation {
698 text: Option<String>,
699}
700
701#[cfg(feature = "ocr-google")]
702#[derive(Deserialize)]
703struct GoogleVisionTextAnnotation {
704 description: String,
705 #[serde(rename = "boundingPoly")]
706 bounding_poly: Option<GoogleVisionBoundingPoly>,
707}
708
709#[cfg(feature = "ocr-google")]
710#[derive(Deserialize)]
711struct GoogleVisionBoundingPoly {
712 #[serde(default)]
713 vertices: Vec<GoogleVisionVertex>,
714}
715
716#[cfg(feature = "ocr-google")]
717#[derive(Deserialize)]
718struct GoogleVisionVertex {
719 x: Option<f32>,
720 y: Option<f32>,
721}
722
723#[cfg(feature = "ocr-google")]
724#[derive(Deserialize)]
725struct GoogleVisionError {
726 message: Option<String>,
727}
728
729#[cfg(feature = "ocr-google")]
730#[derive(Deserialize)]
731struct GoogleCredentialsFile {
732 #[serde(rename = "type")]
733 kind: Option<String>,
734 client_email: Option<String>,
735 private_key: Option<String>,
736 token_uri: Option<String>,
737 client_id: Option<String>,
738 client_secret: Option<String>,
739 refresh_token: Option<String>,
740}
741
742#[cfg(feature = "ocr-google")]
743#[derive(Serialize)]
744struct GoogleServiceAccountClaims<'a> {
745 iss: &'a str,
746 scope: &'a str,
747 aud: &'a str,
748 exp: i64,
749 iat: i64,
750}
751
752#[cfg(feature = "ocr-google")]
753#[derive(Deserialize)]
754struct GoogleOAuthTokenResponse {
755 access_token: Option<String>,
756 error: Option<String>,
757 error_description: Option<String>,
758}
759
760#[cfg(feature = "ocr-google")]
761fn parse_google_credentials(json: &str) -> Result<GoogleCredentialsFile, OcrError> {
762 serde_json::from_str(json)
763 .map_err(|e| OcrError::RecognitionFailed(format!("parse Google credentials JSON: {e}")))
764}
765
766#[cfg(feature = "ocr-google")]
767fn google_access_token(
768 client: &reqwest::blocking::Client,
769 credentials_json: &str,
770) -> Result<String, OcrError> {
771 let credentials = parse_google_credentials(credentials_json)?;
772 match credentials.kind.as_deref() {
773 Some("service_account") => google_service_account_access_token(client, &credentials),
774 Some("authorized_user") => google_authorized_user_access_token(client, &credentials),
775 Some(other) => Err(OcrError::RecognitionFailed(format!(
776 "unsupported Google credentials type {other}"
777 ))),
778 None if credentials.refresh_token.is_some() => {
779 google_authorized_user_access_token(client, &credentials)
780 }
781 None => Err(OcrError::RecognitionFailed(
782 "Google credentials are missing the `type` field".into(),
783 )),
784 }
785}
786
787#[cfg(feature = "ocr-google")]
788fn google_service_account_access_token(
789 client: &reqwest::blocking::Client,
790 credentials: &GoogleCredentialsFile,
791) -> Result<String, OcrError> {
792 let client_email = credentials.client_email.as_deref().ok_or_else(|| {
793 OcrError::RecognitionFailed("Google service-account JSON is missing `client_email`".into())
794 })?;
795 let private_key_pem = credentials.private_key.as_deref().ok_or_else(|| {
796 OcrError::RecognitionFailed("Google service-account JSON is missing `private_key`".into())
797 })?;
798 let token_uri = credentials
799 .token_uri
800 .as_deref()
801 .unwrap_or(GOOGLE_TOKEN_ENDPOINT);
802
803 let now = OffsetDateTime::now_utc().unix_timestamp();
804 let claims = GoogleServiceAccountClaims {
805 iss: client_email,
806 scope: GOOGLE_OAUTH_SCOPE,
807 aud: token_uri,
808 exp: now + 3600,
809 iat: now,
810 };
811
812 let header_json = serde_json::to_vec(&serde_json::json!({
813 "alg": "RS256",
814 "typ": "JWT",
815 }))
816 .map_err(|e| OcrError::RecognitionFailed(format!("serialize Google JWT header: {e}")))?;
817 let claims_json = serde_json::to_vec(&claims)
818 .map_err(|e| OcrError::RecognitionFailed(format!("serialize Google JWT claims: {e}")))?;
819
820 let signing_input = format!(
821 "{}.{}",
822 google_base64_url(&header_json),
823 google_base64_url(&claims_json)
824 );
825 let private_key = RsaPrivateKey::from_pkcs8_pem(private_key_pem).map_err(|e| {
826 OcrError::RecognitionFailed(format!("parse Google service-account private key: {e}"))
827 })?;
828 let signing_key = SigningKey::<Sha256>::new(private_key);
829 let signature = signing_key.sign(signing_input.as_bytes());
830 let assertion = format!("{signing_input}.{}", google_base64_url(&signature.to_vec()));
831
832 let response = client
833 .post(token_uri)
834 .form(&[
835 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
836 ("assertion", assertion.as_str()),
837 ])
838 .send()
839 .map_err(|e| {
840 OcrError::RecognitionFailed(format!(
841 "request Google OAuth access token with service account: {e}"
842 ))
843 })?;
844
845 google_token_response_text(response)
846}
847
848#[cfg(feature = "ocr-google")]
849fn google_authorized_user_access_token(
850 client: &reqwest::blocking::Client,
851 credentials: &GoogleCredentialsFile,
852) -> Result<String, OcrError> {
853 let client_id = credentials.client_id.as_deref().ok_or_else(|| {
854 OcrError::RecognitionFailed(
855 "Google application-default credentials are missing `client_id`".into(),
856 )
857 })?;
858 let client_secret = credentials.client_secret.as_deref().ok_or_else(|| {
859 OcrError::RecognitionFailed(
860 "Google application-default credentials are missing `client_secret`".into(),
861 )
862 })?;
863 let refresh_token = credentials.refresh_token.as_deref().ok_or_else(|| {
864 OcrError::RecognitionFailed(
865 "Google application-default credentials are missing `refresh_token`".into(),
866 )
867 })?;
868 let token_uri = credentials
869 .token_uri
870 .as_deref()
871 .unwrap_or(GOOGLE_TOKEN_ENDPOINT);
872
873 let response = client
874 .post(token_uri)
875 .form(&[
876 ("client_id", client_id),
877 ("client_secret", client_secret),
878 ("refresh_token", refresh_token),
879 ("grant_type", "refresh_token"),
880 ])
881 .send()
882 .map_err(|e| {
883 OcrError::RecognitionFailed(format!(
884 "request Google OAuth access token with refresh token: {e}"
885 ))
886 })?;
887
888 google_token_response_text(response)
889}
890
891#[cfg(feature = "ocr-google")]
892fn google_token_response_text(response: reqwest::blocking::Response) -> Result<String, OcrError> {
893 let status = response.status();
894 let body = response.text().map_err(|e| {
895 OcrError::RecognitionFailed(format!("read Google OAuth token response: {e}"))
896 })?;
897
898 if !status.is_success() {
899 return Err(OcrError::RecognitionFailed(format!(
900 "Google OAuth token exchange returned {status}: {body}"
901 )));
902 }
903
904 let token: GoogleOAuthTokenResponse = serde_json::from_str(&body).map_err(|e| {
905 OcrError::RecognitionFailed(format!("parse Google OAuth token response: {e}"))
906 })?;
907
908 if let Some(access_token) = token.access_token {
909 return Ok(access_token);
910 }
911
912 let error = token.error.unwrap_or_else(|| "unknown error".into());
913 let description = token
914 .error_description
915 .unwrap_or_else(|| "no error description".into());
916 Err(OcrError::RecognitionFailed(format!(
917 "Google OAuth token exchange failed: {error}: {description}"
918 )))
919}
920
921#[cfg(feature = "ocr-google")]
922fn google_base64_url(data: &[u8]) -> String {
923 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
924}
925
926#[cfg(feature = "ocr-google")]
927fn google_application_default_credentials_path() -> Option<std::path::PathBuf> {
928 #[cfg(target_os = "windows")]
929 {
930 std::env::var_os("APPDATA")
931 .map(std::path::PathBuf::from)
932 .map(|dir| {
933 dir.join("gcloud")
934 .join("application_default_credentials.json")
935 })
936 }
937
938 #[cfg(not(target_os = "windows"))]
939 {
940 home_dir().map(|dir| {
941 dir.join(".config")
942 .join("gcloud")
943 .join("application_default_credentials.json")
944 })
945 }
946}
947
948#[cfg(feature = "ocr-aws")]
951const AWS_TEXTRACT_CONTENT_TYPE: &str = "application/x-amz-json-1.1";
952#[cfg(feature = "ocr-aws")]
953const AWS_TEXTRACT_TARGET: &str = "Textract.DetectDocumentText";
954#[cfg(feature = "ocr-aws")]
955const AWS_TEXTRACT_SERVICE: &str = "textract";
956
957#[cfg(feature = "ocr-aws")]
959pub struct AwsTextractBackend {
960 region: String,
961 access_key: Option<String>,
962 secret_key: Option<String>,
963 session_token: Option<String>,
964 client: reqwest::blocking::Client,
965 endpoint: String,
966}
967
968#[cfg(feature = "ocr-aws")]
969impl AwsTextractBackend {
970 pub fn new(region: &str, access_key: &str, secret_key: &str) -> Self {
972 Self::with_credentials(
973 region,
974 access_key,
975 secret_key,
976 None,
977 format!("https://textract.{region}.amazonaws.com/"),
978 )
979 }
980
981 pub fn from_env() -> Result<Self, OcrError> {
989 let profile_name = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".into());
990 let shared = load_aws_profile(&profile_name).unwrap_or_default();
991
992 let region = std::env::var("AWS_REGION")
993 .ok()
994 .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
995 .or(shared.region);
996 let access_key = std::env::var("AWS_ACCESS_KEY_ID")
997 .ok()
998 .or(shared.access_key);
999 let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY")
1000 .ok()
1001 .or(shared.secret_key);
1002 let session_token = std::env::var("AWS_SESSION_TOKEN")
1003 .ok()
1004 .or(shared.session_token);
1005
1006 match (region, access_key, secret_key) {
1007 (Some(region), Some(access_key), Some(secret_key)) => Ok(Self::with_credentials(
1008 ®ion,
1009 &access_key,
1010 &secret_key,
1011 session_token.as_deref(),
1012 format!("https://textract.{region}.amazonaws.com/"),
1013 )),
1014 _ => Err(OcrError::NoEngine),
1015 }
1016 }
1017
1018 fn with_credentials(
1019 region: &str,
1020 access_key: &str,
1021 secret_key: &str,
1022 session_token: Option<&str>,
1023 endpoint: String,
1024 ) -> Self {
1025 Self {
1026 region: region.to_string(),
1027 access_key: Some(access_key.to_string()),
1028 secret_key: Some(secret_key.to_string()),
1029 session_token: session_token.map(str::to_string),
1030 client: blocking_http_client(),
1031 endpoint,
1032 }
1033 }
1034
1035 #[cfg(test)]
1036 fn with_endpoint(region: &str, access_key: &str, secret_key: &str, endpoint: &str) -> Self {
1037 Self::with_credentials(region, access_key, secret_key, None, endpoint.to_string())
1038 }
1039}
1040
1041#[cfg(feature = "ocr-aws")]
1042impl OcrBackend for AwsTextractBackend {
1043 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
1044 let access_key = self.access_key.as_deref().ok_or(OcrError::NoEngine)?;
1045 let secret_key = self.secret_key.as_deref().ok_or(OcrError::NoEngine)?;
1046 let encoded = rgb_base64_png(image_data, width, height)?;
1047 let body = AwsTextractRequest {
1048 document: AwsTextractDocument { bytes: &encoded },
1049 };
1050 let body_json = serde_json::to_vec(&body).map_err(|e| {
1051 OcrError::RecognitionFailed(format!("serialize AWS Textract request: {e}"))
1052 })?;
1053
1054 let url = reqwest::Url::parse(&self.endpoint).map_err(|e| {
1055 OcrError::RecognitionFailed(format!("parse AWS Textract endpoint: {e}"))
1056 })?;
1057 let host = host_header_value(&url)?;
1058 let canonical_uri = if url.path().is_empty() {
1059 "/"
1060 } else {
1061 url.path()
1062 };
1063 let canonical_query = url.query().unwrap_or("");
1064 let now = OffsetDateTime::now_utc();
1065 let amz_date = now
1066 .format(&format_description!(
1067 "[year][month][day]T[hour][minute][second]Z"
1068 ))
1069 .map_err(|e| OcrError::RecognitionFailed(format!("format AWS SigV4 timestamp: {e}")))?;
1070 let date_stamp = now
1071 .format(&format_description!("[year][month][day]"))
1072 .map_err(|e| {
1073 OcrError::RecognitionFailed(format!("format AWS SigV4 date stamp: {e}"))
1074 })?;
1075 let payload_hash = sha256_hex(&body_json);
1076
1077 let mut headers = vec![
1078 (
1079 "content-type".to_string(),
1080 AWS_TEXTRACT_CONTENT_TYPE.to_string(),
1081 ),
1082 ("host".to_string(), host.clone()),
1083 ("x-amz-date".to_string(), amz_date.clone()),
1084 ("x-amz-target".to_string(), AWS_TEXTRACT_TARGET.to_string()),
1085 ];
1086 if let Some(token) = &self.session_token {
1087 headers.push(("x-amz-security-token".to_string(), token.clone()));
1088 }
1089 headers.sort_by(|left, right| left.0.cmp(&right.0));
1090
1091 let canonical_headers = headers
1092 .iter()
1093 .map(|(name, value)| format!("{name}:{}\n", value.trim()))
1094 .collect::<String>();
1095 let signed_headers = headers
1096 .iter()
1097 .map(|(name, _)| name.as_str())
1098 .collect::<Vec<_>>()
1099 .join(";");
1100 let canonical_request = format!(
1101 "POST\n{canonical_uri}\n{canonical_query}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
1102 );
1103 let credential_scope = format!(
1104 "{date_stamp}/{}/{AWS_TEXTRACT_SERVICE}/aws4_request",
1105 self.region
1106 );
1107 let string_to_sign = format!(
1108 "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
1109 sha256_hex(canonical_request.as_bytes())
1110 );
1111 let signature = aws_sigv4_signature(
1112 secret_key,
1113 &date_stamp,
1114 &self.region,
1115 AWS_TEXTRACT_SERVICE,
1116 &string_to_sign,
1117 )?;
1118 let authorization = format!(
1119 "AWS4-HMAC-SHA256 Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
1120 );
1121
1122 let mut request = self
1123 .client
1124 .post(url)
1125 .header(reqwest::header::CONTENT_TYPE, AWS_TEXTRACT_CONTENT_TYPE)
1126 .header("X-Amz-Target", AWS_TEXTRACT_TARGET)
1127 .header("X-Amz-Date", &amz_date)
1128 .header(reqwest::header::HOST, host)
1129 .header(reqwest::header::AUTHORIZATION, authorization);
1130
1131 if let Some(token) = &self.session_token {
1132 request = request.header("X-Amz-Security-Token", token);
1133 }
1134
1135 let response = request.body(body_json).send().map_err(|e| {
1136 OcrError::RecognitionFailed(format!("AWS Textract OCR request failed: {e}"))
1137 })?;
1138
1139 let status = response.status();
1140 let body = response.text().map_err(|e| {
1141 OcrError::RecognitionFailed(format!("read AWS Textract OCR response: {e}"))
1142 })?;
1143
1144 if !status.is_success() {
1145 return Err(OcrError::RecognitionFailed(format!(
1146 "AWS Textract returned {status}: {body}"
1147 )));
1148 }
1149
1150 let parsed: AwsTextractResponse = serde_json::from_str(&body).map_err(|e| {
1151 OcrError::RecognitionFailed(format!("parse AWS Textract OCR response: {e}"))
1152 })?;
1153
1154 let mut words = Vec::new();
1155 let mut lines = Vec::new();
1156 for block in parsed
1157 .blocks
1158 .into_iter()
1159 .filter(|block| block.block_type.as_deref() == Some("LINE"))
1160 {
1161 let text = block.text.unwrap_or_default().trim().to_string();
1162 if text.is_empty() {
1163 continue;
1164 }
1165
1166 lines.push(text.clone());
1167 let bbox = block
1168 .geometry
1169 .and_then(|geometry| geometry.bounding_box)
1170 .map(|bbox| {
1171 [
1172 bbox.left.unwrap_or(0.0) * width as f32,
1173 bbox.top.unwrap_or(0.0) * height as f32,
1174 bbox.width.unwrap_or(0.0) * width as f32,
1175 bbox.height.unwrap_or(0.0) * height as f32,
1176 ]
1177 })
1178 .unwrap_or([0.0, 0.0, 0.0, 0.0]);
1179 let confidence = block.confidence.unwrap_or(100.0) / 100.0;
1180 words.push(OcrWord {
1181 text,
1182 bbox,
1183 confidence,
1184 });
1185 }
1186
1187 let text = lines.join("\n");
1188 let confidence = if text.trim().is_empty() {
1189 0.0
1190 } else {
1191 confidence_from_words(&words, 1.0)
1192 };
1193
1194 Ok(OcrResult {
1195 text,
1196 words,
1197 confidence,
1198 })
1199 }
1200
1201 fn name(&self) -> &str {
1202 "aws-textract"
1203 }
1204}
1205
1206#[cfg(feature = "ocr-aws")]
1207#[derive(Serialize)]
1208struct AwsTextractRequest<'a> {
1209 #[serde(rename = "Document")]
1210 document: AwsTextractDocument<'a>,
1211}
1212
1213#[cfg(feature = "ocr-aws")]
1214#[derive(Serialize)]
1215struct AwsTextractDocument<'a> {
1216 #[serde(rename = "Bytes")]
1217 bytes: &'a str,
1218}
1219
1220#[cfg(feature = "ocr-aws")]
1221#[derive(Deserialize)]
1222struct AwsTextractResponse {
1223 #[serde(rename = "Blocks", default)]
1224 blocks: Vec<AwsTextractBlock>,
1225}
1226
1227#[cfg(feature = "ocr-aws")]
1228#[derive(Deserialize)]
1229struct AwsTextractBlock {
1230 #[serde(rename = "BlockType")]
1231 block_type: Option<String>,
1232 #[serde(rename = "Text")]
1233 text: Option<String>,
1234 #[serde(rename = "Confidence")]
1235 confidence: Option<f32>,
1236 #[serde(rename = "Geometry")]
1237 geometry: Option<AwsTextractGeometry>,
1238}
1239
1240#[cfg(feature = "ocr-aws")]
1241#[derive(Deserialize)]
1242struct AwsTextractGeometry {
1243 #[serde(rename = "BoundingBox")]
1244 bounding_box: Option<AwsTextractBoundingBox>,
1245}
1246
1247#[cfg(feature = "ocr-aws")]
1248#[derive(Deserialize)]
1249struct AwsTextractBoundingBox {
1250 #[serde(rename = "Left")]
1251 left: Option<f32>,
1252 #[serde(rename = "Top")]
1253 top: Option<f32>,
1254 #[serde(rename = "Width")]
1255 width: Option<f32>,
1256 #[serde(rename = "Height")]
1257 height: Option<f32>,
1258}
1259
1260#[cfg(feature = "ocr-aws")]
1261#[derive(Default)]
1262struct AwsResolvedCredentials {
1263 region: Option<String>,
1264 access_key: Option<String>,
1265 secret_key: Option<String>,
1266 session_token: Option<String>,
1267}
1268
1269#[cfg(feature = "ocr-aws")]
1270fn load_aws_profile(profile_name: &str) -> Option<AwsResolvedCredentials> {
1271 let credentials_sections = aws_shared_file("credentials")
1272 .and_then(|path| std::fs::read_to_string(path).ok())
1273 .map(|text| parse_ini_sections(&text))
1274 .unwrap_or_default();
1275 let config_sections = aws_shared_file("config")
1276 .and_then(|path| std::fs::read_to_string(path).ok())
1277 .map(|text| parse_ini_sections(&text))
1278 .unwrap_or_default();
1279
1280 let credentials_section = credentials_sections
1281 .get(profile_name)
1282 .or_else(|| credentials_sections.get(&format!("profile {profile_name}")));
1283 let config_section_name = if profile_name == "default" {
1284 "default".to_string()
1285 } else {
1286 format!("profile {profile_name}")
1287 };
1288 let config_section = config_sections.get(&config_section_name);
1289
1290 Some(AwsResolvedCredentials {
1291 region: config_section.and_then(|section| section.get("region").cloned()),
1292 access_key: credentials_section
1293 .and_then(|section| section.get("aws_access_key_id").cloned()),
1294 secret_key: credentials_section
1295 .and_then(|section| section.get("aws_secret_access_key").cloned()),
1296 session_token: credentials_section
1297 .and_then(|section| section.get("aws_session_token").cloned()),
1298 })
1299}
1300
1301#[cfg(feature = "ocr-aws")]
1302fn aws_shared_file(name: &str) -> Option<std::path::PathBuf> {
1303 #[cfg(target_os = "windows")]
1304 {
1305 home_dir().map(|dir| dir.join(".aws").join(name))
1306 }
1307
1308 #[cfg(not(target_os = "windows"))]
1309 {
1310 home_dir().map(|dir| dir.join(".aws").join(name))
1311 }
1312}
1313
1314#[cfg(feature = "ocr-aws")]
1315fn parse_ini_sections(
1316 text: &str,
1317) -> std::collections::HashMap<String, std::collections::HashMap<String, String>> {
1318 let mut sections: std::collections::HashMap<String, std::collections::HashMap<String, String>> =
1319 std::collections::HashMap::new();
1320 let mut current = String::new();
1321
1322 for raw_line in text.lines() {
1323 let line = raw_line.trim();
1324 if line.is_empty() || line.starts_with('#') || line.starts_with(';') {
1325 continue;
1326 }
1327
1328 if let Some(section) = line
1329 .strip_prefix('[')
1330 .and_then(|line| line.strip_suffix(']'))
1331 {
1332 current = section.trim().to_string();
1333 sections.entry(current.clone()).or_default();
1334 continue;
1335 }
1336
1337 if let Some((key, value)) = line.split_once('=') {
1338 sections
1339 .entry(current.clone())
1340 .or_default()
1341 .insert(key.trim().to_ascii_lowercase(), value.trim().to_string());
1342 }
1343 }
1344
1345 sections
1346}
1347
1348#[cfg(feature = "ocr-aws")]
1349fn host_header_value(url: &reqwest::Url) -> Result<String, OcrError> {
1350 let host = url.host_str().ok_or_else(|| {
1351 OcrError::RecognitionFailed("AWS Textract endpoint is missing a host".into())
1352 })?;
1353 Ok(match url.port() {
1354 Some(port) => format!("{host}:{port}"),
1355 None => host.to_string(),
1356 })
1357}
1358
1359#[cfg(feature = "ocr-aws")]
1360fn aws_sigv4_signature(
1361 secret_key: &str,
1362 date_stamp: &str,
1363 region: &str,
1364 service: &str,
1365 string_to_sign: &str,
1366) -> Result<String, OcrError> {
1367 let k_date = hmac_sha256(format!("AWS4{secret_key}").as_bytes(), date_stamp)?;
1368 let k_region = hmac_sha256(&k_date, region)?;
1369 let k_service = hmac_sha256(&k_region, service)?;
1370 let k_signing = hmac_sha256(&k_service, "aws4_request")?;
1371 let signature = hmac_sha256(&k_signing, string_to_sign)?;
1372 Ok(hex_encode(&signature))
1373}
1374
1375#[cfg(feature = "ocr-aws")]
1376fn hmac_sha256(key: &[u8], data: &str) -> Result<Vec<u8>, OcrError> {
1377 type HmacSha256 = Hmac<Sha256>;
1378
1379 let mut mac = HmacSha256::new_from_slice(key)
1380 .map_err(|e| OcrError::RecognitionFailed(format!("build AWS SigV4 HMAC: {e}")))?;
1381 mac.update(data.as_bytes());
1382 Ok(mac.finalize().into_bytes().to_vec())
1383}
1384
1385#[cfg(feature = "ocr-aws")]
1386fn sha256_hex(data: &[u8]) -> String {
1387 hex_encode(&Sha256::digest(data))
1388}
1389
1390#[cfg(feature = "ocr-aws")]
1391fn hex_encode(bytes: &[u8]) -> String {
1392 let mut output = String::with_capacity(bytes.len() * 2);
1393 for byte in bytes {
1394 use std::fmt::Write as _;
1395 let _ = write!(&mut output, "{byte:02x}");
1396 }
1397 output
1398}
1399
1400#[cfg(feature = "ocr-azure")]
1403const AZURE_DOC_INTELLIGENCE_API_VERSION: &str = "2024-11-30";
1404
1405#[cfg(feature = "ocr-azure")]
1407pub struct AzureDocIntelBackend {
1408 endpoint: String,
1409 api_key: String,
1410 client: reqwest::blocking::Client,
1411}
1412
1413#[cfg(feature = "ocr-azure")]
1414impl AzureDocIntelBackend {
1415 pub fn new(endpoint: &str, api_key: &str) -> Self {
1417 Self {
1418 endpoint: endpoint.trim_end_matches('/').to_string(),
1419 api_key: api_key.to_string(),
1420 client: blocking_http_client(),
1421 }
1422 }
1423
1424 pub fn from_env() -> Result<Self, OcrError> {
1430 let endpoint = std::env::var("AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT")
1431 .map_err(|_| OcrError::NoEngine)?;
1432 let api_key =
1433 std::env::var("AZURE_DOCUMENT_INTELLIGENCE_KEY").map_err(|_| OcrError::NoEngine)?;
1434 Ok(Self::new(&endpoint, &api_key))
1435 }
1436
1437 fn analyze_endpoint(&self) -> String {
1438 let api_version = std::env::var("AZURE_DOCUMENT_INTELLIGENCE_API_VERSION")
1439 .unwrap_or_else(|_| AZURE_DOC_INTELLIGENCE_API_VERSION.to_string());
1440 format!(
1441 "{}/documentintelligence/documentModels/prebuilt-read:analyze?api-version={api_version}",
1442 self.endpoint
1443 )
1444 }
1445
1446 fn poll_interval() -> Duration {
1447 std::env::var("AZURE_DOCUMENT_INTELLIGENCE_POLL_INTERVAL_MS")
1448 .ok()
1449 .and_then(|value| value.parse::<u64>().ok())
1450 .map(Duration::from_millis)
1451 .unwrap_or_else(|| Duration::from_millis(250))
1452 }
1453
1454 fn max_polls() -> usize {
1455 std::env::var("AZURE_DOCUMENT_INTELLIGENCE_MAX_POLLS")
1456 .ok()
1457 .and_then(|value| value.parse::<usize>().ok())
1458 .unwrap_or(120)
1459 }
1460
1461 fn poll_operation(&self, operation_location: &str) -> Result<OcrResult, OcrError> {
1462 for attempt in 0..Self::max_polls() {
1463 let response = self
1464 .client
1465 .get(operation_location)
1466 .header("Ocp-Apim-Subscription-Key", &self.api_key)
1467 .send()
1468 .map_err(|e| {
1469 OcrError::RecognitionFailed(format!(
1470 "Azure Document Intelligence poll request failed: {e}"
1471 ))
1472 })?;
1473
1474 let status = response.status();
1475 let body = response.text().map_err(|e| {
1476 OcrError::RecognitionFailed(format!(
1477 "read Azure Document Intelligence poll response: {e}"
1478 ))
1479 })?;
1480
1481 if !status.is_success() {
1482 return Err(OcrError::RecognitionFailed(format!(
1483 "Azure Document Intelligence poll returned {status}: {body}"
1484 )));
1485 }
1486
1487 let operation: AzureAnalyzeOperation = serde_json::from_str(&body).map_err(|e| {
1488 OcrError::RecognitionFailed(format!(
1489 "parse Azure Document Intelligence poll response: {e}"
1490 ))
1491 })?;
1492
1493 if operation.status.eq_ignore_ascii_case("succeeded") {
1494 return azure_operation_to_result(operation);
1495 }
1496
1497 if operation.status.eq_ignore_ascii_case("failed")
1498 || operation.status.eq_ignore_ascii_case("cancelled")
1499 {
1500 let message = operation
1501 .error
1502 .and_then(|error| error.message)
1503 .unwrap_or_else(|| format!("Azure operation status {}", operation.status));
1504 return Err(OcrError::RecognitionFailed(format!(
1505 "Azure Document Intelligence OCR failed: {message}"
1506 )));
1507 }
1508
1509 if attempt + 1 < Self::max_polls() {
1510 std::thread::sleep(Self::poll_interval());
1511 }
1512 }
1513
1514 Err(OcrError::RecognitionFailed(
1515 "timed out polling Azure Document Intelligence analyze operation".into(),
1516 ))
1517 }
1518}
1519
1520#[cfg(feature = "ocr-azure")]
1521impl OcrBackend for AzureDocIntelBackend {
1522 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
1523 let png_bytes = rgb_png_bytes(image_data, width, height)?;
1524 let response = self
1525 .client
1526 .post(self.analyze_endpoint())
1527 .header("Ocp-Apim-Subscription-Key", &self.api_key)
1528 .header(reqwest::header::CONTENT_TYPE, "application/octet-stream")
1529 .body(png_bytes)
1530 .send()
1531 .map_err(|e| {
1532 OcrError::RecognitionFailed(format!(
1533 "Azure Document Intelligence analyze request failed: {e}"
1534 ))
1535 })?;
1536
1537 let status = response.status();
1538 let operation_location = response
1539 .headers()
1540 .get("operation-location")
1541 .and_then(|value| value.to_str().ok())
1542 .map(str::to_string);
1543 let body = response.text().map_err(|e| {
1544 OcrError::RecognitionFailed(format!(
1545 "read Azure Document Intelligence analyze response: {e}"
1546 ))
1547 })?;
1548
1549 if !status.is_success() {
1550 return Err(OcrError::RecognitionFailed(format!(
1551 "Azure Document Intelligence returned {status}: {body}"
1552 )));
1553 }
1554
1555 if let Some(operation_location) = operation_location {
1556 return self.poll_operation(&operation_location);
1557 }
1558
1559 let operation: AzureAnalyzeOperation = serde_json::from_str(&body).map_err(|e| {
1560 OcrError::RecognitionFailed(format!(
1561 "parse Azure Document Intelligence analyze response: {e}"
1562 ))
1563 })?;
1564 azure_operation_to_result(operation)
1565 }
1566
1567 fn name(&self) -> &str {
1568 "azure-doc-intel"
1569 }
1570}
1571
1572#[cfg(feature = "ocr-azure")]
1573#[derive(Deserialize)]
1574struct AzureAnalyzeOperation {
1575 #[serde(default)]
1576 status: String,
1577 #[serde(rename = "analyzeResult")]
1578 analyze_result: Option<AzureAnalyzeResult>,
1579 error: Option<AzureAnalyzeError>,
1580}
1581
1582#[cfg(feature = "ocr-azure")]
1583#[derive(Deserialize)]
1584struct AzureAnalyzeError {
1585 message: Option<String>,
1586}
1587
1588#[cfg(feature = "ocr-azure")]
1589#[derive(Deserialize)]
1590struct AzureAnalyzeResult {
1591 content: Option<String>,
1592 #[serde(default)]
1593 pages: Vec<AzureAnalyzePage>,
1594}
1595
1596#[cfg(feature = "ocr-azure")]
1597#[derive(Deserialize)]
1598struct AzureAnalyzePage {
1599 #[serde(default)]
1600 words: Vec<AzureAnalyzeWord>,
1601}
1602
1603#[cfg(feature = "ocr-azure")]
1604#[derive(Deserialize)]
1605struct AzureAnalyzeWord {
1606 content: String,
1607 confidence: Option<f32>,
1608 #[serde(default, alias = "boundingPolygon")]
1609 polygon: Vec<f32>,
1610}
1611
1612#[cfg(feature = "ocr-azure")]
1613fn azure_operation_to_result(operation: AzureAnalyzeOperation) -> Result<OcrResult, OcrError> {
1614 let analyze_result = operation.analyze_result.ok_or_else(|| {
1615 OcrError::RecognitionFailed(format!(
1616 "Azure Document Intelligence operation ended with status {} but no analyze result",
1617 operation.status
1618 ))
1619 })?;
1620
1621 let mut words = Vec::new();
1622 for page in analyze_result.pages {
1623 for word in page.words {
1624 let points = word
1625 .polygon
1626 .chunks_exact(2)
1627 .map(|pair| (pair[0], pair[1]))
1628 .collect::<Vec<_>>();
1629 words.push(OcrWord {
1630 text: word.content,
1631 bbox: bbox_from_points(&points),
1632 confidence: word.confidence.unwrap_or(1.0),
1633 });
1634 }
1635 }
1636
1637 let text = analyze_result.content.unwrap_or_else(|| {
1638 words
1639 .iter()
1640 .map(|word| word.text.as_str())
1641 .collect::<Vec<_>>()
1642 .join(" ")
1643 });
1644 let confidence = if text.trim().is_empty() {
1645 0.0
1646 } else {
1647 confidence_from_words(&words, 1.0)
1648 };
1649
1650 Ok(OcrResult {
1651 text,
1652 words,
1653 confidence,
1654 })
1655}
1656
1657#[cfg(feature = "ocr-onnx")]
1661pub struct PaddleOnnxBackend {
1662 engine: pdf_ocr::PaddleOcrEngine,
1663}
1664
1665#[cfg(feature = "ocr-onnx")]
1666impl PaddleOnnxBackend {
1667 pub fn new() -> Result<Self, OcrError> {
1669 let engine = pdf_ocr::PaddleOcrEngine::new()
1670 .map_err(|e| OcrError::RecognitionFailed(format!("init PaddleOCR: {e}")))?;
1671 Ok(Self { engine })
1672 }
1673
1674 pub fn from_env() -> Result<Self, OcrError> {
1679 Self::new()
1680 }
1681}
1682
1683#[cfg(feature = "ocr-onnx")]
1684impl OcrBackend for PaddleOnnxBackend {
1685 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
1686 use pdf_ocr::OcrEngine;
1687
1688 let result = self
1689 .engine
1690 .recognize(image_data, width, height, 300)
1691 .map_err(|e| OcrError::RecognitionFailed(format!("PaddleOCR recognize: {e}")))?;
1692 let text = result.full_text();
1693 let words = result
1694 .words
1695 .into_iter()
1696 .map(|word| OcrWord {
1697 text: word.text,
1698 bbox: [
1699 word.bbox_px[0] as f32,
1700 word.bbox_px[1] as f32,
1701 word.bbox_px[2].saturating_sub(word.bbox_px[0]) as f32,
1702 word.bbox_px[3].saturating_sub(word.bbox_px[1]) as f32,
1703 ],
1704 confidence: word.confidence,
1705 })
1706 .collect();
1707
1708 Ok(OcrResult {
1709 text,
1710 words,
1711 confidence: result.confidence,
1712 })
1713 }
1714
1715 fn name(&self) -> &str {
1716 "paddle-onnx"
1717 }
1718}
1719
1720#[cfg(feature = "ocr")]
1743pub struct OcrsBackend {
1744 engine: ocrs::OcrEngine,
1745}
1746
1747#[cfg(feature = "ocr")]
1748impl OcrsBackend {
1749 pub fn from_bytes(detection: &[u8], recognition: &[u8]) -> Result<Self, OcrError> {
1754 let det = rten::Model::load(detection.to_vec()).map_err(|_| OcrError::NoEngine)?;
1755 let rec = rten::Model::load(recognition.to_vec()).map_err(|_| OcrError::NoEngine)?;
1756 Self::build(det, rec)
1757 }
1758
1759 #[cfg(not(target_arch = "wasm32"))]
1763 pub fn from_files(
1764 detection_path: impl AsRef<std::path::Path>,
1765 recognition_path: impl AsRef<std::path::Path>,
1766 ) -> Result<Self, OcrError> {
1767 let det = rten::Model::load_file(detection_path).map_err(|_| OcrError::NoEngine)?;
1768 let rec = rten::Model::load_file(recognition_path).map_err(|_| OcrError::NoEngine)?;
1769 Self::build(det, rec)
1770 }
1771
1772 #[cfg(not(target_arch = "wasm32"))]
1780 pub fn try_default() -> Result<Self, OcrError> {
1781 let det_path: std::path::PathBuf = std::env::var("OCRS_DETECTION_MODEL")
1782 .ok()
1783 .map(std::path::PathBuf::from)
1784 .or_else(|| default_model_path("text-detection.rten"))
1785 .ok_or(OcrError::NoEngine)?;
1786 let rec_path: std::path::PathBuf = std::env::var("OCRS_RECOGNITION_MODEL")
1787 .ok()
1788 .map(std::path::PathBuf::from)
1789 .or_else(|| default_model_path("text-recognition.rten"))
1790 .ok_or(OcrError::NoEngine)?;
1791 if !det_path.exists() || !rec_path.exists() {
1792 return Err(OcrError::NoEngine);
1793 }
1794 Self::from_files(det_path, rec_path)
1795 }
1796
1797 fn build(detection: rten::Model, recognition: rten::Model) -> Result<Self, OcrError> {
1798 let engine = ocrs::OcrEngine::new(ocrs::OcrEngineParams {
1799 detection_model: Some(detection),
1800 recognition_model: Some(recognition),
1801 ..Default::default()
1802 })
1803 .map_err(|_| OcrError::NoEngine)?;
1804 Ok(Self { engine })
1805 }
1806}
1807
1808#[cfg(all(feature = "ocr", not(target_arch = "wasm32")))]
1811fn default_model_path(filename: &str) -> Option<std::path::PathBuf> {
1812 dirs_sys::home_dir().map(|h: std::path::PathBuf| h.join(".cache").join("ocrs").join(filename))
1813}
1814
1815#[cfg(feature = "ocr")]
1816impl OcrBackend for OcrsBackend {
1817 fn recognize(&self, image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
1818 let image_source = ocrs::ImageSource::from_bytes(image_data, (width, height))
1819 .map_err(|e| OcrError::ImageError(e.to_string()))?;
1820
1821 let input = self
1822 .engine
1823 .prepare_input(image_source)
1824 .map_err(|e| OcrError::RecognitionFailed(e.to_string()))?;
1825
1826 let text = self
1827 .engine
1828 .get_text(&input)
1829 .map_err(|e| OcrError::RecognitionFailed(e.to_string()))?;
1830
1831 let words = text
1832 .split_whitespace()
1833 .map(|w| OcrWord {
1834 text: w.to_string(),
1835 bbox: [0.0, 0.0, 0.0, 0.0],
1836 confidence: 1.0,
1837 })
1838 .collect::<Vec<_>>();
1839
1840 let confidence = if text.is_empty() { 0.0 } else { 1.0 };
1841
1842 Ok(OcrResult {
1843 text,
1844 words,
1845 confidence,
1846 })
1847 }
1848
1849 fn name(&self) -> &str {
1850 "ocrs"
1851 }
1852}
1853
1854#[cfg(not(target_arch = "wasm32"))]
1866pub fn best_available_backend() -> Result<Box<dyn OcrBackend>, OcrError> {
1867 #[cfg(feature = "ocr-mistral")]
1868 if let Ok(backend) = MistralOcrBackend::from_env() {
1869 return Ok(Box::new(backend));
1870 }
1871
1872 #[cfg(feature = "ocr-google")]
1873 if let Ok(backend) = GoogleVisionBackend::from_env() {
1874 return Ok(Box::new(backend));
1875 }
1876
1877 #[cfg(feature = "ocr-azure")]
1878 if let Ok(backend) = AzureDocIntelBackend::from_env() {
1879 return Ok(Box::new(backend));
1880 }
1881
1882 #[cfg(feature = "ocr-aws")]
1883 if let Ok(backend) = AwsTextractBackend::from_env() {
1884 return Ok(Box::new(backend));
1885 }
1886
1887 #[cfg(feature = "ocr-onnx")]
1888 if let Ok(backend) = PaddleOnnxBackend::from_env() {
1889 return Ok(Box::new(backend));
1890 }
1891
1892 #[cfg(feature = "ocr")]
1893 if let Ok(backend) = OcrsBackend::try_default() {
1894 return Ok(Box::new(backend));
1895 }
1896
1897 Err(OcrError::NoEngine)
1898}
1899
1900#[cfg(all(feature = "ocr", not(target_arch = "wasm32")))]
1906pub fn ocr_page_default(image_data: &[u8], width: u32, height: u32) -> Result<OcrResult, OcrError> {
1907 let backend = OcrsBackend::try_default()?;
1908 backend.recognize(image_data, width, height)
1909}
1910
1911#[cfg(all(
1912 test,
1913 not(target_arch = "wasm32"),
1914 any(
1915 feature = "ocr-mistral",
1916 feature = "ocr-google",
1917 feature = "ocr-aws",
1918 feature = "ocr-azure"
1919 )
1920))]
1921mod tests {
1922 use super::*;
1923
1924 use std::collections::HashMap;
1925 use std::ffi::OsString;
1926 use std::io::{Read, Write};
1927 use std::net::{TcpListener, TcpStream};
1928 use std::sync::{Mutex, OnceLock};
1929 use std::thread;
1930
1931 fn env_lock() -> &'static Mutex<()> {
1935 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
1936 LOCK.get_or_init(|| Mutex::new(()))
1937 }
1938
1939 struct ScopedEnv {
1940 saved: Vec<(String, Option<OsString>)>,
1941 }
1942
1943 impl ScopedEnv {
1944 fn set(changes: &[(&str, Option<&str>)]) -> Self {
1945 let saved = changes
1946 .iter()
1947 .map(|(key, _)| ((*key).to_string(), std::env::var_os(key)))
1948 .collect::<Vec<_>>();
1949
1950 for (key, value) in changes {
1951 match value {
1952 Some(value) => std::env::set_var(key, value),
1953 None => std::env::remove_var(key),
1954 }
1955 }
1956
1957 Self { saved }
1958 }
1959 }
1960
1961 impl Drop for ScopedEnv {
1962 fn drop(&mut self) {
1963 for (key, value) in self.saved.drain(..) {
1964 match value {
1965 Some(value) => std::env::set_var(&key, value),
1966 None => std::env::remove_var(&key),
1967 }
1968 }
1969 }
1970 }
1971
1972 struct MockServer {
1973 base_url: String,
1974 join_handle: Option<thread::JoinHandle<()>>,
1975 }
1976
1977 impl MockServer {
1978 fn start<F>(requests: usize, handler: F) -> Self
1979 where
1980 F: Fn(MockRequest, usize, &str) -> MockResponse + Send + 'static,
1981 {
1982 let listener = TcpListener::bind("127.0.0.1:0").expect("bind mock server");
1983 let address = listener.local_addr().expect("mock server address");
1984 let base_url = format!("http://{address}");
1985 let handler_base_url = base_url.clone();
1986
1987 let join_handle = thread::spawn(move || {
1988 for index in 0..requests {
1989 let (mut stream, _) = listener.accept().expect("accept mock request");
1990 let request = read_http_request(&mut stream);
1991 let response = handler(request, index, &handler_base_url);
1992 write_http_response(&mut stream, response);
1993 }
1994 });
1995
1996 Self {
1997 base_url,
1998 join_handle: Some(join_handle),
1999 }
2000 }
2001
2002 fn url(&self, path: &str) -> String {
2003 format!("{}{}", self.base_url, path)
2004 }
2005
2006 fn finish(mut self) {
2007 if let Some(join_handle) = self.join_handle.take() {
2008 join_handle.join().expect("mock server thread");
2009 }
2010 }
2011 }
2012
2013 #[derive(Debug)]
2014 struct MockRequest {
2015 method: String,
2016 path: String,
2017 headers: HashMap<String, String>,
2018 body: Vec<u8>,
2019 }
2020
2021 struct MockResponse {
2022 status_code: u16,
2023 headers: Vec<(String, String)>,
2024 body: Vec<u8>,
2025 }
2026
2027 impl MockResponse {
2028 fn json(status_code: u16, body: &str) -> Self {
2029 Self {
2030 status_code,
2031 headers: vec![("Content-Type".into(), "application/json".into())],
2032 body: body.as_bytes().to_vec(),
2033 }
2034 }
2035
2036 fn empty(status_code: u16, headers: &[(&str, &str)]) -> Self {
2037 Self {
2038 status_code,
2039 headers: headers
2040 .iter()
2041 .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
2042 .collect(),
2043 body: Vec::new(),
2044 }
2045 }
2046 }
2047
2048 fn read_http_request(stream: &mut TcpStream) -> MockRequest {
2049 let mut buffer = Vec::new();
2050 let mut chunk = [0u8; 4096];
2051 let header_end = loop {
2052 let read = stream.read(&mut chunk).expect("read mock request");
2053 assert!(read > 0, "mock request closed before headers completed");
2054 buffer.extend_from_slice(&chunk[..read]);
2055
2056 if let Some(index) = find_bytes(&buffer, b"\r\n\r\n") {
2057 break index + 4;
2058 }
2059 };
2060
2061 let header_text = String::from_utf8_lossy(&buffer[..header_end]);
2062 let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
2063 let request_line = lines.next().expect("request line");
2064 let mut request_parts = request_line.split_whitespace();
2065 let method = request_parts.next().expect("request method").to_string();
2066 let path = request_parts.next().expect("request path").to_string();
2067
2068 let mut headers = HashMap::new();
2069 let mut content_length = 0usize;
2070 for line in lines {
2071 if let Some((name, value)) = line.split_once(':') {
2072 let key = name.trim().to_ascii_lowercase();
2073 let value = value.trim().to_string();
2074 if key == "content-length" {
2075 content_length = value.parse::<usize>().expect("content-length");
2076 }
2077 headers.insert(key, value);
2078 }
2079 }
2080
2081 let mut body = buffer[header_end..].to_vec();
2082 while body.len() < content_length {
2083 let read = stream.read(&mut chunk).expect("read mock request body");
2084 assert!(read > 0, "mock request closed before body completed");
2085 body.extend_from_slice(&chunk[..read]);
2086 }
2087
2088 MockRequest {
2089 method,
2090 path,
2091 headers,
2092 body,
2093 }
2094 }
2095
2096 fn write_http_response(stream: &mut TcpStream, response: MockResponse) {
2097 let reason = match response.status_code {
2098 200 => "OK",
2099 202 => "Accepted",
2100 400 => "Bad Request",
2101 401 => "Unauthorized",
2102 403 => "Forbidden",
2103 404 => "Not Found",
2104 _ => "OK",
2105 };
2106 let mut head = format!(
2107 "HTTP/1.1 {} {}\r\nContent-Length: {}\r\nConnection: close\r\n",
2108 response.status_code,
2109 reason,
2110 response.body.len()
2111 );
2112 for (name, value) in response.headers {
2113 head.push_str(&format!("{name}: {value}\r\n"));
2114 }
2115 head.push_str("\r\n");
2116
2117 stream
2118 .write_all(head.as_bytes())
2119 .expect("write mock response head");
2120 stream
2121 .write_all(&response.body)
2122 .expect("write mock response body");
2123 }
2124
2125 fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
2126 haystack
2127 .windows(needle.len())
2128 .position(|window| window == needle)
2129 }
2130
2131 #[cfg(feature = "ocr-mistral")]
2132 #[test]
2133 fn ocr_mistral_name_is_stable() {
2134 let backend = MistralOcrBackend::new("test-key");
2135 assert_eq!(backend.name(), "mistral");
2136 }
2137
2138 #[cfg(feature = "ocr-mistral")]
2139 #[test]
2140 fn ocr_mistral_from_env_reads_api_key() {
2141 let _env = env_lock().lock().expect("env lock");
2142 let _scoped = ScopedEnv::set(&[("MISTRAL_API_KEY", Some("test-key"))]);
2143
2144 let backend = MistralOcrBackend::from_env().expect("Mistral backend from env");
2145 assert_eq!(backend.name(), "mistral");
2146 }
2147
2148 #[cfg(feature = "ocr-mistral")]
2149 #[test]
2150 fn ocr_mistral_encodes_rgb_image_as_png_data_url() {
2151 let rgb = vec![
2152 255, 255, 255, 0, 0, 0, 0, 0, 0, 255, 255, 255,
2154 ];
2155 let data_url = rgb_data_url(&rgb, 2, 2).expect("PNG data URL");
2156 assert!(data_url.starts_with("data:image/png;base64,"));
2157 assert!(data_url.len() > "data:image/png;base64,".len());
2158 }
2159
2160 #[cfg(feature = "ocr-mistral")]
2161 #[test]
2162 fn ocr_mistral_collects_markdown_pages() {
2163 let response = MistralOcrResponse {
2164 pages: vec![
2165 MistralOcrPage {
2166 markdown: Some(" Hello ".into()),
2167 },
2168 MistralOcrPage {
2169 markdown: Some("World".into()),
2170 },
2171 MistralOcrPage { markdown: None },
2172 ],
2173 };
2174
2175 assert_eq!(mistral_markdown_text(&response), "Hello\n\nWorld");
2176 }
2177
2178 #[cfg(feature = "ocr-google")]
2179 #[test]
2180 fn ocr_google_name_is_stable() {
2181 let backend = GoogleVisionBackend::from_api_key("test-key");
2182 assert_eq!(backend.name(), "google-vision");
2183 }
2184
2185 #[cfg(feature = "ocr-google")]
2186 #[test]
2187 fn ocr_google_from_env_prefers_api_key() {
2188 let _env = env_lock().lock().expect("env lock");
2189 let temp_path = std::env::temp_dir().join("codex-google-creds.json");
2190 let temp_path_string = temp_path.to_string_lossy().into_owned();
2191 std::fs::write(
2192 &temp_path,
2193 r#"{"type":"service_account","client_email":"test@example.com","private_key":"missing"}"#,
2194 )
2195 .expect("write google credentials fixture");
2196 let _scoped = ScopedEnv::set(&[
2197 ("GOOGLE_VISION_API_KEY", Some("test-key")),
2198 (
2199 "GOOGLE_APPLICATION_CREDENTIALS",
2200 Some(temp_path_string.as_str()),
2201 ),
2202 ("GOOGLE_CLOUD_PROJECT", None),
2203 ]);
2204
2205 let backend = GoogleVisionBackend::from_env().expect("Google backend from env");
2206 assert_eq!(backend.name(), "google-vision");
2207 }
2208
2209 #[cfg(feature = "ocr-google")]
2210 #[test]
2211 fn ocr_google_mock_request_uses_api_key_and_parses_words() {
2212 let server = MockServer::start(1, |request, _, _| {
2213 assert_eq!(request.method, "POST");
2214 assert!(request.path.starts_with("/v1/images:annotate?key=test-key"));
2215 assert_eq!(
2216 request.headers.get("content-type").map(String::as_str),
2217 Some("application/json")
2218 );
2219
2220 let body: serde_json::Value =
2221 serde_json::from_slice(&request.body).expect("google request JSON");
2222 assert_eq!(
2223 body["requests"][0]["features"][0]["type"].as_str(),
2224 Some("TEXT_DETECTION")
2225 );
2226
2227 let content = body["requests"][0]["image"]["content"]
2228 .as_str()
2229 .expect("google image content");
2230 let png = base64::engine::general_purpose::STANDARD
2231 .decode(content)
2232 .expect("decode google image");
2233 assert!(png.starts_with(&[0x89, b'P', b'N', b'G']));
2234
2235 MockResponse::json(
2236 200,
2237 r#"{
2238 "responses": [{
2239 "fullTextAnnotation": { "text": "Hello World" },
2240 "textAnnotations": [
2241 { "description": "Hello World" },
2242 {
2243 "description": "Hello",
2244 "boundingPoly": {
2245 "vertices": [
2246 { "x": 1, "y": 2 },
2247 { "x": 41, "y": 2 },
2248 { "x": 41, "y": 14 },
2249 { "x": 1, "y": 14 }
2250 ]
2251 }
2252 },
2253 {
2254 "description": "World",
2255 "boundingPoly": {
2256 "vertices": [
2257 { "x": 50, "y": 2 },
2258 { "x": 92, "y": 2 },
2259 { "x": 92, "y": 14 },
2260 { "x": 50, "y": 14 }
2261 ]
2262 }
2263 }
2264 ]
2265 }]
2266 }"#,
2267 )
2268 });
2269
2270 let backend = GoogleVisionBackend::with_endpoint_api_key(
2271 "test-key",
2272 &server.url("/v1/images:annotate"),
2273 );
2274 let (pixels, width, height) = test_text_image("TEST 123");
2275 let result = backend
2276 .recognize(&pixels, width, height)
2277 .expect("Google OCR");
2278
2279 assert_eq!(result.text, "Hello World");
2280 assert_eq!(result.words.len(), 2);
2281 assert_eq!(result.words[0].text, "Hello");
2282 assert_eq!(result.words[0].bbox, [1.0, 2.0, 40.0, 12.0]);
2283 assert_eq!(result.words[1].text, "World");
2284
2285 server.finish();
2286 }
2287
2288 #[cfg(feature = "ocr-google")]
2289 #[test]
2290 fn ocr_google_service_account_flow_uses_jwt_exchange() {
2291 let server = MockServer::start(2, |request, index, _| match index {
2292 0 => {
2293 assert_eq!(request.method, "POST");
2294 assert_eq!(request.path, "/token");
2295 let body = String::from_utf8(request.body).expect("token body");
2296 assert!(body
2297 .contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer"));
2298 assert!(body.contains("assertion="));
2299
2300 MockResponse::json(200, r#"{"access_token":"ya29.test-token"}"#)
2301 }
2302 1 => {
2303 assert_eq!(request.method, "POST");
2304 assert_eq!(request.path, "/v1/images:annotate");
2305 assert_eq!(
2306 request.headers.get("authorization").map(String::as_str),
2307 Some("Bearer ya29.test-token")
2308 );
2309
2310 MockResponse::json(
2311 200,
2312 r#"{
2313 "responses": [{
2314 "fullTextAnnotation": { "text": "Service Account" },
2315 "textAnnotations": [
2316 { "description": "Service Account" },
2317 {
2318 "description": "Service",
2319 "boundingPoly": {
2320 "vertices": [
2321 { "x": 0, "y": 0 },
2322 { "x": 60, "y": 0 },
2323 { "x": 60, "y": 10 },
2324 { "x": 0, "y": 10 }
2325 ]
2326 }
2327 }
2328 ]
2329 }]
2330 }"#,
2331 )
2332 }
2333 _ => unreachable!(),
2334 });
2335
2336 let credentials_json = format!(
2337 r#"{{
2338 "type": "service_account",
2339 "client_email": "test@example.com",
2340 "private_key": {private_key:?},
2341 "token_uri": {token_uri:?}
2342 }}"#,
2343 private_key = "-----BEGIN PRIVATE KEY-----\nTEST_PLACEHOLDER_NOT_A_REAL_KEY\n-----END PRIVATE KEY-----\n",
2344 token_uri = server.url("/token"),
2345 );
2346
2347 let backend = GoogleVisionBackend::with_endpoint_credentials(
2348 &credentials_json,
2349 &server.url("/v1/images:annotate"),
2350 );
2351 let (pixels, width, height) = test_text_image("TEST");
2352 let result = backend
2353 .recognize(&pixels, width, height)
2354 .expect("Google OCR via service account");
2355
2356 assert_eq!(result.text, "Service Account");
2357 assert_eq!(result.words.len(), 1);
2358 assert_eq!(result.words[0].text, "Service");
2359
2360 server.finish();
2361 }
2362
2363 #[cfg(feature = "ocr-aws")]
2364 #[test]
2365 fn ocr_aws_name_is_stable() {
2366 let backend = AwsTextractBackend::new("eu-west-1", "access", "secret");
2367 assert_eq!(backend.name(), "aws-textract");
2368 }
2369
2370 #[cfg(feature = "ocr-aws")]
2371 #[test]
2372 fn ocr_aws_from_env_reads_region_and_keys() {
2373 let _env = env_lock().lock().expect("env lock");
2374 let _scoped = ScopedEnv::set(&[
2375 ("AWS_REGION", Some("eu-west-1")),
2376 ("AWS_DEFAULT_REGION", None),
2377 ("AWS_ACCESS_KEY_ID", Some("access")),
2378 ("AWS_SECRET_ACCESS_KEY", Some("secret")),
2379 ("AWS_SESSION_TOKEN", None),
2380 ("AWS_PROFILE", None),
2381 ]);
2382
2383 let backend = AwsTextractBackend::from_env().expect("AWS backend from env");
2384 assert_eq!(backend.name(), "aws-textract");
2385 }
2386
2387 #[cfg(feature = "ocr-aws")]
2388 #[test]
2389 fn ocr_aws_mock_request_signs_sigv4_and_parses_lines() {
2390 let server = MockServer::start(1, |request, _, _| {
2391 assert_eq!(request.method, "POST");
2392 assert_eq!(request.path, "/");
2393 assert_eq!(
2394 request.headers.get("content-type").map(String::as_str),
2395 Some("application/x-amz-json-1.1")
2396 );
2397 assert_eq!(
2398 request.headers.get("x-amz-target").map(String::as_str),
2399 Some("Textract.DetectDocumentText")
2400 );
2401 let authorization = request
2402 .headers
2403 .get("authorization")
2404 .expect("aws authorization");
2405 assert!(authorization.starts_with("AWS4-HMAC-SHA256 "));
2406
2407 let body: serde_json::Value =
2408 serde_json::from_slice(&request.body).expect("aws request JSON");
2409 let encoded = body["Document"]["Bytes"]
2410 .as_str()
2411 .expect("aws document bytes");
2412 let png = base64::engine::general_purpose::STANDARD
2413 .decode(encoded)
2414 .expect("decode aws image");
2415 assert!(png.starts_with(&[0x89, b'P', b'N', b'G']));
2416
2417 MockResponse::json(
2418 200,
2419 r#"{
2420 "Blocks": [
2421 {
2422 "BlockType": "LINE",
2423 "Text": "Hello Textract",
2424 "Confidence": 97.5,
2425 "Geometry": {
2426 "BoundingBox": {
2427 "Left": 0.1,
2428 "Top": 0.2,
2429 "Width": 0.5,
2430 "Height": 0.1
2431 }
2432 }
2433 }
2434 ]
2435 }"#,
2436 )
2437 });
2438
2439 let backend =
2440 AwsTextractBackend::with_endpoint("eu-west-1", "access", "secret", &server.url("/"));
2441 let (pixels, width, height) = test_text_image("TEST");
2442 let result = backend.recognize(&pixels, width, height).expect("AWS OCR");
2443
2444 assert_eq!(result.text, "Hello Textract");
2445 assert_eq!(result.words.len(), 1);
2446 assert_eq!(result.words[0].text, "Hello Textract");
2447 assert!((result.words[0].bbox[0] - (width as f32 * 0.1)).abs() < 0.001);
2448 assert!((result.words[0].confidence - 0.975).abs() < 0.0001);
2449
2450 server.finish();
2451 }
2452
2453 #[cfg(feature = "ocr-azure")]
2454 #[test]
2455 fn ocr_azure_name_is_stable() {
2456 let backend =
2457 AzureDocIntelBackend::new("https://example.cognitiveservices.azure.com", "key");
2458 assert_eq!(backend.name(), "azure-doc-intel");
2459 }
2460
2461 #[cfg(feature = "ocr-azure")]
2462 #[test]
2463 fn ocr_azure_from_env_reads_endpoint_and_key() {
2464 let _env = env_lock().lock().expect("env lock");
2465 let _scoped = ScopedEnv::set(&[
2466 (
2467 "AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT",
2468 Some("https://example.cognitiveservices.azure.com"),
2469 ),
2470 ("AZURE_DOCUMENT_INTELLIGENCE_KEY", Some("test-key")),
2471 ]);
2472
2473 let backend = AzureDocIntelBackend::from_env().expect("Azure backend from env");
2474 assert_eq!(backend.name(), "azure-doc-intel");
2475 }
2476
2477 #[cfg(feature = "ocr-azure")]
2478 #[test]
2479 fn ocr_azure_mock_request_polls_operation_and_parses_words() {
2480 let server = MockServer::start(2, |request, index, base_url| match index {
2481 0 => {
2482 assert_eq!(request.method, "POST");
2483 assert!(request
2484 .path
2485 .starts_with("/documentintelligence/documentModels/prebuilt-read:analyze?"));
2486 assert_eq!(
2487 request
2488 .headers
2489 .get("ocp-apim-subscription-key")
2490 .map(String::as_str),
2491 Some("test-key")
2492 );
2493 assert_eq!(
2494 request.headers.get("content-type").map(String::as_str),
2495 Some("application/octet-stream")
2496 );
2497 assert!(request.body.starts_with(&[0x89, b'P', b'N', b'G']));
2498
2499 MockResponse::empty(
2500 202,
2501 &[("Operation-Location", &format!("{base_url}/operations/123"))],
2502 )
2503 }
2504 1 => {
2505 assert_eq!(request.method, "GET");
2506 assert_eq!(request.path, "/operations/123");
2507
2508 MockResponse::json(
2509 200,
2510 r#"{
2511 "status": "succeeded",
2512 "analyzeResult": {
2513 "content": "Hello Azure",
2514 "pages": [
2515 {
2516 "words": [
2517 {
2518 "content": "Hello",
2519 "confidence": 0.98,
2520 "polygon": [1,2, 31,2, 31,12, 1,12]
2521 },
2522 {
2523 "content": "Azure",
2524 "confidence": 0.96,
2525 "polygon": [40,2, 78,2, 78,12, 40,12]
2526 }
2527 ]
2528 }
2529 ]
2530 }
2531 }"#,
2532 )
2533 }
2534 _ => unreachable!(),
2535 });
2536
2537 let backend = AzureDocIntelBackend::new(&server.base_url, "test-key");
2538 let (pixels, width, height) = test_text_image("TEST");
2539 let result = backend
2540 .recognize(&pixels, width, height)
2541 .expect("Azure OCR");
2542
2543 assert_eq!(result.text, "Hello Azure");
2544 assert_eq!(result.words.len(), 2);
2545 assert_eq!(result.words[0].bbox, [1.0, 2.0, 30.0, 10.0]);
2546 assert!((result.confidence - 0.97).abs() < 0.001);
2547
2548 server.finish();
2549 }
2550
2551 #[cfg(any(
2552 feature = "ocr-mistral",
2553 feature = "ocr-google",
2554 feature = "ocr-azure",
2555 feature = "ocr-aws"
2556 ))]
2557 #[test]
2558 fn ocr_best_available_respects_cloud_priority() {
2559 let _env = env_lock().lock().expect("env lock");
2560
2561 #[cfg(feature = "ocr-mistral")]
2562 {
2563 let _scoped = ScopedEnv::set(&[
2564 ("MISTRAL_API_KEY", Some("mistral-key")),
2565 ("GOOGLE_VISION_API_KEY", Some("google-key")),
2566 (
2567 "AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT",
2568 Some("https://example.azure.com"),
2569 ),
2570 ("AZURE_DOCUMENT_INTELLIGENCE_KEY", Some("azure-key")),
2571 ("AWS_REGION", Some("eu-west-1")),
2572 ("AWS_ACCESS_KEY_ID", Some("aws-access")),
2573 ("AWS_SECRET_ACCESS_KEY", Some("aws-secret")),
2574 ]);
2575
2576 let backend = best_available_backend().expect("best OCR backend");
2577 assert_eq!(backend.name(), "mistral");
2578 return;
2579 }
2580
2581 #[cfg(all(not(feature = "ocr-mistral"), feature = "ocr-google"))]
2582 {
2583 let _scoped = ScopedEnv::set(&[
2584 ("GOOGLE_VISION_API_KEY", Some("google-key")),
2585 (
2586 "AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT",
2587 Some("https://example.azure.com"),
2588 ),
2589 ("AZURE_DOCUMENT_INTELLIGENCE_KEY", Some("azure-key")),
2590 ("AWS_REGION", Some("eu-west-1")),
2591 ("AWS_ACCESS_KEY_ID", Some("aws-access")),
2592 ("AWS_SECRET_ACCESS_KEY", Some("aws-secret")),
2593 ]);
2594
2595 let backend = best_available_backend().expect("best OCR backend");
2596 assert_eq!(backend.name(), "google-vision");
2597 return;
2598 }
2599
2600 #[cfg(all(
2601 not(feature = "ocr-mistral"),
2602 not(feature = "ocr-google"),
2603 feature = "ocr-azure"
2604 ))]
2605 {
2606 let _scoped = ScopedEnv::set(&[
2607 (
2608 "AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT",
2609 Some("https://example.azure.com"),
2610 ),
2611 ("AZURE_DOCUMENT_INTELLIGENCE_KEY", Some("azure-key")),
2612 ("AWS_REGION", Some("eu-west-1")),
2613 ("AWS_ACCESS_KEY_ID", Some("aws-access")),
2614 ("AWS_SECRET_ACCESS_KEY", Some("aws-secret")),
2615 ]);
2616
2617 let backend = best_available_backend().expect("best OCR backend");
2618 assert_eq!(backend.name(), "azure-doc-intel");
2619 return;
2620 }
2621
2622 #[cfg(all(
2623 not(feature = "ocr-mistral"),
2624 not(feature = "ocr-google"),
2625 not(feature = "ocr-azure"),
2626 feature = "ocr-aws"
2627 ))]
2628 {
2629 let _scoped = ScopedEnv::set(&[
2630 ("AWS_REGION", Some("eu-west-1")),
2631 ("AWS_ACCESS_KEY_ID", Some("aws-access")),
2632 ("AWS_SECRET_ACCESS_KEY", Some("aws-secret")),
2633 ]);
2634
2635 let backend = best_available_backend().expect("best OCR backend");
2636 assert_eq!(backend.name(), "aws-textract");
2637 }
2638 }
2639
2640 fn test_text_image(text: &str) -> (Vec<u8>, u32, u32) {
2641 let scale = 8usize;
2642 let glyph_width = 5usize;
2643 let glyph_height = 7usize;
2644 let spacing = 2usize;
2645 let margin = 12usize;
2646 let width = margin * 2
2647 + text
2648 .chars()
2649 .map(|ch| match ch {
2650 ' ' => scale * 3,
2651 _ => glyph_width * scale + spacing * scale,
2652 })
2653 .sum::<usize>();
2654 let height = margin * 2 + glyph_height * scale;
2655 let mut pixels = vec![255u8; width * height * 3];
2656 let mut cursor_x = margin;
2657
2658 for ch in text.chars() {
2659 if ch == ' ' {
2660 cursor_x += scale * 3;
2661 continue;
2662 }
2663
2664 draw_glyph(
2665 &mut pixels,
2666 width,
2667 cursor_x,
2668 margin,
2669 scale,
2670 glyph_pattern(ch),
2671 );
2672 cursor_x += glyph_width * scale + spacing * scale;
2673 }
2674
2675 (pixels, width as u32, height as u32)
2676 }
2677
2678 fn draw_glyph(
2679 pixels: &mut [u8],
2680 image_width: usize,
2681 offset_x: usize,
2682 offset_y: usize,
2683 scale: usize,
2684 glyph: [&str; 7],
2685 ) {
2686 for (row, pattern) in glyph.into_iter().enumerate() {
2687 for (col, bit) in pattern.bytes().enumerate() {
2688 if bit != b'#' {
2689 continue;
2690 }
2691
2692 for dy in 0..scale {
2693 for dx in 0..scale {
2694 let x = offset_x + col * scale + dx;
2695 let y = offset_y + row * scale + dy;
2696 let idx = (y * image_width + x) * 3;
2697 pixels[idx] = 0;
2698 pixels[idx + 1] = 0;
2699 pixels[idx + 2] = 0;
2700 }
2701 }
2702 }
2703 }
2704 }
2705
2706 fn glyph_pattern(ch: char) -> [&'static str; 7] {
2707 match ch {
2708 '1' => [
2709 "..#..", ".##..", "..#..", "..#..", "..#..", "..#..", ".###.",
2710 ],
2711 '2' => [
2712 ".###.", "#...#", "....#", "...#.", "..#..", ".#...", "#####",
2713 ],
2714 '3' => [
2715 ".###.", "#...#", "....#", "..##.", "....#", "#...#", ".###.",
2716 ],
2717 'E' => [
2718 "#####", "#....", "#....", "####.", "#....", "#....", "#####",
2719 ],
2720 'S' => [
2721 ".####", "#....", "#....", ".###.", "....#", "....#", "####.",
2722 ],
2723 'T' => [
2724 "#####", "..#..", "..#..", "..#..", "..#..", "..#..", "..#..",
2725 ],
2726 _ => panic!("unsupported glyph: {ch}"),
2727 }
2728 }
2729}