mod common;
use std::io::Cursor;
use std::sync::{Arc, Mutex};
use common::{FileBlock, adir_body, sqpk_addfile_body, wrap_patch};
use zipatch_rs::test_utils::make_chunk;
use zipatch_rs::{ApplyContext, Checkpoint, CheckpointPolicy, CheckpointSink, ZiPatchReader};
#[derive(Clone)]
struct PolicySink(CheckpointPolicy);
impl CheckpointSink for PolicySink {
fn record(&mut self, _: &Checkpoint) -> std::io::Result<()> {
Ok(())
}
fn policy(&self) -> CheckpointPolicy {
self.0
}
}
fn build_patch_with_n_blocks(n: usize) -> Vec<u8> {
let blocks: Vec<FileBlock> = (0..n)
.map(|i| FileBlock {
is_compressed: false,
decompressed: vec![0x10 + i as u8; 64],
})
.collect();
wrap_patch(vec![
make_chunk(b"ADIR", &adir_body("alpha")),
make_chunk(b"SQPK", &sqpk_addfile_body("data/target.dat", 0, &blocks)),
make_chunk(b"ADIR", &adir_body("bravo")),
])
}
#[test]
fn flush_policy_flushes_on_every_boundary_not_on_mid_block() {
let patch = build_patch_with_n_blocks(2);
let tmp = tempfile::tempdir().unwrap();
let mut ctx =
ApplyContext::new(tmp.path()).with_checkpoint_sink(PolicySink(CheckpointPolicy::Flush));
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert_eq!(
ctx.test_flush_count, 4,
"flush called on wrong count (3 boundary + 1 post-loop)"
);
assert_eq!(
ctx.test_sync_count, 0,
"sync_all must not fire under Flush policy"
);
}
#[test]
fn fsync_policy_syncs_on_every_boundary_not_on_mid_block() {
let patch = build_patch_with_n_blocks(2);
let tmp = tempfile::tempdir().unwrap();
let mut ctx =
ApplyContext::new(tmp.path()).with_checkpoint_sink(PolicySink(CheckpointPolicy::Fsync));
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert_eq!(
ctx.test_sync_count, 3,
"sync_all must fire once per boundary record (3 boundaries)"
);
}
#[test]
fn fsync_every_n_counts_only_boundary_records() {
let patch = build_patch_with_n_blocks(4);
let tmp = tempfile::tempdir().unwrap();
let mut ctx = ApplyContext::new(tmp.path())
.with_checkpoint_sink(PolicySink(CheckpointPolicy::FsyncEveryN(3)));
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert_eq!(
ctx.test_sync_count, 1,
"FsyncEveryN(3) with 3 boundary records must fire exactly 1 sync_all"
);
}
#[test]
fn fsync_every_2_fires_once_per_two_boundary_records() {
let patch = build_patch_with_n_blocks(2);
let tmp = tempfile::tempdir().unwrap();
let mut ctx = ApplyContext::new(tmp.path())
.with_checkpoint_sink(PolicySink(CheckpointPolicy::FsyncEveryN(2)));
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert_eq!(
ctx.test_sync_count, 1,
"FsyncEveryN(2) with 3 boundary records must fire 1 sync_all"
);
}
#[test]
fn fsync_every_1_fires_on_every_boundary_record() {
let patch = build_patch_with_n_blocks(2);
let tmp = tempfile::tempdir().unwrap();
let mut ctx = ApplyContext::new(tmp.path())
.with_checkpoint_sink(PolicySink(CheckpointPolicy::FsyncEveryN(1)));
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert_eq!(
ctx.test_sync_count, 3,
"FsyncEveryN(1) must sync on every boundary record (same as Fsync)"
);
}
#[test]
fn no_sink_installed_does_not_panic_and_apply_succeeds() {
let patch = build_patch_with_n_blocks(2);
let tmp = tempfile::tempdir().unwrap();
let mut ctx = ApplyContext::new(tmp.path()); ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
assert!(tmp.path().join("alpha").is_dir());
assert!(tmp.path().join("bravo").is_dir());
}
#[test]
fn closure_sink_receives_both_mid_block_and_boundary_records() {
let patch = build_patch_with_n_blocks(3);
let records: Arc<Mutex<Vec<Checkpoint>>> = Arc::new(Mutex::new(Vec::new()));
let records2 = records.clone();
let tmp = tempfile::tempdir().unwrap();
let mut ctx = ApplyContext::new(tmp.path()).with_checkpoint_sink(
move |c: &Checkpoint| -> std::io::Result<()> {
records2.lock().unwrap().push(c.clone());
Ok(())
},
);
ZiPatchReader::new(Cursor::new(patch))
.unwrap()
.apply_to(&mut ctx)
.unwrap();
let got = records.lock().unwrap();
assert_eq!(got.len(), 6, "expected 3 boundary + 3 mid-block records");
let mid_count = got
.iter()
.filter(|c| matches!(c, Checkpoint::Sequential(s) if s.in_flight.is_some()))
.count();
assert_eq!(mid_count, 3, "exactly 3 mid-block records");
}