use std::path::{Path, PathBuf};
use fontdue::{Font, FontSettings};
use crate::{Error, Result};
const N: usize = 32;
const ROTS: [f32; 5] = [-20.0, -10.0, 0.0, 10.0, 20.0];
pub struct GlyphMatcher {
font: Font,
}
impl GlyphMatcher {
pub fn from_font_path(path: &Path) -> Result<Self> {
let bytes =
std::fs::read(path).map_err(|e| Error::msg(format!("CJK 字体读取失败: {e}")))?;
let font = Font::from_bytes(bytes, FontSettings::default())
.map_err(|e| Error::msg(format!("CJK 字体解析失败: {e}")))?;
Ok(Self { font })
}
pub fn from_system() -> Result<Self> {
if let Ok(p) = std::env::var("DRISSION_CJK_FONT") {
return Self::from_font_path(Path::new(&p));
}
for p in cjk_font_candidates() {
if p.exists()
&& let Ok(m) = Self::from_font_path(&p)
{
return Ok(m);
}
}
Err(Error::msg(
"未找到系统 CJK 字体(可设 DRISSION_CJK_FONT 指向 .ttf/.ttc)",
))
}
pub fn similarity(&self, crop_feat: &[f32], ch: char) -> f32 {
let upright = rasterize_centered(&self.font, ch, N);
if upright.iter().all(|&v| v == 0.0) {
return 0.0; }
let mut best = 0f32;
for ° in &ROTS {
let t = if deg == 0.0 {
normalize(upright.clone())
} else {
normalize(rotate_grid(&upright, N, deg))
};
let s = dot(crop_feat, &t);
if s > best {
best = s;
}
}
best.clamp(0.0, 1.0)
}
}
pub fn crop_feat(crop: &image::DynamicImage) -> Vec<f32> {
let g = crop
.resize_exact(N as u32, N as u32, image::imageops::FilterType::Triangle)
.to_luma8();
let at = |x: usize, y: usize| g.get_pixel(x as u32, y as u32)[0] as f32;
let mut mag = vec![0f32; N * N];
for y in 1..N - 1 {
for x in 1..N - 1 {
let gx = (at(x + 1, y - 1) + 2.0 * at(x + 1, y) + at(x + 1, y + 1))
- (at(x - 1, y - 1) + 2.0 * at(x - 1, y) + at(x - 1, y + 1));
let gy = (at(x - 1, y + 1) + 2.0 * at(x, y + 1) + at(x + 1, y + 1))
- (at(x - 1, y - 1) + 2.0 * at(x, y - 1) + at(x + 1, y - 1));
mag[y * N + x] = (gx * gx + gy * gy).sqrt();
}
}
normalize(mag)
}
fn rasterize_centered(font: &Font, ch: char, n: usize) -> Vec<f32> {
let px = n as f32 * 0.85;
let (m, bmp) = font.rasterize(ch, px);
let mut out = vec![0f32; n * n];
if m.width == 0 || m.height == 0 {
return out;
}
let ox = (n as i32 - m.width as i32) / 2;
let oy = (n as i32 - m.height as i32) / 2;
for y in 0..m.height {
for x in 0..m.width {
let dx = ox + x as i32;
let dy = oy + y as i32;
if dx >= 0 && dy >= 0 && (dx as usize) < n && (dy as usize) < n {
out[dy as usize * n + dx as usize] = bmp[y * m.width + x] as f32 / 255.0;
}
}
}
out
}
fn rotate_grid(src: &[f32], n: usize, deg: f32) -> Vec<f32> {
let rad = deg.to_radians();
let (s, c) = rad.sin_cos();
let cen = (n as f32 - 1.0) / 2.0;
let mut out = vec![0f32; n * n];
for y in 0..n {
for x in 0..n {
let dx = x as f32 - cen;
let dy = y as f32 - cen;
let sx = c * dx + s * dy + cen;
let sy = -s * dx + c * dy + cen;
out[y * n + x] = bilinear(src, n, sx, sy);
}
}
out
}
fn bilinear(src: &[f32], n: usize, x: f32, y: f32) -> f32 {
if x < 0.0 || y < 0.0 || x > n as f32 - 1.0 || y > n as f32 - 1.0 {
return 0.0;
}
let x0 = x.floor() as usize;
let y0 = y.floor() as usize;
let x1 = (x0 + 1).min(n - 1);
let y1 = (y0 + 1).min(n - 1);
let (fx, fy) = (x - x0 as f32, y - y0 as f32);
let a = src[y0 * n + x0];
let b = src[y0 * n + x1];
let c = src[y1 * n + x0];
let d = src[y1 * n + x1];
a * (1.0 - fx) * (1.0 - fy) + b * fx * (1.0 - fy) + c * (1.0 - fx) * fy + d * fx * fy
}
fn normalize(mut v: Vec<f32>) -> Vec<f32> {
let len = v.len().max(1) as f32;
let mean = v.iter().sum::<f32>() / len;
for x in &mut v {
*x -= mean;
}
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-6 {
for x in &mut v {
*x /= norm;
}
}
v
}
fn dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
fn cjk_font_candidates() -> Vec<PathBuf> {
#[cfg(target_os = "macos")]
let list = [
"/System/Library/Fonts/PingFang.ttc",
"/System/Library/Fonts/STHeiti Medium.ttc",
"/System/Library/Fonts/Supplemental/Songti.ttc",
"/Library/Fonts/Arial Unicode.ttf",
];
#[cfg(target_os = "windows")]
let list = [
"C:\\Windows\\Fonts\\msyh.ttc",
"C:\\Windows\\Fonts\\simhei.ttf",
"C:\\Windows\\Fonts\\simsun.ttc",
"C:\\Windows\\Fonts\\msyh.ttf",
];
#[cfg(not(any(target_os = "macos", target_os = "windows")))]
let list = [
"/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc",
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
"/usr/share/fonts/truetype/arphic/uming.ttc",
"/usr/share/fonts/noto-cjk/NotoSansCJK-Regular.ttc",
];
list.iter().map(PathBuf::from).collect()
}
pub struct SampleBank {
samples: Vec<(char, Vec<f32>)>,
}
impl SampleBank {
pub fn from_dir(dir: &Path) -> Result<Self> {
let mut samples = Vec::new();
let rd = std::fs::read_dir(dir).map_err(|e| Error::msg(format!("样本库目录: {e}")))?;
for ent in rd.flatten() {
let p = ent.path();
if !p.is_dir() {
continue;
}
let Some(ch) = p
.file_name()
.and_then(|s| s.to_str())
.and_then(|s| s.chars().next())
else {
continue;
};
let Ok(files) = std::fs::read_dir(&p) else {
continue;
};
for f in files.flatten() {
let fp = f.path();
let ok_ext = fp
.extension()
.and_then(|s| s.to_str())
.map(|e| {
matches!(
e.to_ascii_lowercase().as_str(),
"png" | "jpg" | "jpeg" | "bmp"
)
})
.unwrap_or(false);
if !ok_ext {
continue;
}
if let Ok(bytes) = std::fs::read(&fp)
&& let Ok(img) = image::load_from_memory(&bytes)
{
samples.push((ch, crop_feat(&img)));
}
}
}
if samples.is_empty() {
return Err(Error::msg("样本库为空(目录下需有 {字}/ 子目录及样本图)"));
}
Ok(Self { samples })
}
pub fn len(&self) -> usize {
self.samples.len()
}
pub fn is_empty(&self) -> bool {
self.samples.is_empty()
}
pub fn has_char(&self, ch: char) -> bool {
self.samples.iter().any(|(c, _)| *c == ch)
}
pub fn similarity(&self, crop_feat: &[f32], ch: char) -> f32 {
if !self.has_char(ch) {
return 0.0;
}
let queries: Vec<Vec<f32>> = ROTS
.iter()
.map(|&d| {
if d == 0.0 {
crop_feat.to_vec()
} else {
normalize(rotate_grid(crop_feat, N, d))
}
})
.collect();
let mut best = 0f32;
for (c, feat) in &self.samples {
if *c != ch {
continue;
}
for q in &queries {
let s = dot(q, feat);
if s > best {
best = s;
}
}
}
best.clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_zero_mean_unit_norm() {
let v = normalize(vec![1.0, 2.0, 3.0, 4.0]);
let mean: f32 = v.iter().sum::<f32>() / v.len() as f32;
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(mean.abs() < 1e-5);
assert!((norm - 1.0).abs() < 1e-5);
}
#[test]
fn dot_self_is_one_for_normalized() {
let v = normalize(vec![0.0, 1.0, 0.0, 2.0, 0.0, 3.0]);
assert!((dot(&v, &v) - 1.0).abs() < 1e-5);
}
#[test]
fn rotate_zero_is_identity() {
let mut src = vec![0f32; N * N];
src[10 * N + 12] = 1.0;
let r = rotate_grid(&src, N, 0.0);
assert!((r[10 * N + 12] - 1.0).abs() < 1e-4);
}
}