use std::cell::RefCell;
use spdf_types::{SpdfError, SpdfResult};
use tesseract::PageSegMode;
use crate::engine::{OcrEngine, OcrOptions, OcrResult};
thread_local! {
static TESS_CACHE: RefCell<Option<CacheEntry>> = const { RefCell::new(None) };
}
struct CacheEntry {
key: (Option<String>, String),
tess: Option<tesseract::Tesseract>,
}
#[derive(Debug)]
pub struct TesseractEngine {
datapath: Option<String>,
}
impl TesseractEngine {
pub fn new(datapath: Option<String>) -> Self {
Self { datapath }
}
fn normalize_language(code: &str) -> String {
match code.to_ascii_lowercase().as_str() {
"en" | "en-us" | "en-gb" => "eng".into(),
"fr" => "fra".into(),
"de" => "deu".into(),
"es" => "spa".into(),
"it" => "ita".into(),
"pt" => "por".into(),
"ru" => "rus".into(),
"nl" => "nld".into(),
"pl" => "pol".into(),
"tr" => "tur".into(),
"vi" => "vie".into(),
"zh" | "zh-cn" => "chi_sim".into(),
"zh-tw" => "chi_tra".into(),
"ja" => "jpn".into(),
"ko" => "kor".into(),
"ar" => "ara".into(),
_ => code.to_string(),
}
}
fn joined_language(langs: &[String]) -> String {
langs
.iter()
.map(|c| Self::normalize_language(c))
.collect::<Vec<_>>()
.join("+")
}
}
impl OcrEngine for TesseractEngine {
fn name(&self) -> &'static str {
"tesseract"
}
fn recognize(&self, image: &[u8], options: &OcrOptions) -> SpdfResult<Vec<OcrResult>> {
if options.languages.is_empty() {
return Err(SpdfError::Ocr(
"tesseract: no language codes provided".into(),
));
}
let lang = Self::joined_language(&options.languages);
let tsv = TESS_CACHE.with(|cell| -> SpdfResult<String> {
let mut slot = cell.borrow_mut();
let key = (self.datapath.clone(), lang.clone());
let needs_new = match slot.as_ref() {
Some(e) => e.key != key,
None => true,
};
if needs_new {
*slot = None;
let mut tess = tesseract::Tesseract::new(self.datapath.as_deref(), Some(&lang))
.map_err(|e| SpdfError::Ocr(format!("tesseract init ({lang}): {e}")))?;
tess.set_page_seg_mode(PageSegMode::PsmAuto);
*slot = Some(CacheEntry {
key,
tess: Some(tess),
});
}
let entry = slot.as_mut().expect("initialised above");
let tess = entry.tess.take().expect("tesseract instance present");
let tess = if let Some(dpi) = options.dpi {
tess.set_variable("user_defined_dpi", &dpi.to_string())
.map_err(|e| SpdfError::Ocr(format!("tesseract set user_defined_dpi: {e}")))?
} else {
tess
};
let tess = tess
.set_image_from_mem(image)
.map_err(|e| SpdfError::Ocr(format!("tesseract set_image: {e}")))?;
let mut tess = tess
.recognize()
.map_err(|e| SpdfError::Ocr(format!("tesseract recognize: {e}")))?;
let tsv = tess
.get_tsv_text(0)
.map_err(|e| SpdfError::Ocr(format!("tesseract get_tsv_text: {e}")))?;
entry.tess = Some(tess);
Ok(tsv)
})?;
Ok(parse_tsv(&tsv))
}
}
fn parse_tsv(tsv: &str) -> Vec<OcrResult> {
let mut out = Vec::new();
for line in tsv.lines() {
let cols: Vec<&str> = line.split('\t').collect();
if cols.len() < 12 {
continue;
}
let level: i32 = match cols[0].parse() {
Ok(v) => v,
Err(_) => continue, };
if level != 5 {
continue;
}
let left: f64 = cols[6].parse().unwrap_or(0.0);
let top: f64 = cols[7].parse().unwrap_or(0.0);
let width: f64 = cols[8].parse().unwrap_or(0.0);
let height: f64 = cols[9].parse().unwrap_or(0.0);
let conf: f64 = cols[10].parse().unwrap_or(-1.0);
if conf < 0.0 || width <= 0.0 || height <= 0.0 {
continue;
}
let text = cols[11].trim();
if text.is_empty() {
continue;
}
out.push(OcrResult {
text: text.to_string(),
bbox: [left, top, left + width, top + height],
confidence: conf / 100.0,
});
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_tsv_keeps_only_word_level_rows() {
let tsv = "\
level\tpage_num\tblock_num\tpar_num\tline_num\tword_num\tleft\ttop\twidth\theight\tconf\ttext
1\t1\t0\t0\t0\t0\t0\t0\t100\t100\t-1\t
2\t1\t1\t0\t0\t0\t0\t0\t100\t100\t-1\t
5\t1\t1\t1\t1\t1\t10\t20\t30\t15\t92\tHello
5\t1\t1\t1\t1\t2\t50\t20\t40\t15\t88\tworld
5\t1\t1\t1\t1\t3\t0\t0\t0\t0\t-1\tnoise
5\t1\t1\t1\t1\t4\t100\t20\t20\t15\t10\t
";
let out = parse_tsv(tsv);
assert_eq!(out.len(), 2, "only real words survive, got {:?}", out);
assert_eq!(out[0].text, "Hello");
assert_eq!(out[0].bbox, [10.0, 20.0, 40.0, 35.0]);
assert!((out[0].confidence - 0.92).abs() < 1e-9);
assert_eq!(out[1].text, "world");
}
#[test]
fn normalize_language_maps_common_codes() {
assert_eq!(TesseractEngine::normalize_language("en"), "eng");
assert_eq!(TesseractEngine::normalize_language("EN-US"), "eng");
assert_eq!(TesseractEngine::normalize_language("zh"), "chi_sim");
assert_eq!(TesseractEngine::normalize_language("eng"), "eng");
assert_eq!(TesseractEngine::normalize_language("klingon"), "klingon");
}
}