use crate::ocr::error::{OcrError, OcrResult};
use crate::ocr::features::extract;
use crate::ocr::recognize::Prototype;
use image::{GrayImage, Luma};
use std::path::Path;
pub fn prototype_from_art(label: char, art: &str) -> Prototype {
let rows: Vec<&str> = art
.lines()
.map(str::trim_end)
.filter(|row| !row.is_empty())
.collect();
let height = rows.len() as u32;
let width = rows.iter().map(|r| r.len()).max().unwrap_or(0) as u32;
let mut full = GrayImage::from_pixel(width.max(1), height.max(1), Luma([255]));
for (y, row) in rows.iter().enumerate() {
for (x, ch) in row.chars().enumerate() {
if ch == '#' {
full.put_pixel(x as u32, y as u32, Luma([0]));
}
}
}
let img = tight_crop(&full);
Prototype {
label,
features: extract(&img),
}
}
fn tight_crop(src: &GrayImage) -> GrayImage {
let (w, h) = src.dimensions();
let mut min_x = w;
let mut min_y = h;
let mut max_x = 0u32;
let mut max_y = 0u32;
for y in 0..h {
for x in 0..w {
if src.get_pixel(x, y)[0] < 128 {
if x < min_x {
min_x = x;
}
if y < min_y {
min_y = y;
}
if x > max_x {
max_x = x;
}
if y > max_y {
max_y = y;
}
}
}
}
if min_x > max_x || min_y > max_y {
return src.clone();
}
let cw = max_x - min_x + 1;
let ch = max_y - min_y + 1;
image::imageops::crop_imm(src, min_x, min_y, cw, ch).to_image()
}
pub const BUNDLED_GLYPHS: &[(char, &str)] = &[
(
'0',
"
.#####.
##...##
##...##
##...##
##...##
##...##
##...##
##...##
.#####.",
),
(
'1',
"
...##..
..###..
.####..
...##..
...##..
...##..
...##..
...##..
######.",
),
(
'2',
"
.#####.
##...##
.....##
....##.
...##..
..##...
.##....
##.....
#######",
),
(
'3',
"
.#####.
##...##
.....##
...###.
.....##
.....##
##...##
##...##
.#####.",
),
(
'4',
"
....##.
...###.
..####.
.##.##.
##..##.
#######
....##.
....##.
....##.",
),
(
'5',
"
#######
##.....
##.....
######.
.....##
.....##
##...##
##...##
.#####.",
),
(
'6',
"
.#####.
##...##
##.....
##.....
######.
##...##
##...##
##...##
.#####.",
),
(
'7',
"
#######
##...##
.....##
....##.
...##..
..##...
..##...
..##...
..##...",
),
(
'8',
"
.#####.
##...##
##...##
##...##
.#####.
##...##
##...##
##...##
.#####.",
),
(
'9',
"
.#####.
##...##
##...##
##...##
.######
.....##
.....##
##...##
.#####.",
),
(
'A',
"
...#...
..###..
..###..
.##.##.
.##.##.
#######
##...##
##...##
##...##",
),
(
'B',
"
######.
##...##
##...##
##...##
######.
##...##
##...##
##...##
######.",
),
(
'C',
"
.#####.
##...##
##.....
##.....
##.....
##.....
##.....
##...##
.#####.",
),
(
'D',
"
#####..
##..##.
##...##
##...##
##...##
##...##
##...##
##..##.
#####..",
),
(
'E',
"
#######
##.....
##.....
##.....
######.
##.....
##.....
##.....
#######",
),
(
'F',
"
#######
##.....
##.....
##.....
######.
##.....
##.....
##.....
##.....",
),
(
'G',
"
.#####.
##...##
##.....
##.....
##..###
##...##
##...##
##...##
.#####.",
),
(
'H',
"
##...##
##...##
##...##
##...##
#######
##...##
##...##
##...##
##...##",
),
(
'I',
"
#####..
.##....
.##....
.##....
.##....
.##....
.##....
.##....
#####..",
),
(
'J',
"
...####
.....##
.....##
.....##
.....##
.....##
##...##
##...##
.#####.",
),
(
'K',
"
##..##.
##.##..
####...
###....
####...
####...
##.##..
##..##.
##...##",
),
(
'L',
"
##.....
##.....
##.....
##.....
##.....
##.....
##.....
##.....
#######",
),
(
'M',
"
##...##
###.###
#######
##.#.##
##...##
##...##
##...##
##...##
##...##",
),
(
'N',
"
##...##
###..##
####.##
#######
##.####
##..###
##...##
##...##
##...##",
),
(
'O',
"
.#####.
##...##
##...##
##...##
##...##
##...##
##...##
##...##
.#####.",
),
(
'P',
"
######.
##...##
##...##
##...##
######.
##.....
##.....
##.....
##.....",
),
(
'Q',
"
.#####.
##...##
##...##
##...##
##...##
##.#.##
##..###
##...##
.######",
),
(
'R',
"
######.
##...##
##...##
##...##
######.
####...
##.##..
##..##.
##...##",
),
(
'S',
"
.#####.
##...##
##.....
##.....
.#####.
.....##
.....##
##...##
.#####.",
),
(
'T',
"
#######
#.##.#.
..##...
..##...
..##...
..##...
..##...
..##...
..##...",
),
(
'U',
"
##...##
##...##
##...##
##...##
##...##
##...##
##...##
##...##
.#####.",
),
(
'V',
"
##...##
##...##
##...##
##...##
##...##
##...##
.##.##.
..###..
...#...",
),
(
'W',
"
##...##
##...##
##...##
##...##
##.#.##
##.#.##
#######
###.###
##...##",
),
(
'X',
"
##...##
##...##
.##.##.
..###..
...#...
..###..
.##.##.
##...##
##...##",
),
(
'Y',
"
##...##
##...##
.##.##.
..###..
...#...
...#...
...#...
...#...
...#...",
),
(
'Z',
"
#######
#....##
....##.
...##..
..##...
.##....
##.....
##.....
#######",
),
];
pub fn bundled_prototypes() -> Vec<Prototype> {
BUNDLED_GLYPHS
.iter()
.map(|(ch, art)| prototype_from_art(*ch, art))
.collect()
}
pub fn save_prototypes_json(prototypes: &[Prototype], path: impl AsRef<Path>) -> OcrResult<()> {
let json = serde_json::to_vec_pretty(prototypes)
.map_err(|e| OcrError::Config(format!("serialize prototypes: {e}")))?;
std::fs::write(path.as_ref(), json).map_err(OcrError::Io)?;
Ok(())
}
pub fn load_prototypes_json(path: impl AsRef<Path>) -> OcrResult<Vec<Prototype>> {
use crate::ocr::features::FEATURE_COUNT;
let bytes = std::fs::read(path.as_ref()).map_err(OcrError::Io)?;
let protos: Vec<Prototype> = serde_json::from_slice(&bytes)
.map_err(|e| OcrError::Config(format!("parse prototypes: {e}")))?;
for p in &protos {
if p.features.len() != FEATURE_COUNT {
return Err(OcrError::Config(format!(
"prototype '{}' has {} features but this build expects {}. \
The feature vector was extended in a recent release — retrain with \
`cargo run --features ocr-train --example train_prototypes`.",
p.label,
p.features.len(),
FEATURE_COUNT
)));
}
}
Ok(protos)
}
pub fn dedupe_prototypes(
prototypes: Vec<Prototype>,
max_per_label: usize,
) -> Vec<Prototype> {
let max_per_label = max_per_label.max(1);
let mut by_label: std::collections::HashMap<char, Vec<Prototype>> =
std::collections::HashMap::new();
for p in prototypes {
by_label.entry(p.label).or_default().push(p);
}
let mut out: Vec<Prototype> = Vec::new();
for (_label, group) in by_label {
if group.len() <= max_per_label {
out.extend(group);
continue;
}
out.extend(k_medoids(group, max_per_label));
}
out
}
fn k_medoids(items: Vec<Prototype>, k: usize) -> Vec<Prototype> {
let n = items.len();
if k >= n {
return items;
}
let mut dist = vec![0f32; n * n];
for i in 0..n {
for j in (i + 1)..n {
let d = euclidean(&items[i].features, &items[j].features);
dist[i * n + j] = d;
dist[j * n + i] = d;
}
}
let mut medoids: Vec<usize> = (0..k).map(|i| (i * n) / k).collect();
for _iter in 0..50 {
let mut total_cost: f32 = (0..n)
.map(|i| (0..k).map(|m| dist[i * n + medoids[m]]).fold(f32::INFINITY, f32::min))
.sum();
let mut improved = false;
for m_idx in 0..k {
for candidate in 0..n {
if medoids.contains(&candidate) {
continue;
}
let original = medoids[m_idx];
medoids[m_idx] = candidate;
let new_cost: f32 = (0..n)
.map(|i| {
(0..k)
.map(|m| dist[i * n + medoids[m]])
.fold(f32::INFINITY, f32::min)
})
.sum();
if new_cost < total_cost {
total_cost = new_cost;
improved = true;
} else {
medoids[m_idx] = original;
}
}
}
if !improved {
break;
}
}
medoids.iter().map(|&i| items[i].clone()).collect()
}
fn euclidean(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}