use std::path::{Path, PathBuf};
use tract_onnx::prelude::*;
use crate::browser::Tab;
use crate::util::base64_decode;
use crate::{Error, Result};
const CHARSET_JSON: &str = include_str!("assets/charset.json");
const MODEL_URL: &str = "https://raw.githubusercontent.com/86maid/ddddocr/master/model/common.onnx";
fn terr(e: impl std::fmt::Display) -> Error {
Error::msg(format!("OCR: {e}"))
}
pub struct Ocr {
model: InferenceModel,
charset: Vec<String>,
}
impl Ocr {
pub async fn new() -> Result<Self> {
let path = ensure_model().await?;
Self::from_model_path(&path)
}
pub fn from_model_path(onnx: &Path) -> Result<Self> {
let charset = parse_charset(CHARSET_JSON)?;
let model = tract_onnx::onnx().model_for_path(onnx).map_err(terr)?;
Ok(Self { model, charset })
}
pub fn recognize(&self, image: &[u8]) -> Result<String> {
let (data, w) = preprocess(image)?;
let runnable = self
.model
.clone()
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec![1, 1, 64, w]),
)
.map_err(terr)?
.into_optimized()
.map_err(terr)?
.into_runnable()
.map_err(terr)?;
let input =
tract_ndarray::Array4::<f32>::from_shape_vec((1, 1, 64, w), data).map_err(terr)?;
let out = runnable
.run(tvec![Tensor::from(input).into()])
.map_err(terr)?;
let t = out[0].clone().into_tensor();
let view = t.to_plain_array_view::<f32>().map_err(terr)?;
Ok(ctc_decode(&view, &self.charset))
}
}
fn parse_charset(s: &str) -> Result<Vec<String>> {
let v: serde_json::Value = serde_json::from_str(s).map_err(terr)?;
let arr = v["charset"]
.as_array()
.ok_or_else(|| Error::msg("OCR: charset 缺失"))?;
Ok(arr
.iter()
.map(|x| x.as_str().unwrap_or("").to_string())
.collect())
}
fn preprocess(bytes: &[u8]) -> Result<(Vec<f32>, usize)> {
let img = image::load_from_memory(bytes).map_err(terr)?;
let (w, h) = (img.width(), img.height());
if w == 0 || h == 0 {
return Err(Error::msg("OCR: 空图"));
}
let new_w = ((w as f32) * 64.0 / (h as f32)).round().max(1.0) as usize;
let luma = img
.resize_exact(new_w as u32, 64, image::imageops::FilterType::Lanczos3)
.to_luma8();
let mut data = Vec::with_capacity(64 * new_w);
for y in 0..64u32 {
for x in 0..new_w as u32 {
data.push((luma.get_pixel(x, y)[0] as f32 / 255.0 - 0.5) / 0.5);
}
}
Ok((data, new_w))
}
fn ctc_decode(view: &tract_ndarray::ArrayViewD<f32>, charset: &[String]) -> String {
let shape = view.shape();
let c = charset.len();
let cls_axis = shape
.iter()
.position(|&d| d == c)
.unwrap_or(shape.len() - 1);
let t_axis = (0..shape.len())
.find(|&a| a != cls_axis && shape[a] > 1)
.unwrap_or(0);
let tn = shape[t_axis];
let mut out = String::new();
let mut prev = usize::MAX;
let mut idx = vec![0usize; shape.len()];
for t in 0..tn {
let mut best = 0usize;
let mut bestv = f32::MIN;
idx[t_axis] = t;
for k in 0..c {
idx[cls_axis] = k;
let v = view[idx.as_slice()];
if v > bestv {
bestv = v;
best = k;
}
}
if best != 0
&& best != prev
&& let Some(ch) = charset.get(best)
{
out.push_str(ch);
}
prev = best;
}
out
}
async fn ensure_model() -> Result<PathBuf> {
if let Ok(p) = std::env::var("DRISSION_OCR_MODEL") {
let p = PathBuf::from(p);
if p.exists() {
return Ok(p);
}
return Err(Error::msg(format!(
"OCR: DRISSION_OCR_MODEL 路径不存在: {}",
p.display()
)));
}
let dir = dirs::cache_dir()
.unwrap_or_else(std::env::temp_dir)
.join("drission")
.join("ocr");
std::fs::create_dir_all(&dir).map_err(terr)?;
let path = dir.join("ddddocr_common.onnx");
if path.exists()
&& std::fs::metadata(&path)
.map(|m| m.len() > 1_000_000)
.unwrap_or(false)
{
return Ok(path);
}
let url = std::env::var("DRISSION_OCR_MODEL_URL").unwrap_or_else(|_| MODEL_URL.to_string());
tracing::info!(target: "drission::ocr", "下载 OCR 模型(~54MB,仅首次): {url}");
let bytes = reqwest::get(&url)
.await
.map_err(terr)?
.bytes()
.await
.map_err(terr)?;
if bytes.len() < 1_000_000 {
return Err(Error::msg(format!(
"OCR: 模型下载异常({} bytes)",
bytes.len()
)));
}
let tmp = path.with_extension("onnx.part");
std::fs::write(&tmp, &bytes).map_err(terr)?;
std::fs::rename(&tmp, &path).map_err(terr)?;
Ok(path)
}
static DEFAULT_OCR: tokio::sync::OnceCell<Ocr> = tokio::sync::OnceCell::const_new();
impl Tab {
pub async fn ocr_image(&self, selector: &str) -> Result<String> {
let ocr = DEFAULT_OCR.get_or_try_init(Ocr::new).await?;
let bytes = self.fetch_image_bytes(selector).await?;
ocr.recognize(&bytes)
}
async fn fetch_image_bytes(&self, selector: &str) -> Result<Vec<u8>> {
let el = self.ele(selector).await?;
if let Ok(src) = el.run_js("return node.currentSrc||node.src||'';").await
&& let Some(s) = src.as_str()
&& let Some(i) = s.find("base64,")
&& let Some(b) = base64_decode(&s[i + 7..])
&& !b.is_empty()
{
return Ok(b);
}
el.screenshot_bytes().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn charset_loads_and_blank_first() {
let cs = parse_charset(CHARSET_JSON).unwrap();
assert!(cs.len() > 1000);
assert_eq!(cs[0], ""); assert!(cs.iter().any(|c| c == "5") && cs.iter().any(|c| c == "z"));
}
#[test]
fn ctc_collapses_repeats_and_blanks() {
let charset = vec![
"".to_string(),
"a".to_string(),
"b".to_string(),
"c".to_string(),
];
let seq = [1usize, 1, 0, 2, 2];
let mut arr = tract_ndarray::Array3::<f32>::zeros((seq.len(), 1, charset.len()));
for (t, &k) in seq.iter().enumerate() {
arr[[t, 0, k]] = 1.0;
}
let dynv = arr.into_dyn();
assert_eq!(ctc_decode(&dynv.view(), &charset), "ab");
}
}