use libpgs::pgs::{
decode_rle, CompositionObject, DisplaySetBuilder, ObjectBitmap,
PaletteEntry, PcsData, PdsData, SegmentType, WdsData, WindowDefinition,
};
use std::path::Path;
const FIXTURES: &[&str] = &[
"tests/fixtures/matroska-with-cues.mkv",
"tests/fixtures/matroska-no-cues.mkv",
"tests/fixtures/mpeg-transport-stream.m2ts",
];
fn available_fixtures() -> Vec<&'static str> {
FIXTURES
.iter()
.copied()
.filter(|p| Path::new(p).exists())
.collect()
}
fn scale_2x(pixels: &[u8], w: usize, h: usize) -> Vec<u8> {
let new_w = w * 2;
let new_h = h * 2;
let mut out = vec![0u8; new_w * new_h];
for y in 0..new_h {
for x in 0..new_w {
out[y * new_w + x] = pixels[(y / 2) * w + (x / 2)];
}
}
out
}
fn shift_palette(entry: &PaletteEntry) -> PaletteEntry {
PaletteEntry {
id: entry.id,
luminance: 255 - entry.luminance,
cr: entry.cb, cb: entry.cr,
alpha: entry.alpha,
}
}
#[test]
fn extract_transform_write_reread() {
let fixtures = available_fixtures();
if fixtures.is_empty() {
eprintln!("No fixture files found — skipping round_trip_transform test");
return;
}
for fixture in &fixtures {
eprintln!("--- Testing round-trip transform on: {fixture}");
let all_tracks =
libpgs::extract_all_display_sets(Path::new(fixture)).expect("extraction should succeed");
let total_ds: usize = all_tracks.iter().map(|t| t.display_sets.len()).sum();
if total_ds == 0 {
eprintln!(" No display sets found, skipping");
continue;
}
let track = all_tracks
.iter()
.find(|t| !t.display_sets.is_empty())
.unwrap();
eprintln!(
" Track {}: {} display sets, language={:?}",
track.track.track_id,
track.display_sets.len(),
track.track.language
);
let mut transformed = Vec::new();
for ds in &track.display_sets {
let pcs_seg = ds
.segments
.iter()
.find(|s| s.segment_type == SegmentType::PresentationComposition)
.expect("display set must have PCS");
let pcs = pcs_seg.parse_pcs().expect("PCS should parse");
let wds_opt = ds
.segments
.iter()
.find(|s| s.segment_type == SegmentType::WindowDefinition)
.and_then(|s| s.parse_wds());
let palettes: Vec<PdsData> = ds
.segments
.iter()
.filter(|s| s.segment_type == SegmentType::PaletteDefinition)
.filter_map(|s| s.parse_pds())
.map(|pds| PdsData {
id: pds.id,
version: pds.version,
entries: pds.entries.iter().map(shift_palette).collect(),
})
.collect();
let ods_segments: Vec<_> = ds
.segments
.iter()
.filter(|s| s.segment_type == SegmentType::ObjectDefinition)
.filter_map(|s| s.parse_ods())
.collect();
let mut ods_by_id: std::collections::BTreeMap<u16, (Option<u16>, Option<u16>, Vec<u8>)> =
std::collections::BTreeMap::new();
for ods in &ods_segments {
let entry = ods_by_id
.entry(ods.id)
.or_insert_with(|| (None, None, Vec::new()));
if let Some(w) = ods.width {
entry.0 = Some(w);
}
if let Some(h) = ods.height {
entry.1 = Some(h);
}
entry.2.extend_from_slice(&ods.rle_data);
}
let mut objects = Vec::new();
let mut scaled_sizes: std::collections::BTreeMap<u16, (u16, u16)> =
std::collections::BTreeMap::new();
for (id, (w_opt, h_opt, rle_data)) in &ods_by_id {
let w = w_opt.expect("ODS must have width");
let h = h_opt.expect("ODS must have height");
let co = pcs.objects.iter().find(|o| o.object_id == *id);
let (ox, oy) = co.map_or((0u16, 0u16), |c| (c.x, c.y));
let max_w = pcs.video_width.saturating_sub(ox);
let max_h = pcs.video_height.saturating_sub(oy);
let new_w = (w * 2).min(max_w);
let new_h = (h * 2).min(max_h);
let pixels = decode_rle(rle_data, w, h).expect("RLE decode should succeed");
let scaled_full = scale_2x(&pixels, w as usize, h as usize);
let cropped = if new_w == w * 2 && new_h == h * 2 {
scaled_full
} else {
let full_w = (w * 2) as usize;
let mut cropped = vec![0u8; new_w as usize * new_h as usize];
for row in 0..new_h as usize {
let src_start = row * full_w;
let dst_start = row * new_w as usize;
cropped[dst_start..dst_start + new_w as usize]
.copy_from_slice(&scaled_full[src_start..src_start + new_w as usize]);
}
cropped
};
scaled_sizes.insert(*id, (new_w, new_h));
objects.push(ObjectBitmap {
id: *id,
version: 0,
width: new_w,
height: new_h,
pixels: cropped,
});
}
let new_pcs = PcsData {
video_width: pcs.video_width,
video_height: pcs.video_height,
composition_number: pcs.composition_number,
composition_state: pcs.composition_state,
palette_only: pcs.palette_only,
palette_id: pcs.palette_id,
objects: pcs
.objects
.iter()
.map(|co| CompositionObject {
object_id: co.object_id,
window_id: co.window_id,
x: co.x,
y: co.y,
crop: co.crop.clone(),
})
.collect(),
};
let new_wds = wds_opt.map(|wds| WdsData {
windows: wds
.windows
.iter()
.map(|win| {
let mut needed_w = win.width * 2;
let mut needed_h = win.height * 2;
needed_w = needed_w.min(pcs.video_width.saturating_sub(win.x));
needed_h = needed_h.min(pcs.video_height.saturating_sub(win.y));
WindowDefinition {
id: win.id,
x: win.x,
y: win.y,
width: needed_w,
height: needed_h,
}
})
.collect(),
});
let mut builder = DisplaySetBuilder::new(ds.pts).pcs(new_pcs);
if let Some(wds) = new_wds {
builder = builder.wds(wds);
}
for pds in palettes {
builder = builder.palette(pds);
}
for obj in objects {
builder = builder.object(obj);
}
let new_ds = builder.build().expect("builder should succeed");
transformed.push(new_ds);
}
eprintln!(" Built {} transformed display sets", transformed.len());
let out_path = std::env::temp_dir().join(format!(
"libpgs_test_transform_{}.sup",
Path::new(fixture)
.file_stem()
.unwrap()
.to_string_lossy()
));
libpgs::write_sup_file(&transformed, &out_path).expect("write_sup_file should succeed");
let file_size = std::fs::metadata(&out_path).unwrap().len();
eprintln!(" Wrote {} bytes to {}", file_size, out_path.display());
assert!(file_size > 0, "output .sup file should not be empty");
let reread = libpgs::extract_all_display_sets(&out_path)
.expect("re-reading written .sup should succeed");
assert_eq!(reread.len(), 1, "SUP file should have exactly 1 track");
let reread_ds = &reread[0].display_sets;
assert_eq!(
reread_ds.len(),
transformed.len(),
"re-read display set count should match written count"
);
for (i, (orig, reread)) in transformed.iter().zip(reread_ds.iter()).enumerate() {
assert_eq!(
orig.pts, reread.pts,
"DS {i}: PTS mismatch"
);
assert_eq!(
orig.composition_state, reread.composition_state,
"DS {i}: composition state mismatch"
);
assert_eq!(
orig.segments.len(),
reread.segments.len(),
"DS {i}: segment count mismatch"
);
let orig_pcs = orig.segments[0].parse_pcs().unwrap();
let reread_pcs = reread.segments[0].parse_pcs().unwrap();
assert_eq!(orig_pcs.video_width, reread_pcs.video_width, "DS {i}: video width");
assert_eq!(orig_pcs.video_height, reread_pcs.video_height, "DS {i}: video height");
assert_eq!(orig_pcs.objects.len(), reread_pcs.objects.len(), "DS {i}: object count");
let orig_ods: Vec<_> = orig
.segments
.iter()
.filter(|s| s.segment_type == SegmentType::ObjectDefinition)
.collect();
let reread_ods: Vec<_> = reread
.segments
.iter()
.filter(|s| s.segment_type == SegmentType::ObjectDefinition)
.collect();
assert_eq!(orig_ods.len(), reread_ods.len(), "DS {i}: ODS segment count");
let orig_pds: Vec<_> = orig
.segments
.iter()
.filter_map(|s| s.parse_pds())
.collect();
let reread_pds: Vec<_> = reread
.segments
.iter()
.filter_map(|s| s.parse_pds())
.collect();
assert_eq!(orig_pds.len(), reread_pds.len(), "DS {i}: PDS count");
for (j, (op, rp)) in orig_pds.iter().zip(reread_pds.iter()).enumerate() {
assert_eq!(op.entries.len(), rp.entries.len(), "DS {i} PDS {j}: entry count");
for (k, (oe, re)) in op.entries.iter().zip(rp.entries.iter()).enumerate() {
assert_eq!(oe.luminance, re.luminance, "DS {i} PDS {j} entry {k}: luminance");
assert_eq!(oe.cr, re.cr, "DS {i} PDS {j} entry {k}: Cr");
assert_eq!(oe.cb, re.cb, "DS {i} PDS {j} entry {k}: Cb");
assert_eq!(oe.alpha, re.alpha, "DS {i} PDS {j} entry {k}: alpha");
}
}
}
let _ = std::fs::remove_file(&out_path);
eprintln!(" PASS: round-trip transform verified for {fixture}");
}
}