use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::io::{BufWriter, IsTerminal, Read};
use anyhow::{anyhow, Context};
use ocrs::{DecodeMethod, DimOrder, ImageSource, OcrEngine, OcrEngineParams, OcrInput};
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
mod models;
use models::{load_model, ModelSource};
mod output;
use output::{
format_json_output, format_text_output, generate_annotated_png, FormatJsonArgs,
GeneratePngArgs, OutputFormat,
};
fn write_image(path: &str, img: NdTensorView<f32, 3>) -> anyhow::Result<()> {
let img_width = img.size(2);
let img_height = img.size(1);
let color_type = match img.size(0) {
1 => png::ColorType::Grayscale,
3 => png::ColorType::Rgb,
4 => png::ColorType::Rgba,
chans => return Err(anyhow!("Unsupported channel count {}", chans)),
};
let hwc_img = img.permuted([1, 2, 0]);
let out_img = image_from_tensor(hwc_img);
let file = fs::File::create(path)?;
let writer = BufWriter::new(file);
let mut encoder = png::Encoder::new(writer, img_width as u32, img_height as u32);
encoder.set_color(color_type);
let mut writer = encoder.write_header()?;
writer.write_image_data(&out_img)?;
Ok(())
}
fn image_from_tensor(tensor: NdTensorView<f32, 3>) -> Vec<u8> {
tensor
.iter()
.map(|x| (x.clamp(0., 1.) * 255.0) as u8)
.collect()
}
enum InputSource {
File(String),
Stdin,
Clipboard,
}
fn write_preprocessed_text_line_images(
input: &OcrInput,
engine: &OcrEngine,
line_rects: &[Vec<RotatedRect>],
output_dir: &str,
) -> anyhow::Result<()> {
std::fs::create_dir_all(output_dir)
.with_context(|| format!("Failed to create dir {}/", output_dir))?;
for (line_index, word_rects) in line_rects.iter().enumerate() {
let filename = format!("{}/line-{}.png", output_dir, line_index);
let mut line_img = engine.prepare_recognition_input(input, word_rects.as_slice())?;
line_img.apply(|x| x + 0.5);
let shape = [1, line_img.size(0), line_img.size(1)];
let line_img = line_img.into_shape(shape);
write_image(&filename, line_img.view())
.with_context(|| format!("Failed to write line image to {}", filename))?;
}
Ok(())
}
struct Args {
detection_model: Option<String>,
recognition_model: Option<String>,
input: InputSource,
debug: bool,
output_format: OutputFormat,
output_path: Option<String>,
beam_search: bool,
text_map: bool,
text_mask: bool,
text_line_images: bool,
allowed_chars: Option<String>,
alphabet: Option<String>,
}
fn parse_args() -> Result<Args, lexopt::Error> {
use lexopt::prelude::*;
let mut values = VecDeque::new();
let mut allowed_chars = None;
let mut alphabet = None;
let mut beam_search = false;
let mut clipboard = false;
let mut debug = false;
let mut detection_model = None;
let mut output_format = OutputFormat::Text;
let mut output_path = None;
let mut recognition_model = None;
let mut text_line_images = false;
let mut text_map = false;
let mut text_mask = false;
let mut parser = lexopt::Parser::from_env();
while let Some(arg) = parser.next()? {
match arg {
Value(val) => values.push_back(val.string()?),
Long("allowed-chars") => {
allowed_chars = Some(parser.value()?.string()?);
}
Short('a') | Long("alphabet") => {
alphabet = Some(parser.value()?.string()?);
}
Long("beam") => {
beam_search = true;
}
Short('c') | Long("clipboard") => {
clipboard = true;
}
Long("debug") => {
debug = true;
}
Long("detect-model") => {
detection_model = Some(parser.value()?.string()?);
}
Short('j') | Long("json") => {
output_format = OutputFormat::Json;
}
Short('o') | Long("output") => {
output_path = Some(parser.value()?.string()?);
}
Short('p') | Long("png") => {
output_format = OutputFormat::Png;
}
Long("rec-model") => {
recognition_model = Some(parser.value()?.string()?);
}
Long("text-line-images") => {
text_line_images = true;
}
Long("text-map") => {
text_map = true;
}
Long("text-mask") => {
text_mask = true;
}
Long("help") => {
println!(
"Extract text from an image.
Usage: {bin_name} [OPTIONS] [image]
If no image path is given, reads from stdin.
Options:
--allowed-chars <chars>
Filter characters produced by text recognition
-a, --alphabet <chars>
Specify the alphabet used by the recognition model
-c, --clipboard
Read image from system clipboard
--detect-model <path>
Use a custom text detection model
-j, --json
Output text and structure in JSON format
-o, --output <path>
Output file path (defaults to stdout)
-p, --png
Output annotated copy of input image in PNG format
--rec-model <path>
Use a custom text recognition model
--version
Display version info
Advanced options:
(Note: These options are unstable and may change between releases)
--beam
Use beam search for decoding
--debug
Enable debug logging
--text-line-images
Export images of identified text lines
--text-map
Generate a text probability map for the input image
--text-mask
Generate a binary text mask for the input image
",
bin_name = parser.bin_name().unwrap_or("ocrs")
);
std::process::exit(0);
}
Long("version") => {
println!("ocrs {}", env!("CARGO_PKG_VERSION"));
std::process::exit(0);
}
_ => return Err(arg.unexpected()),
}
}
let image = values.pop_front();
let stdin_is_pipe = !std::io::stdin().is_terminal();
let input = match (clipboard, image, stdin_is_pipe) {
(true, Some(_), _) => {
return Err("cannot use both --clipboard and an image path".into());
}
(true, _, true) => {
return Err("cannot use both --clipboard and stdin".into());
}
(true, None, false) => InputSource::Clipboard,
(false, Some(path), _) => InputSource::File(path),
(false, None, true) => InputSource::Stdin,
(false, None, false) => {
return Err("missing `<image>` arg (or use --clipboard / pipe to stdin)".into());
}
};
Ok(Args {
alphabet,
beam_search,
debug,
detection_model,
input,
output_format,
output_path,
recognition_model,
text_map,
text_mask,
text_line_images,
allowed_chars,
})
}
const DETECTION_MODEL: &str = "https://ocrs-models.s3-accelerate.amazonaws.com/text-detection.rten";
const RECOGNITION_MODEL: &str =
"https://ocrs-models.s3-accelerate.amazonaws.com/text-recognition.rten";
fn image_to_tensor(image: image::DynamicImage) -> NdTensor<u8, 3> {
let image = image.into_rgb8();
let (width, height) = image.dimensions();
NdTensor::from_data([height as usize, width as usize, 3], image.into_vec())
}
fn load_image_from_file(path: &str) -> anyhow::Result<NdTensor<u8, 3>> {
image::open(path)
.map(image_to_tensor)
.with_context(|| format!("Failed to read image from {}", path))
}
fn load_image_from_stdin() -> anyhow::Result<NdTensor<u8, 3>> {
let mut buf = Vec::new();
std::io::stdin()
.read_to_end(&mut buf)
.context("Failed to read image from stdin")?;
let image = image::load_from_memory(&buf).context("Failed to decode image from stdin")?;
Ok(image_to_tensor(image))
}
#[cfg(feature = "clipboard")]
fn load_image_from_clipboard() -> anyhow::Result<NdTensor<u8, 3>> {
use arboard::Clipboard;
let mut clipboard = Clipboard::new().context("Failed to access clipboard")?;
let image_data = clipboard
.get_image()
.context("Failed to get image from clipboard. Is there an image copied?")?;
let rgba_bytes = image_data.bytes.into_owned();
let rgb_bytes: Vec<u8> = rgba_bytes
.chunks_exact(4)
.flat_map(|chunk| [chunk[0], chunk[1], chunk[2]])
.collect();
Ok(NdTensor::from_data(
[image_data.height, image_data.width, 3],
rgb_bytes,
))
}
#[cfg(not(feature = "clipboard"))]
fn load_image_from_clipboard() -> anyhow::Result<NdTensor<u8, 3>> {
Err(anyhow!(
"ocrs was compiled without clipboard support. Use `cargo install ocrs-cli --features clipboard` to enable it."
))
}
fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;
let detection_model_src = args
.detection_model
.as_ref()
.map_or(ModelSource::Url(DETECTION_MODEL), |path| {
ModelSource::Path(path)
});
let detection_model = load_model(detection_model_src).with_context(|| {
format!(
"Failed to load text detection model from {}",
detection_model_src
)
})?;
let recognition_model_src = args
.recognition_model
.as_ref()
.map_or(ModelSource::Url(RECOGNITION_MODEL), |path| {
ModelSource::Path(path)
});
let recognition_model = load_model(recognition_model_src).with_context(|| {
format!(
"Failed to load text recognition model from {}",
recognition_model_src
)
})?;
#[allow(clippy::needless_update)]
let engine = OcrEngine::new(OcrEngineParams {
detection_model: Some(detection_model),
recognition_model: Some(recognition_model),
debug: args.debug,
alphabet: args.alphabet,
decode_method: if args.beam_search {
DecodeMethod::BeamSearch { width: 100 }
} else {
DecodeMethod::Greedy
},
allowed_chars: args.allowed_chars,
..Default::default()
})?;
let (color_img, input_path): (NdTensor<u8, 3>, String) = match &args.input {
InputSource::Clipboard => (load_image_from_clipboard()?, "<clipboard>".to_string()),
InputSource::File(path) => (load_image_from_file(path)?, path.clone()),
InputSource::Stdin => (load_image_from_stdin()?, "<stdin>".to_string()),
};
let color_img_source = ImageSource::from_tensor(color_img.view(), DimOrder::Hwc)?;
let ocr_input = engine.prepare_input(color_img_source)?;
if args.text_map || args.text_mask {
let text_map = engine.detect_text_pixels(&ocr_input)?;
let [height, width] = text_map.shape();
let text_map = text_map.into_shape([1, height, width]);
if args.text_map {
write_image("text-map.png", text_map.view())?;
}
if args.text_mask {
let threshold = engine.detection_threshold();
let text_mask = text_map.map(|x| if *x > threshold { 1. } else { 0. });
write_image("text-mask.png", text_mask.view())?;
}
}
let word_rects = engine.detect_words(&ocr_input)?;
let line_rects = engine.find_text_lines(&ocr_input, &word_rects);
if args.text_line_images {
write_preprocessed_text_line_images(&ocr_input, &engine, &line_rects, "lines")?;
}
let line_texts = engine.recognize_text(&ocr_input, &line_rects)?;
let write_output_str = |content: String| -> Result<(), Box<dyn Error>> {
if let Some(output_path) = &args.output_path {
std::fs::write(output_path, content.into_bytes())
.with_context(|| format!("Failed to write output to {}", output_path))?;
} else {
println!("{}", content);
}
Ok(())
};
match args.output_format {
OutputFormat::Text => {
let content = format_text_output(&line_texts);
write_output_str(content)?;
}
OutputFormat::Json => {
let content = format_json_output(FormatJsonArgs {
input_path: &input_path,
input_hw: color_img.shape()[1..].try_into()?,
text_lines: &line_texts,
});
write_output_str(content)?;
}
OutputFormat::Png => {
let png_args = GeneratePngArgs {
img: color_img.view(),
line_rects: &line_rects,
text_lines: &line_texts,
};
let annotated_img = generate_annotated_png(png_args);
let Some(output_path) = args.output_path else {
return Err("Output path must be specified when generating annotated PNG".into());
};
write_image(&output_path, annotated_img.view())
.with_context(|| format!("Failed to write output to {}", &output_path))?;
}
}
if args.debug {
println!(
"Found {} words, {} lines in image of size {}x{}",
word_rects.len(),
line_rects.len(),
color_img.size(2),
color_img.size(1),
);
}
Ok(())
}