1use std::io::{self, Write};
13
14#[derive(Debug, Clone, Copy)]
35pub struct CrashPoint {
36 after: usize,
37}
38
39impl CrashPoint {
40 pub fn after_byte(n: usize) -> Self {
42 Self { after: n }
43 }
44
45 pub fn wrap<W: Write>(self, writer: W) -> CrashWriter<W> {
47 CrashWriter {
48 inner: writer,
49 after: self.after,
50 written: 0,
51 }
52 }
53}
54
55pub struct CrashWriter<W: Write> {
57 inner: W,
58 after: usize,
59 written: usize,
60}
61
62impl<W: Write> CrashWriter<W> {
63 pub fn bytes_written(&self) -> usize {
65 self.written
66 }
67
68 pub fn into_inner(self) -> W {
70 self.inner
71 }
72}
73
74impl<W: Write> Write for CrashWriter<W> {
75 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
76 if self.written >= self.after {
77 return Err(io::Error::new(
78 io::ErrorKind::WriteZero,
79 "crash point reached",
80 ));
81 }
82 let remaining = self.after - self.written;
83 let to_write = remaining.min(buf.len());
84 if to_write == 0 {
85 return Err(io::Error::new(
86 io::ErrorKind::WriteZero,
87 "crash point reached",
88 ));
89 }
90 let written = self.inner.write(&buf[..to_write])?;
91 self.written += written;
92 Ok(written)
95 }
96
97 fn flush(&mut self) -> io::Result<()> {
98 self.inner.flush()
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn crash_after_byte_passes_through_then_truncates() {
108 let sink: Vec<u8> = Vec::new();
109 let mut w = CrashPoint::after_byte(3).wrap(sink);
110 w.write_all(b"abcd").ok();
111 let sink = w.into_inner();
112 assert_eq!(sink, b"abc");
113 }
114
115 #[test]
116 fn crash_after_zero_writes_nothing() {
117 let sink: Vec<u8> = Vec::new();
118 let mut w = CrashPoint::after_byte(0).wrap(sink);
119 let r = w.write(b"a");
120 assert!(r.is_err());
121 let sink = w.into_inner();
122 assert!(sink.is_empty());
123 }
124
125 #[test]
126 fn crash_with_large_budget_passes_through() {
127 let sink: Vec<u8> = Vec::new();
128 let mut w = CrashPoint::after_byte(1_000).wrap(sink);
129 w.write_all(b"hello").unwrap();
130 let sink = w.into_inner();
131 assert_eq!(sink, b"hello");
132 }
133
134 #[test]
135 fn bytes_written_tracks_progress() {
136 let sink: Vec<u8> = Vec::new();
137 let mut w = CrashPoint::after_byte(5).wrap(sink);
138 w.write_all(b"ab").unwrap();
139 assert_eq!(w.bytes_written(), 2);
140 }
141
142 #[test]
143 fn split_across_writes_still_truncates_at_offset() {
144 let sink: Vec<u8> = Vec::new();
145 let mut w = CrashPoint::after_byte(4).wrap(sink);
146 w.write_all(b"ab").unwrap();
147 let _ = w.write_all(b"cdef");
150 let sink = w.into_inner();
151 assert_eq!(sink, b"abcd");
152 }
153}