1use crate::read_full;
2use std::io::{self, Read};
3
4enum ReplacingReaderState {
5 NotInitialized,
7
8 LastReadIsMiddle,
10
11 LastReadIsStart,
13}
14
15pub struct ReplacingReader<'a> {
22 underlying_reader: &'a mut dyn Read,
23 buffer: Vec<u8>,
30 old_pattern: &'a [u8],
31 new_pattern: &'a [u8],
32 read_ptr: usize,
33
34 state: ReplacingReaderState,
35
36 eof_position: Option<usize>,
39
40 next_match_ptr: Option<usize>,
42
43 serve_new_ptr: Option<usize>,
46}
47
48impl ReplacingReader<'_> {
49 pub fn new<'a>(r: &'a mut dyn Read, old: &'a [u8], new: &'a [u8]) -> ReplacingReader<'a> {
50 if old.len() == 0 { panic!("old pattern can not be empty") };
51
52 let buffer = vec![0; 2 * old.len()];
53 ReplacingReader {
54 underlying_reader: r,
55 old_pattern: old,
56 new_pattern: new,
57 read_ptr: 0,
58 buffer: buffer,
59 state: ReplacingReaderState::NotInitialized,
60 eof_position: None,
61
62 next_match_ptr: None,
63 serve_new_ptr: None,
64 }
65 }
66
67 #[inline(always)]
68 fn try_match_from(&self, start: usize) -> bool {
69 let mut ptr = start;
70 let mut match_len = 0usize;
71 loop {
72 if match_len == self.old_pattern.len() {
73 return true;
74 }
75 if self.buffer[ptr] == self.old_pattern[match_len] {
76 match_len += 1;
77 ptr += 1;
78 if ptr == self.buffer.len() {
79 ptr = 0;
80 }
81 } else {
82 return false;
83 }
84 }
85 }
86}
87
88impl Read for ReplacingReader<'_> {
89 fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
90 let buf_available = buf.len();
91 if let Some(new_ptr) = self.serve_new_ptr {
93 let remaining_new_pattern_len = self.new_pattern.len() - new_ptr;
94 if remaining_new_pattern_len > buf_available {
95 buf.copy_from_slice(&self.new_pattern[new_ptr..new_ptr + buf_available]);
96 self.serve_new_ptr = Some(new_ptr + buf_available);
97 return Ok(buf_available);
98 } else if remaining_new_pattern_len > 0 {
99 buf[..remaining_new_pattern_len].copy_from_slice(&self.new_pattern[new_ptr..]);
100 self.serve_new_ptr = None;
101 return Ok(remaining_new_pattern_len);
102 }
103 }
104
105 if let Some(next_match_ptr) = self.next_match_ptr {
107 if next_match_ptr > self.read_ptr {
108 let remaining_buf_available = next_match_ptr - self.read_ptr;
109 if buf_available >= remaining_buf_available {
110 buf[..remaining_buf_available]
112 .copy_from_slice(&self.buffer[self.read_ptr..next_match_ptr]);
113 self.serve_new_ptr = Some(0);
114 self.read_ptr = next_match_ptr + self.old_pattern.len();
115 if self.read_ptr >= self.buffer.len() {
116 self.read_ptr -= self.buffer.len();
117 }
118 self.next_match_ptr = None;
119 return Ok(remaining_buf_available);
120 } else {
121 buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
122 self.read_ptr += buf_available;
123 return Ok(buf_available);
124 }
125 } else if next_match_ptr == self.read_ptr {
126 self.serve_new_ptr = Some(0);
127 self.read_ptr += self.old_pattern.len() ;
128 if self.read_ptr >= self.buffer.len() {
129 self.read_ptr -= self.buffer.len();
130 }
131 self.next_match_ptr = None;
132 return self.read(buf);
133 } {
134 let remaining_buf_available = self.buffer.len() - self.read_ptr;
135 if buf_available >= remaining_buf_available {
136 buf[..remaining_buf_available].copy_from_slice(&self.buffer[self.read_ptr..]);
137 self.read_ptr = 0;
138 return Ok(remaining_buf_available);
139 } else {
140 buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
141 self.read_ptr += buf_available;
142 return Ok(buf_available);
143 }
144 }
145 }
146
147 match self.state {
149 ReplacingReaderState::NotInitialized => {
150 match read_full(&mut self.buffer, self.underlying_reader) {
152 Ok(read_len) => {
153 if read_len < self.buffer.len() {
154 self.eof_position = Some(read_len);
156 }
157 if read_len >= self.old_pattern.len() {
158 let possible_match_start = read_len - self.old_pattern.len();
159 for guess_start in 0..possible_match_start {
160 if self.try_match_from(guess_start) {
161 self.next_match_ptr = Some(guess_start);
162 break;
163 }
164 }
165 }
166
167 self.state = ReplacingReaderState::LastReadIsMiddle;
168 return self.read(buf);
169 }
170 Err(e) => return Err(e),
171 };
172 }
173 _ => (),
174 };
175
176 if let Some(eof_position) = self.eof_position {
178 if eof_position < self.read_ptr {
180 let max_read_size = self.buffer.len() - self.read_ptr;
182 if max_read_size >= self.old_pattern.len() {
183 for guess_start in self.read_ptr..self.read_ptr + 1 + max_read_size - self.old_pattern.len() {
184 if self.try_match_from(guess_start) {
185 self.next_match_ptr = Some(guess_start % self.buffer.len());
186 return self.read(buf);
187 }
188 }
189 }
190 if max_read_size > buf_available {
191 buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
192 self.read_ptr += buf_available;
193 return Ok(buf_available);
194 } else {
195 buf[..max_read_size].copy_from_slice(&self.buffer[self.read_ptr..]);
196 self.read_ptr = 0;
197 return Ok(max_read_size);
198 }
199 } else if eof_position == self.read_ptr {
200 return Ok(0);
201 } else {
202 let max_read_size = eof_position - self.read_ptr;
203 if max_read_size >= self.old_pattern.len() {
204 for guess_start in self.read_ptr..self.read_ptr + 1 + max_read_size - self.old_pattern.len() {
205 if self.try_match_from(guess_start) {
206 self.next_match_ptr = Some(guess_start);
207 return self.read(buf);
208 }
209 }
210 }
211 if max_read_size > buf_available {
212 buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
213 self.read_ptr += buf_available;
214 return Ok(buf_available);
215 } else {
216 buf[..max_read_size].copy_from_slice(&self.buffer[self.read_ptr..eof_position]);
217 self.read_ptr += max_read_size;
218 return Ok(max_read_size);
219 }
220 }
221 }
222
223 let wrap_pos = self.old_pattern.len();
225 match self.state {
226 ReplacingReaderState::LastReadIsStart => {
227 if self.read_ptr >= wrap_pos {
228 let remaining_data_len = self.buffer.len() - self.read_ptr;
229 if buf_available >= remaining_data_len {
230 buf[..remaining_data_len].copy_from_slice(&self.buffer[self.read_ptr..]);
231 self.read_ptr = 0;
232 return Ok(remaining_data_len);
233 } else {
234 buf.copy_from_slice(
235 &self.buffer[self.read_ptr..self.read_ptr + buf_available],
236 );
237 self.read_ptr += buf_available;
238 return Ok(buf_available);
239 }
240 }
241 match read_full(&mut self.buffer[wrap_pos..], self.underlying_reader) {
243 Ok(size) => {
244 let mut last_possible_match_start = wrap_pos;
245 if size < self.old_pattern.len() {
246 let eof_position = wrap_pos + size;
248 last_possible_match_start = eof_position - self.old_pattern.len() ;
249 self.eof_position = Some(eof_position);
250 }
251 let first_possible_match_start = if self.read_ptr<1 {0} else {self.read_ptr};
252 for guess_start in first_possible_match_start..last_possible_match_start {
253 if self.try_match_from(guess_start) {
254 self.next_match_ptr = Some(guess_start);
255 }
256 }
257 }
258 Err(e) => return Err(e),
259
260 };
261 self.state = ReplacingReaderState::LastReadIsMiddle;
262 }
263 ReplacingReaderState::LastReadIsMiddle => {
264 if self.read_ptr < wrap_pos {
265 let remaining_data_len = wrap_pos - self.read_ptr;
267 if buf_available >= remaining_data_len {
268 buf[..remaining_data_len]
269 .copy_from_slice(&self.buffer[self.read_ptr..wrap_pos]);
270 self.read_ptr = wrap_pos;
271 return Ok(remaining_data_len);
272 } else {
273 buf.copy_from_slice(
274 &self.buffer[self.read_ptr..self.read_ptr + buf_available],
275 );
276 self.read_ptr += buf_available;
277 return Ok(buf_available);
278 }
279 }
280 match read_full(&mut self.buffer[..wrap_pos], self.underlying_reader) {
281 Ok(size) => {
282 let first_possible_match_start = if self.read_ptr > wrap_pos {self.read_ptr} else {wrap_pos };
283 let mut last_possible_match_start = self.buffer.len();
284 if size < self.old_pattern.len() {
285 let eof_position = size;
286 last_possible_match_start =
287 self.buffer.len() - self.old_pattern.len() + size;
288 self.eof_position = Some(eof_position);
289 }
290 for guess_start in first_possible_match_start..last_possible_match_start {
291 if self.try_match_from(guess_start % self.buffer.len()) {
292 self.next_match_ptr = Some(guess_start % self.buffer.len());
293 }
294 }
295 }
296 Err(e) => return Err(e),
297 }
298 self.state = ReplacingReaderState::LastReadIsStart;
299 }
300 _ => panic!("unknown state"),
301 }
302
303 return self.read(buf);
304 }
305}
306
307#[cfg(test)]
308mod testconv {
309
310 mod test_replacing_reader {
311 use crate::conv::ReplacingReader;
312 use std::io::Read;
313 use std::fmt::Write;
314
315 fn run_string_through(input: String, old: String, new: String) -> String {
316 let mut input_bytes = input.as_bytes();
317 let mut reader = ReplacingReader::new(&mut input_bytes, old.as_bytes(), new.as_bytes());
318 let mut ret = String::new();
319 reader.read_to_string(&mut ret).unwrap();
320 ret
321 }
322
323
324 #[test]
325 fn test_varying_input_len() {
326 let input_pattern = "ab";
327 let old_pattern = "ab";
328 let new_pattern = "cd";
329 for input_len in 0..40 {
330 let mut input = input_pattern.repeat(input_len/2);
331 let mut expect = new_pattern.repeat(input_len/2);
332 if input_len %2 == 1 {
333 input.write_char(input_pattern.chars().nth(0).unwrap()).unwrap();
334 expect.write_char(input_pattern.chars().nth(0).unwrap()).unwrap();
335 }
336
337 assert_eq!(
338 run_string_through(input, String::from(old_pattern), String::from(new_pattern)),
339 expect,
340 );
341 }
342 }
343
344 #[test]
345 fn test_simple() {
346 let input = "abcabcabcabcabc";
347 let old = "ab";
348 let new = "cde";
349 let expect = "cdeccdeccdeccdeccdec";
350 assert_eq!(
351 run_string_through(String::from(input), String::from(old), String::from(new)),
352 String::from(expect)
353 );
354 }
355
356 #[test]
357 fn test_zero_new() {
358 let input = "abcabcabcabcabc";
359 let old = "ab";
360 let expect = "ccccc";
361 assert_eq!(
362 run_string_through(String::from(input), String::from(old), String::new()),
363 String::from(expect)
364 );
365 }
366
367 #[test]
368 fn test_insert_two_places() {
369 let base_str = String::from("012345678901234567890123456789");
370
371 for n_prefix in 0..5 {
372 for insert_len in 1..8usize {
373 for insert_pos_1 in 0..base_str.len() {
374 for insert_pos_2 in insert_pos_1+1..base_str.len() {
375 let mut insert_pattern = String::new();
376 for i in 0..insert_len {
377 insert_pattern.write_char(std::char::from_u32('a' as u32 + i as u32).unwrap()).unwrap();
378 }
379 let replace_to = String::from("test");
380
381 let mut input_str = "_".repeat(n_prefix);
382 let mut expect_str = "_".repeat(n_prefix);
383 input_str.write_str(&base_str[..insert_pos_1]).unwrap();
384 expect_str.write_str(&base_str[..insert_pos_1]).unwrap();
385
386 input_str.write_str(&insert_pattern).unwrap();
387 expect_str.write_str(&replace_to).unwrap();
388
389 input_str.write_str(&base_str[insert_pos_1..insert_pos_2]).unwrap();
390 expect_str.write_str(&base_str[insert_pos_1..insert_pos_2]).unwrap();
391
392 input_str.write_str(&insert_pattern).unwrap();
393 expect_str.write_str(&replace_to).unwrap();
394
395 input_str.write_str(&base_str[insert_pos_2..]).unwrap();
396 expect_str.write_str(&base_str[insert_pos_2..]).unwrap();
397
398 assert_eq!(run_string_through(input_str, insert_pattern, replace_to), expect_str);
399 }
400 }
401 }
402 }
403
404 }
405 }
406}