Skip to main content

easyio/conv/
mod.rs

1use crate::read_full;
2use std::io::{self, Read};
3
4enum ReplacingReaderState {
5    // the buffer has not been initialized yet
6    NotInitialized,
7
8    // the buffer is in this sequence: [4 5 6 7 0 1 2 3]
9    LastReadIsMiddle,
10
11    // the buffer is in this sequence: [0 1 2 3 4 5 6 7]
12    LastReadIsStart,
13}
14
15/// ReplacingReader wraps around an underlying reader and transiently replaces given patterns in the read.
16///
17/// The pattern must no overlap, in such case the behavior is undefined.
18/// The internal buffer is 2 * len(old_pattern), caller can wrap std::io::BufReader if more buffer is required.
19///
20/// A runtime panic will be thrown if old.len() == 0.
21pub struct ReplacingReader<'a> {
22    underlying_reader: &'a mut dyn Read,
23    // buffer is separated into two parts and has a capacity of 2 * old_pattern.len()
24    //
25    // buffer:         X X X A | B C X X
26    // next_match_ptr:       *
27    // read_ptr:       *
28    // next time when read_ptr is about to hit next_match_ptr, we transition to feed new to read() call
29    buffer: Vec<u8>,
30    old_pattern: &'a [u8],
31    new_pattern: &'a [u8],
32    read_ptr: usize,
33
34    state: ReplacingReaderState,
35
36    // this is the location of eof in the buffer, if already met
37    // the last byte should be buffer[eof_position - 1]
38    eof_position: Option<usize>,
39
40    // this is the location of the next match, if present
41    next_match_ptr: Option<usize>,
42
43    // if this is Some, we are in progress of serving from new_pattern,
44    // this should be set to None when serve_new_ptr == Some(new_pattern.size())
45    serve_new_ptr: Option<usize>,
46}
47
48impl ReplacingReader<'_> {
49    pub fn new<'a>(r: &'a mut dyn Read, old: &'a [u8], new: &'a [u8]) -> ReplacingReader<'a> {
50        if old.len() ==  0 { panic!("old pattern can not be empty") };
51
52        let buffer = vec![0; 2 * old.len()];
53        ReplacingReader {
54            underlying_reader: r,
55            old_pattern: old,
56            new_pattern: new,
57            read_ptr: 0,
58            buffer: buffer,
59            state: ReplacingReaderState::NotInitialized,
60            eof_position: None,
61
62            next_match_ptr: None,
63            serve_new_ptr: None,
64        }
65    }
66
67    #[inline(always)]
68    fn try_match_from(&self, start: usize) -> bool {
69        let mut ptr = start;
70        let mut match_len = 0usize;
71        loop {
72            if match_len == self.old_pattern.len() {
73                return true;
74            }
75            if self.buffer[ptr] == self.old_pattern[match_len] {
76                match_len += 1;
77                ptr += 1;
78                if ptr == self.buffer.len() {
79                    ptr = 0;
80                }
81            } else {
82                return false;
83            }
84        }
85    }
86}
87
88impl Read for ReplacingReader<'_> {
89    fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
90        let buf_available = buf.len();
91        // first check if we are already serving new_pattern
92        if let Some(new_ptr) = self.serve_new_ptr {
93            let remaining_new_pattern_len = self.new_pattern.len() - new_ptr;
94            if remaining_new_pattern_len > buf_available {
95                buf.copy_from_slice(&self.new_pattern[new_ptr..new_ptr + buf_available]);
96                self.serve_new_ptr = Some(new_ptr + buf_available);
97                return Ok(buf_available);
98            } else if remaining_new_pattern_len > 0 {
99                buf[..remaining_new_pattern_len].copy_from_slice(&self.new_pattern[new_ptr..]);
100                self.serve_new_ptr = None;
101                return Ok(remaining_new_pattern_len);
102            }
103        }
104
105        // then, if this read is going to enter self.next_match_ptr?
106        if let Some(next_match_ptr) = self.next_match_ptr {
107            if next_match_ptr > self.read_ptr {
108                let remaining_buf_available = next_match_ptr - self.read_ptr;
109                if buf_available >= remaining_buf_available {
110                    // we can read until start of match
111                    buf[..remaining_buf_available]
112                        .copy_from_slice(&self.buffer[self.read_ptr..next_match_ptr]);
113                    self.serve_new_ptr = Some(0);
114                    self.read_ptr = next_match_ptr + self.old_pattern.len();
115                    if self.read_ptr >= self.buffer.len() {
116                        self.read_ptr -= self.buffer.len();
117                    }
118                    self.next_match_ptr = None;
119                    return Ok(remaining_buf_available);
120                } else {
121                    buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
122                    self.read_ptr += buf_available;
123                    return Ok(buf_available);
124                }
125            } else if next_match_ptr == self.read_ptr {
126                self.serve_new_ptr = Some(0);
127                self.read_ptr += self.old_pattern.len() ;
128                if self.read_ptr >= self.buffer.len() {
129                    self.read_ptr -= self.buffer.len();
130                }
131                self.next_match_ptr = None;
132                return self.read(buf);
133            } {
134                let remaining_buf_available = self.buffer.len() - self.read_ptr;
135                if buf_available >= remaining_buf_available {
136                    buf[..remaining_buf_available].copy_from_slice(&self.buffer[self.read_ptr..]);
137                    self.read_ptr = 0;
138                    return Ok(remaining_buf_available);
139                } else {
140                    buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
141                    self.read_ptr += buf_available;
142                    return Ok(buf_available);
143                }
144            }
145        }
146
147        // initialize the buffer first
148        match self.state {
149            ReplacingReaderState::NotInitialized => {
150                // first we make a full read to fill the buffer
151                match read_full(&mut self.buffer, self.underlying_reader) {
152                    Ok(read_len) => {
153                        if read_len < self.buffer.len() {
154                            // we already hit eof
155                            self.eof_position = Some(read_len);
156                        }
157                        if read_len >= self.old_pattern.len() {
158                            let possible_match_start = read_len - self.old_pattern.len();
159                            for guess_start in 0..possible_match_start {
160                                if self.try_match_from(guess_start) {
161                                    self.next_match_ptr = Some(guess_start);
162                                    break;
163                                }
164                            }
165                        }
166
167                        self.state = ReplacingReaderState::LastReadIsMiddle;
168                        return self.read(buf);
169                    }
170                    Err(e) => return Err(e),
171                };
172            }
173            _ => (),
174        };
175
176        // if we are at the end of stream and no patterns were found, nothing to do except serve the last bit of stream until end.
177        if let Some(eof_position) = self.eof_position {
178            // remaining buffer is from read_ptr to eof_position
179            if eof_position < self.read_ptr {
180                // read at most into the end of buffer
181                let max_read_size = self.buffer.len() - self.read_ptr;
182                if max_read_size >= self.old_pattern.len() {
183                    for guess_start in self.read_ptr..self.read_ptr + 1 + max_read_size - self.old_pattern.len() {
184                        if self.try_match_from(guess_start) {
185                            self.next_match_ptr = Some(guess_start % self.buffer.len());
186                            return self.read(buf);
187                        }
188                    }
189                }
190                if max_read_size > buf_available {
191                    buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
192                    self.read_ptr += buf_available;
193                    return Ok(buf_available);
194                } else {
195                    buf[..max_read_size].copy_from_slice(&self.buffer[self.read_ptr..]);
196                    self.read_ptr = 0;
197                    return Ok(max_read_size);
198                }
199            } else if eof_position == self.read_ptr {
200                return Ok(0);
201            } else {
202                let max_read_size = eof_position - self.read_ptr;
203                if max_read_size >= self.old_pattern.len() {
204                    for guess_start in self.read_ptr..self.read_ptr + 1 + max_read_size - self.old_pattern.len() {
205                        if self.try_match_from(guess_start) {
206                            self.next_match_ptr = Some(guess_start);
207                            return self.read(buf);
208                        }
209                    }
210                }
211                if max_read_size > buf_available {
212                    buf.copy_from_slice(&self.buffer[self.read_ptr..self.read_ptr + buf_available]);
213                    self.read_ptr += buf_available;
214                    return Ok(buf_available);
215                } else {
216                    buf[..max_read_size].copy_from_slice(&self.buffer[self.read_ptr..eof_position]);
217                    self.read_ptr += max_read_size;
218                    return Ok(max_read_size);
219                }
220            }
221        }
222
223        // here is the general case: either serve until the older half of buffer was empty or we advance buffer and do the actual pattern matching
224        let wrap_pos = self.old_pattern.len();
225        match self.state {
226            ReplacingReaderState::LastReadIsStart => {
227                if self.read_ptr >= wrap_pos {
228                    let remaining_data_len = self.buffer.len() - self.read_ptr;
229                    if buf_available >= remaining_data_len {
230                        buf[..remaining_data_len].copy_from_slice(&self.buffer[self.read_ptr..]);
231                        self.read_ptr = 0;
232                        return Ok(remaining_data_len);
233                    } else {
234                        buf.copy_from_slice(
235                            &self.buffer[self.read_ptr..self.read_ptr + buf_available],
236                        );
237                        self.read_ptr += buf_available;
238                        return Ok(buf_available);
239                    }
240                }
241                // next we read from the middle
242                match read_full(&mut self.buffer[wrap_pos..], self.underlying_reader) {
243                    Ok(size) => {
244                        let mut last_possible_match_start = wrap_pos;
245                        if size < self.old_pattern.len() {
246                            // eof is met, set eof position
247                            let eof_position = wrap_pos + size;
248                            last_possible_match_start = eof_position - self.old_pattern.len()  ;
249                            self.eof_position = Some(eof_position);
250                        }
251                        let first_possible_match_start = if self.read_ptr<1 {0} else {self.read_ptr};
252                        for guess_start in first_possible_match_start..last_possible_match_start {
253                            if self.try_match_from(guess_start) {
254                                self.next_match_ptr = Some(guess_start);
255                            }
256                        }
257                    }
258                    Err(e) => return Err(e),
259
260                };
261                self.state = ReplacingReaderState::LastReadIsMiddle;
262            }
263            ReplacingReaderState::LastReadIsMiddle => {
264                if self.read_ptr < wrap_pos {
265                    // we still need to serve up to wrap_pos
266                    let remaining_data_len = wrap_pos - self.read_ptr;
267                    if buf_available >= remaining_data_len {
268                        buf[..remaining_data_len]
269                            .copy_from_slice(&self.buffer[self.read_ptr..wrap_pos]);
270                        self.read_ptr = wrap_pos;
271                        return Ok(remaining_data_len);
272                    } else {
273                        buf.copy_from_slice(
274                            &self.buffer[self.read_ptr..self.read_ptr + buf_available],
275                        );
276                        self.read_ptr += buf_available;
277                        return Ok(buf_available);
278                    }
279                }
280                match read_full(&mut self.buffer[..wrap_pos], self.underlying_reader) {
281                    Ok(size) => {
282                        let first_possible_match_start =  if self.read_ptr > wrap_pos {self.read_ptr} else {wrap_pos };
283                        let mut last_possible_match_start = self.buffer.len();
284                        if size < self.old_pattern.len() {
285                            let eof_position = size;
286                            last_possible_match_start =
287                                self.buffer.len() - self.old_pattern.len() + size;
288                            self.eof_position = Some(eof_position);
289                        }
290                        for guess_start in first_possible_match_start..last_possible_match_start {
291                            if self.try_match_from(guess_start % self.buffer.len()) {
292                                self.next_match_ptr = Some(guess_start % self.buffer.len());
293                            }
294                        }
295                    }
296                    Err(e) => return Err(e),
297                }
298                self.state = ReplacingReaderState::LastReadIsStart;
299            }
300            _ => panic!("unknown state"),
301        }
302
303        return self.read(buf);
304    }
305}
306
307#[cfg(test)]
308mod testconv {
309
310    mod test_replacing_reader {
311        use crate::conv::ReplacingReader;
312        use std::io::Read;
313        use std::fmt::Write;
314
315        fn run_string_through(input: String, old: String, new: String) -> String {
316            let mut input_bytes = input.as_bytes();
317            let mut reader = ReplacingReader::new(&mut input_bytes, old.as_bytes(), new.as_bytes());
318            let mut ret = String::new();
319            reader.read_to_string(&mut ret).unwrap();
320            ret
321        }
322
323
324        #[test]
325        fn test_varying_input_len() {
326            let input_pattern = "ab";
327            let old_pattern = "ab";
328            let new_pattern = "cd";
329            for input_len in 0..40 {
330                let mut input = input_pattern.repeat(input_len/2);
331                let mut expect = new_pattern.repeat(input_len/2);
332                if input_len %2 == 1 {
333                    input.write_char(input_pattern.chars().nth(0).unwrap()).unwrap();
334                    expect.write_char(input_pattern.chars().nth(0).unwrap()).unwrap();
335                }
336
337                assert_eq!(
338                    run_string_through(input, String::from(old_pattern), String::from(new_pattern)),
339                    expect,
340                );
341            }
342        }
343
344        #[test]
345        fn test_simple() {
346            let input = "abcabcabcabcabc";
347            let old = "ab";
348            let new = "cde";
349            let expect = "cdeccdeccdeccdeccdec";
350            assert_eq!(
351                run_string_through(String::from(input), String::from(old), String::from(new)),
352                String::from(expect)
353            );
354        }
355
356        #[test]
357        fn test_zero_new() {
358            let input = "abcabcabcabcabc";
359            let old = "ab";
360            let expect = "ccccc";
361            assert_eq!(
362                run_string_through(String::from(input), String::from(old), String::new()),
363                String::from(expect)
364            );
365        }
366
367        #[test]
368        fn test_insert_two_places() {
369            let base_str = String::from("012345678901234567890123456789");
370
371            for n_prefix in 0..5 {
372                for insert_len in 1..8usize {
373                    for insert_pos_1 in 0..base_str.len() {
374                        for insert_pos_2 in insert_pos_1+1..base_str.len() {
375                            let mut insert_pattern = String::new();
376                            for i in 0..insert_len {
377                                insert_pattern.write_char(std::char::from_u32('a' as u32 + i as u32).unwrap()).unwrap();
378                            }
379                            let replace_to = String::from("test");
380
381                            let mut input_str = "_".repeat(n_prefix);
382                            let mut expect_str = "_".repeat(n_prefix);
383                            input_str.write_str(&base_str[..insert_pos_1]).unwrap();
384                            expect_str.write_str(&base_str[..insert_pos_1]).unwrap();
385
386                            input_str.write_str(&insert_pattern).unwrap();
387                            expect_str.write_str(&replace_to).unwrap();
388
389                            input_str.write_str(&base_str[insert_pos_1..insert_pos_2]).unwrap();
390                            expect_str.write_str(&base_str[insert_pos_1..insert_pos_2]).unwrap();
391
392                            input_str.write_str(&insert_pattern).unwrap();
393                            expect_str.write_str(&replace_to).unwrap();
394
395                            input_str.write_str(&base_str[insert_pos_2..]).unwrap();
396                            expect_str.write_str(&base_str[insert_pos_2..]).unwrap();
397
398                            assert_eq!(run_string_through(input_str, insert_pattern, replace_to), expect_str);
399                        }
400                    }
401                }
402            }
403
404        }
405    }
406}