use crate::apply::{self, ApplyConfig, ApplySession, Checkpoint, SequentialCheckpoint};
use crate::chunk::{self, ZiPatchReader};
use crate::{ApplyError, ApplyResult, ParseError};
use std::io::{Read, Seek};
use std::ops::ControlFlow;
impl ApplyConfig {
pub fn apply_patch<R: Read>(self, reader: ZiPatchReader<R>) -> ApplyResult<()> {
self.into_session().apply_patch(reader)
}
pub fn resume_apply_patch<R: Read + Seek>(
self,
reader: ZiPatchReader<R>,
from: Option<&SequentialCheckpoint>,
) -> ApplyResult<SequentialCheckpoint> {
self.into_session().resume_apply_patch(reader, from)
}
}
impl ApplySession {
pub fn apply_patch<R: Read>(&mut self, mut reader: ZiPatchReader<R>) -> ApplyResult<()> {
let span = tracing::info_span!(crate::tracing_schema::span_names::APPLY_PATCH);
let _enter = span.enter();
let started = std::time::Instant::now();
self.patch_name = reader.patch_name().map(str::to_owned);
self.patch_size = None;
let result = run_apply_loop(&mut reader, self, 0);
let flush_result = self.flush();
let (final_result, chunks_applied) = match (result, flush_result) {
(Ok(n), Ok(())) => (Ok(()), n),
(Ok(_), Err(e)) => (Err(ApplyError::from(e)), 0),
(Err(e), _) => (Err(e), 0),
};
if final_result.is_ok() {
tracing::info!(
chunks = chunks_applied,
bytes_read = reader.bytes_read(),
resumed_from_chunk = tracing::field::Empty,
elapsed_ms = started.elapsed().as_millis() as u64,
"apply_patch: patch applied"
);
}
final_result
}
#[allow(clippy::too_many_lines)]
pub fn resume_apply_patch<R: Read + Seek>(
&mut self,
mut reader: ZiPatchReader<R>,
from: Option<&SequentialCheckpoint>,
) -> ApplyResult<SequentialCheckpoint> {
let span = tracing::info_span!(crate::tracing_schema::span_names::RESUME_APPLY_PATCH);
let _enter = span.enter();
let started = std::time::Instant::now();
if let Some(cp) = from {
if !cp
.schema_version
.compatible_with(SequentialCheckpoint::CURRENT_SCHEMA_VERSION)
{
return Err(ApplyError::SchemaVersionMismatch {
kind: "sequential-checkpoint",
found: cp.schema_version,
expected: SequentialCheckpoint::CURRENT_SCHEMA_VERSION,
});
}
}
let reader_name = reader.patch_name().map(str::to_owned);
let total_size = stream_total_size(&mut reader)?;
self.patch_name.clone_from(&reader_name);
self.patch_size = Some(total_size);
let effective_from = from.and_then(|cp| {
let name_match = cp.patch_name == reader_name;
let size_match = match cp.patch_size {
Some(sz) => sz == total_size,
None => true,
};
if name_match && size_match {
Some(cp)
} else {
tracing::warn!(
expected_patch_name = ?reader_name,
expected_patch_size = total_size,
checkpoint_patch_name = ?cp.patch_name,
checkpoint_patch_size = ?cp.patch_size,
"resume_apply_patch: stale checkpoint, restarting from chunk 0"
);
None
}
});
let resumed_from_chunk = effective_from.map(|cp| cp.next_chunk_index);
let skipped_bytes_at_start = effective_from.map_or(0, |cp| cp.bytes_read);
let has_in_flight = effective_from
.and_then(|cp| cp.in_flight.as_ref())
.is_some();
if let Some(cp) = effective_from {
tracing::info!(
patch_name = ?reader_name,
skipped_chunks = cp.next_chunk_index,
skipped_bytes = cp.bytes_read,
has_in_flight,
"resume_apply_patch: resuming patch"
);
fast_forward(&mut reader, cp.next_chunk_index, cp.bytes_read)?;
}
let start_index = effective_from.map_or(0, |cp| cp.next_chunk_index);
let in_flight = effective_from.and_then(|cp| cp.in_flight.clone());
let result: ApplyResult<u64> = (|| {
if let Some(in_flight) = in_flight {
resume_in_flight_chunk(&mut reader, self, start_index, &in_flight)?;
run_apply_loop(&mut reader, self, start_index + 1).map(|n| n + 1)
} else {
run_apply_loop(&mut reader, self, start_index)
}
})();
let flush_result = self.flush();
let (final_result, chunks_applied) = match (result, flush_result) {
(Ok(n), Ok(())) => (Ok(()), n),
(Ok(_), Err(e)) => (Err(ApplyError::from(e)), 0),
(Err(e), _) => (Err(e), 0),
};
match final_result {
Ok(()) => {
let bytes_read = reader.bytes_read();
if let Some(from_chunk) = resumed_from_chunk {
tracing::info!(
chunks = chunks_applied,
bytes_read,
resumed_from_chunk = from_chunk,
skipped_bytes = skipped_bytes_at_start,
elapsed_ms = started.elapsed().as_millis() as u64,
"resume_apply_patch: patch applied"
);
} else {
tracing::info!(
chunks = chunks_applied,
bytes_read,
resumed_from_chunk = tracing::field::Empty,
elapsed_ms = started.elapsed().as_millis() as u64,
"resume_apply_patch: patch applied"
);
}
Ok(SequentialCheckpoint {
schema_version: SequentialCheckpoint::CURRENT_SCHEMA_VERSION,
next_chunk_index: start_index + chunks_applied,
bytes_read,
patch_name: reader_name,
patch_size: Some(total_size),
in_flight: None,
})
}
Err(e) => Err(e),
}
}
}
fn run_apply_loop<R: Read>(
reader: &mut ZiPatchReader<R>,
session: &mut ApplySession,
start_index: u64,
) -> ApplyResult<u64> {
let mut index = start_index;
while let Some(rec) = reader.next_chunk()? {
session.current_chunk_index = index;
session.current_chunk_bytes_read = rec.bytes_read;
rec.chunk.apply(session)?;
let bytes_read = rec.bytes_read;
let next_chunk_index = index + 1;
let checkpoint = Checkpoint::Sequential(SequentialCheckpoint {
schema_version: SequentialCheckpoint::CURRENT_SCHEMA_VERSION,
next_chunk_index,
bytes_read,
patch_name: session.patch_name.clone(),
patch_size: session.patch_size,
in_flight: None,
});
tracing::debug!(
next_chunk_index,
bytes_read,
in_flight = false,
"apply_patch: checkpoint recorded"
);
session.record_checkpoint(&checkpoint)?;
let event = apply::ChunkEvent {
index: index as usize,
kind: rec.tag,
bytes_read,
};
if let ControlFlow::Break(()) = session.observer_mut().on_chunk_applied(event) {
return Err(ApplyError::Cancelled);
}
if session.cancel_requested() {
return Err(ApplyError::Cancelled);
}
index += 1;
}
Ok(index - start_index)
}
fn stream_total_size<R: Read + Seek>(reader: &mut ZiPatchReader<R>) -> ApplyResult<u64> {
let inner = reader.inner_mut();
let current = inner.stream_position()?;
let end = inner.seek(std::io::SeekFrom::End(0))?;
inner.seek(std::io::SeekFrom::Start(current))?;
Ok(end)
}
fn fast_forward<R: Read>(
reader: &mut ZiPatchReader<R>,
target_chunks: u64,
expected_bytes_read: u64,
) -> ApplyResult<()> {
let mut consumed: u64 = 0;
while consumed < target_chunks {
match reader.next_chunk()? {
Some(_) => consumed += 1,
None => {
return Err(ApplyError::Parse(ParseError::TruncatedPatch));
}
}
}
if reader.bytes_read() != expected_bytes_read {
tracing::warn!(
actual_bytes_read = reader.bytes_read(),
expected_bytes_read,
target_chunks,
"resume_apply_patch: bytes_read drift during fast-forward"
);
}
tracing::debug!(
skipped_chunks = target_chunks,
bytes_read = reader.bytes_read(),
"resume_apply_patch: fast-forward complete"
);
Ok(())
}
fn resume_in_flight_chunk<R: Read>(
reader: &mut ZiPatchReader<R>,
session: &mut ApplySession,
chunk_index: u64,
in_flight: &apply::InFlightAddFile,
) -> ApplyResult<()> {
let Some(rec) = reader.next_chunk()? else {
return Err(ApplyError::Parse(ParseError::TruncatedPatch));
};
session.current_chunk_index = chunk_index;
session.current_chunk_bytes_read = rec.bytes_read;
let (start_block, start_bytes) = match resolve_in_flight_resume(&rec.chunk, session, in_flight)
{
InFlightResume::Resume {
start_block,
start_bytes,
} => (start_block, start_bytes),
InFlightResume::Restart => (0, 0),
};
match &rec.chunk {
chunk::Chunk::Sqpk(chunk::SqpkCommand::File(file))
if matches!(
file.operation,
crate::chunk::sqpk::SqpkFileOperation::AddFile
) =>
{
apply::sqpk::apply_file_add_from(file, session, start_block, start_bytes)?;
}
_ => rec.chunk.apply(session)?,
}
let bytes_read = rec.bytes_read;
let tag = rec.tag;
let next_chunk_index = chunk_index + 1;
let checkpoint = Checkpoint::Sequential(SequentialCheckpoint {
schema_version: SequentialCheckpoint::CURRENT_SCHEMA_VERSION,
next_chunk_index,
bytes_read,
patch_name: session.patch_name.clone(),
patch_size: session.patch_size,
in_flight: None,
});
session.record_checkpoint(&checkpoint)?;
let event = apply::ChunkEvent {
index: chunk_index as usize,
kind: tag,
bytes_read,
};
if let ControlFlow::Break(()) = session.observer_mut().on_chunk_applied(event) {
return Err(ApplyError::Cancelled);
}
if session.cancel_requested() {
return Err(ApplyError::Cancelled);
}
Ok(())
}
enum InFlightResume {
Resume {
start_block: usize,
start_bytes: u64,
},
Restart,
}
fn resolve_in_flight_resume(
chunk: &chunk::Chunk,
session: &ApplySession,
in_flight: &apply::InFlightAddFile,
) -> InFlightResume {
let chunk::Chunk::Sqpk(chunk::SqpkCommand::File(file)) = chunk else {
tracing::warn!(
"resume_apply_patch: in-flight chunk is not an SqpkFile; discarding in-flight state"
);
return InFlightResume::Restart;
};
if !matches!(
file.operation,
crate::chunk::sqpk::SqpkFileOperation::AddFile
) {
tracing::warn!(
"resume_apply_patch: in-flight chunk is not an AddFile; discarding in-flight state"
);
return InFlightResume::Restart;
}
let expected_path = apply::path::generic_path(session, &file.path);
if expected_path != in_flight.target_path {
tracing::warn!(
chunk_path = %expected_path.display(),
in_flight_path = %in_flight.target_path.display(),
"resume_apply_patch: in-flight target path does not match chunk; discarding"
);
return InFlightResume::Restart;
}
let chunk_offset = file.file_offset;
if chunk_offset != in_flight.file_offset {
tracing::warn!(
chunk_offset,
in_flight_offset = in_flight.file_offset,
"resume_apply_patch: in-flight file_offset does not match chunk; discarding"
);
return InFlightResume::Restart;
}
if in_flight.block_idx as usize > file.blocks.len() {
tracing::warn!(
block_idx = in_flight.block_idx,
block_count = file.blocks.len(),
"resume_apply_patch: in-flight block_idx out of range; discarding"
);
return InFlightResume::Restart;
}
if chunk_offset == 0 && in_flight.bytes_into_target > 0 {
let on_disk_len = session
.vfs()
.metadata(&in_flight.target_path)
.map_or(0, |m| m.len);
if on_disk_len < in_flight.bytes_into_target {
tracing::warn!(
target = %in_flight.target_path.display(),
on_disk_len,
bytes_into_target = in_flight.bytes_into_target,
"resume_apply_patch: target file truncated or missing since checkpoint; restarting AddFile"
);
return InFlightResume::Restart;
}
}
InFlightResume::Resume {
start_block: in_flight.block_idx as usize,
start_bytes: in_flight.bytes_into_target,
}
}