use std::io::{self, Write};
#[derive(Debug, Clone, Copy)]
pub struct CrashPoint {
after: usize,
}
impl CrashPoint {
pub fn after_byte(n: usize) -> Self {
Self { after: n }
}
pub fn wrap<W: Write>(self, writer: W) -> CrashWriter<W> {
CrashWriter {
inner: writer,
after: self.after,
written: 0,
}
}
}
pub struct CrashWriter<W: Write> {
inner: W,
after: usize,
written: usize,
}
impl<W: Write> CrashWriter<W> {
pub fn bytes_written(&self) -> usize {
self.written
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W: Write> Write for CrashWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.written >= self.after {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"crash point reached",
));
}
let remaining = self.after - self.written;
let to_write = remaining.min(buf.len());
if to_write == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"crash point reached",
));
}
let written = self.inner.write(&buf[..to_write])?;
self.written += written;
Ok(written)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crash_after_byte_passes_through_then_truncates() {
let sink: Vec<u8> = Vec::new();
let mut w = CrashPoint::after_byte(3).wrap(sink);
w.write_all(b"abcd").ok();
let sink = w.into_inner();
assert_eq!(sink, b"abc");
}
#[test]
fn crash_after_zero_writes_nothing() {
let sink: Vec<u8> = Vec::new();
let mut w = CrashPoint::after_byte(0).wrap(sink);
let r = w.write(b"a");
assert!(r.is_err());
let sink = w.into_inner();
assert!(sink.is_empty());
}
#[test]
fn crash_with_large_budget_passes_through() {
let sink: Vec<u8> = Vec::new();
let mut w = CrashPoint::after_byte(1_000).wrap(sink);
w.write_all(b"hello").unwrap();
let sink = w.into_inner();
assert_eq!(sink, b"hello");
}
#[test]
fn bytes_written_tracks_progress() {
let sink: Vec<u8> = Vec::new();
let mut w = CrashPoint::after_byte(5).wrap(sink);
w.write_all(b"ab").unwrap();
assert_eq!(w.bytes_written(), 2);
}
#[test]
fn split_across_writes_still_truncates_at_offset() {
let sink: Vec<u8> = Vec::new();
let mut w = CrashPoint::after_byte(4).wrap(sink);
w.write_all(b"ab").unwrap();
let _ = w.write_all(b"cdef");
let sink = w.into_inner();
assert_eq!(sink, b"abcd");
}
}