use std::path::{Path, PathBuf};
use tract_onnx::prelude::*;
#[cfg(feature = "camoufox")]
use crate::browser::Tab;
#[cfg(feature = "camoufox")]
use crate::util::base64_decode;
use crate::{Error, Result};
mod glyph;
pub use glyph::{GlyphMatcher, SampleBank};
const CHARSET_JSON: &str = include_str!("assets/charset.json");
const MODEL_URL: &str = "https://raw.githubusercontent.com/86maid/ddddocr/master/model/common.onnx";
fn terr(e: impl std::fmt::Display) -> Error {
Error::msg(format!("OCR: {e}"))
}
pub struct Ocr {
model: InferenceModel,
charset: Vec<String>,
}
impl Ocr {
pub async fn new() -> Result<Self> {
let path = ensure_model().await?;
let charset = match std::env::var("DRISSION_OCR_CHARSET") {
Ok(p) => load_charset_file(Path::new(&p))?,
Err(_) => parse_charset(CHARSET_JSON)?,
};
Self::from_model_path_with_charset(&path, charset)
}
pub fn from_model_path(onnx: &Path) -> Result<Self> {
Self::from_model_path_with_charset(onnx, parse_charset(CHARSET_JSON)?)
}
pub fn from_model_path_with_charset(onnx: &Path, charset: Vec<String>) -> Result<Self> {
if charset.len() < 2 {
return Err(Error::msg("OCR: 字符集过小(至少 blank + 1 字)"));
}
if charset.first().map(String::as_str) != Some("") {
tracing::warn!(target: "drission::ocr",
"OCR 自定义字符集首项不是 CTC blank(空串);若识别结果全乱,请在 charset[0] 处补一个空串");
}
let model = tract_onnx::onnx().model_for_path(onnx).map_err(terr)?;
Ok(Self { model, charset })
}
pub fn from_files(onnx: &Path, charset: &Path) -> Result<Self> {
Self::from_model_path_with_charset(onnx, load_charset_file(charset)?)
}
pub fn set_model(&mut self, onnx: &Path) -> Result<()> {
self.model = tract_onnx::onnx().model_for_path(onnx).map_err(terr)?;
Ok(())
}
pub fn set_model_with_charset(&mut self, onnx: &Path, charset: Vec<String>) -> Result<()> {
if charset.len() < 2 {
return Err(Error::msg("OCR: 字符集过小(至少 blank + 1 字)"));
}
self.model = tract_onnx::onnx().model_for_path(onnx).map_err(terr)?;
self.charset = charset;
Ok(())
}
pub fn charset_len(&self) -> usize {
self.charset.len()
}
pub async fn default_model_path() -> Result<PathBuf> {
ensure_model().await
}
pub fn recognize(&self, image: &[u8]) -> Result<String> {
let (data, w) = preprocess(image)?;
let runnable = self
.model
.clone()
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec![1, 1, 64, w]),
)
.map_err(terr)?
.into_optimized()
.map_err(terr)?
.into_runnable()
.map_err(terr)?;
let input =
tract_ndarray::Array4::<f32>::from_shape_vec((1, 1, 64, w), data).map_err(terr)?;
let out = runnable
.run(tvec![Tensor::from(input).into()])
.map_err(terr)?;
let t = out[0].clone().into_tensor();
let view = t.to_plain_array_view::<f32>().map_err(terr)?;
Ok(ctc_decode(&view, &self.charset))
}
pub fn char_affinity(&self, image: &[u8], chars: &[char]) -> Result<Vec<f32>> {
let (data, w) = preprocess(image)?;
let runnable = self
.model
.clone()
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec![1, 1, 64, w]),
)
.map_err(terr)?
.into_typed()
.map_err(terr)?
.into_runnable()
.map_err(terr)?;
let input =
tract_ndarray::Array4::<f32>::from_shape_vec((1, 1, 64, w), data).map_err(terr)?;
let out = runnable
.run(tvec![Tensor::from(input).into()])
.map_err(terr)?;
let t = out[0].clone().into_tensor();
let view = t.to_plain_array_view::<f32>().map_err(terr)?;
let shape = view.shape();
let c = self.charset.len();
let cls_axis = shape
.iter()
.position(|&d| d == c)
.unwrap_or(shape.len() - 1);
let t_axis = (0..shape.len())
.find(|&a| a != cls_axis && shape[a] > 1)
.unwrap_or(0);
let tn = shape[t_axis];
let idxs: Vec<Option<usize>> = chars
.iter()
.map(|ch| {
let s = ch.to_string();
self.charset.iter().position(|x| x == &s)
})
.collect();
let mut best = vec![0f32; chars.len()];
let mut idx = vec![0usize; shape.len()];
for ti in 0..tn {
idx[t_axis] = ti;
let mut maxl = f32::MIN;
for k in 0..c {
idx[cls_axis] = k;
let v = view[idx.as_slice()];
if v > maxl {
maxl = v;
}
}
let mut sum = 0f32;
for k in 0..c {
idx[cls_axis] = k;
sum += (view[idx.as_slice()] - maxl).exp();
}
if sum <= 0.0 {
continue;
}
for (j, oi) in idxs.iter().enumerate() {
if let Some(k) = oi {
idx[cls_axis] = *k;
let p = (view[idx.as_slice()] - maxl).exp() / sum;
if p > best[j] {
best[j] = p;
}
}
}
}
Ok(best)
}
}
#[derive(Debug, Clone, Copy)]
pub struct BBox {
pub x1: u32,
pub y1: u32,
pub x2: u32,
pub y2: u32,
pub score: f32,
}
impl BBox {
pub fn center(&self) -> (u32, u32) {
((self.x1 + self.x2) / 2, (self.y1 + self.y2) / 2)
}
pub fn width(&self) -> u32 {
self.x2.saturating_sub(self.x1)
}
pub fn height(&self) -> u32 {
self.y2.saturating_sub(self.y1)
}
}
const DET_SIZE: usize = 416;
const DET_STRIDES: [usize; 3] = [8, 16, 32];
const DET_SCORE_THR: f32 = 0.1;
const DET_NMS_THR: f32 = 0.45;
const DET_MODEL_URL: &str =
"https://raw.githubusercontent.com/86maid/ddddocr/master/model/common_det.onnx";
pub struct Det {
model: TypedModel,
}
impl Det {
pub async fn new() -> Result<Self> {
let path = ensure_det_model().await?;
Self::from_model_path(&path)
}
pub fn from_model_path(onnx: &Path) -> Result<Self> {
Ok(Self {
model: load_det_model(onnx)?,
})
}
pub fn set_model(&mut self, onnx: &Path) -> Result<()> {
self.model = load_det_model(onnx)?;
Ok(())
}
pub async fn default_model_path() -> Result<PathBuf> {
ensure_det_model().await
}
pub fn detect(&self, image: &[u8]) -> Result<Vec<BBox>> {
let img = image::load_from_memory(image).map_err(terr)?;
let (ow, oh) = (img.width(), img.height());
if ow == 0 || oh == 0 {
return Err(Error::msg("DET: 空图"));
}
let ratio = (DET_SIZE as f32 / ow as f32).min(DET_SIZE as f32 / oh as f32);
let rw = ((ow as f32 * ratio).round().max(1.0) as u32).min(DET_SIZE as u32);
let rh = ((oh as f32 * ratio).round().max(1.0) as u32).min(DET_SIZE as u32);
let resized = img
.resize_exact(rw, rh, image::imageops::FilterType::Triangle)
.to_rgb8();
let plane = DET_SIZE * DET_SIZE;
let mut data = vec![114f32; 3 * plane];
for y in 0..rh {
for x in 0..rw {
let p = resized.get_pixel(x, y);
let idx = y as usize * DET_SIZE + x as usize;
data[idx] = p[0] as f32;
data[plane + idx] = p[1] as f32;
data[2 * plane + idx] = p[2] as f32;
}
}
let input = tract_ndarray::Array4::<f32>::from_shape_vec((1, 3, DET_SIZE, DET_SIZE), data)
.map_err(terr)?;
let runnable = self.model.clone().into_runnable().map_err(terr)?;
let out = runnable
.run(tvec![Tensor::from(input).into()])
.map_err(terr)?;
let t = out[0].clone().into_tensor();
let view = t.to_plain_array_view::<f32>().map_err(terr)?;
let flat: Vec<f32> = view.iter().copied().collect();
Ok(decode_det(&flat, ratio, ow, oh))
}
}
fn load_det_model(onnx: &Path) -> Result<TypedModel> {
tract_onnx::onnx()
.model_for_path(onnx)
.map_err(terr)?
.with_input_fact(
0,
InferenceFact::dt_shape(f32::datum_type(), tvec![1, 3, DET_SIZE, DET_SIZE]),
)
.map_err(terr)?
.into_optimized()
.map_err(terr)
}
fn det_grids() -> Vec<(f32, f32, f32)> {
let mut g = Vec::new();
for &s in &DET_STRIDES {
let n = DET_SIZE / s;
for i in 0..n {
for j in 0..n {
g.push((j as f32, i as f32, s as f32));
}
}
}
g
}
fn decode_det(flat: &[f32], ratio: f32, ow: u32, oh: u32) -> Vec<BBox> {
let grids = det_grids();
let n = (flat.len() / 6).min(grids.len());
let (owf, ohf) = (ow as f32, oh as f32);
let mut cand: Vec<BBox> = Vec::new();
for (k, &(gx, gy, s)) in grids.iter().enumerate().take(n) {
let o = k * 6;
let score = flat[o + 4] * flat[o + 5];
if score < DET_SCORE_THR {
continue;
}
let cx = (flat[o] + gx) * s;
let cy = (flat[o + 1] + gy) * s;
let w = flat[o + 2].exp() * s;
let h = flat[o + 3].exp() * s;
let x1 = ((cx - w / 2.0) / ratio).clamp(0.0, owf - 1.0);
let y1 = ((cy - h / 2.0) / ratio).clamp(0.0, ohf - 1.0);
let x2 = ((cx + w / 2.0) / ratio).clamp(0.0, owf - 1.0);
let y2 = ((cy + h / 2.0) / ratio).clamp(0.0, ohf - 1.0);
if x2 > x1 && y2 > y1 {
cand.push(BBox {
x1: x1 as u32,
y1: y1 as u32,
x2: x2 as u32,
y2: y2 as u32,
score,
});
}
}
nms_boxes(cand, DET_NMS_THR)
}
fn nms_boxes(mut boxes: Vec<BBox>, thr: f32) -> Vec<BBox> {
boxes.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut keep: Vec<BBox> = Vec::new();
'outer: for b in boxes {
for k in &keep {
if box_iou(&b, k) > thr {
continue 'outer;
}
}
keep.push(b);
}
keep
}
fn box_iou(a: &BBox, b: &BBox) -> f32 {
let x1 = a.x1.max(b.x1);
let y1 = a.y1.max(b.y1);
let x2 = a.x2.min(b.x2);
let y2 = a.y2.min(b.y2);
if x2 <= x1 || y2 <= y1 {
return 0.0;
}
let inter = ((x2 - x1) as f32) * ((y2 - y1) as f32);
let aa = ((a.x2 - a.x1) as f32) * ((a.y2 - a.y1) as f32);
let ab = ((b.x2 - b.x1) as f32) * ((b.y2 - b.y1) as f32);
inter / (aa + ab - inter)
}
fn box_center_in(b: &BBox, region: &BBox) -> bool {
let (cx, cy) = b.center();
cx >= region.x1 && cx <= region.x2 && cy >= region.y1 && cy <= region.y2
}
async fn ensure_det_model() -> Result<PathBuf> {
if let Ok(p) = std::env::var("DRISSION_DET_MODEL") {
let p = PathBuf::from(p);
if p.exists() {
return Ok(p);
}
return Err(Error::msg(format!(
"DET: DRISSION_DET_MODEL 路径不存在: {}",
p.display()
)));
}
let dir = dirs::cache_dir()
.unwrap_or_else(std::env::temp_dir)
.join("drission")
.join("ocr");
std::fs::create_dir_all(&dir).map_err(terr)?;
let path = dir.join("ddddocr_common_det.onnx");
if path.exists()
&& std::fs::metadata(&path)
.map(|m| m.len() > 1_000_000)
.unwrap_or(false)
{
return Ok(path);
}
let url = std::env::var("DRISSION_DET_MODEL_URL").unwrap_or_else(|_| DET_MODEL_URL.to_string());
tracing::info!(target: "drission::ocr", "下载检测模型(仅首次): {url}");
let bytes = reqwest::get(&url)
.await
.map_err(terr)?
.bytes()
.await
.map_err(terr)?;
if bytes.len() < 1_000_000 {
return Err(Error::msg(format!(
"DET: 模型下载异常({} bytes)",
bytes.len()
)));
}
let tmp = path.with_extension("onnx.part");
std::fs::write(&tmp, &bytes).map_err(terr)?;
std::fs::rename(&tmp, &path).map_err(terr)?;
Ok(path)
}
#[derive(Debug, Clone, Copy)]
pub struct ClickHit {
pub target: char,
pub bbox: BBox,
pub point: (u32, u32),
pub affinity: f32,
pub template: Option<f32>,
}
pub struct ClickWord {
pub det: Det,
pub ocr: Ocr,
pub font: Option<GlyphMatcher>,
pub bank: Option<SampleBank>,
}
const TEMPLATE_WEIGHT: f32 = 1.5;
impl ClickWord {
pub async fn new() -> Result<Self> {
Ok(Self {
det: Det::new().await?,
ocr: Ocr::new().await?,
font: GlyphMatcher::from_system().ok(),
bank: load_sample_bank_env(),
})
}
pub fn from_models(det: Det, ocr: Ocr) -> Self {
Self {
det,
ocr,
font: GlyphMatcher::from_system().ok(),
bank: load_sample_bank_env(),
}
}
pub fn set_font(&mut self, font: Option<GlyphMatcher>) {
self.font = font;
}
pub fn set_sample_bank(&mut self, bank: Option<SampleBank>) {
self.bank = bank;
}
pub fn has_font(&self) -> bool {
self.font.is_some() || self.bank.is_some()
}
pub fn chars(&self, image: &[u8]) -> Result<Vec<(BBox, String)>> {
let img = image::load_from_memory(image).map_err(terr)?;
let (iw, ih) = (img.width(), img.height());
let mut out = Vec::new();
for b in self.det.detect(image)? {
let crop = crop_padded(&img, &b, iw, ih);
let mut buf = std::io::Cursor::new(Vec::new());
crop.write_to(&mut buf, image::ImageFormat::Png)
.map_err(terr)?;
let txt = self.ocr.recognize(buf.get_ref()).unwrap_or_default();
out.push((b, txt));
}
Ok(out)
}
pub fn crops(&self, image: &[u8]) -> Result<Vec<(BBox, Vec<u8>)>> {
let img = image::load_from_memory(image).map_err(terr)?;
let (iw, ih) = (img.width(), img.height());
let mut out = Vec::new();
for b in self.det.detect(image)? {
let crop = crop_padded(&img, &b, iw, ih);
let mut buf = std::io::Cursor::new(Vec::new());
crop.write_to(&mut buf, image::ImageFormat::Png)
.map_err(terr)?;
out.push((b, buf.into_inner()));
}
Ok(out)
}
pub fn solve(&self, image: &[u8], targets: &[String]) -> Result<Vec<ClickHit>> {
self.solve_excluding(image, targets, &[])
}
pub fn solve_excluding(
&self,
image: &[u8],
targets: &[String],
exclude: &[BBox],
) -> Result<Vec<ClickHit>> {
let chars: Vec<char> = targets
.iter()
.filter_map(|s| s.trim().chars().next())
.collect();
if chars.is_empty() {
return Ok(vec![]);
}
let img = image::load_from_memory(image).map_err(terr)?;
let (iw, ih) = (img.width(), img.height());
let boxes: Vec<BBox> = self
.det
.detect(image)?
.into_iter()
.filter(|b| !exclude.iter().any(|r| box_center_in(b, r)))
.collect();
let mut aff: Vec<Vec<f32>> = Vec::with_capacity(boxes.len());
let mut tpl: Vec<Vec<f32>> = Vec::with_capacity(boxes.len());
for b in &boxes {
let crop = crop_padded(&img, b, iw, ih);
let mut vecs: Vec<Vec<f32>> = Vec::new();
for v in glyph_variants(&crop) {
let mut buf = std::io::Cursor::new(Vec::new());
if v.write_to(&mut buf, image::ImageFormat::Png).is_err() {
continue;
}
if let Ok(a) = self.ocr.char_affinity(buf.get_ref(), &chars) {
vecs.push(a);
}
}
if vecs.is_empty() {
vecs.push(vec![0.0; chars.len()]);
}
let pick = select_by_margin(&vecs);
aff.push(vecs.swap_remove(pick));
tpl.push(if self.font.is_some() || self.bank.is_some() {
let cf = glyph::crop_feat(&crop);
chars
.iter()
.map(|&ch| {
if let Some(b) = self.bank.as_ref().filter(|b| b.has_char(ch)) {
b.similarity(&cf, ch)
} else if let Some(f) = &self.font {
f.similarity(&cf, ch)
} else {
0.0
}
})
.collect()
} else {
vec![0.0; chars.len()]
});
}
let combo: Vec<Vec<f32>> = aff
.iter()
.zip(&tpl)
.map(|(a, t)| {
a.iter()
.zip(t)
.map(|(&av, &tv)| av + TEMPLATE_WEIGHT * tv)
.collect()
})
.collect();
let assign = assign_optimal(&combo, chars.len());
let has_font = self.has_font();
let mut hits = Vec::new();
for (t, ch) in chars.iter().enumerate() {
if let Some(bi) = assign[t] {
let bbox = boxes[bi];
hits.push(ClickHit {
target: *ch,
bbox,
point: bbox.center(),
affinity: aff[bi].get(t).copied().unwrap_or(0.0),
template: has_font.then(|| tpl[bi].get(t).copied().unwrap_or(0.0)),
});
}
}
Ok(hits)
}
pub fn points_for(&self, image: &[u8], targets: &[String]) -> Result<Vec<(u32, u32)>> {
Ok(self
.solve(image, targets)?
.into_iter()
.map(|h| h.point)
.collect())
}
pub fn points_for_text(&self, image: &[u8], targets: &[String]) -> Result<Vec<(u32, u32)>> {
Ok(match_order(&self.chars(image)?, targets))
}
}
fn match_order(items: &[(BBox, String)], targets: &[String]) -> Vec<(u32, u32)> {
let mut used = vec![false; items.len()];
let mut pts = Vec::new();
for t in targets {
let t = t.trim();
if t.is_empty() {
continue;
}
let mut pick = items
.iter()
.enumerate()
.find(|(i, (_, s))| !used[*i] && s.trim() == t)
.map(|(i, _)| i);
if pick.is_none() {
pick = items
.iter()
.enumerate()
.find(|(i, (_, s))| {
let s = s.trim();
!used[*i] && !s.is_empty() && (s.contains(t) || t.contains(s))
})
.map(|(i, _)| i);
}
if let Some(i) = pick {
used[i] = true;
pts.push(items[i].0.center());
}
}
pts
}
const ASSIGN_BONUS: f32 = 1000.0;
fn assign_optimal(aff: &[Vec<f32>], n_targets: usize) -> Vec<Option<usize>> {
if n_targets == 0 {
return vec![];
}
let n_boxes = aff.len();
let col_max: Vec<f32> = (0..n_targets)
.map(|t| {
aff.iter()
.map(|r| r.get(t).copied().unwrap_or(0.0))
.fold(0.0f32, f32::max)
})
.collect();
let mut suffix = vec![0.0f32; n_targets + 1];
for t in (0..n_targets).rev() {
suffix[t] = suffix[t + 1] + ASSIGN_BONUS + col_max[t];
}
let mut used = vec![false; n_boxes];
let mut cur = vec![None; n_targets];
let mut best_assign = vec![None; n_targets];
let mut best_score = f32::MIN;
assign_dfs(
aff,
n_targets,
&suffix,
0,
0.0,
&mut used,
&mut cur,
&mut best_score,
&mut best_assign,
);
best_assign
}
#[allow(clippy::too_many_arguments)]
fn assign_dfs(
aff: &[Vec<f32>],
n_targets: usize,
suffix: &[f32],
t: usize,
score: f32,
used: &mut [bool],
cur: &mut [Option<usize>],
best_score: &mut f32,
best_assign: &mut [Option<usize>],
) {
if t == n_targets {
if score > *best_score {
*best_score = score;
best_assign.copy_from_slice(cur);
}
return;
}
if score + suffix[t] <= *best_score {
return;
}
for b in 0..used.len() {
if used[b] {
continue;
}
used[b] = true;
cur[t] = Some(b);
let a = aff[b].get(t).copied().unwrap_or(0.0);
assign_dfs(
aff,
n_targets,
suffix,
t + 1,
score + ASSIGN_BONUS + a,
used,
cur,
best_score,
best_assign,
);
cur[t] = None;
used[b] = false;
}
cur[t] = None;
assign_dfs(
aff,
n_targets,
suffix,
t + 1,
score,
used,
cur,
best_score,
best_assign,
);
}
fn load_sample_bank_env() -> Option<SampleBank> {
std::env::var("DRISSION_GLYPH_SAMPLES")
.ok()
.and_then(|d| SampleBank::from_dir(Path::new(&d)).ok())
}
fn crop_padded(img: &image::DynamicImage, b: &BBox, iw: u32, ih: u32) -> image::DynamicImage {
let pad = (b.width().max(b.height()) / 6).max(2);
let x = b.x1.saturating_sub(pad);
let y = b.y1.saturating_sub(pad);
let w = (b.x2 + pad)
.min(iw.saturating_sub(1))
.saturating_sub(x)
.max(1);
let h = (b.y2 + pad)
.min(ih.saturating_sub(1))
.saturating_sub(y)
.max(1);
img.crop_imm(x, y, w, h)
}
fn glyph_variants(crop: &image::DynamicImage) -> Vec<image::DynamicImage> {
vec![crop.clone(), autocontrast(crop), otsu_binarize(crop)]
}
fn autocontrast(img: &image::DynamicImage) -> image::DynamicImage {
let luma = img.to_luma8();
let (w, h) = (luma.width(), luma.height());
let mut hist = [0u32; 256];
for p in luma.pixels() {
hist[p[0] as usize] += 1;
}
let total = (w * h).max(1);
let cut = (total as f32 * 0.02) as u32;
let mut lo = 0u8;
let mut acc = 0u32;
for (i, &c) in hist.iter().enumerate() {
acc += c;
if acc > cut {
lo = i as u8;
break;
}
}
let mut hi = 255u8;
acc = 0;
for i in (0..256).rev() {
acc += hist[i];
if acc > cut {
hi = i as u8;
break;
}
}
if hi <= lo {
return image::DynamicImage::ImageLuma8(luma);
}
let span = (hi - lo) as f32;
let out = image::ImageBuffer::from_fn(w, h, |x, y| {
let v = luma.get_pixel(x, y)[0];
let nv = ((v.saturating_sub(lo) as f32 / span) * 255.0).clamp(0.0, 255.0) as u8;
image::Luma([nv])
});
image::DynamicImage::ImageLuma8(out)
}
fn otsu_binarize(img: &image::DynamicImage) -> image::DynamicImage {
let luma = img.to_luma8();
let (w, h) = (luma.width(), luma.height());
let n = (w * h).max(1) as f32;
let mut hist = [0u32; 256];
for p in luma.pixels() {
hist[p[0] as usize] += 1;
}
let sum: f32 = (0..256).map(|i| i as f32 * hist[i] as f32).sum();
let (mut sumb, mut wb, mut maxv, mut thr) = (0f32, 0f32, 0f32, 0u8);
for (i, &c) in hist.iter().enumerate() {
wb += c as f32;
if wb == 0.0 {
continue;
}
let wf = n - wb;
if wf <= 0.0 {
break;
}
sumb += i as f32 * c as f32;
let mb = sumb / wb;
let mf = (sum - sumb) / wf;
let between = wb * wf * (mb - mf) * (mb - mf);
if between > maxv {
maxv = between;
thr = i as u8;
}
}
let out = image::ImageBuffer::from_fn(w, h, |x, y| {
image::Luma([if luma.get_pixel(x, y)[0] > thr {
255
} else {
0
}])
});
image::DynamicImage::ImageLuma8(out)
}
fn select_by_margin(vectors: &[Vec<f32>]) -> usize {
let mut best = 0usize;
let mut bestm = f32::MIN;
for (i, v) in vectors.iter().enumerate() {
let m = margin(v);
if m > bestm {
bestm = m;
best = i;
}
}
best
}
fn margin(v: &[f32]) -> f32 {
let (mut top1, mut top2) = (f32::MIN, f32::MIN);
for &x in v {
if x > top1 {
top2 = top1;
top1 = x;
} else if x > top2 {
top2 = x;
}
}
if top1 == f32::MIN {
0.0
} else if top2 == f32::MIN {
top1
} else {
top1 - top2
}
}
fn parse_charset(s: &str) -> Result<Vec<String>> {
let v: serde_json::Value = serde_json::from_str(s).map_err(terr)?;
let arr = v["charset"]
.as_array()
.ok_or_else(|| Error::msg("OCR: charset 缺失"))?;
Ok(arr
.iter()
.map(|x| x.as_str().unwrap_or("").to_string())
.collect())
}
pub fn load_charset_file(path: &Path) -> Result<Vec<String>> {
let s = std::fs::read_to_string(path).map_err(terr)?;
let t = s.trim_start();
if t.starts_with('{') {
return parse_charset(&s);
}
if t.starts_with('[') {
let v: serde_json::Value = serde_json::from_str(&s).map_err(terr)?;
let arr = v
.as_array()
.ok_or_else(|| Error::msg("OCR: charset 文件不是 JSON 数组"))?;
return Ok(arr
.iter()
.map(|x| x.as_str().unwrap_or("").to_string())
.collect());
}
Ok(s.lines()
.map(|l| l.trim_end_matches('\r').to_string())
.collect())
}
fn preprocess(bytes: &[u8]) -> Result<(Vec<f32>, usize)> {
let img = image::load_from_memory(bytes).map_err(terr)?;
let (w, h) = (img.width(), img.height());
if w == 0 || h == 0 {
return Err(Error::msg("OCR: 空图"));
}
let new_w = ((w as f32) * 64.0 / (h as f32)).round().max(1.0) as usize;
let luma = img
.resize_exact(new_w as u32, 64, image::imageops::FilterType::Lanczos3)
.to_luma8();
let mut data = Vec::with_capacity(64 * new_w);
for y in 0..64u32 {
for x in 0..new_w as u32 {
data.push((luma.get_pixel(x, y)[0] as f32 / 255.0 - 0.5) / 0.5);
}
}
Ok((data, new_w))
}
fn ctc_decode(view: &tract_ndarray::ArrayViewD<f32>, charset: &[String]) -> String {
let shape = view.shape();
let c = charset.len();
let cls_axis = shape
.iter()
.position(|&d| d == c)
.unwrap_or(shape.len() - 1);
let t_axis = (0..shape.len())
.find(|&a| a != cls_axis && shape[a] > 1)
.unwrap_or(0);
let tn = shape[t_axis];
let mut out = String::new();
let mut prev = usize::MAX;
let mut idx = vec![0usize; shape.len()];
for t in 0..tn {
let mut best = 0usize;
let mut bestv = f32::MIN;
idx[t_axis] = t;
for k in 0..c {
idx[cls_axis] = k;
let v = view[idx.as_slice()];
if v > bestv {
bestv = v;
best = k;
}
}
if best != 0
&& best != prev
&& let Some(ch) = charset.get(best)
{
out.push_str(ch);
}
prev = best;
}
out
}
async fn ensure_model() -> Result<PathBuf> {
if let Ok(p) = std::env::var("DRISSION_OCR_MODEL") {
let p = PathBuf::from(p);
if p.exists() {
return Ok(p);
}
return Err(Error::msg(format!(
"OCR: DRISSION_OCR_MODEL 路径不存在: {}",
p.display()
)));
}
let dir = dirs::cache_dir()
.unwrap_or_else(std::env::temp_dir)
.join("drission")
.join("ocr");
std::fs::create_dir_all(&dir).map_err(terr)?;
let path = dir.join("ddddocr_common.onnx");
if path.exists()
&& std::fs::metadata(&path)
.map(|m| m.len() > 1_000_000)
.unwrap_or(false)
{
return Ok(path);
}
let url = std::env::var("DRISSION_OCR_MODEL_URL").unwrap_or_else(|_| MODEL_URL.to_string());
tracing::info!(target: "drission::ocr", "下载 OCR 模型(~54MB,仅首次): {url}");
let bytes = reqwest::get(&url)
.await
.map_err(terr)?
.bytes()
.await
.map_err(terr)?;
if bytes.len() < 1_000_000 {
return Err(Error::msg(format!(
"OCR: 模型下载异常({} bytes)",
bytes.len()
)));
}
let tmp = path.with_extension("onnx.part");
std::fs::write(&tmp, &bytes).map_err(terr)?;
std::fs::rename(&tmp, &path).map_err(terr)?;
Ok(path)
}
#[cfg(feature = "camoufox")]
static DEFAULT_OCR: tokio::sync::OnceCell<Ocr> = tokio::sync::OnceCell::const_new();
#[cfg(feature = "camoufox")]
static OCR_OVERRIDE: tokio::sync::RwLock<Option<std::sync::Arc<Ocr>>> =
tokio::sync::RwLock::const_new(None);
#[cfg(feature = "camoufox")]
pub async fn set_default_ocr(ocr: Ocr) {
*OCR_OVERRIDE.write().await = Some(std::sync::Arc::new(ocr));
}
#[cfg(feature = "camoufox")]
impl Tab {
pub async fn ocr_image(&self, selector: &str) -> Result<String> {
let bytes = self.fetch_image_bytes(selector).await?;
if let Some(ocr) = OCR_OVERRIDE.read().await.clone() {
return ocr.recognize(&bytes);
}
let ocr = DEFAULT_OCR.get_or_try_init(Ocr::new).await?;
ocr.recognize(&bytes)
}
async fn fetch_image_bytes(&self, selector: &str) -> Result<Vec<u8>> {
let el = self.ele(selector).await?;
if let Ok(src) = el.run_js("return node.currentSrc||node.src||'';").await
&& let Some(s) = src.as_str()
&& let Some(i) = s.find("base64,")
&& let Some(b) = base64_decode(&s[i + 7..])
&& !b.is_empty()
{
return Ok(b);
}
el.screenshot_bytes().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn charset_loads_and_blank_first() {
let cs = parse_charset(CHARSET_JSON).unwrap();
assert!(cs.len() > 1000);
assert_eq!(cs[0], ""); assert!(cs.iter().any(|c| c == "5") && cs.iter().any(|c| c == "z"));
}
#[test]
fn ctc_collapses_repeats_and_blanks() {
let charset = vec![
"".to_string(),
"a".to_string(),
"b".to_string(),
"c".to_string(),
];
let seq = [1usize, 1, 0, 2, 2];
let mut arr = tract_ndarray::Array3::<f32>::zeros((seq.len(), 1, charset.len()));
for (t, &k) in seq.iter().enumerate() {
arr[[t, 0, k]] = 1.0;
}
let dynv = arr.into_dyn();
assert_eq!(ctc_decode(&dynv.view(), &charset), "ab");
}
#[test]
fn det_grids_count_matches_yolox() {
assert_eq!(det_grids().len(), 2704 + 676 + 169);
}
fn bb(cx: u32, cy: u32) -> BBox {
BBox {
x1: cx - 5,
y1: cy - 5,
x2: cx + 5,
y2: cy + 5,
score: 0.9,
}
}
#[test]
fn match_order_follows_targets_and_skips_missing() {
let items = vec![
(bb(100, 100), "体".to_string()),
(bb(20, 30), "验".to_string()),
(bb(200, 50), "安".to_string()),
];
let targets = vec!["验".to_string(), "体".to_string()];
assert_eq!(match_order(&items, &targets), vec![(20, 30), (100, 100)]);
let targets2 = vec!["元".to_string(), "安".to_string()];
assert_eq!(match_order(&items, &targets2), vec![(200, 50)]);
let targets3 = vec!["体".to_string(), "体".to_string()];
assert_eq!(match_order(&items, &targets3), vec![(100, 100)]);
}
#[test]
fn box_center_in_excludes_toolbar_corner() {
let band = BBox {
x1: 200,
y1: 0,
x2: 320,
y2: 40,
score: 0.0,
};
assert!(box_center_in(&bb(260, 20), &band)); assert!(!box_center_in(&bb(100, 90), &band)); assert!(!box_center_in(&bb(260, 120), &band)); }
#[test]
fn assign_optimal_beats_greedy_order() {
let aff = vec![vec![0.9, 0.8], vec![0.85, 0.1]];
assert_eq!(assign_optimal(&aff, 2), vec![Some(1), Some(0)]);
}
#[test]
fn assign_optimal_all_distinct_when_enough() {
let aff = vec![vec![0.2, 0.9], vec![0.9, 0.2], vec![0.5, 0.5]];
let a = assign_optimal(&aff, 2);
assert_eq!(a, vec![Some(1), Some(0)]);
assert_ne!(a[0], a[1]); }
#[test]
fn assign_optimal_partial_when_fewer_boxes() {
let aff = vec![vec![0.1, 0.9, 0.2]];
let a = assign_optimal(&aff, 3);
assert_eq!(a[1], Some(0));
assert_eq!(a.iter().filter(|x| x.is_some()).count(), 1);
assert_eq!(assign_optimal(&Vec::<Vec<f32>>::new(), 2), vec![None, None]);
assert_eq!(assign_optimal(&aff, 0), Vec::<Option<usize>>::new());
}
#[test]
fn select_by_margin_picks_sharpest() {
let v = vec![vec![0.40, 0.39, 0.38], vec![0.60, 0.10, 0.05]];
assert_eq!(select_by_margin(&v), 1);
let v2 = vec![vec![0.5, 0.2], vec![0.5, 0.2]];
assert_eq!(select_by_margin(&v2), 0);
let v3 = vec![vec![0.2], vec![0.7]];
assert_eq!(select_by_margin(&v3), 1);
assert_eq!(margin(&[]), 0.0);
}
#[test]
fn load_charset_file_three_formats() {
let dir = std::env::temp_dir();
let p1 = dir.join("drission_cs_obj.json");
std::fs::write(&p1, r#"{"charset":["","a","b"]}"#).unwrap();
assert_eq!(load_charset_file(&p1).unwrap(), vec!["", "a", "b"]);
let p2 = dir.join("drission_cs_arr.json");
std::fs::write(&p2, r#"["","x","y","z"]"#).unwrap();
assert_eq!(load_charset_file(&p2).unwrap(), vec!["", "x", "y", "z"]);
let p3 = dir.join("drission_cs_lines.txt");
std::fs::write(&p3, "\n甲\n乙\n").unwrap(); assert_eq!(load_charset_file(&p3).unwrap(), vec!["", "甲", "乙"]);
for p in [p1, p2, p3] {
let _ = std::fs::remove_file(p);
}
}
#[test]
fn glyph_variants_keep_size_and_count() {
let mut im = image::RgbImage::new(9, 7);
for (x, y, p) in im.enumerate_pixels_mut() {
*p = image::Rgb([(x * 25) as u8, (y * 30) as u8, 90]);
}
let d = image::DynamicImage::ImageRgb8(im);
let vs = glyph_variants(&d);
assert_eq!(vs.len(), 3); for v in &vs {
assert_eq!((v.width(), v.height()), (9, 7));
}
}
}