readfilter/
lib.rs

1//! # ReadFilter: Wrappers for `Read`able things.
2//!
3//! A collection of structs you can wrap around something `Read`able in order to transparently
4//! filter out unwanted content.
5use std::io::{self, Read};
6
7const BUF_SIZE: usize = 8 * 1024;
8//const BUF_SIZE: usize = 16;
9
10/// The common bits that each stream iterator/reader has, and the only bits
11/// needed to impl Read on each of them.
12struct Common<T>
13where
14    T: Read,
15{
16    /// The lower layer. The place where bytes are read from before processing.
17    source: T,
18    /// Stores and accumulates bytes read from source.
19    working_buf: [u8; BUF_SIZE],
20    /// The byte right after the last byte in the `working_buf`. The next byte that should be written
21    /// into, and the byte right after the last valid byte that can be read.  When this is 0,
22    /// `working_buf` is empty. Sometimes when we are read from, we might be told to fill a buffer
23    /// that is smaller than the amount of bytes we have read from the lower layer already. If this
24    /// ends up being the case, `.read()` stores a non-zero value here so `.next()` can be signaled
25    /// to not overwrite data at the very beginning of `working_buffer`.
26    unconsumed_bytes: usize,
27}
28
29impl<T> Common<T>
30where
31    T: Read,
32{
33    fn new(source: T) -> Self {
34        Self {
35            source,
36            working_buf: [0; BUF_SIZE],
37            unconsumed_bytes: 0,
38        }
39    }
40}
41
42/// Removes all non-whitelisted characters from the wrapped stream.
43///
44/// Non-utf8 characters are lost due to internal string conversions using
45/// [String::from_utf8_lossy()][String::from_utf8_lossy].
46///
47/// ```
48/// use readfilter::CharWhitelist;
49/// use std::io::Read;
50/// let buf = "aabbccddee".as_bytes();
51/// let mut wrapped = CharWhitelist::new(buf, "ace");
52/// let mut s = String::new();
53/// wrapped.read_to_string(&mut s).unwrap();
54/// assert_eq!(&s, "aaccee");
55/// ```
56///
57/// ```no_run
58/// use readfilter::CharWhitelist;
59/// use std::fs::OpenOptions;
60/// use std::io::Read;
61/// let mut output = vec![];
62/// let fd = OpenOptions::new().read(true).open("input.txt").unwrap();
63/// let mut input = CharWhitelist::new(fd, "01");
64/// input.read_to_end(&mut output);
65/// // output only contains the '0' and '1' characters from input.txt
66/// ```
67pub struct CharWhitelist<T>
68where
69    T: Read,
70{
71    common: Common<T>,
72    allowed_chars: Vec<char>,
73}
74
75impl<T> CharWhitelist<T>
76where
77    T: Read,
78{
79    pub fn new(source: T, allowed_chars: &str) -> Self {
80        Self {
81            common: Common::new(source),
82            allowed_chars: allowed_chars.chars().collect(),
83        }
84    }
85
86    fn next(&mut self) -> Option<io::Result<usize>> {
87        let mut read_previously = self.common.unconsumed_bytes;
88        // keep looping until we get an error, fail to read any bytes, or have
89        // read a full buffer
90        loop {
91            let read_this_time = self
92                .common
93                .source
94                .read(&mut self.common.working_buf[read_previously..]);
95            // return the error if there is one
96            if let Err(e) = read_this_time {
97                return Some(Err(e));
98            }
99            let read_this_time = read_this_time.unwrap();
100            // if we didn't read anything, time to stop looping
101            if read_this_time < 1 {
102                return Some(Ok(read_previously));
103            }
104            assert!(read_this_time > 0);
105            // convert [u8, BUF_SIZE] (with length read_this_time) to String
106            let buf = String::from_utf8_lossy(
107                &self.common.working_buf[read_previously..read_previously + read_this_time],
108            )
109            .into_owned();
110            for c in buf.chars() {
111                if self.allowed_chars.contains(&c) {
112                    let _ = c.encode_utf8(&mut self.common.working_buf[read_previously..]);
113                    read_previously += c.len_utf8();
114                }
115            }
116        }
117    }
118}
119
120/// Removes comments from the wrapped stream.
121///
122/// A comment starts with `#` and ends with a newline `\n` or the end of file.
123///
124/// ```
125/// use readfilter::CommentStrip;
126/// use std::io::Read;
127/// let mut input = CommentStrip::new("a\nb# foo\nc#bar".as_bytes());
128/// let mut output = String::new();
129/// input.read_to_string(&mut output);
130/// assert_eq!(output, "a\nbc");
131/// ```
132pub struct CommentStrip<T>
133where
134    T: Read,
135{
136    common: Common<T>,
137    /// Sometimes reads from source will end with a comment that isn't finished
138    /// yet. This flag is used to keep track of whether or not we need to keep
139    /// ignoring bytes until the comment ends (i.e. we see a newline)
140    ignore_until_next_newline: bool,
141}
142
143impl<T> CommentStrip<T>
144where
145    T: Read,
146{
147    pub fn new(source: T) -> Self {
148        Self {
149            common: Common::new(source),
150            ignore_until_next_newline: false,
151        }
152    }
153
154    fn next(&mut self) -> Option<io::Result<usize>> {
155        let mut read_previously = self.common.unconsumed_bytes;
156        // keep looping until we get an error, fail to read any bytes, or have
157        // read a full buffer
158        loop {
159            let read_this_time = self
160                .common
161                .source
162                .read(&mut self.common.working_buf[read_previously..]);
163            // return the error if there is one
164            if let Err(e) = read_this_time {
165                return Some(Err(e));
166            }
167            let read_this_time = read_this_time.unwrap();
168            // if we didn't read anything, time to stop looping
169            if read_this_time < 1 {
170                return Some(Ok(read_previously));
171            }
172            assert!(read_this_time > 0);
173            // convert [u8, BUF_SIZE] (with length read_this_time) to String
174            let mut buf = String::from_utf8_lossy(
175                &self.common.working_buf[read_previously..read_previously + read_this_time],
176            )
177            .into_owned();
178            // if needed, ignore bytes up through a newline
179            buf = if self.ignore_until_next_newline {
180                match buf.find('\n') {
181                    // found a newline. ignore bytes up through it, keep bytes after it, and unset
182                    // ignore_until_next_newline
183                    Some(idx) => {
184                        self.ignore_until_next_newline = false;
185                        buf[idx + 1..].to_string()
186                    }
187                    // didn't find newline. ignore all bytes
188                    None => String::new(),
189                }
190            } else {
191                buf
192            };
193            // if buffer empty, end early
194            if buf.is_empty() {
195                return Some(Ok(read_previously));
196            }
197            // loop until no more comments in buf
198            loop {
199                // look for comment character. if found, ignore bytes after it until newline
200                let start_idx = buf.find('#');
201                if start_idx.is_none() {
202                    break;
203                }
204                let start_idx = start_idx.unwrap();
205                buf = match buf[start_idx..].find('\n') {
206                    // newline found. ignore bytes between comment char and newline
207                    Some(end_idx) => {
208                        String::from(&buf[..start_idx]) + &buf[start_idx + end_idx + 1..]
209                    }
210                    // no newline found. ignore all bytes after comment char and set flag to keep
211                    // ignoring on next loop
212                    None => {
213                        self.ignore_until_next_newline = true;
214                        buf[..start_idx].to_string()
215                    }
216                };
217            }
218            let remaining_len = buf.len();
219            if remaining_len > 0 {
220                self.common.working_buf[read_previously..read_previously + remaining_len]
221                    .copy_from_slice(buf.as_bytes());
222                read_previously += remaining_len;
223            }
224        }
225    }
226}
227
228/// Each type impls Read. As all the hard work is done in self.next(), this can
229/// be generalized. .read() is probably how the user should be using these types.
230macro_rules! impl_read_trait_for_stream_iter {
231    ($MyType:ty) => {
232        impl<T> Read for $MyType
233        where
234            T: Read,
235        {
236            fn read(&mut self, out_buf: &mut [u8]) -> io::Result<usize> {
237                let mut bytes_given = 0;
238                if self.common.unconsumed_bytes >= out_buf.len() {
239                    // We have more data already buffered than the user wants to read.
240                    // 1. Copy to them the max amount of data
241                    // 2. Update our buffer so that it starts with the buffered bytes
242                    // right after the ones we just gave them
243                    // 3. Update the length of our buffer
244                    // Then we're done. We shouldn't read anything more because we can't
245                    // even give it to them yet.
246                    let out_buf_len = out_buf.len();
247                    out_buf[..out_buf_len].copy_from_slice(&self.common.working_buf[..out_buf_len]);
248                    self.common
249                        .working_buf
250                        .copy_within(out_buf_len..self.common.unconsumed_bytes, 0);
251                    self.common.unconsumed_bytes -= out_buf_len;
252                    return Ok(out_buf.len());
253                } else {
254                    // We have less data already buffered than what the user wants to read.
255                    // 1. Copy to them all that we have.
256                    // 2. Update the length of our buffer to be zero.
257                    // 3. Note that we've given them some bytes.
258                    // Continue with this function. We might be able to give them more.
259                    out_buf[..self.common.unconsumed_bytes]
260                        .copy_from_slice(&self.common.working_buf[..self.common.unconsumed_bytes]);
261                    bytes_given += self.common.unconsumed_bytes;
262                    self.common.unconsumed_bytes = 0;
263                }
264                assert_eq!(self.common.unconsumed_bytes, 0);
265                // If we're here, then we must need to read some more bytes and give
266                // them to the out_buf. First try to read more.
267                let next_res = self.next();
268                if let Some(Ok(working_buf_len)) = next_res {
269                    // We successfully read something. working_buf_len is the number of bytes at
270                    // the front of the working buf that are valid. We have access to the working
271                    // buf since it is ours.
272                    let max_bytes_to_give = out_buf.len() - bytes_given;
273                    if working_buf_len >= max_bytes_to_give {
274                        // If we read too many bytes, then
275                        // 1. Give as many as possible to the out_buf
276                        // 2. Move the remaining working_buf bytes to the front of the working_buf
277                        // 3. Update the len of the working_buf And then we're done and can return.
278                        out_buf[bytes_given..]
279                            .copy_from_slice(&self.common.working_buf[..max_bytes_to_give]);
280                        self.common
281                            .working_buf
282                            .copy_within(max_bytes_to_give..working_buf_len, 0);
283                        self.common.unconsumed_bytes = working_buf_len - max_bytes_to_give;
284                        bytes_given += max_bytes_to_give;
285                        return Ok(bytes_given);
286                    } else {
287                        // We read fewer bytes than there is remaining space in out_buf. We can
288                        // give it all the bytes. For simplicity, just return after doing so. We
289                        // could loop around and do all this again.
290                        out_buf[bytes_given..bytes_given + working_buf_len]
291                            .copy_from_slice(&self.common.working_buf[..working_buf_len]);
292                        bytes_given += working_buf_len;
293                        return Ok(bytes_given);
294                    }
295                } else if let Some(Err(e)) = next_res {
296                    return Err(e);
297                } else {
298                    // We failed to iterate forward in the stream. Must be done.
299                    return Ok(bytes_given);
300                }
301            }
302        }
303    };
304}
305
306impl_read_trait_for_stream_iter!(CharWhitelist<T>);
307impl_read_trait_for_stream_iter!(CommentStrip<T>);
308
309#[cfg(test)]
310fn read_to_string(mut buf: impl Read) -> String {
311    let mut s = String::new();
312    buf.read_to_string(&mut s).unwrap();
313    s
314}
315
316#[cfg(test)]
317mod test_comment_strip_iter {
318    use super::{read_to_string, CommentStrip, BUF_SIZE};
319
320    #[test]
321    fn empty_is_empty() {
322        let s = "".as_bytes();
323        assert_eq!(read_to_string(CommentStrip::new(s)).len(), 0);
324    }
325
326    #[test]
327    fn ignore_all_short() {
328        for s in &["#foo baz", "#foo baz\n", "#    ", "#    \n", "#", "#\n"] {
329            assert_eq!(read_to_string(CommentStrip::new(s.as_bytes())).len(), 0);
330        }
331    }
332
333    #[test]
334    fn ignore_all_long_1() {
335        // just less than a full buffer
336        let mut s = vec!['#' as u8];
337        s.append(&mut vec![' ' as u8; BUF_SIZE - 2]);
338        assert_eq!(s.len(), BUF_SIZE - 1);
339        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
340
341        // exactly a full buffer
342        let mut s = vec!['#' as u8];
343        s.append(&mut vec![' ' as u8; BUF_SIZE - 1]);
344        assert_eq!(s.len(), BUF_SIZE);
345        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
346
347        // just over a full buffer
348        let mut s = vec!['#' as u8];
349        s.append(&mut vec![' ' as u8; BUF_SIZE]);
350        assert_eq!(s.len(), BUF_SIZE + 1);
351        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
352
353        // over 2 buffers in size
354        let mut s = vec!['#' as u8];
355        s.append(&mut vec![' ' as u8; BUF_SIZE * 2 + 2]);
356        assert_eq!(s.len(), BUF_SIZE * 2 + 3);
357        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
358    }
359
360    #[test]
361    fn ignore_all_long_2() {
362        // just less than a full buffer
363        let mut s = vec!['#' as u8];
364        s.append(&mut vec![' ' as u8; BUF_SIZE - 3]);
365        s.push('\n' as u8);
366        assert_eq!(s.len(), BUF_SIZE - 1);
367        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
368
369        // exactly a full buffer
370        let mut s = vec!['#' as u8];
371        s.append(&mut vec![' ' as u8; BUF_SIZE - 2]);
372        s.push('\n' as u8);
373        assert_eq!(s.len(), BUF_SIZE);
374        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
375
376        // just over a full buffer
377        let mut s = vec!['#' as u8];
378        s.append(&mut vec![' ' as u8; BUF_SIZE - 1]);
379        s.push('\n' as u8);
380        assert_eq!(s.len(), BUF_SIZE + 1);
381        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
382
383        // over 2 buffers in size
384        let mut s = vec!['#' as u8];
385        s.append(&mut vec![' ' as u8; BUF_SIZE * 2 + 1]);
386        s.push('\n' as u8);
387        assert_eq!(s.len(), BUF_SIZE * 2 + 3);
388        assert_eq!(read_to_string(CommentStrip::new(&s[..])).len(), 0);
389    }
390
391    #[test]
392    fn keep_end_short() {
393        for s in &["#\nfoo", "#  \nfoo"] {
394            let out = read_to_string(CommentStrip::new(s.as_bytes()));
395            assert_eq!(out, "foo");
396        }
397
398        for s in &["#\nfoo  foo", "#  \nfoo  foo"] {
399            let out = read_to_string(CommentStrip::new(s.as_bytes()));
400            assert_eq!(out, "foo  foo");
401        }
402
403        for s in &["#\nfoo \n foo", "#  \nfoo \n foo"] {
404            let out = read_to_string(CommentStrip::new(s.as_bytes()));
405            assert_eq!(out, "foo \n foo");
406        }
407    }
408
409    #[test]
410    fn keep_end_long() {
411        let content = " foo \n foo ";
412
413        // just under BUF_SIZE
414        let mut s = vec!['#' as u8; BUF_SIZE - content.len() - 1 - 1];
415        s.push('\n' as u8);
416        for c in content.chars() {
417            s.push(c as u8);
418        }
419        assert_eq!(s.len(), BUF_SIZE - 1);
420        assert_eq!(read_to_string(CommentStrip::new(&s[..])), content);
421
422        // equal to BUF_SIZE
423        let mut s = vec!['#' as u8; BUF_SIZE - content.len() - 1];
424        s.push('\n' as u8);
425        for c in content.chars() {
426            s.push(c as u8);
427        }
428        assert_eq!(s.len(), BUF_SIZE);
429        assert_eq!(read_to_string(CommentStrip::new(&s[..])), content);
430
431        // just over BUF_SIZE
432        let mut s = vec!['#' as u8; BUF_SIZE - content.len() - 1 + 1];
433        s.push('\n' as u8);
434        for c in content.chars() {
435            s.push(c as u8);
436        }
437        assert_eq!(s.len(), BUF_SIZE + 1);
438        assert_eq!(read_to_string(CommentStrip::new(&s[..])), content);
439
440        // comment is over BUF_SIZE by itself
441        let mut s = vec!['#' as u8; BUF_SIZE + 1];
442        s.push('\n' as u8);
443        for c in content.chars() {
444            s.push(c as u8);
445        }
446        assert_eq!(s.len(), BUF_SIZE + 2 + content.len());
447        assert_eq!(read_to_string(CommentStrip::new(&s[..])), content);
448    }
449}
450
451#[cfg(test)]
452mod test_char_whitelist_iter {
453    use super::{read_to_string, CharWhitelist};
454
455    #[test]
456    fn empty_whitelist() {
457        let in_buf = "A\u{00a1}\u{01d6a9}".as_bytes();
458        assert_eq!(read_to_string(CharWhitelist::new(in_buf, "")).len(), 0);
459    }
460
461    #[test]
462    fn whitelist_allows_all() {
463        let s = "A\u{00a1}\u{1d6a9}";
464        assert_eq!(read_to_string(CharWhitelist::new(s.as_bytes(), s)), s);
465    }
466
467    #[test]
468    fn whitelist_allows_single() {
469        for allowed in vec!["A", "\u{00a1}", "\u{1d6a9}"] {
470            let in_buf = "A\u{00a1}\u{1d6a9}".as_bytes();
471            assert_eq!(read_to_string(CharWhitelist::new(in_buf, allowed)), allowed);
472        }
473    }
474}
475
476#[cfg(test)]
477mod test_comment_strip_iter_read {
478    use super::{CommentStrip, BUF_SIZE};
479    use std::io::Read;
480
481    #[test]
482    fn just_comment_returns_empty() {
483        let in_buf = "# foo \n".as_bytes();
484        let mut out_buf = [0; 1];
485        let mut csi = CommentStrip::new(&in_buf[..]);
486        let len = csi.read(&mut out_buf).unwrap();
487        assert_eq!(len, 0);
488    }
489
490    #[test]
491    fn just_byte_before_comment() {
492        let in_buf = "a# foo \n".as_bytes();
493        let mut out_buf = [0; BUF_SIZE];
494        let mut csi = CommentStrip::new(&in_buf[..]);
495        let len = csi.read(&mut out_buf).unwrap();
496        assert_eq!(String::from_utf8_lossy(&out_buf[..len]), "a");
497    }
498
499    #[test]
500    fn just_byte_after_comment() {
501        let in_buf = "# foo \na".as_bytes();
502        let mut out_buf = [0; BUF_SIZE];
503        let mut csi = CommentStrip::new(&in_buf[..]);
504        let len = csi.read(&mut out_buf).unwrap();
505        assert_eq!(String::from_utf8_lossy(&out_buf[..len]), "a");
506    }
507
508    #[test]
509    fn just_byte_before_and_after_comment() {
510        let in_buf = "a# foo \nB".as_bytes();
511        let mut out_buf = [0; BUF_SIZE];
512        let mut csi = CommentStrip::new(&in_buf[..]);
513        let len = csi.read(&mut out_buf).unwrap();
514        assert_eq!(String::from_utf8_lossy(&out_buf[..len]), "aB");
515    }
516}
517
518macro_rules! impl_tests_for_common_read {
519    ($mod_name:ident, $MyType:ident) => {
520        #[cfg(test)]
521        mod $mod_name {
522            use super::{$MyType, BUF_SIZE};
523            use std::io::Read;
524
525            #[test]
526            fn empty() {
527                let in_buf = vec![];
528                let mut out_buf = [0; 1];
529                let mut csi = $MyType::new(&in_buf[..]);
530                let len = csi.read(&mut out_buf).unwrap();
531                assert_eq!(len, 0);
532            }
533
534            #[test]
535            fn many_tiny_reads() {
536                let in_buf = "abc123".as_bytes();
537                let mut out_buf = [0; 1];
538                let mut acc = String::new();
539                let mut csi = $MyType::new(&in_buf[..]);
540                for _ in 0..in_buf.len() {
541                    let len = csi.read(&mut out_buf).unwrap();
542                    acc += &String::from_utf8_lossy(&out_buf[..len]);
543                }
544                assert_eq!(acc, "abc123");
545                // future reads should be length zero
546                assert_eq!(csi.read(&mut out_buf).unwrap(), 0);
547            }
548
549            #[test]
550            fn big_inbuf_tiny_outbuf() {
551                let mut in_buf = vec!['a' as u8; BUF_SIZE / 2];
552                in_buf.append(&mut vec!['b' as u8; BUF_SIZE / 2]);
553                in_buf.append(&mut vec!['c' as u8; BUF_SIZE / 2]);
554                in_buf.append(&mut vec!['d' as u8; BUF_SIZE / 2]);
555                in_buf.append(&mut vec!['e' as u8; BUF_SIZE / 2]);
556                let mut out_buf = [0; 2];
557                let mut acc = String::new();
558                let mut csi = $MyType::new(&in_buf[..]);
559                loop {
560                    let len = csi.read(&mut out_buf).unwrap();
561                    acc += &String::from_utf8_lossy(&out_buf[..len]);
562                    if len < 1 {
563                        break;
564                    }
565                }
566                assert_eq!(acc, String::from_utf8_lossy(&in_buf[..]));
567            }
568
569            #[test]
570            fn big_inbuf_just_smaller_outbuf() {
571                let mut in_buf = vec!['a' as u8; BUF_SIZE / 2];
572                in_buf.append(&mut vec!['b' as u8; BUF_SIZE / 2]);
573                in_buf.append(&mut vec!['c' as u8; BUF_SIZE / 2]);
574                in_buf.append(&mut vec!['d' as u8; BUF_SIZE / 2]);
575                assert_eq!(in_buf.len(), BUF_SIZE * 2);
576                let mut out_buf = [0; BUF_SIZE * 2 - 1];
577                let mut acc = String::new();
578                let mut csi = $MyType::new(&in_buf[..]);
579                loop {
580                    let len = csi.read(&mut out_buf).unwrap();
581                    acc += &String::from_utf8_lossy(&out_buf[..len]);
582                    if len < 1 {
583                        break;
584                    }
585                }
586                assert_eq!(acc, String::from_utf8_lossy(&in_buf[..]));
587            }
588
589            #[test]
590            fn big_inbuf_just_larger_outbuf() {
591                let mut in_buf = vec!['a' as u8; BUF_SIZE / 2];
592                in_buf.append(&mut vec!['b' as u8; BUF_SIZE / 2]);
593                in_buf.append(&mut vec!['c' as u8; BUF_SIZE / 2]);
594                in_buf.append(&mut vec!['d' as u8; BUF_SIZE / 2]);
595                assert_eq!(in_buf.len(), BUF_SIZE * 2);
596                let mut out_buf = [0; BUF_SIZE * 2 + 1];
597                let mut acc = String::new();
598                let mut csi = $MyType::new(&in_buf[..]);
599                loop {
600                    let len = csi.read(&mut out_buf).unwrap();
601                    acc += &String::from_utf8_lossy(&out_buf[..len]);
602                    if len < 1 {
603                        break;
604                    }
605                }
606                assert_eq!(acc, String::from_utf8_lossy(&in_buf[..]));
607            }
608        }
609    };
610}
611
612impl_tests_for_common_read!(test_read_common_with_comment_strip_iter, CommentStrip);
613// can't test because its .new() requires more than just the source
614//impl_tests_for_common_read!(test_read_common_with_char_whitelist_iter, CharWhitelist);