use alloc::rc::Rc;
use alloc::vec::Vec;
use core::cell::RefCell;
use crate::encoding::{CompressionLevel, FrameCompressor, MatchGeneratorDriver, Matcher, Sequence};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct CapturedRawSequence {
pub block_idx: u32,
pub seq_in_block: u32,
pub ll: u32,
pub of: u32,
pub ml: u32,
}
#[derive(Clone, Debug, Default)]
pub struct SequenceCapture {
pub sequences: Vec<CapturedRawSequence>,
pub block_tail_lengths: Vec<u32>,
}
struct CapturingMatcher {
inner: MatchGeneratorDriver,
recorded: Rc<RefCell<Vec<CapturedRawSequence>>>,
block_tail_lengths: Rc<RefCell<Vec<u32>>>,
current_block: u32,
}
impl Matcher for CapturingMatcher {
fn get_next_space(&mut self) -> Vec<u8> {
self.inner.get_next_space()
}
fn get_last_space(&mut self) -> &[u8] {
self.inner.get_last_space()
}
fn commit_space(&mut self, space: Vec<u8>) {
self.inner.commit_space(space);
}
fn skip_matching(&mut self) {
let tail_ll = self.inner.get_last_space().len() as u32;
self.inner.skip_matching();
self.block_tail_lengths.borrow_mut().push(tail_ll);
self.current_block = self.current_block.saturating_add(1);
}
fn skip_matching_with_hint(&mut self, incompressible_hint: Option<bool>) {
let tail_ll = self.inner.get_last_space().len() as u32;
self.inner.skip_matching_with_hint(incompressible_hint);
self.block_tail_lengths.borrow_mut().push(tail_ll);
self.current_block = self.current_block.saturating_add(1);
}
fn start_matching(&mut self, mut handle_sequence: impl for<'a> FnMut(Sequence<'a>)) {
let recorded = self.recorded.clone();
let block_idx = self.current_block;
let mut seq_in_block: u32 = 0;
let mut block_tail_ll: u32 = 0;
self.inner.start_matching(|seq| {
match &seq {
Sequence::Triple {
literals,
offset,
match_len,
} => {
recorded.borrow_mut().push(CapturedRawSequence {
block_idx,
seq_in_block,
ll: literals.len() as u32,
of: *offset as u32,
ml: *match_len as u32,
});
seq_in_block = seq_in_block.saturating_add(1);
}
Sequence::Literals { literals } => {
block_tail_ll = literals.len() as u32;
}
}
handle_sequence(seq);
});
self.block_tail_lengths.borrow_mut().push(block_tail_ll);
self.current_block = self.current_block.saturating_add(1);
}
fn reset(&mut self, level: CompressionLevel) {
self.inner.reset(level);
self.recorded.borrow_mut().clear();
self.block_tail_lengths.borrow_mut().clear();
self.current_block = 0;
}
fn set_source_size_hint(&mut self, size: u64) {
self.inner.set_source_size_hint(size);
}
fn prime_with_dictionary(&mut self, dict_content: &[u8], offset_hist: [u32; 3]) {
self.inner.prime_with_dictionary(dict_content, offset_hist);
}
fn seed_dictionary_entropy(
&mut self,
huff: Option<&crate::huff0::huff0_encoder::HuffmanTable>,
ll: Option<&crate::fse::fse_encoder::FSETable>,
ml: Option<&crate::fse::fse_encoder::FSETable>,
of: Option<&crate::fse::fse_encoder::FSETable>,
) {
self.inner.seed_dictionary_entropy(huff, ll, ml, of);
}
fn supports_dictionary_priming(&self) -> bool {
self.inner.supports_dictionary_priming()
}
fn window_size(&self) -> u64 {
self.inner.window_size()
}
}
pub fn compress_and_collect_sequences(input: &[u8], level: CompressionLevel) -> SequenceCapture {
assert!(
!input.is_empty(),
"compress_and_collect_sequences requires non-empty input: \
the frame compressor emits a zero-length raw block for \
empty input without invoking the matcher, so no block \
metadata is recorded.",
);
assert!(
!matches!(level, CompressionLevel::Uncompressed),
"compress_and_collect_sequences does not support \
CompressionLevel::Uncompressed: raw-block emission bypasses \
the matcher entirely, so no sequences or block tails are \
recorded. Use a compressible level (Fastest / Level(N) / \
Default / Better) for sequence-stream audits.",
);
let post_split = matches!(level, CompressionLevel::Level(n) if n >= 16);
assert!(
!post_split,
"compress_and_collect_sequences does not support post-split \
levels (Level(n) where n >= 16): `compress_block_with_post_split` \
emits multiple physical blocks per matcher call, which the \
current per-matcher-call block counter cannot track. The \
tool is validated for Fastest / Default / Better / Best / \
Level(1..=15); higher numeric levels (including levels above \
22 which clamp to Level 22 params) need per-physical-block \
hooks that don't exist yet.",
);
let driver = MatchGeneratorDriver::new(1024 * 128, 1);
let recorded: Rc<RefCell<Vec<CapturedRawSequence>>> = Rc::new(RefCell::new(Vec::new()));
let block_tail_lengths: Rc<RefCell<Vec<u32>>> = Rc::new(RefCell::new(Vec::new()));
let matcher = CapturingMatcher {
inner: driver,
recorded: recorded.clone(),
block_tail_lengths: block_tail_lengths.clone(),
current_block: 0,
};
let mut output: Vec<u8> = Vec::new();
let mut compressor: FrameCompressor<&[u8], &mut Vec<u8>, CapturingMatcher> =
FrameCompressor::new_with_matcher(matcher, level);
compressor.set_source(input);
compressor.set_drain(&mut output);
compressor.set_source_size_hint(input.len() as u64);
compressor.compress();
drop(compressor);
let sequences = Rc::try_unwrap(recorded)
.expect("CapturingMatcher dropped with compressor; recorder is single-owner")
.into_inner();
let block_tail_lengths = Rc::try_unwrap(block_tail_lengths)
.expect("CapturingMatcher dropped with compressor; tail-length vec is single-owner")
.into_inner();
let reconstructed: u64 = sequences
.iter()
.map(|s| s.ll as u64 + s.ml as u64)
.sum::<u64>()
+ block_tail_lengths.iter().map(|t| *t as u64).sum::<u64>();
assert_eq!(
reconstructed,
input.len() as u64,
"sequence_capture: matcher-bypassing block path (RLE block? raw-frame fast-path?) \
left the captured stream short: Σ(ll+ml)+Σ(tails)={reconstructed}, input.len()={}. \
The current wrapper only sees blocks routed through `Matcher` methods on \
`CapturingMatcher`. Use a non-RLE-friendly fixture or extend capture to \
cover the bypassing path before relying on cumulative-position alignment.",
input.len(),
);
let raw_or_rle = detect_raw_or_rle_blocks_in_frame(&output).expect(
"sequence_capture: failed to parse emitted frame header — refusing to \
return a possibly-misaligned capture without raw-block detection",
);
assert!(
raw_or_rle.is_empty(),
"compress_and_collect_sequences: emitted frame contains {} raw/RLE block(s) at \
on-wire indices {:?}. The matcher recorded triples for those blocks but the \
on-wire form has no sequences for them — alignment against FFI delimiters \
would silently shift. Use a more compressible fixture (or a smaller block \
size) that keeps every block on the compressed path.",
raw_or_rle.len(),
raw_or_rle,
);
SequenceCapture {
sequences,
block_tail_lengths,
}
}
fn detect_raw_or_rle_blocks_in_frame(frame: &[u8]) -> Result<Vec<usize>, &'static str> {
const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
if frame.len() < 6 || frame[..4] != ZSTD_MAGIC {
return Err("frame missing zstd magic");
}
let mut cursor = 4_usize;
let fhd = frame[cursor];
cursor += 1;
let dict_id_flag = fhd & 0b11;
let content_checksum_flag = (fhd >> 2) & 1;
let single_segment_flag = (fhd >> 5) & 1;
let fcs_flag = (fhd >> 6) & 0b11;
if single_segment_flag == 0 {
cursor = cursor.checked_add(1).ok_or("cursor overflow")?;
}
let dict_id_size = match dict_id_flag {
0 => 0,
1 => 1,
2 => 2,
3 => 4,
_ => unreachable!(),
};
cursor = cursor.checked_add(dict_id_size).ok_or("cursor overflow")?;
let fcs_size = match (single_segment_flag, fcs_flag) {
(1, 0) => 1,
(_, 0) => 0,
(_, 1) => 2,
(_, 2) => 4,
(_, 3) => 8,
_ => unreachable!(),
};
cursor = cursor.checked_add(fcs_size).ok_or("cursor overflow")?;
if cursor > frame.len() {
return Err("truncated frame header");
}
let mut raw_or_rle = Vec::new();
let mut block_idx: usize = 0;
loop {
if cursor.checked_add(3).ok_or("cursor overflow")? > frame.len() {
return Err("truncated block header");
}
let header = u32::from(frame[cursor])
| (u32::from(frame[cursor + 1]) << 8)
| (u32::from(frame[cursor + 2]) << 16);
cursor += 3;
let last_block = (header & 1) != 0;
let block_type = (header >> 1) & 0b11;
let block_size = (header >> 3) as usize;
match block_type {
0 => {
raw_or_rle.push(block_idx);
cursor = cursor.checked_add(block_size).ok_or("cursor overflow")?;
}
1 => {
raw_or_rle.push(block_idx);
cursor = cursor.checked_add(1).ok_or("cursor overflow")?;
}
2 => {
cursor = cursor.checked_add(block_size).ok_or("cursor overflow")?;
}
3 => return Err("reserved block_type in frame"),
_ => unreachable!(),
}
block_idx += 1;
if cursor > frame.len() {
return Err("block content extends past frame end");
}
if last_block {
break;
}
}
if content_checksum_flag == 1 && cursor.checked_add(4).is_none_or(|end| end > frame.len()) {
return Err("truncated content checksum");
}
Ok(raw_or_rle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encoding::CompressionLevel;
use alloc::vec::Vec;
#[test]
fn captures_at_least_one_triple_on_repeating_pattern() {
let pattern: [u8; 16] = *b"PATTERN_1234_END";
let data: Vec<u8> = pattern.iter().copied().cycle().take(16 * 1024).collect();
let captured = compress_and_collect_sequences(&data, CompressionLevel::Level(3));
let seqs = &captured.sequences;
assert!(
!seqs.is_empty(),
"expected at least one Triple sequence on 16KB repeating pattern, got 0",
);
assert!(
seqs.iter().all(|s| s.block_idx == 0),
"16KB repeating pattern produced multi-block sequence stream: {:?}",
seqs.iter().map(|s| s.block_idx).collect::<Vec<_>>(),
);
for s in seqs {
assert!(s.of >= 1, "non-positive offset captured: {:?}", s);
assert!(s.ml >= 1, "non-positive match length captured: {:?}", s);
}
for (i, s) in seqs.iter().enumerate() {
assert_eq!(
s.seq_in_block, i as u32,
"seq_in_block discontinuity at idx {}: {:?}",
i, seqs,
);
}
assert_eq!(captured.block_tail_lengths.len(), 1);
let cumulative: u64 = seqs.iter().map(|s| s.ll as u64 + s.ml as u64).sum::<u64>()
+ captured.block_tail_lengths[0] as u64;
assert_eq!(
cumulative,
data.len() as u64,
"Σ(ll+ml) + tail must reconstruct the input length exactly: \
seqs sum + tail {} should == input {}",
cumulative,
data.len(),
);
}
#[test]
#[should_panic(expected = "raw/RLE block")]
fn rejects_incompressible_input_with_raw_on_wire_block() {
let mut state: u32 = 0x1234_5678;
let data: Vec<u8> = (0..1024)
.map(|_| {
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
(state >> 16) as u8
})
.collect();
let _ = compress_and_collect_sequences(&data, CompressionLevel::Level(3));
}
#[test]
#[should_panic(expected = "raw/RLE block")]
fn rejects_constant_run_with_rle_on_wire_block() {
let data: Vec<u8> = alloc::vec![b'A'; 16 * 1024];
let _ = compress_and_collect_sequences(&data, CompressionLevel::Level(3));
}
#[test]
fn captures_multi_block_tails_and_indices() {
let pattern: [u8; 16] = *b"PATTERN_1234_END";
let data: Vec<u8> = pattern.iter().copied().cycle().take(200 * 1024).collect();
let captured = compress_and_collect_sequences(&data, CompressionLevel::Level(3));
assert!(
captured.block_tail_lengths.len() >= 2,
"expected ≥2 block tail entries for 200 KiB input, got {}: {:?}",
captured.block_tail_lengths.len(),
captured.block_tail_lengths,
);
let cumulative: u64 = captured
.sequences
.iter()
.map(|s| s.ll as u64 + s.ml as u64)
.sum::<u64>()
+ captured
.block_tail_lengths
.iter()
.map(|t| *t as u64)
.sum::<u64>();
assert_eq!(
cumulative,
data.len() as u64,
"multi-block Σ(ll+ml)+Σ(tails) mismatch: got {}, want {} \
(blocks={}, triples={})",
cumulative,
data.len(),
captured.block_tail_lengths.len(),
captured.sequences.len(),
);
let num_blocks = captured.block_tail_lengths.len() as u32;
for s in &captured.sequences {
assert!(
s.block_idx < num_blocks,
"triple block_idx={} out of range (num_blocks={})",
s.block_idx,
num_blocks,
);
}
let mut max_seen: i64 = -1;
for s in &captured.sequences {
let idx = s.block_idx as i64;
assert!(
idx <= max_seen + 1,
"non-monotonic / gapped block_idx in triple stream: \
max_seen={max_seen}, observed={idx}",
);
if idx > max_seen {
max_seen = idx;
}
}
}
#[test]
#[should_panic(expected = "requires non-empty input")]
fn rejects_empty_input() {
let _ = compress_and_collect_sequences(&[], CompressionLevel::Level(3));
}
#[test]
#[should_panic(expected = "CompressionLevel::Uncompressed")]
fn rejects_uncompressed_level() {
let _ = compress_and_collect_sequences(
b"hello there general kenobi",
CompressionLevel::Uncompressed,
);
}
#[test]
#[should_panic(expected = "does not support post-split levels")]
fn rejects_post_split_numeric_level() {
let _ = compress_and_collect_sequences(
b"hello there general kenobi",
CompressionLevel::Level(16),
);
}
#[test]
fn captures_through_best_preset() {
let pattern: [u8; 32] = {
let mut p = [0u8; 32];
for (i, b) in p.iter_mut().enumerate() {
*b = (i as u8).wrapping_mul(11).wrapping_add(3);
}
p
};
let data: Vec<u8> = pattern.iter().copied().cycle().take(32 * 1024).collect();
let captured = compress_and_collect_sequences(&data, CompressionLevel::Best);
let cumulative: u64 = captured
.sequences
.iter()
.map(|s| s.ll as u64 + s.ml as u64)
.sum::<u64>()
+ captured
.block_tail_lengths
.iter()
.map(|t| *t as u64)
.sum::<u64>();
assert_eq!(cumulative, data.len() as u64);
}
#[test]
fn captures_through_pre_split_level_15() {
let pattern: [u8; 16] = *b"PATTERN_1234_END";
let data: Vec<u8> = pattern.iter().copied().cycle().take(32 * 1024).collect();
let captured = compress_and_collect_sequences(&data, CompressionLevel::Level(15));
let cumulative: u64 = captured
.sequences
.iter()
.map(|s| s.ll as u64 + s.ml as u64)
.sum::<u64>()
+ captured
.block_tail_lengths
.iter()
.map(|t| *t as u64)
.sum::<u64>();
assert_eq!(cumulative, data.len() as u64);
}
}