use std::cmp::{max, min, Ordering};
use std::collections::HashMap;
use std::env;
use std::io::IsTerminal;
use std::path::{Path, PathBuf};
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(unix)]
use libc;
use image::imageops::{FilterType, crop_imm, resize};
use image::{DynamicImage, Rgb, RgbImage, RgbaImage};
use imageproc::drawing::draw_hollow_rect_mut;
use ocr_rs::{OcrEngine, TextBox};
use pgs_rs::parse_pgs;
use pgs_rs::parse::{
CompositionObject, CompositionState, ObjectDefinition, PaletteDefinition, Pgs,
SegmentContents,
};
use serde::{Deserialize, Serialize};
use crate::{SubtitleDocument, TranslatorError};
const OCR_MODEL_DIR_ENV_VAR: &str = "SHINKAI_TRANSLATOR_OCR_MODEL_DIR";
const OCR_MODEL_VERSION: &str = "paddleocr-v5";
const DEFAULT_LAST_CUE_DURATION_PTS: u32 = 180_000;
const MAX_CONTINUATION_GAP_PTS: u32 = 45_000;
const PTS_PER_MILLISECOND: u64 = 90;
const MIN_ALPHA: u8 = 8;
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum PgsOcrLanguage {
Auto,
English,
Latin,
}
impl PgsOcrLanguage {
pub fn parse_name(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"auto" => Some(Self::Auto),
"english" | "en" => Some(Self::English),
"latin" => Some(Self::Latin),
_ => None,
}
}
}
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(default)]
pub struct PgsOcrConfig {
pub language: PgsOcrLanguage,
pub model_cache_dir: Option<PathBuf>,
}
impl Default for PgsOcrConfig {
fn default() -> Self {
Self {
language: PgsOcrLanguage::Auto,
model_cache_dir: env::var(OCR_MODEL_DIR_ENV_VAR).ok().map(PathBuf::from),
}
}
}
impl PgsOcrConfig {
pub fn validate(&self) -> Result<(), TranslatorError> {
Ok(())
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum ResolvedOcrProfile {
English,
Latin,
}
#[derive(Clone, Debug)]
struct ModelArtifact {
file_name: &'static str,
url: &'static str,
}
#[derive(Clone, Debug)]
struct ModelBundle {
detection: ModelArtifact,
recognition: ModelArtifact,
charset: ModelArtifact,
}
#[derive(Clone, Debug)]
struct ModelPaths {
detection: PathBuf,
recognition: PathBuf,
charset: PathBuf,
}
#[derive(Clone, Debug)]
struct RecognizedTextBlock {
bbox: TextBox,
text: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct OcrCue {
start_pts: u32,
end_pts: u32,
text: String,
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct DisplayTextEvent {
pts: u32,
is_empty: bool,
text: Option<String>,
}
#[derive(Clone, Debug)]
struct PendingCue {
start_pts: u32,
last_pts: u32,
normalized_text: String,
rendered_text: String,
}
#[derive(Debug)]
struct ResolvedDisplaySet<'a> {
presentation_timestamp: u32,
width: u16,
height: u16,
palette_id: u8,
composition_objects: &'a [CompositionObject],
palettes: HashMap<u8, &'a PaletteDefinition>,
objects: HashMap<u16, &'a ObjectDefinition>,
}
impl ResolvedDisplaySet<'_> {
fn is_empty(&self) -> bool {
self.composition_objects.is_empty()
}
}
pub async fn ocr_pgs_file_to_srt(
sup_path: &Path,
stream_language: Option<&str>,
source_language_hint: Option<&str>,
config: &PgsOcrConfig,
) -> Result<String, TranslatorError> {
config.validate()?;
let profile = resolve_profile(config.language, stream_language, source_language_hint);
let model_paths = ensure_models(config, profile).await?;
let sup_path = sup_path.to_path_buf();
let config = config.clone();
tokio::task::spawn_blocking(move || {
ocr_pgs_file_to_srt_blocking(&sup_path, &model_paths, profile, &config)
})
.await
.map_err(|error| TranslatorError::Ocr(format!("PGS OCR task failed: {error}")))?
}
pub fn cleanup_fragmentary_subtitle_document(document: &mut SubtitleDocument) {
for cue in document.cues_mut() {
let lines = cue
.text()
.lines()
.map(ToOwned::to_owned)
.collect::<Vec<_>>();
let repaired = repair_fragmentary_ocr_lines(lines);
if !repaired.is_empty() {
cue.set_text(repaired.join("\n"));
}
}
}
fn new_ocr_engine_quiet(
detection: &Path,
recognition: &Path,
charset: &Path,
) -> Result<OcrEngine, TranslatorError> {
#[cfg(unix)]
{
use std::fs::File;
let stderr_fd = std::io::stderr().as_raw_fd();
let saved_fd = unsafe { libc::dup(stderr_fd) };
let devnull = File::open("/dev/null")
.map_err(|error| TranslatorError::Ocr(format!("failed to open /dev/null: {error}")))?;
unsafe { libc::dup2(devnull.as_raw_fd(), stderr_fd) };
let result = OcrEngine::new(detection, recognition, charset, None)
.map_err(|error| TranslatorError::Ocr(format!("failed to initialize PaddleOCR engine: {error}")));
unsafe {
libc::dup2(saved_fd, stderr_fd);
libc::close(saved_fd);
}
result
}
#[cfg(not(unix))]
{
OcrEngine::new(detection, recognition, charset, None)
.map_err(|error| TranslatorError::Ocr(format!("failed to initialize PaddleOCR engine: {error}")))
}
}
fn ocr_pgs_file_to_srt_blocking(
sup_path: &Path,
model_paths: &ModelPaths,
_profile: ResolvedOcrProfile,
_config: &PgsOcrConfig,
) -> Result<String, TranslatorError> {
let mut data = std::fs::read(sup_path)?;
let pgs = parse_pgs(data.as_mut_slice())
.map_err(|error| TranslatorError::Ocr(format!("failed to parse PGS subtitle stream: {error}")))?;
let display_sets = collect_display_sets(&pgs)?;
let show_progress = std::io::stderr().is_terminal() && display_sets.len() >= 25;
let ocr = new_ocr_engine_quiet(
&model_paths.detection,
&model_paths.recognition,
&model_paths.charset,
)?;
let mut events = Vec::with_capacity(display_sets.len());
if show_progress {
eprintln!("OCR: processing {} PGS display sets", display_sets.len());
}
for (index, display_set) in display_sets.iter().enumerate() {
let text = recognize_display_set_text(display_set, &ocr)?;
events.push(DisplayTextEvent {
pts: display_set.presentation_timestamp,
is_empty: display_set.is_empty(),
text,
});
if show_progress && ((index + 1) % 25 == 0 || index + 1 == display_sets.len()) {
eprintln!(
"OCR: processed {}/{} display sets",
index + 1,
display_sets.len()
);
}
}
let cues = merge_display_texts(&events);
if cues.is_empty() {
return Err(TranslatorError::Ocr(
"OCR did not extract any subtitle text from the PGS stream".to_owned(),
));
}
Ok(render_cues_to_srt(&cues))
}
fn recognize_display_set_text(
display_set: &ResolvedDisplaySet<'_>,
ocr: &OcrEngine,
) -> Result<Option<String>, TranslatorError> {
if display_set.is_empty() {
return Ok(None);
}
let image = render_display_set_for_ocr(display_set)?;
let Some(bounds) = alpha_bounds(&image) else {
return Ok(None);
};
let cropped = crop_imm(&image, bounds.0, bounds.1, bounds.2, bounds.3).to_image();
let prepared = prepare_ocr_image(&cropped);
let prepared_image = DynamicImage::ImageRgb8(prepared);
let detections = ocr
.det_model()
.detect_and_crop(&prepared_image)
.map_err(|error| TranslatorError::Ocr(format!("PaddleOCR detection failed on rendered PGS frame: {error}")))?;
if detections.is_empty() {
return Ok(None);
}
let cropped_images = detections
.iter()
.map(|(image, _)| image.clone())
.collect::<Vec<_>>();
let recognition_results = ocr
.recognize_batch(&cropped_images)
.map_err(|error| TranslatorError::Ocr(format!("PaddleOCR recognition failed on rendered PGS frame: {error}")))?;
let recognized_blocks = detections
.into_iter()
.zip(recognition_results)
.filter_map(|((_, bbox), recognition)| {
let text = normalize_text_line(&recognition.text);
if text.is_empty() {
None
} else {
Some(RecognizedTextBlock { bbox, text })
}
})
.collect::<Vec<_>>();
let lines = collect_recognized_text_lines(recognized_blocks);
let text = repair_fragmentary_ocr_lines(lines).join("\n");
if text.trim().is_empty() {
Ok(None)
} else {
Ok(Some(text))
}
}
fn collect_display_sets<'a>(pgs: &'a Pgs) -> Result<Vec<ResolvedDisplaySet<'a>>, TranslatorError> {
let mut display_sets = Vec::new();
let mut palettes = HashMap::new();
let mut objects = HashMap::new();
let mut index = 0usize;
while index < pgs.segments.len() {
let segment = &pgs.segments[index];
let SegmentContents::PresentationComposition(presentation_composition) = &segment.contents else {
index += 1;
continue;
};
if presentation_composition.composition_state == CompositionState::EpochStart {
palettes.clear();
objects.clear();
}
let presentation_timestamp = segment.pts;
let decoding_timestamp = segment.dts;
let mut display_set = ResolvedDisplaySet {
presentation_timestamp,
width: presentation_composition.width,
height: presentation_composition.height,
palette_id: presentation_composition.palette_id,
composition_objects: &presentation_composition.composition_objects,
palettes: palettes.clone(),
objects: objects.clone(),
};
index += 1;
let mut saw_end = false;
while index < pgs.segments.len() {
let segment = &pgs.segments[index];
if segment.pts != presentation_timestamp || segment.dts != decoding_timestamp {
break;
}
match &segment.contents {
SegmentContents::PresentationComposition(_) => {
return Err(TranslatorError::Ocr(format!(
"unexpected nested PGS presentation composition at {presentation_timestamp}"
)));
}
SegmentContents::WindowDefinition(_) => {}
SegmentContents::PaletteDefinition(palette_definition) => {
palettes.insert(palette_definition.id, palette_definition);
display_set
.palettes
.insert(palette_definition.id, palette_definition);
}
SegmentContents::ObjectDefinition(object_definition) => {
objects.insert(object_definition.id, object_definition);
display_set
.objects
.insert(object_definition.id, object_definition);
}
SegmentContents::End => {
saw_end = true;
index += 1;
break;
}
}
index += 1;
}
if !saw_end {
return Err(TranslatorError::Ocr(format!(
"unterminated PGS display set at {presentation_timestamp}"
)));
}
display_sets.push(display_set);
}
Ok(display_sets)
}
fn render_display_set_for_ocr(display_set: &ResolvedDisplaySet<'_>) -> Result<RgbaImage, TranslatorError> {
let width = display_set.width as usize;
let height = display_set.height as usize;
let mut rgba = vec![0u8; width * height * 4];
for composition_object in display_set.composition_objects {
let object = display_set.objects.get(&composition_object.id).ok_or_else(|| {
TranslatorError::Ocr(format!(
"PGS object {} is missing from display set at {}",
composition_object.id, display_set.presentation_timestamp
))
})?;
let palette = display_set.palettes.get(&display_set.palette_id);
let mut pixel_offset = ((composition_object.vertical_position as usize * width)
+ composition_object.horizontal_position as usize)
* 4;
for pixel in &object.data.0 {
let palette_entry = palette.and_then(|palette| palette.entries.get(&pixel.color));
for _ in 0..pixel.count {
if !is_cropped(pixel_offset / 4, width, composition_object) {
if let Some(entry) = palette_entry {
rgba[pixel_offset] = entry.luminance;
rgba[pixel_offset + 1] = entry.luminance;
rgba[pixel_offset + 2] = entry.luminance;
rgba[pixel_offset + 3] = entry.alpha;
}
}
move_one_pixel_forward(
&mut pixel_offset,
width,
composition_object.horizontal_position as usize,
object.width as usize,
);
}
}
}
RgbaImage::from_raw(display_set.width as u32, display_set.height as u32, rgba).ok_or_else(|| {
TranslatorError::Ocr("failed to build RGBA frame from rendered PGS subtitle".to_owned())
})
}
fn is_cropped(pixel_index: usize, frame_width: usize, object: &CompositionObject) -> bool {
let Some(cropped) = &object.cropped else {
return false;
};
let x = pixel_index % frame_width;
let y = pixel_index / frame_width;
let crop_left = object.horizontal_position as usize + cropped.horizontal_position as usize;
let crop_top = object.vertical_position as usize + cropped.vertical_position as usize;
let crop_right = crop_left + cropped.width as usize;
let crop_bottom = crop_top + cropped.height as usize;
x < crop_left || x >= crop_right || y < crop_top || y >= crop_bottom
}
fn move_one_pixel_forward(
pixel_offset: &mut usize,
frame_width: usize,
horizontal_position: usize,
object_width: usize,
) {
*pixel_offset += 4;
let x = (*pixel_offset / 4) % frame_width;
if x == horizontal_position + object_width {
*pixel_offset += (frame_width - object_width) * 4;
}
}
fn alpha_bounds(image: &RgbaImage) -> Option<(u32, u32, u32, u32)> {
let mut min_x = u32::MAX;
let mut min_y = u32::MAX;
let mut max_x = 0u32;
let mut max_y = 0u32;
let mut found = false;
for (x, y, pixel) in image.enumerate_pixels() {
if pixel[3] <= MIN_ALPHA {
continue;
}
found = true;
min_x = min(min_x, x);
min_y = min(min_y, y);
max_x = max(max_x, x);
max_y = max(max_y, y);
}
if !found {
return None;
}
let padding = 6u32;
let x = min_x.saturating_sub(padding);
let y = min_y.saturating_sub(padding);
let width = min(image.width(), max_x.saturating_add(padding).saturating_add(1)) - x;
let height = min(image.height(), max_y.saturating_add(padding).saturating_add(1)) - y;
Some((x, y, width.max(1), height.max(1)))
}
fn prepare_ocr_image(image: &RgbaImage) -> RgbImage {
let mut rgb = RgbImage::new(image.width(), image.height());
for (x, y, pixel) in image.enumerate_pixels() {
let alpha = pixel[3] as u16;
let red = ((pixel[0] as u16) * alpha / 255) as u8;
let green = ((pixel[1] as u16) * alpha / 255) as u8;
let blue = ((pixel[2] as u16) * alpha / 255) as u8;
rgb.put_pixel(x, y, Rgb([red, green, blue]));
}
let scale_factor = if rgb.height() < 72 {
3
} else if rgb.height() < 160 {
2
} else {
1
};
if scale_factor == 1 {
rgb
} else {
resize(
&rgb,
rgb.width() * scale_factor,
rgb.height() * scale_factor,
FilterType::CatmullRom,
)
}
}
fn normalize_text_line(line: &str) -> String {
line.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn normalize_for_compare(text: &str) -> String {
text.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_ascii_lowercase()
}
fn collect_recognized_text_lines(mut blocks: Vec<RecognizedTextBlock>) -> Vec<String> {
if blocks.is_empty() {
return Vec::new();
}
blocks.sort_by(compare_recognized_text_blocks);
let mut lines: Vec<Vec<RecognizedTextBlock>> = Vec::with_capacity(blocks.len());
for block in blocks {
if let Some(current_line) = lines.last_mut() {
if belongs_to_same_line(current_line, &block) {
current_line.push(block);
continue;
}
}
lines.push(vec![block]);
}
lines
.into_iter()
.map(|mut line| {
line.sort_by(|a, b| a.bbox.rect.left().cmp(&b.bbox.rect.left()));
render_recognized_text_line(&line)
})
.filter(|line| !line.is_empty())
.collect()
}
fn compare_recognized_text_blocks(left: &RecognizedTextBlock, right: &RecognizedTextBlock) -> Ordering {
left.bbox
.rect
.top()
.cmp(&right.bbox.rect.top())
.then(left.bbox.rect.left().cmp(&right.bbox.rect.left()))
.then(left.bbox.rect.width().cmp(&right.bbox.rect.width()))
.then(left.bbox.rect.height().cmp(&right.bbox.rect.height()))
}
fn belongs_to_same_line(current_line: &[RecognizedTextBlock], next_block: &RecognizedTextBlock) -> bool {
let current_center = line_center_y(current_line);
let next_center = text_box_center_y(&next_block.bbox);
let current_height = line_height(current_line);
let next_height = next_block.bbox.rect.height() as i32;
let threshold = max(12, max(current_height, next_height) / 2);
(next_center - current_center).abs() <= threshold
}
fn line_center_y(line: &[RecognizedTextBlock]) -> i32 {
let (top, bottom) = line_bounds(line);
top + ((bottom - top) / 2)
}
fn line_height(line: &[RecognizedTextBlock]) -> i32 {
let (top, bottom) = line_bounds(line);
(bottom - top).max(0)
}
fn line_bounds(line: &[RecognizedTextBlock]) -> (i32, i32) {
let top = line
.iter()
.map(|block| block.bbox.rect.top())
.min()
.unwrap_or_default();
let bottom = line
.iter()
.map(|block| block.bbox.rect.top() + block.bbox.rect.height() as i32)
.max()
.unwrap_or_default();
(top, bottom)
}
fn text_box_center_y(bbox: &TextBox) -> i32 {
bbox.rect.top() + (bbox.rect.height() as i32 / 2)
}
fn render_recognized_text_line(line: &[RecognizedTextBlock]) -> String {
let mut rendered = String::new();
for block in line {
let piece = block.text.trim();
if piece.is_empty() {
continue;
}
append_ocr_text_piece(&mut rendered, piece);
}
normalize_text_line(&rendered)
}
fn append_ocr_text_piece(rendered: &mut String, piece: &str) {
if rendered.is_empty() {
rendered.push_str(piece);
return;
}
if is_punctuation_fragment(piece) || ends_with_connector(rendered) {
rendered.push_str(piece);
return;
}
if piece
.chars()
.next()
.is_some_and(|character| matches!(character, ',' | '.' | '!' | '?' | ':' | ';' | ')' | ']' | '}' | '\'' | '"'))
{
rendered.push_str(piece);
return;
}
rendered.push(' ');
rendered.push_str(piece);
}
fn ends_with_connector(text: &str) -> bool {
text.chars()
.last()
.is_some_and(|character| matches!(character, '-' | '’' | '\'' | '/' | '('))
}
fn repair_fragmentary_ocr_lines(lines: Vec<String>) -> Vec<String> {
let has_multiple_lines = lines.len() > 1;
let mut repaired: Vec<String> = Vec::with_capacity(lines.len());
for line in lines {
let line = strip_leading_fragment_punctuation(&line);
if is_punctuation_fragment(&line) {
if let Some(previous) = repaired.last_mut() {
let fragment = line.trim();
if !previous.trim_end().ends_with(fragment) {
previous.push_str(fragment);
}
}
continue;
}
if has_multiple_lines && is_noise_stub_line(&line) {
continue;
}
repaired.push(line);
}
repaired
}
fn strip_leading_fragment_punctuation(line: &str) -> String {
let trimmed = line.trim();
let mut characters = trimmed.chars();
let Some(first) = characters.next() else {
return String::new();
};
if first.is_alphanumeric() || first.is_whitespace() {
return trimmed.to_owned();
}
let remainder = characters.as_str().trim_start();
if remainder.is_empty() {
return trimmed.to_owned();
}
if remainder.starts_with(first) {
return trimmed.to_owned();
}
if remainder
.chars()
.next()
.is_some_and(|character| character.is_alphabetic())
{
remainder.to_owned()
} else {
trimmed.to_owned()
}
}
fn is_punctuation_fragment(line: &str) -> bool {
let trimmed = line.trim();
!trimmed.is_empty()
&& trimmed.chars().count() <= 4
&& trimmed
.chars()
.all(|character| !character.is_alphanumeric() && !character.is_whitespace())
}
fn is_noise_stub_line(line: &str) -> bool {
let trimmed = line.trim();
let chars = trimmed.chars().collect::<Vec<_>>();
match chars.len() {
0 => true,
1 => chars[0].is_alphanumeric(),
2 => chars.iter().all(|character| character.is_ascii_uppercase()),
_ => false,
}
}
fn merge_display_texts(events: &[DisplayTextEvent]) -> Vec<OcrCue> {
let mut cues = Vec::new();
let mut active: Option<PendingCue> = None;
for event in events {
if event.is_empty {
if let Some(current) = active.take() {
cues.push(OcrCue {
start_pts: current.start_pts,
end_pts: max(current.start_pts + 1, event.pts),
text: current.rendered_text,
});
}
continue;
}
let Some(text) = event.text.as_ref().map(|text| text.trim()).filter(|text| !text.is_empty()) else {
continue;
};
let normalized = normalize_for_compare(text);
match active.take() {
Some(mut current) if current.normalized_text == normalized => {
current.last_pts = event.pts;
active = Some(current);
}
Some(mut current) if should_extend_pending_cue(¤t, text, &normalized, event.pts) => {
current.last_pts = event.pts;
current.normalized_text = normalized;
current.rendered_text = text.to_owned();
active = Some(current);
}
Some(current) => {
cues.push(OcrCue {
start_pts: current.start_pts,
end_pts: max(current.start_pts + 1, event.pts),
text: current.rendered_text,
});
active = Some(PendingCue {
start_pts: event.pts,
last_pts: event.pts,
normalized_text: normalized,
rendered_text: text.to_owned(),
});
}
_ => {
active = Some(PendingCue {
start_pts: event.pts,
last_pts: event.pts,
normalized_text: normalized,
rendered_text: text.to_owned(),
});
}
}
}
if let Some(current) = active.take() {
cues.push(OcrCue {
start_pts: current.start_pts,
end_pts: current.start_pts.saturating_add(DEFAULT_LAST_CUE_DURATION_PTS),
text: current.rendered_text,
});
}
cues
}
fn should_extend_pending_cue(
current: &PendingCue,
next_text: &str,
next_normalized: &str,
next_pts: u32,
) -> bool {
if next_pts.saturating_sub(current.last_pts) > MAX_CONTINUATION_GAP_PTS {
return false;
}
if is_fragmentary_text(¤t.rendered_text) && !is_fragmentary_text(next_text) {
return true;
}
current.normalized_text.len() >= 2 && next_normalized.contains(¤t.normalized_text)
}
fn is_fragmentary_text(text: &str) -> bool {
let trimmed = text.trim();
if trimmed.is_empty() {
return true;
}
let alnum_count = trimmed.chars().filter(|character| character.is_alphanumeric()).count();
let visible_count = trimmed.chars().filter(|character| !character.is_whitespace()).count();
let token_count = trimmed.split_whitespace().count();
(alnum_count == 0 && visible_count <= 4) || (token_count <= 1 && alnum_count <= 3)
}
fn render_cues_to_srt(cues: &[OcrCue]) -> String {
let blocks = cues
.iter()
.enumerate()
.map(|(index, cue)| {
format!(
"{}\n{} --> {}\n{}",
index + 1,
format_srt_timestamp(cue.start_pts),
format_srt_timestamp(cue.end_pts),
cue.text
)
})
.collect::<Vec<_>>()
.join("\n\n");
format!("{blocks}\n")
}
fn format_srt_timestamp(pts: u32) -> String {
let total_milliseconds = ((pts as u64) + (PTS_PER_MILLISECOND / 2)) / PTS_PER_MILLISECOND;
let hours = total_milliseconds / 3_600_000;
let minutes = (total_milliseconds % 3_600_000) / 60_000;
let seconds = (total_milliseconds % 60_000) / 1_000;
let milliseconds = total_milliseconds % 1_000;
format!("{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}")
}
async fn ensure_models(
config: &PgsOcrConfig,
profile: ResolvedOcrProfile,
) -> Result<ModelPaths, TranslatorError> {
let bundle = model_bundle(profile);
let cache_dir = config
.model_cache_dir
.clone()
.unwrap_or_else(default_model_cache_dir)
.join("paddleocr")
.join(OCR_MODEL_VERSION)
.join(profile.cache_segment());
tokio::fs::create_dir_all(&cache_dir).await?;
let client = reqwest::Client::builder().build()?;
let detection = ensure_model_artifact(&client, &cache_dir, &bundle.detection).await?;
let recognition = ensure_model_artifact(&client, &cache_dir, &bundle.recognition).await?;
let charset = ensure_model_artifact(&client, &cache_dir, &bundle.charset).await?;
Ok(ModelPaths {
detection,
recognition,
charset,
})
}
async fn ensure_model_artifact(
client: &reqwest::Client,
cache_dir: &Path,
artifact: &ModelArtifact,
) -> Result<PathBuf, TranslatorError> {
let path = cache_dir.join(artifact.file_name);
if tokio::fs::metadata(&path).await.is_ok() {
return Ok(path);
}
let response = client
.get(artifact.url)
.send()
.await?
.error_for_status()
.map_err(|error| {
TranslatorError::Ocr(format!(
"failed to download OCR model {}: {error}",
artifact.file_name
))
})?;
let bytes = response.bytes().await?;
let temp_path = path.with_extension("download");
tokio::fs::write(&temp_path, &bytes).await?;
tokio::fs::rename(&temp_path, &path).await?;
Ok(path)
}
fn default_model_cache_dir() -> PathBuf {
if let Ok(value) = env::var(OCR_MODEL_DIR_ENV_VAR) {
if !value.trim().is_empty() {
return PathBuf::from(value);
}
}
if let Ok(value) = env::var("XDG_CACHE_HOME") {
if !value.trim().is_empty() {
return PathBuf::from(value).join("shinkai-translator");
}
}
if let Ok(home) = env::var("HOME") {
if !home.trim().is_empty() {
return PathBuf::from(home).join(".cache").join("shinkai-translator");
}
}
env::temp_dir().join("shinkai-translator")
}
fn resolve_profile(
requested: PgsOcrLanguage,
stream_language: Option<&str>,
source_language_hint: Option<&str>,
) -> ResolvedOcrProfile {
match requested {
PgsOcrLanguage::English => ResolvedOcrProfile::English,
PgsOcrLanguage::Latin => ResolvedOcrProfile::Latin,
PgsOcrLanguage::Auto => {
if stream_language
.is_some_and(|language| matches_language(language, &["eng", "en"][..]))
|| source_language_hint
.is_some_and(|language| normalize_language_name(language).contains("english"))
{
ResolvedOcrProfile::English
} else {
ResolvedOcrProfile::Latin
}
}
}
}
fn matches_language(value: &str, candidates: &[&str]) -> bool {
candidates
.iter()
.any(|candidate| value.trim().eq_ignore_ascii_case(candidate))
}
fn normalize_language_name(value: &str) -> String {
value
.chars()
.map(|character| {
if character.is_ascii_alphanumeric() {
character.to_ascii_lowercase()
} else {
' '
}
})
.collect::<String>()
}
fn model_bundle(profile: ResolvedOcrProfile) -> ModelBundle {
let detection = ModelArtifact {
file_name: "PP-OCRv5_mobile_det.mnn",
url: "https://raw.githubusercontent.com/zibo-chen/rust-paddle-ocr/next/models/PP-OCRv5_mobile_det.mnn",
};
let (recognition, charset) = match profile {
ResolvedOcrProfile::English => (
ModelArtifact {
file_name: "en_PP-OCRv5_mobile_rec_infer.mnn",
url: "https://raw.githubusercontent.com/zibo-chen/rust-paddle-ocr/next/models/en_PP-OCRv5_mobile_rec_infer.mnn",
},
ModelArtifact {
file_name: "ppocr_keys_en.txt",
url: "https://raw.githubusercontent.com/zibo-chen/rust-paddle-ocr/next/models/ppocr_keys_en.txt",
},
),
ResolvedOcrProfile::Latin => (
ModelArtifact {
file_name: "latin_PP-OCRv5_mobile_rec_infer.mnn",
url: "https://raw.githubusercontent.com/zibo-chen/rust-paddle-ocr/next/models/latin_PP-OCRv5_mobile_rec_infer.mnn",
},
ModelArtifact {
file_name: "ppocr_keys_latin.txt",
url: "https://raw.githubusercontent.com/zibo-chen/rust-paddle-ocr/next/models/ppocr_keys_latin.txt",
},
),
};
ModelBundle {
detection,
recognition,
charset,
}
}
impl ResolvedOcrProfile {
fn cache_segment(self) -> &'static str {
match self {
Self::English => "english",
Self::Latin => "latin",
}
}
}
pub struct OcrDebugResult {
pub frames_processed: usize,
pub frames_with_text: usize,
}
pub async fn debug_pgs_ocr(
sup_path: &Path,
config: &PgsOcrConfig,
stream_language: Option<&str>,
from_ms: Option<u64>,
to_ms: Option<u64>,
output_dir: &Path,
) -> Result<OcrDebugResult, TranslatorError> {
config.validate()?;
tokio::fs::create_dir_all(output_dir).await?;
let profile = resolve_profile(config.language, stream_language, None);
let model_paths = ensure_models(config, profile).await?;
let sup_path = sup_path.to_path_buf();
let output_dir = output_dir.to_path_buf();
tokio::task::spawn_blocking(move || {
debug_pgs_ocr_blocking(&sup_path, &model_paths, from_ms, to_ms, &output_dir)
})
.await
.map_err(|error| TranslatorError::Ocr(format!("PGS OCR debug task failed: {error}")))?
}
fn pts_from_ms(ms: u64) -> u32 {
ms.saturating_mul(PTS_PER_MILLISECOND) as u32
}
fn pts_to_ms_u64(pts: u32) -> u64 {
((pts as u64) + PTS_PER_MILLISECOND / 2) / PTS_PER_MILLISECOND
}
fn format_debug_filename_timestamp(pts: u32) -> String {
let ms = pts_to_ms_u64(pts);
let hours = ms / 3_600_000;
let minutes = (ms % 3_600_000) / 60_000;
let seconds = (ms % 60_000) / 1_000;
let millis = ms % 1_000;
format!("{hours:02}h{minutes:02}m{seconds:02}s{millis:03}")
}
fn format_debug_readable_timestamp(pts: u32) -> String {
let ms = pts_to_ms_u64(pts);
let hours = ms / 3_600_000;
let minutes = (ms % 3_600_000) / 60_000;
let seconds = (ms % 60_000) / 1_000;
let millis = ms % 1_000;
format!("{hours:02}:{minutes:02}:{seconds:02},{millis:03}")
}
struct DisplaySetDebugInfo {
raw_rgba: RgbaImage,
has_alpha_content: bool,
prepared_rgb: RgbImage,
recognized_blocks: Vec<RecognizedTextBlock>,
lines: Vec<String>,
final_text: Option<String>,
}
fn recognize_display_set_debug(
display_set: &ResolvedDisplaySet<'_>,
ocr: &OcrEngine,
) -> Result<DisplaySetDebugInfo, TranslatorError> {
let raw_rgba = render_display_set_for_ocr(display_set)?;
let Some(bounds) = alpha_bounds(&raw_rgba) else {
return Ok(DisplaySetDebugInfo {
raw_rgba,
has_alpha_content: false,
prepared_rgb: RgbImage::new(1, 1),
recognized_blocks: Vec::new(),
lines: Vec::new(),
final_text: None,
});
};
let cropped = crop_imm(&raw_rgba, bounds.0, bounds.1, bounds.2, bounds.3).to_image();
let prepared_rgb = prepare_ocr_image(&cropped);
let prepared_image = DynamicImage::ImageRgb8(prepared_rgb.clone());
let detections = ocr
.det_model()
.detect_and_crop(&prepared_image)
.map_err(|error| TranslatorError::Ocr(format!("PaddleOCR detection failed: {error}")))?;
if detections.is_empty() {
return Ok(DisplaySetDebugInfo {
raw_rgba,
has_alpha_content: true,
prepared_rgb,
recognized_blocks: Vec::new(),
lines: Vec::new(),
final_text: None,
});
}
let cropped_images = detections.iter().map(|(img, _)| img.clone()).collect::<Vec<_>>();
let recognition_results = ocr
.recognize_batch(&cropped_images)
.map_err(|error| TranslatorError::Ocr(format!("PaddleOCR recognition failed: {error}")))?;
let recognized_blocks = detections
.into_iter()
.zip(recognition_results)
.filter_map(|((_, bbox), recognition)| {
let text = normalize_text_line(&recognition.text);
if text.is_empty() { None } else { Some(RecognizedTextBlock { bbox, text }) }
})
.collect::<Vec<_>>();
let lines = collect_recognized_text_lines(recognized_blocks.clone());
let repaired = repair_fragmentary_ocr_lines(lines.clone());
let final_text_str = repaired.join("\n");
let final_text = if final_text_str.trim().is_empty() { None } else { Some(final_text_str) };
Ok(DisplaySetDebugInfo {
raw_rgba,
has_alpha_content: true,
prepared_rgb,
recognized_blocks,
lines,
final_text,
})
}
fn debug_pgs_ocr_blocking(
sup_path: &Path,
model_paths: &ModelPaths,
from_ms: Option<u64>,
to_ms: Option<u64>,
output_dir: &Path,
) -> Result<OcrDebugResult, TranslatorError> {
let mut data = std::fs::read(sup_path)?;
let pgs = parse_pgs(data.as_mut_slice())
.map_err(|error| TranslatorError::Ocr(format!("failed to parse PGS: {error}")))?;
let display_sets = collect_display_sets(&pgs)?;
let from_pts = from_ms.map(pts_from_ms);
let to_pts = to_ms.map(pts_from_ms);
let matching_indices: Vec<usize> = display_sets
.iter()
.enumerate()
.filter_map(|(i, ds)| {
let pts = ds.presentation_timestamp;
let in_range = from_pts.map_or(true, |from| pts >= from)
&& to_pts.map_or(true, |to| pts <= to);
in_range.then_some(i)
})
.collect();
if matching_indices.is_empty() {
return Err(TranslatorError::Ocr(
"no display sets found in the specified time range".to_owned(),
));
}
let ocr = new_ocr_engine_quiet(
&model_paths.detection,
&model_paths.recognition,
&model_paths.charset,
)?;
let total = matching_indices.len();
let mut frames_with_text = 0usize;
let mut index_lines = vec![
"OCR Debug Report".to_owned(),
format!("Source: {}", sup_path.display()),
format!("Total display sets in range: {}", total),
String::new(),
format!("{:<6} {:<17} {:<9} {}", "Frame", "Timestamp", "State", "OCR Text"),
"-".repeat(80),
];
for (seq, &set_index) in matching_indices.iter().enumerate() {
let display_set = &display_sets[set_index];
let frame_num = seq + 1;
let pts = display_set.presentation_timestamp;
let ts_filename = format_debug_filename_timestamp(pts);
let ts_readable = format_debug_readable_timestamp(pts);
let file_prefix = format!("{frame_num:04}_{ts_filename}");
if display_set.is_empty() {
index_lines.push(format!("{frame_num:<6} {ts_readable:<17} {:<9} -", "empty"));
let empty_path = output_dir.join(format!("{file_prefix}_empty.txt"));
std::fs::write(
&empty_path,
format!("Timestamp: {ts_readable}\nPTS: {pts}\nState: empty display set (no objects)\n"),
)
.map_err(|e| TranslatorError::Ocr(format!("failed to write debug file: {e}")))?;
continue;
}
match recognize_display_set_debug(display_set, &ocr) {
Err(error) => {
index_lines.push(format!(
"{frame_num:<6} {ts_readable:<17} {:<9} ERROR: {error}",
"error"
));
let error_path = output_dir.join(format!("{file_prefix}_error.txt"));
std::fs::write(
&error_path,
format!("Timestamp: {ts_readable}\nPTS: {pts}\nError: {error}\n"),
)
.map_err(|e| TranslatorError::Ocr(format!("failed to write debug file: {e}")))?;
}
Ok(info) => {
let has_text = info.final_text.is_some();
if has_text {
frames_with_text += 1;
}
let ocr_text = build_debug_ocr_report(pts, &ts_readable, &info);
let state = if !info.has_alpha_content {
"blank"
} else if has_text {
"text"
} else {
"no-text"
};
let summary = info
.final_text
.as_deref()
.and_then(|t| t.lines().next())
.unwrap_or("-")
.to_owned();
let raw_path = output_dir.join(format!("{file_prefix}_raw.png"));
DynamicImage::ImageRgba8(info.raw_rgba)
.save(&raw_path)
.map_err(|e| TranslatorError::Ocr(format!("failed to save debug image: {e}")))?;
if info.has_alpha_content {
let prep_path = output_dir.join(format!("{file_prefix}_prepared.png"));
DynamicImage::ImageRgb8(info.prepared_rgb.clone())
.save(&prep_path)
.map_err(|e| TranslatorError::Ocr(format!("failed to save debug image: {e}")))?;
let mut annotated = info.prepared_rgb.clone();
for block in &info.recognized_blocks {
draw_hollow_rect_mut(&mut annotated, block.bbox.rect, Rgb([255, 50, 50]));
}
let ann_path = output_dir.join(format!("{file_prefix}_annotated.png"));
DynamicImage::ImageRgb8(annotated)
.save(&ann_path)
.map_err(|e| TranslatorError::Ocr(format!("failed to save debug image: {e}")))?;
}
let ocr_path = output_dir.join(format!("{file_prefix}_ocr.txt"));
std::fs::write(&ocr_path, &ocr_text)
.map_err(|e| TranslatorError::Ocr(format!("failed to write debug file: {e}")))?;
index_lines.push(format!("{frame_num:<6} {ts_readable:<17} {state:<9} {summary}"));
}
}
}
let index_path = output_dir.join("index.txt");
std::fs::write(&index_path, index_lines.join("\n") + "\n")
.map_err(|error| TranslatorError::Ocr(format!("failed to write index file: {error}")))?;
Ok(OcrDebugResult {
frames_processed: total,
frames_with_text,
})
}
fn build_debug_ocr_report(pts: u32, ts_readable: &str, info: &DisplaySetDebugInfo) -> String {
let mut out = String::new();
out.push_str(&format!("Timestamp: {ts_readable}\n"));
out.push_str(&format!("PTS: {pts}\n"));
if !info.has_alpha_content {
out.push_str("State: blank (no visible content)\n");
return out;
}
out.push_str(&format!(
"Image size: {}x{}\n",
info.prepared_rgb.width(),
info.prepared_rgb.height()
));
out.push_str(&format!("Detected blocks: {}\n", info.recognized_blocks.len()));
out.push('\n');
for (i, block) in info.recognized_blocks.iter().enumerate() {
out.push_str(&format!("Block {}\n", i + 1));
out.push_str(&format!(
" Bbox: left={}, top={}, width={}, height={}\n",
block.bbox.rect.left(),
block.bbox.rect.top(),
block.bbox.rect.width(),
block.bbox.rect.height(),
));
out.push_str(&format!(" Text: {:?}\n", block.text));
out.push('\n');
}
if !info.lines.is_empty() {
out.push_str("Lines after grouping:\n");
for (i, line) in info.lines.iter().enumerate() {
out.push_str(&format!(" {}: {:?}\n", i + 1, line));
}
out.push('\n');
}
out.push_str("Final OCR text:\n");
match &info.final_text {
Some(text) => {
for line in text.lines() {
out.push_str(&format!(" {line}\n"));
}
}
None => out.push_str(" (none)\n"),
}
out
}
#[cfg(test)]
mod tests {
use super::{
DEFAULT_LAST_CUE_DURATION_PTS, DisplayTextEvent, MIN_ALPHA, PgsOcrLanguage,
RgbaImage, RecognizedTextBlock, TextBox, alpha_bounds,
cleanup_fragmentary_subtitle_document, collect_display_sets,
collect_recognized_text_lines, format_srt_timestamp,
merge_display_texts, prepare_ocr_image, render_display_set_for_ocr,
repair_fragmentary_ocr_lines, resolve_profile,
};
use image::Rgba;
use imageproc::rect::Rect;
use pgs_rs::parse::{
CompositionObject, CompositionState, LastInSequence, ObjectDefinition, PaletteDefinition,
PaletteEntry, Pgs, PresentationComposition, RlEncodedPixels, RunLengthEncodedData,
Segment, SegmentContents,
};
use std::collections::HashMap;
use crate::domain::{RenderPlan, SubtitleCue, SubtitleDocument, SubtitleFormat};
#[test]
fn resolve_profile_prefers_english_for_english_hints() {
assert_eq!(
resolve_profile(PgsOcrLanguage::Auto, Some("eng"), None),
super::ResolvedOcrProfile::English
);
assert_eq!(
resolve_profile(PgsOcrLanguage::Auto, None, Some("English")),
super::ResolvedOcrProfile::English
);
assert_eq!(
resolve_profile(PgsOcrLanguage::Auto, Some("por"), Some("Portuguese")),
super::ResolvedOcrProfile::Latin
);
}
#[test]
fn alpha_bounds_finds_visible_subtitle_region() {
let mut image = RgbaImage::new(10, 10);
image.put_pixel(4, 5, Rgba([255, 255, 255, MIN_ALPHA + 1]));
image.put_pixel(5, 6, Rgba([255, 255, 255, MIN_ALPHA + 1]));
let bounds = alpha_bounds(&image).expect("bounds should exist");
assert_eq!(bounds.0, 0);
assert_eq!(bounds.1, 0);
assert!(bounds.2 >= 6);
assert!(bounds.3 >= 7);
}
#[test]
fn prepare_ocr_image_preserves_visible_pixels() {
let mut image = RgbaImage::new(2, 2);
image.put_pixel(0, 0, Rgba([255, 255, 255, 255]));
let prepared = prepare_ocr_image(&image);
assert!(prepared.width() >= 2);
assert!(prepared.height() >= 2);
let pixel = prepared.get_pixel(0, 0);
assert_eq!(pixel[0], 255);
}
#[test]
fn collect_recognized_text_lines_groups_words_by_line() {
let lines = collect_recognized_text_lines(vec![
RecognizedTextBlock {
bbox: TextBox::new(Rect::at(12, 10).of_size(20, 10), 0.99),
text: "Hello".to_owned(),
},
RecognizedTextBlock {
bbox: TextBox::new(Rect::at(35, 11).of_size(24, 10), 0.99),
text: "world".to_owned(),
},
RecognizedTextBlock {
bbox: TextBox::new(Rect::at(10, 40).of_size(14, 10), 0.99),
text: "Tudo".to_owned(),
},
RecognizedTextBlock {
bbox: TextBox::new(Rect::at(27, 41).of_size(16, 10), 0.99),
text: "bem?".to_owned(),
},
]);
assert_eq!(lines, vec!["Hello world".to_owned(), "Tudo bem?".to_owned()]);
}
#[test]
fn merge_display_texts_merges_duplicates_and_closes_on_empty() {
let cues = merge_display_texts(&[
DisplayTextEvent {
pts: 9_000,
is_empty: false,
text: Some("Hello".to_owned()),
},
DisplayTextEvent {
pts: 18_000,
is_empty: false,
text: Some("Hello".to_owned()),
},
DisplayTextEvent {
pts: 27_000,
is_empty: true,
text: None,
},
DisplayTextEvent {
pts: 36_000,
is_empty: false,
text: Some("World".to_owned()),
},
][..]);
assert_eq!(cues.len(), 2);
assert_eq!(cues[0].start_pts, 9_000);
assert_eq!(cues[0].end_pts, 27_000);
assert_eq!(cues[0].text, "Hello");
assert_eq!(cues[1].start_pts, 36_000);
assert_eq!(cues[1].end_pts, 36_000 + DEFAULT_LAST_CUE_DURATION_PTS);
}
#[test]
fn merge_display_texts_promotes_short_fragment_into_following_full_text() {
let cues = merge_display_texts(&[
DisplayTextEvent {
pts: 9_000,
is_empty: false,
text: Some("?".to_owned()),
},
DisplayTextEvent {
pts: 18_000,
is_empty: false,
text: Some("...e algo que voce deve descobrir por conta propria.".to_owned()),
},
DisplayTextEvent {
pts: 27_000,
is_empty: true,
text: None,
},
][..]);
assert_eq!(cues.len(), 1);
assert_eq!(cues[0].start_pts, 9_000);
assert_eq!(cues[0].end_pts, 27_000);
assert_eq!(cues[0].text, "...e algo que voce deve descobrir por conta propria.");
}
#[test]
fn merge_display_texts_promotes_progressive_text_expansion() {
let cues = merge_display_texts(&[
DisplayTextEvent {
pts: 9_000,
is_empty: false,
text: Some("hello".to_owned()),
},
DisplayTextEvent {
pts: 18_000,
is_empty: false,
text: Some("hello world".to_owned()),
},
DisplayTextEvent {
pts: 27_000,
is_empty: true,
text: None,
},
][..]);
assert_eq!(cues.len(), 1);
assert_eq!(cues[0].start_pts, 9_000);
assert_eq!(cues[0].end_pts, 27_000);
assert_eq!(cues[0].text, "hello world");
}
#[test]
fn repair_fragmentary_ocr_lines_drops_leading_punctuation_only_line() {
let repaired = repair_fragmentary_ocr_lines(vec![
"?".to_owned(),
"Por que? Por que voce nao vai-".to_owned(),
]);
assert_eq!(repaired, vec!["Por que? Por que voce nao vai-"]);
}
#[test]
fn repair_fragmentary_ocr_lines_attaches_trailing_punctuation_only_line() {
let repaired = repair_fragmentary_ocr_lines(vec![
"Por que voce morreu'".to_owned(),
"?".to_owned(),
]);
assert_eq!(repaired, vec!["Por que voce morreu'?".to_owned()]);
}
#[test]
fn repair_fragmentary_ocr_lines_drops_single_character_stub_lines() {
let repaired = repair_fragmentary_ocr_lines(vec![
"Voce deveria pelo menos checar seu e-mail uma vez por dia".to_owned(),
"a".to_owned(),
]);
assert_eq!(
repaired,
vec!["Voce deveria pelo menos checar seu e-mail uma vez por dia".to_owned()]
);
}
#[test]
fn repair_fragmentary_ocr_lines_drops_short_uppercase_stub_lines() {
let repaired = repair_fragmentary_ocr_lines(vec![
"Por que ela esta chorando?".to_owned(),
"IS".to_owned(),
]);
assert_eq!(repaired, vec!["Por que ela esta chorando?".to_owned()]);
}
#[test]
fn repair_fragmentary_ocr_lines_strips_single_leading_punctuation_noise() {
let repaired = repair_fragmentary_ocr_lines(vec!["? tenho apenas dado".to_owned()]);
assert_eq!(repaired, vec!["tenho apenas dado".to_owned()]);
}
#[test]
fn cleanup_fragmentary_subtitle_document_repairs_existing_cue_lines() {
let mut document = SubtitleDocument::from_parts(
SubtitleFormat::Srt,
vec![SubtitleCue::new(
"cue-1",
Some("1".to_owned()),
"00:00:01,000",
"00:00:02,000",
None,
"? tenho apenas dado\nT\nPor que voce morreu'?\n?",
Default::default(),
)],
RenderPlan::Srt,
);
cleanup_fragmentary_subtitle_document(&mut document);
assert_eq!(
document.cues()[0].text(),
"tenho apenas dado\nPor que voce morreu'?"
);
}
#[test]
fn format_srt_timestamp_converts_pts_to_clock_time() {
assert_eq!(format_srt_timestamp(90_000), "00:00:01,000");
assert_eq!(format_srt_timestamp(135_000), "00:00:01,500");
}
#[test]
fn collect_display_sets_persists_objects_and_palettes_across_display_sets() {
let pgs = Pgs {
segments: vec![
Segment {
pts: 9_000,
dts: 9_000,
contents: SegmentContents::PresentationComposition(PresentationComposition {
width: 2,
height: 1,
frame_rate: 0,
composition_number: 1,
composition_state: CompositionState::EpochStart,
palette_update: false,
palette_id: 0,
composition_objects: vec![CompositionObject {
id: 0,
window_id: 0,
horizontal_position: 0,
vertical_position: 0,
cropped: None,
}],
}),
},
Segment {
pts: 9_000,
dts: 9_000,
contents: SegmentContents::PaletteDefinition(PaletteDefinition {
id: 0,
version: 0,
entries: HashMap::from([(
1,
PaletteEntry {
id: 1,
luminance: 255,
color_difference_red: 128,
color_difference_blue: 128,
alpha: 255,
},
)]),
}),
},
Segment {
pts: 9_000,
dts: 9_000,
contents: SegmentContents::ObjectDefinition(ObjectDefinition {
id: 0,
version: 0,
last_in_sequence: LastInSequence::FirstAndLast,
width: 2,
height: 1,
data: RunLengthEncodedData(vec![RlEncodedPixels { count: 2, color: 1 }]),
}),
},
Segment {
pts: 9_000,
dts: 9_000,
contents: SegmentContents::End,
},
Segment {
pts: 18_000,
dts: 18_000,
contents: SegmentContents::PresentationComposition(PresentationComposition {
width: 2,
height: 1,
frame_rate: 0,
composition_number: 2,
composition_state: CompositionState::Normal,
palette_update: false,
palette_id: 0,
composition_objects: vec![CompositionObject {
id: 0,
window_id: 0,
horizontal_position: 0,
vertical_position: 0,
cropped: None,
}],
}),
},
Segment {
pts: 18_000,
dts: 18_000,
contents: SegmentContents::End,
},
],
};
let display_sets = collect_display_sets(&pgs).expect("display sets should resolve");
assert_eq!(display_sets.len(), 2);
assert!(display_sets[1].objects.contains_key(&0));
assert!(display_sets[1].palettes.contains_key(&0));
let image = render_display_set_for_ocr(&display_sets[1])
.expect("reused object and palette should still render");
assert_eq!(image.get_pixel(0, 0)[3], 255);
assert_eq!(image.get_pixel(1, 0)[3], 255);
}
}