use std::{io::Cursor, pin::Pin, time::Duration};
use image::ImageFormat;
use reqwest::{
Client,
multipart::{Form, Part},
};
use serde::{Deserialize, Serialize};
use crate::ocr::{OcrEngine, OcrOptions, OcrResult};
#[derive(Debug, Serialize, Deserialize)]
pub struct HttpOcrResponseItem {
text: String,
bbox: [f32; 4],
confidence: f32,
#[serde(default)]
polygon: Option<[[f32; 2]; 4]>,
}
impl HttpOcrResponseItem {
fn into_ocr_result(self) -> OcrResult {
OcrResult {
text: self.text,
bbox: self.bbox,
confidence: self.confidence,
polygon: self.polygon,
}
}
}
#[derive(Debug, Deserialize)]
struct ProdOcrItem(Vec<[f32; 2]>, String, f32);
impl ProdOcrItem {
fn into_ocr_result(self) -> OcrResult {
let ProdOcrItem(poly, text, confidence) = self;
let mut min_x = f32::INFINITY;
let mut min_y = f32::INFINITY;
let mut max_x = f32::NEG_INFINITY;
let mut max_y = f32::NEG_INFINITY;
for [x, y] in &poly {
min_x = min_x.min(*x);
min_y = min_y.min(*y);
max_x = max_x.max(*x);
max_y = max_y.max(*y);
}
let polygon = match poly.as_slice() {
[a, b, c, d] => Some([*a, *b, *c, *d]),
_ => None,
};
OcrResult {
text,
bbox: [min_x, min_y, max_x, max_y],
confidence,
polygon,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum HttpOcrResponse {
Standard { results: Vec<HttpOcrResponseItem> },
Prod { result: Vec<ProdOcrItem> },
}
impl HttpOcrResponse {
fn into_results(self) -> Vec<OcrResult> {
match self {
HttpOcrResponse::Standard { results } => {
results.into_iter().map(|i| i.into_ocr_result()).collect()
}
HttpOcrResponse::Prod { result } => {
result.into_iter().map(|i| i.into_ocr_result()).collect()
}
}
}
}
pub struct HttpOcrEngine {
pub name: String,
server_url: String,
headers: Vec<(String, String)>,
retry: OcrRetryConfig,
}
#[derive(Debug, Clone)]
pub struct OcrRetryConfig {
pub max_attempts: u32,
pub base_backoff_ms: u64,
pub max_backoff_ms: u64,
pub jitter_ms: u64,
pub fast_retry_ms: u64,
pub request_timeout_ms: u64,
pub hedge_delays_ms: Vec<u64>,
}
impl Default for OcrRetryConfig {
fn default() -> Self {
Self {
max_attempts: 10,
base_backoff_ms: 1000,
max_backoff_ms: 10_000,
jitter_ms: 500,
fast_retry_ms: 500,
request_timeout_ms: 60_000,
hedge_delays_ms: Vec::new(),
}
}
}
fn jitter_ms(max: u64) -> u64 {
if max == 0 {
return 0;
}
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| u64::from(d.subsec_nanos()) % (max + 1))
.unwrap_or(0)
}
fn is_connection_drop(err: &reqwest::Error) -> bool {
let mut source: Option<&(dyn std::error::Error + 'static)> = Some(err);
while let Some(e) = source {
let msg = e.to_string().to_ascii_lowercase();
if msg.contains("connection reset")
|| msg.contains("hang up")
|| msg.contains("broken pipe")
|| msg.contains("connection closed")
|| msg.contains("incompletemessage")
{
return true;
}
source = e.source();
}
false
}
async fn send_one(
client: &Client,
url: &str,
headers: &[(String, String)],
png_bytes: &[u8],
language: &str,
timeout_ms: u64,
) -> Result<String, reqwest::Error> {
let form = Form::new()
.part(
"file",
Part::bytes(png_bytes.to_vec())
.file_name("image.png")
.mime_str("image/png")?,
)
.text("language", language.to_string());
let mut request = client
.post(url)
.multipart(form)
.timeout(Duration::from_millis(timeout_ms));
for (name, value) in headers {
request = request.header(name.as_str(), value.as_str());
}
match request.send().await.and_then(|r| r.error_for_status()) {
Ok(resp) => resp.text().await,
Err(e) => Err(e),
}
}
fn is_retryable(err: &reqwest::Error) -> bool {
if err.is_timeout() || err.is_connect() {
return true;
}
if let Some(status) = err.status() {
return matches!(status.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504);
}
err.is_body() || err.is_request()
}
impl HttpOcrEngine {
pub fn new(server_url: String) -> Self {
Self::with_headers(server_url, Vec::new())
}
pub fn with_headers(server_url: String, headers: Vec<(String, String)>) -> Self {
Self {
name: "http-ocr".to_string(),
server_url,
headers,
retry: OcrRetryConfig::default(),
}
}
pub fn with_retry(mut self, retry: OcrRetryConfig) -> Self {
self.retry = retry;
self
}
async fn send_hedged(
&self,
client: &Client,
png_bytes: &[u8],
language: &str,
) -> Result<String, reqwest::Error> {
let delays = &self.retry.hedge_delays_ms;
let timeout_ms = self.retry.request_timeout_ms;
if delays.len() <= 1 {
if let Some(&d) = delays.first().filter(|&&d| d > 0) {
tokio::time::sleep(Duration::from_millis(d)).await;
}
return send_one(
client,
&self.server_url,
&self.headers,
png_bytes,
language,
timeout_ms,
)
.await;
}
let (tx, mut rx) = tokio::sync::mpsc::channel(delays.len());
let mut handles = Vec::with_capacity(delays.len());
for &delay in delays {
let tx = tx.clone();
let client = client.clone();
let url = self.server_url.clone();
let headers = self.headers.clone();
let png = png_bytes.to_vec();
let lang = language.to_string();
handles.push(tokio::spawn(async move {
if delay > 0 {
tokio::time::sleep(Duration::from_millis(delay)).await;
}
let res = send_one(&client, &url, &headers, &png, &lang, timeout_ms).await;
let _ = tx.send(res).await;
}));
}
drop(tx);
let mut last_err: Option<reqwest::Error> = None;
while let Some(res) = rx.recv().await {
match res {
Ok(body) => {
for h in &handles {
h.abort();
}
return Ok(body);
}
Err(e) => last_err = Some(e),
}
}
Err(last_err.expect("hedge group always yields at least one result"))
}
}
impl OcrEngine for HttpOcrEngine {
fn name(&self) -> &str {
&self.name
}
fn recognize<'a, 'b: 'a, 'c: 'a>(
&'a self,
image_data: &'c [u8],
width: u32,
height: u32,
options: &'b OcrOptions,
) -> Pin<
Box<
dyn Future<Output = Result<Vec<OcrResult>, Box<dyn std::error::Error + Send + Sync>>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let img: image::RgbImage =
image::ImageBuffer::from_raw(width, height, image_data.to_vec())
.ok_or("failed to create image buffer from raw RGB data")?;
let mut png_bytes = Vec::new();
img.write_to(&mut Cursor::new(&mut png_bytes), ImageFormat::Png)?;
let client = Client::new();
let max_attempts = self.retry.max_attempts.max(1);
let mut attempt: u32 = 0;
let raw = loop {
attempt += 1;
match self
.send_hedged(&client, &png_bytes, &options.language)
.await
{
Ok(body) => break body,
Err(e) => {
if attempt >= max_attempts || !is_retryable(&e) {
return Err(e.into());
}
let base = if is_connection_drop(&e) {
self.retry.fast_retry_ms
} else {
(self
.retry
.base_backoff_ms
.saturating_mul(2u64.saturating_pow(attempt - 1)))
.min(self.retry.max_backoff_ms)
};
let delay = base + jitter_ms(self.retry.jitter_ms);
if std::env::var("LITEPARSE_DEBUG_OCR").is_ok() {
eprintln!(
"[ocr-http] attempt {attempt}/{max_attempts} failed ({e}); retrying in {delay}ms"
);
}
tokio::time::sleep(Duration::from_millis(delay)).await;
}
}
};
let response: HttpOcrResponse = serde_json::from_str(&raw).map_err(|e| {
let snippet: String = raw.chars().take(200).collect();
format!("OCR server returned unparseable response: {e}; body starts: {snippet}")
})?;
let results = response.into_results();
if std::env::var("LITEPARSE_DEBUG_OCR").is_ok() {
eprintln!(
"[ocr-http] {} bytes -> {} result(s)",
raw.len(),
results.len()
);
}
Ok(results)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_sets_name_and_url() {
let e = HttpOcrEngine::new("http://example.com/ocr".into());
assert_eq!(e.name(), "http-ocr");
assert_eq!(e.server_url, "http://example.com/ocr");
}
#[test]
fn test_response_deserializes() {
let raw = r#"{"results":[{"text":"hi","bbox":[1.0,2.0,3.0,4.0],"confidence":0.85}]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
let results = parsed.into_results();
assert_eq!(results.len(), 1);
assert_eq!(results[0].text, "hi");
assert_eq!(results[0].bbox, [1.0, 2.0, 3.0, 4.0]);
assert!((results[0].confidence - 0.85).abs() < 1e-6);
}
#[test]
fn test_response_deserializes_empty() {
let raw = r#"{"results":[]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
assert!(parsed.into_results().is_empty());
}
#[test]
fn test_prod_response_deserializes() {
let raw = r#"{"document_angle":-90,"result":[[[[10.0,20.0],[60.0,20.0],[60.0,40.0],[10.0,40.0]],"hi",0.85]]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
let results = parsed.into_results();
assert_eq!(results.len(), 1);
assert_eq!(results[0].text, "hi");
assert_eq!(results[0].bbox, [10.0, 20.0, 60.0, 40.0]);
assert!((results[0].confidence - 0.85).abs() < 1e-6);
assert_eq!(
results[0].polygon,
Some([[10.0, 20.0], [60.0, 20.0], [60.0, 40.0], [10.0, 40.0]])
);
}
#[test]
fn test_prod_response_empty() {
let raw = r#"{"document_angle":null,"result":[]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
assert!(parsed.into_results().is_empty());
}
#[tokio::test]
async fn test_recognize_network_error() {
let e = HttpOcrEngine::new("http://127.0.0.1:1/ocr".into()).with_retry(OcrRetryConfig {
max_attempts: 1,
..Default::default()
});
let opts = OcrOptions {
language: "eng".into(),
dpi: 150.0,
};
let r = e.recognize(&[0u8; 4], 1, 1, &opts).await;
assert!(r.is_err());
}
#[test]
fn test_default_retry_matches_worker_parity() {
let c = OcrRetryConfig::default();
assert_eq!(c.max_attempts, 10);
assert_eq!(c.base_backoff_ms, 1000);
assert_eq!(c.max_backoff_ms, 10_000);
}
#[test]
fn test_jitter_within_bounds() {
for _ in 0..1000 {
assert!(jitter_ms(500) <= 500);
}
assert_eq!(jitter_ms(0), 0);
}
#[test]
fn test_default_has_no_hedging() {
assert!(OcrRetryConfig::default().hedge_delays_ms.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_hedged_all_fail_returns_error() {
let e = HttpOcrEngine::new("http://127.0.0.1:1/ocr".into()).with_retry(OcrRetryConfig {
max_attempts: 1,
hedge_delays_ms: vec![0, 10],
..Default::default()
});
let opts = OcrOptions {
language: "eng".into(),
dpi: 150.0,
};
let r = e.recognize(&[0u8; 4], 1, 1, &opts).await;
assert!(r.is_err());
}
}