use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use super::{OcrEngine, default_engine};
pub const MAX_IMAGES_PER_PAGE: usize = 10;
const MIN_ALT_TEXT_LEN: usize = 20;
const CACHE_TTL_SECS: u64 = 30 * 24 * 60 * 60;
pub struct FetchOcrEnricher {
engine: Arc<dyn OcrEngine>,
cache_dir: PathBuf,
max_per_page: usize,
}
impl Default for FetchOcrEnricher {
fn default() -> Self {
Self::with_max(MAX_IMAGES_PER_PAGE)
}
}
impl FetchOcrEnricher {
pub fn new() -> Result<Self> {
let cache_dir = default_cache_dir()?;
std::fs::create_dir_all(&cache_dir)
.with_context(|| format!("create OCR cache dir {}", cache_dir.display()))?;
Ok(Self {
engine: Arc::from(default_engine()),
cache_dir,
max_per_page: MAX_IMAGES_PER_PAGE,
})
}
pub fn with_max(max_per_page: usize) -> Self {
let cache_dir = default_cache_dir().unwrap_or_else(|_| PathBuf::from("/tmp/nab-ocr-cache"));
Self {
engine: Arc::from(default_engine()),
cache_dir,
max_per_page,
}
}
pub fn with_engine_and_cache(
engine: Arc<dyn OcrEngine>,
cache_dir: PathBuf,
max_per_page: usize,
) -> Self {
Self {
engine,
cache_dir,
max_per_page,
}
}
pub fn is_available(&self) -> bool {
self.engine.is_available()
}
pub async fn enrich_images(
&self,
html: &str,
base_url: &str,
http_client: &reqwest::Client,
) -> HashMap<String, String> {
let candidates = extract_image_candidates(html, base_url);
let mut results = HashMap::new();
for url in candidates.into_iter().take(self.max_per_page) {
match self.ocr_url(&url, http_client).await {
Ok(Some(text)) if !text.trim().is_empty() => {
results.insert(url, text.trim().to_string());
}
Ok(_) => {}
Err(e) => {
tracing::debug!(url = %url, "OCR skipped: {e}");
}
}
}
results
}
pub fn annotate_markdown(
&self,
markdown: &str,
ocr_results: &HashMap<String, String>,
) -> String {
if ocr_results.is_empty() {
return markdown.to_string();
}
annotate_markdown_images(markdown, ocr_results)
}
async fn ocr_url(&self, url: &str, http_client: &reqwest::Client) -> Result<Option<String>> {
let bytes = http_client
.get(url)
.send()
.await
.with_context(|| format!("fetch image {url}"))?
.bytes()
.await
.with_context(|| format!("read image bytes {url}"))?;
if bytes.is_empty() {
return Ok(None);
}
let hash = hex_sha256(&bytes);
let cache_path = self.cache_dir.join(format!("{hash}.txt"));
if let Some(cached) = read_cache(&cache_path) {
return Ok(Some(cached));
}
let result = self
.engine
.ocr_image(&bytes)
.await
.with_context(|| format!("OCR failed for {url}"))?;
let text = result.text;
if let Err(e) = std::fs::write(&cache_path, &text) {
tracing::debug!("OCR cache write failed for {hash}: {e}");
}
Ok(Some(text))
}
}
fn extract_image_candidates(html: &str, base_url: &str) -> Vec<String> {
use scraper::{Html, Selector};
let doc = Html::parse_document(html);
let Ok(sel) = Selector::parse("img") else {
return vec![];
};
let base = url::Url::parse(base_url).ok();
doc.select(&sel)
.filter_map(|el| {
let alt = el.value().attr("alt").unwrap_or("");
if alt.len() >= MIN_ALT_TEXT_LEN {
return None;
}
let src = el.value().attr("src")?;
if src.starts_with("data:") {
return None;
}
let resolved = resolve_url(src, base.as_ref())?;
Some(resolved)
})
.collect()
}
fn resolve_url(src: &str, base: Option<&url::Url>) -> Option<String> {
if src.starts_with("http://") || src.starts_with("https://") {
return Some(src.to_string());
}
let base = base?;
base.join(src).ok().map(|u| u.to_string())
}
fn annotate_markdown_images(markdown: &str, ocr_results: &HashMap<String, String>) -> String {
let mut output = String::with_capacity(markdown.len() + ocr_results.len() * 40);
let chars: Vec<char> = markdown.chars().collect();
let n = chars.len();
let mut i = 0;
while i < n {
if i + 1 < n
&& chars[i] == '!'
&& chars[i + 1] == '['
&& let Some((end, url)) = parse_markdown_image(&chars, i)
{
output.push_str(&markdown[char_byte_offset(&chars, i)..char_byte_offset(&chars, end)]);
if let Some(text) = ocr_results.get(&url) {
let clean = text.replace('\n', " ");
let _ = write!(output, " [Image: {clean}]");
}
i = end;
continue;
}
output.push(chars[i]);
i += 1;
}
output
}
fn parse_markdown_image(chars: &[char], start: usize) -> Option<(usize, String)> {
let n = chars.len();
let mut i = start + 2;
let mut depth = 1usize;
while i < n && depth > 0 {
match chars[i] {
'[' => depth += 1,
']' => depth -= 1,
_ => {}
}
i += 1;
}
if i >= n || chars[i] != '(' {
return None;
}
i += 1;
let url_start = i;
while i < n && chars[i] != ')' && !chars[i].is_whitespace() {
i += 1;
}
if i > url_start && i < n {
let url: String = chars[url_start..i].iter().collect();
while i < n && chars[i] != ')' {
i += 1;
}
if i < n {
i += 1; }
return Some((i, url));
}
None
}
fn char_byte_offset(chars: &[char], char_idx: usize) -> usize {
chars[..char_idx].iter().map(|c| c.len_utf8()).sum()
}
fn hex_sha256(bytes: &[u8]) -> String {
let digest = Sha256::digest(bytes);
hex::encode(digest)
}
fn default_cache_dir() -> Result<PathBuf> {
let base = dirs::data_local_dir()
.or_else(dirs::home_dir)
.ok_or_else(|| anyhow::anyhow!("cannot determine home directory"))?;
Ok(base.join("nab/cache/ocr"))
}
fn read_cache(path: &std::path::Path) -> Option<String> {
let meta = std::fs::metadata(path).ok()?;
let mtime = meta.modified().ok()?;
let age = SystemTime::now().duration_since(mtime).ok()?;
if age > Duration::from_secs(CACHE_TTL_SECS) {
return None;
}
std::fs::read_to_string(path).ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_image_candidates_skips_good_alt_text() {
let html = r#"<img src="photo.jpg" alt="A landscape photo of mountains">"#;
let candidates = extract_image_candidates(html, "https://example.com/page");
assert!(candidates.is_empty(), "should skip well-described image");
}
#[test]
fn extract_image_candidates_includes_thin_alt_text() {
let html = r#"<img src="chart.png" alt="">"#;
let candidates = extract_image_candidates(html, "https://example.com/page");
assert_eq!(candidates.len(), 1);
assert!(candidates[0].contains("chart.png"));
}
#[test]
fn enrich_images_respects_max_per_page() {
let html: String = (1..=5)
.map(|i| format!(r#"<img src="img{i}.png" alt="">"#))
.collect::<Vec<_>>()
.join("\n");
let candidates = extract_image_candidates(&html, "https://example.com/");
let capped: Vec<_> = candidates.into_iter().take(3).collect();
assert_eq!(capped.len(), 3);
}
#[test]
fn annotate_markdown_inserts_ocr_annotation() {
let markdown = "# Title\n\n\n\nSome text.";
let mut ocr = HashMap::new();
ocr.insert(
"https://example.com/chart.png".to_string(),
"Q3 Revenue: $42M".to_string(),
);
let enricher = FetchOcrEnricher::with_max(10);
let result = enricher.annotate_markdown(markdown, &ocr);
assert!(
result.contains("[Image: Q3 Revenue: $42M]"),
"annotation missing in: {result}"
);
assert!(result.contains("# Title"), "original content preserved");
}
#[test]
fn annotate_markdown_leaves_no_match_unchanged() {
let markdown = "";
let ocr: HashMap<String, String> = HashMap::new();
let enricher = FetchOcrEnricher::with_max(10);
let result = enricher.annotate_markdown(markdown, &ocr);
assert_eq!(result, markdown);
}
#[test]
fn annotate_markdown_collapses_newlines_in_ocr_text() {
let markdown = "";
let mut ocr = HashMap::new();
ocr.insert(
"https://example.com/img.png".to_string(),
"Line one\nLine two".to_string(),
);
let enricher = FetchOcrEnricher::with_max(10);
let result = enricher.annotate_markdown(markdown, &ocr);
assert!(
result.contains("[Image: Line one Line two]"),
"got: {result}"
);
}
#[test]
fn read_cache_returns_none_for_stale_file() {
use std::io::Write;
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("stale.txt");
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(b"cached text").expect("write");
let old_time = SystemTime::now() - Duration::from_secs(31 * 24 * 60 * 60);
f.set_modified(old_time).expect("set_modified");
drop(f);
let result = read_cache(&path);
assert!(
result.is_none(),
"expected None for stale cache, got: {result:?}"
);
}
#[test]
fn read_cache_returns_content_for_fresh_file() {
use std::io::Write;
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("fresh.txt");
let mut f = std::fs::File::create(&path).expect("create");
f.write_all(b"recognized text").expect("write");
drop(f);
let result = read_cache(&path);
assert_eq!(result.as_deref(), Some("recognized text"));
}
}
use std::fmt::Write as _;