use std::{io::Cursor, pin::Pin, time::Duration};
use image::ImageFormat;
use reqwest::{
Client,
multipart::{Form, Part},
};
use serde::{Deserialize, Serialize};
use crate::ocr::{OcrEngine, OcrOptions, OcrResult};
#[derive(Debug, Serialize, Deserialize)]
pub struct HttpOcrResponseItem {
text: String,
bbox: [f32; 4],
confidence: f32,
#[serde(default)]
polygon: Option<[[f32; 2]; 4]>,
}
impl HttpOcrResponseItem {
fn into_ocr_result(self) -> OcrResult {
OcrResult {
text: self.text,
bbox: self.bbox,
confidence: self.confidence,
polygon: self.polygon,
}
}
}
#[derive(Debug, Deserialize)]
struct ProdOcrItem(Vec<[f32; 2]>, String, f32);
impl ProdOcrItem {
fn into_ocr_result(self) -> OcrResult {
let ProdOcrItem(poly, text, confidence) = self;
let mut min_x = f32::INFINITY;
let mut min_y = f32::INFINITY;
let mut max_x = f32::NEG_INFINITY;
let mut max_y = f32::NEG_INFINITY;
for [x, y] in &poly {
min_x = min_x.min(*x);
min_y = min_y.min(*y);
max_x = max_x.max(*x);
max_y = max_y.max(*y);
}
let polygon = match poly.as_slice() {
[a, b, c, d] => Some([*a, *b, *c, *d]),
_ => None,
};
OcrResult {
text,
bbox: [min_x, min_y, max_x, max_y],
confidence,
polygon,
}
}
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum HttpOcrResponse {
Standard { results: Vec<HttpOcrResponseItem> },
Prod { result: Vec<ProdOcrItem> },
}
impl HttpOcrResponse {
fn into_results(self) -> Vec<OcrResult> {
match self {
HttpOcrResponse::Standard { results } => {
results.into_iter().map(|i| i.into_ocr_result()).collect()
}
HttpOcrResponse::Prod { result } => {
result.into_iter().map(|i| i.into_ocr_result()).collect()
}
}
}
}
pub struct HttpOcrEngine {
pub name: String,
server_url: String,
headers: Vec<(String, String)>,
}
impl HttpOcrEngine {
pub fn new(server_url: String) -> Self {
Self::with_headers(server_url, Vec::new())
}
pub fn with_headers(server_url: String, headers: Vec<(String, String)>) -> Self {
Self {
name: "http-ocr".to_string(),
server_url,
headers,
}
}
}
impl OcrEngine for HttpOcrEngine {
fn name(&self) -> &str {
&self.name
}
fn recognize<'a, 'b: 'a, 'c: 'a>(
&'a self,
image_data: &'c [u8],
width: u32,
height: u32,
options: &'b OcrOptions,
) -> Pin<
Box<
dyn Future<Output = Result<Vec<OcrResult>, Box<dyn std::error::Error + Send + Sync>>>
+ Send
+ '_,
>,
> {
Box::pin(async move {
let img: image::RgbImage =
image::ImageBuffer::from_raw(width, height, image_data.to_vec())
.ok_or("failed to create image buffer from raw RGB data")?;
let mut png_bytes = Vec::new();
img.write_to(&mut Cursor::new(&mut png_bytes), ImageFormat::Png)?;
let client = Client::new();
let form = Form::new()
.part(
"file",
Part::bytes(png_bytes)
.file_name("image.png")
.mime_str("image/png")?,
)
.text("language", options.language.clone());
let mut request = client
.post(&self.server_url)
.multipart(form)
.timeout(Duration::from_millis(60000));
for (name, value) in &self.headers {
request = request.header(name.as_str(), value.as_str());
}
let raw = request.send().await?.error_for_status()?.text().await?;
let response: HttpOcrResponse = serde_json::from_str(&raw).map_err(|e| {
let snippet: String = raw.chars().take(200).collect();
format!("OCR server returned unparseable response: {e}; body starts: {snippet}")
})?;
let results = response.into_results();
if std::env::var("LITEPARSE_DEBUG_OCR").is_ok() {
eprintln!(
"[ocr-http] {} bytes -> {} result(s)",
raw.len(),
results.len()
);
}
Ok(results)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_sets_name_and_url() {
let e = HttpOcrEngine::new("http://example.com/ocr".into());
assert_eq!(e.name(), "http-ocr");
assert_eq!(e.server_url, "http://example.com/ocr");
}
#[test]
fn test_response_deserializes() {
let raw = r#"{"results":[{"text":"hi","bbox":[1.0,2.0,3.0,4.0],"confidence":0.85}]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
let results = parsed.into_results();
assert_eq!(results.len(), 1);
assert_eq!(results[0].text, "hi");
assert_eq!(results[0].bbox, [1.0, 2.0, 3.0, 4.0]);
assert!((results[0].confidence - 0.85).abs() < 1e-6);
}
#[test]
fn test_response_deserializes_empty() {
let raw = r#"{"results":[]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
assert!(parsed.into_results().is_empty());
}
#[test]
fn test_prod_response_deserializes() {
let raw = r#"{"document_angle":-90,"result":[[[[10.0,20.0],[60.0,20.0],[60.0,40.0],[10.0,40.0]],"hi",0.85]]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
let results = parsed.into_results();
assert_eq!(results.len(), 1);
assert_eq!(results[0].text, "hi");
assert_eq!(results[0].bbox, [10.0, 20.0, 60.0, 40.0]);
assert!((results[0].confidence - 0.85).abs() < 1e-6);
assert_eq!(
results[0].polygon,
Some([[10.0, 20.0], [60.0, 20.0], [60.0, 40.0], [10.0, 40.0]])
);
}
#[test]
fn test_prod_response_empty() {
let raw = r#"{"document_angle":null,"result":[]}"#;
let parsed: HttpOcrResponse = serde_json::from_str(raw).unwrap();
assert!(parsed.into_results().is_empty());
}
#[tokio::test]
async fn test_recognize_network_error() {
let e = HttpOcrEngine::new("http://127.0.0.1:1/ocr".into());
let opts = OcrOptions {
language: "eng".into(),
dpi: 150.0,
};
let r = e.recognize(&[0u8; 4], 1, 1, &opts).await;
assert!(r.is_err());
}
}