use wasm_bindgen::prelude::*;
use rten::ops;
use rten::{Model, ModelOptions, OpRegistry};
use rten_imageproc::{min_area_rect, BoundingRect, PointF};
use rten_tensor::prelude::*;
use crate::{ImageSource, OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput, TextItem};
#[wasm_bindgen]
pub struct OcrEngineInit {
detection_model: Option<Model>,
recognition_model: Option<Model>,
}
impl Default for OcrEngineInit {
fn default() -> OcrEngineInit {
OcrEngineInit::new()
}
}
#[wasm_bindgen]
impl OcrEngineInit {
#[wasm_bindgen(constructor)]
pub fn new() -> OcrEngineInit {
OcrEngineInit {
detection_model: None,
recognition_model: None,
}
}
fn op_registry() -> OpRegistry {
let mut reg = OpRegistry::new();
reg.register_op::<ops::Add>();
reg.register_op::<ops::AveragePool>();
reg.register_op::<ops::Cast>();
reg.register_op::<ops::Concat>();
reg.register_op::<ops::ConstantOfShape>();
reg.register_op::<ops::Conv>();
reg.register_op::<ops::ConvTranspose>();
reg.register_op::<ops::GRU>();
reg.register_op::<ops::Gather>();
reg.register_op::<ops::LogSoftmax>();
reg.register_op::<ops::MatMul>();
reg.register_op::<ops::MaxPool>();
reg.register_op::<ops::Pad>();
reg.register_op::<ops::Relu>();
reg.register_op::<ops::Reshape>();
reg.register_op::<ops::Shape>();
reg.register_op::<ops::Sigmoid>();
reg.register_op::<ops::Slice>();
reg.register_op::<ops::Transpose>();
reg.register_op::<ops::Unsqueeze>();
reg
}
#[wasm_bindgen(js_name = setDetectionModel)]
pub fn set_detection_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.detection_model = Some(model);
Ok(())
}
#[wasm_bindgen(js_name = setRecognitionModel)]
pub fn set_recognition_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.recognition_model = Some(model);
Ok(())
}
}
#[wasm_bindgen]
pub struct OcrEngine {
engine: BaseOcrEngine,
}
#[wasm_bindgen]
impl OcrEngine {
#[wasm_bindgen(constructor)]
pub fn new(init: OcrEngineInit) -> Result<OcrEngine, String> {
let OcrEngineInit {
detection_model,
recognition_model,
} = init;
let engine = BaseOcrEngine::new(OcrEngineParams {
detection_model,
recognition_model,
..Default::default()
})
.map_err(|e| e.to_string())?;
Ok(OcrEngine { engine })
}
#[wasm_bindgen(js_name = loadImage)]
pub fn load_image(&self, width: u32, height: u32, data: &[u8]) -> Result<Image, String> {
let image_source =
ImageSource::from_bytes(data, (width, height)).map_err(|err| err.to_string())?;
self.engine
.prepare_input(image_source)
.map(|input| Image { input })
.map_err(|e| e.to_string())
}
#[wasm_bindgen(js_name = detectText)]
pub fn detect_text(&self, image: &Image) -> Result<Vec<DetectedLine>, String> {
let words = self
.engine
.detect_words(&image.input)
.map_err(|e| e.to_string())?;
Ok(self
.engine
.find_text_lines(&image.input, &words)
.into_iter()
.map(|words| {
DetectedLine::new(
words
.into_iter()
.map(|word| RotatedRect { rect: word })
.collect(),
)
})
.collect())
}
#[wasm_bindgen(js_name = recognizeText)]
pub fn recognize_text(
&self,
image: &Image,
lines: Vec<DetectedLine>,
) -> Result<Vec<TextLine>, String> {
let lines: Vec<Vec<rten_imageproc::RotatedRect>> = lines
.iter()
.map(|line| {
let words: Vec<rten_imageproc::RotatedRect> =
line.words.iter().map(|word| word.rect).collect();
words
})
.collect();
let text_lines = self
.engine
.recognize_text(&image.input, &lines)
.map_err(|e| e.to_string())?
.into_iter()
.map(|line| {
line.map(|line| TextLine { line: Some(line) })
.unwrap_or(TextLine { line: None })
})
.collect();
Ok(text_lines)
}
#[wasm_bindgen(js_name = getText)]
pub fn get_text(&self, image: &Image) -> Result<String, String> {
self.engine
.get_text(&image.input)
.map_err(|e| e.to_string())
}
#[wasm_bindgen(js_name = getTextLines)]
pub fn get_text_lines(&self, image: &Image) -> Result<Vec<TextLine>, String> {
let words = self
.engine
.detect_words(&image.input)
.map_err(|e| e.to_string())?;
let lines = self.engine.find_text_lines(&image.input, &words);
let text_lines = self
.engine
.recognize_text(&image.input, &lines)
.map_err(|e| e.to_string())?
.into_iter()
.map(|line| {
line.map(|line| TextLine { line: Some(line) })
.unwrap_or(TextLine { line: None })
})
.collect();
Ok(text_lines)
}
}
#[wasm_bindgen]
pub struct Image {
input: OcrInput,
}
#[wasm_bindgen]
impl Image {
pub fn channels(&self) -> usize {
self.input.image.size(0)
}
pub fn width(&self) -> usize {
self.input.image.size(2)
}
pub fn height(&self) -> usize {
self.input.image.size(1)
}
pub fn data(&self) -> Vec<u8> {
self.input
.image
.permuted([1, 2, 0])
.iter()
.map(|x| ((x + 0.5) * 255.) as u8)
.collect()
}
}
#[wasm_bindgen]
#[derive(Clone)]
pub struct RotatedRect {
rect: rten_imageproc::RotatedRect,
}
#[wasm_bindgen]
impl RotatedRect {
pub fn corners(&self) -> Vec<f32> {
self.rect
.corners()
.into_iter()
.flat_map(|c| [c.x, c.y])
.collect()
}
#[wasm_bindgen(js_name = boundingRect)]
pub fn bounding_rect(&self) -> Vec<f32> {
let br = self.rect.bounding_rect();
[br.left(), br.top(), br.right(), br.bottom()].into()
}
}
#[wasm_bindgen]
#[derive(Clone)]
pub struct DetectedLine {
words: Vec<RotatedRect>,
}
#[wasm_bindgen]
impl DetectedLine {
fn new(words: Vec<RotatedRect>) -> DetectedLine {
DetectedLine { words }
}
#[wasm_bindgen(js_name = rotatedRect)]
pub fn rotated_rect(&self) -> RotatedRect {
let points: Vec<PointF> = self
.words
.iter()
.flat_map(|word| word.rect.corners().into_iter())
.collect();
let rect = min_area_rect(&points).expect("expected non-empty rect");
RotatedRect { rect }
}
pub fn words(&self) -> Vec<RotatedRect> {
self.words.clone()
}
}
#[wasm_bindgen]
#[derive(Clone)]
pub struct TextWord {
rect: RotatedRect,
text: String,
}
#[wasm_bindgen]
impl TextWord {
pub fn text(&self) -> String {
self.text.clone()
}
#[wasm_bindgen(js_name = rotatedRect)]
pub fn rotated_rect(&self) -> RotatedRect {
self.rect.clone()
}
}
#[wasm_bindgen]
#[derive(Clone)]
pub struct TextLine {
line: Option<super::TextLine>,
}
#[wasm_bindgen]
impl TextLine {
pub fn text(&self) -> String {
self.line
.as_ref()
.map(|l| l.to_string())
.unwrap_or_default()
}
pub fn words(&self) -> Vec<TextWord> {
self.line
.as_ref()
.map(|l| {
l.words()
.map(|w| TextWord {
text: w.to_string(),
rect: RotatedRect {
rect: w.rotated_rect(),
},
})
.collect()
})
.unwrap_or_default()
}
}