use std::io::Read;
use std::path::PathBuf;
use std::time::Duration;
use sha2::{Digest, Sha256};
use crate::error::{Error, Result};
const DETECTION_URL: &str = "https://ocrs-models.s3-accelerate.amazonaws.com/text-detection.rten";
const RECOGNITION_URL: &str =
"https://ocrs-models.s3-accelerate.amazonaws.com/text-recognition.rten";
const DETECTION_FILENAME: &str = "text-detection.rten";
const RECOGNITION_FILENAME: &str = "text-recognition.rten";
const DETECTION_ENV: &str = "WAYDRIVER_OCRS_DETECTION_MODEL";
const RECOGNITION_ENV: &str = "WAYDRIVER_OCRS_RECOGNITION_MODEL";
const DETECTION_SHA256: &str = "f15cfb56bd02c4bf478a20343986504a1f01e1665c2b3a0ad66340f054b1b5ca";
const RECOGNITION_SHA256: &str = "e484866d4cce403175bd8d00b128feb08ab42e208de30e42cd9889d8f1735a6e";
pub(super) fn ensure_models() -> Result<(PathBuf, PathBuf)> {
if let (Ok(det), Ok(rec)) = (std::env::var(DETECTION_ENV), std::env::var(RECOGNITION_ENV)) {
tracing::debug!(
detection = %det, recognition = %rec,
"visual: using ocrs model paths from env vars (SHA-256 verification skipped)"
);
return Ok((PathBuf::from(det), PathBuf::from(rec)));
}
let cache_dir = waydriver_cache_dir()?;
let detection_path = cache_dir.join(DETECTION_FILENAME);
let recognition_path = cache_dir.join(RECOGNITION_FILENAME);
ensure_one_model(&cache_dir, &detection_path, DETECTION_URL, DETECTION_SHA256)?;
ensure_one_model(
&cache_dir,
&recognition_path,
RECOGNITION_URL,
RECOGNITION_SHA256,
)?;
Ok((detection_path, recognition_path))
}
fn ensure_one_model(
cache_dir: &std::path::Path,
dest: &std::path::Path,
url: &str,
expected_sha256: &str,
) -> Result<()> {
if dest.exists() {
match verify_sha256(dest, expected_sha256) {
Ok(()) => return Ok(()),
Err(e) => {
tracing::warn!(
path = %dest.display(),
err = %e,
"visual: cached model failed SHA-256 verification, re-downloading"
);
let _ = std::fs::remove_file(dest);
}
}
}
std::fs::create_dir_all(cache_dir).map_err(|e| {
Error::visual(format!(
"failed to create cache dir {}: {e}",
cache_dir.display()
))
})?;
download_to(url, dest, expected_sha256)
}
pub(super) fn verify_sha256(path: &std::path::Path, expected_hex: &str) -> Result<()> {
let mut f = std::fs::File::open(path).map_err(|e| {
Error::visual(format!(
"failed to open {} for SHA-256 verification: {e}",
path.display()
))
})?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = f.read(&mut buf).map_err(|e| {
Error::visual(format!(
"read error during SHA-256 of {}: {e}",
path.display()
))
})?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let got = hex::encode(hasher.finalize());
if got.eq_ignore_ascii_case(expected_hex) {
Ok(())
} else {
Err(Error::visual(format!(
"SHA-256 mismatch for {}: expected {}, got {}. If upstream rebuilt the ocrs model files, \
bump DETECTION_SHA256 / RECOGNITION_SHA256 in models.rs, or set \
WAYDRIVER_OCRS_DETECTION_MODEL / WAYDRIVER_OCRS_RECOGNITION_MODEL to point at known-good files.",
path.display(),
expected_hex,
got,
)))
}
}
fn waydriver_cache_dir() -> Result<PathBuf> {
let base = if let Ok(xdg) = std::env::var("XDG_CACHE_HOME") {
PathBuf::from(xdg)
} else if let Ok(home) = std::env::var("HOME") {
PathBuf::from(home).join(".cache")
} else {
return Err(Error::visual(
"neither XDG_CACHE_HOME nor HOME is set; cannot locate cache directory \
for ocrs model files. Set WAYDRIVER_OCRS_DETECTION_MODEL and \
WAYDRIVER_OCRS_RECOGNITION_MODEL to override.",
));
};
Ok(base.join("waydriver").join("ocrs-models"))
}
fn download_to(url: &str, dest_path: &std::path::Path, expected_hex: &str) -> Result<()> {
let partial = dest_path.with_extension("rten.partial");
tracing::info!(%url, dest = %dest_path.display(), "visual: downloading ocrs model");
let started = std::time::Instant::now();
let agent = ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(300)))
.build()
.new_agent();
let response = agent
.get(url)
.call()
.map_err(|e| Error::visual(format!("failed to fetch {url}: {e}")))?;
let mut out = std::fs::File::create(&partial)
.map_err(|e| Error::visual(format!("failed to create {}: {e}", partial.display())))?;
let mut reader = response.into_body().into_reader();
let mut buf = [0u8; 64 * 1024];
let mut total = 0u64;
let mut hasher = Sha256::new();
loop {
let n = reader
.read(&mut buf)
.map_err(|e| Error::visual(format!("read error from {url}: {e}")))?;
if n == 0 {
break;
}
use std::io::Write;
out.write_all(&buf[..n])
.map_err(|e| Error::visual(format!("write error: {e}")))?;
hasher.update(&buf[..n]);
total += n as u64;
}
drop(out);
let got = hex::encode(hasher.finalize());
if !got.eq_ignore_ascii_case(expected_hex) {
let _ = std::fs::remove_file(&partial);
return Err(Error::visual(format!(
"downloaded {url} but SHA-256 didn't match: expected {expected_hex}, got {got}. \
If upstream rebuilt the model files, bump the SHA-256 constants in \
crates/waydriver/src/visual/models.rs."
)));
}
std::fs::rename(&partial, dest_path).map_err(|e| {
Error::visual(format!(
"failed to rename {} → {}: {e}",
partial.display(),
dest_path.display()
))
})?;
tracing::info!(
bytes = total,
elapsed_ms = started.elapsed().as_millis(),
dest = %dest_path.display(),
"visual: ocrs model download complete"
);
Ok(())
}