replacing_buf_stream/
lib.rs

1use bytes::{Buf, Bytes, BytesMut};
2use futures::Async;
3use tower_web::util::BufStream;
4
5pub struct ReplacingBufStream<Stream: BufStream<Item = std::io::Cursor<Bytes>>> {
6    stream: Stream,
7    to_be_replaced: Bytes,
8    replacement: Bytes,
9    bufs: Vec<std::io::Cursor<Bytes>>,
10}
11
12impl<Stream: BufStream<Item = std::io::Cursor<Bytes>>> ReplacingBufStream<Stream> {
13    pub fn new(
14        inner: Stream,
15        to_be_replaced: Bytes,
16        replacement: Bytes,
17    ) -> ReplacingBufStream<Stream> {
18        ReplacingBufStream {
19            stream: inner,
20            to_be_replaced,
21            replacement,
22            bufs: Vec::new(),
23        }
24    }
25}
26
27impl<Stream: BufStream<Item = std::io::Cursor<Bytes>>> BufStream for ReplacingBufStream<Stream>
28where
29    <Stream as tower_web::util::BufStream>::Error: std::fmt::Debug,
30{
31    type Item = std::io::Cursor<bytes::Bytes>;
32    type Error = Stream::Error;
33
34    fn poll(&mut self) -> futures::Poll<Option<Self::Item>, Self::Error> {
35        match non_matching_prefix_length(&self.bufs, &self.to_be_replaced.as_ref()) {
36            Action::Return(size) => {
37                if self.bufs[0].bytes().len() == size {
38                    Ok(Async::Ready(Some(self.bufs.remove(0))))
39                } else {
40                    let buf = self.bufs.remove(0);
41                    let mut ret = buf.into_inner();
42                    self.bufs
43                        .insert(0, std::io::Cursor::new(ret.split_off(size)));
44                    Ok(Async::Ready(Some(std::io::Cursor::new(ret))))
45                }
46            }
47            Action::Replace => {
48                let mut len = 0;
49                let mut buf = None;
50                while len < self.to_be_replaced.len() {
51                    let buf2 = self.bufs.remove(0);
52                    len += buf2.bytes().len();
53                    buf = Some(buf2)
54                }
55                let buf = buf.unwrap();
56                if len > self.to_be_replaced.len() {
57                    let mut inner = buf.into_inner();
58                    self.bufs.insert(
59                        0,
60                        std::io::Cursor::new(
61                            inner.split_off(inner.len() + self.to_be_replaced.len() - len),
62                        ),
63                    );
64                }
65                Ok(Async::Ready(Some(std::io::Cursor::new(
66                    self.replacement.clone(),
67                ))))
68            }
69            Action::Read => match self.stream.poll() {
70                Ok(Async::Ready(Some(buf))) => {
71                    self.bufs.push(buf);
72                    self.poll()
73                }
74                Ok(Async::Ready(None)) => {
75                    if !self.bufs.is_empty() {
76                        return Ok(Async::Ready(Some(self.bufs.remove(0))));
77                    }
78                    Ok(Async::Ready(None))
79                }
80                Ok(Async::NotReady) => Ok(Async::NotReady),
81                Err(err) => Err(err),
82            },
83        }
84    }
85}
86
87#[derive(Debug, Eq, PartialEq)]
88enum Action {
89    Return(usize),
90    Replace,
91    Read,
92}
93
94pub struct FreezingBufStream<Stream: BufStream<Item = std::io::Cursor<BytesMut>>>(pub Stream);
95
96impl<Err, Stream: BufStream<Item = std::io::Cursor<BytesMut>, Error = Err>> BufStream
97    for FreezingBufStream<Stream>
98{
99    type Item = std::io::Cursor<Bytes>;
100    type Error = Err;
101
102    fn poll(&mut self) -> Result<Async<Option<Self::Item>>, Self::Error> {
103        match self.0.poll() {
104            Ok(Async::Ready(Some(bytes))) => {
105                let position = bytes.position() as usize;
106                let bytes = bytes.into_inner();
107                Ok(Async::Ready(Some(std::io::Cursor::new(
108                    bytes.freeze().split_off(position),
109                ))))
110            }
111            Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
112            Ok(Async::NotReady) => Ok(Async::NotReady),
113            Err(err) => Err(err),
114        }
115    }
116}
117
118fn non_matching_prefix_length<B: Buf>(bufs: &[B], prefix: &[u8]) -> Action {
119    if bufs.is_empty() {
120        return Action::Read;
121    }
122    if bufs[0].bytes().len() < prefix.len() {
123        if prefix.starts_with(&bufs[0].bytes()) {
124            return match non_matching_prefix_length(&bufs[1..], &prefix[bufs[0].bytes().len()..]) {
125                Action::Return(_) => Action::Return(bufs[0].bytes().len()),
126                Action::Replace => Action::Replace,
127                Action::Read => Action::Read,
128            };
129        }
130        let mut ret = 0;
131        while !prefix.starts_with(&bufs[0].bytes()[ret..]) && ret < bufs[0].bytes().len() {
132            ret += 1;
133        }
134        return Action::Return(ret);
135    } else {
136        if bufs[0].bytes().starts_with(prefix) {
137            return Action::Replace;
138        }
139        let mut ret = 0;
140        while !bufs[0].bytes()[ret..]
141            .starts_with(&prefix[..std::cmp::min(prefix.len(), bufs[0].bytes().len() - ret)])
142        {
143            ret += 1;
144        }
145        return Action::Return(ret);
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::{non_matching_prefix_length, Action, FreezingBufStream, ReplacingBufStream};
152    use bytes::{Bytes, BytesMut};
153    use futures::Future;
154    use std::path::PathBuf;
155    use tokio::runtime::Runtime;
156    use tower_web::util::BufStream;
157
158    const PREFIX: &'static [u8] = b"matches";
159
160    #[test]
161    fn no_bufs() {
162        let bufs = make_bufs(&[]);
163        assert_eq!(Action::Read, non_matching_prefix_length(&bufs, PREFIX));
164    }
165
166    #[test]
167    fn one_buf_is_prefix() {
168        let bufs = make_bufs(&["matches"]);
169        assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
170    }
171
172    #[test]
173    fn one_buf_is_prefix_start() {
174        let bufs = make_bufs(&["mat"]);
175        assert_eq!(Action::Read, non_matching_prefix_length(&bufs, PREFIX));
176    }
177
178    #[test]
179    fn one_buf_is_not_prefix_shorter() {
180        let bufs = make_bufs(&["nope"]);
181        assert_eq!(Action::Return(4), non_matching_prefix_length(&bufs, PREFIX));
182    }
183
184    #[test]
185    fn one_buf_is_not_prefix_longer() {
186        let bufs = make_bufs(&["nopenopenope"]);
187        assert_eq!(
188            Action::Return(12),
189            non_matching_prefix_length(&bufs, PREFIX)
190        );
191    }
192
193    #[test]
194    fn one_buf_starts_with_prefix() {
195        let bufs = make_bufs(&["matchesyesyesyes"]);
196        assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
197    }
198
199    #[test]
200    fn one_buf_contains_prefix() {
201        let bufs = make_bufs(&["yesmatchesyesyes"]);
202        assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
203    }
204
205    #[test]
206    fn one_buf_same_length_ends_prefix() {
207        let bufs = make_bufs(&["yesmatc"]);
208        assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
209    }
210
211    #[test]
212    fn two_buf_exact_match() {
213        let bufs = make_bufs(&["mat", "ches"]);
214        assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
215    }
216
217    #[test]
218    fn two_buf_with_suffix() {
219        let bufs = make_bufs(&["mat", "chesyes"]);
220        assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
221    }
222
223    #[test]
224    fn two_buf_with_prefix() {
225        let bufs = make_bufs(&["yesmat", "ches"]);
226        assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
227    }
228
229    #[test]
230    fn two_buf_wrong_ending() {
231        let bufs = make_bufs(&["mat", "chnope"]);
232        assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
233    }
234
235    #[test]
236    fn two_buf_prefix_and_wrong_ending() {
237        let bufs = make_bufs(&["nomat", "chnope"]);
238        assert_eq!(Action::Return(2), non_matching_prefix_length(&bufs, PREFIX));
239    }
240
241    #[test]
242    fn three_bufs() {
243        let bufs = make_bufs(&["mat", "ch", "es"]);
244        assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
245    }
246
247    #[test]
248    fn stream() {
249        use tower_web::util::BufStream;
250        let stream = ReplacingBufStream::new(
251            Bytes::from("foobarbaz"),
252            Bytes::from("bar"),
253            Bytes::from("blammo"),
254        );
255        let res = stream.collect::<Vec<u8>>().wait().unwrap();
256        assert_eq!(res, "fooblammobaz".as_bytes().to_vec())
257    }
258
259    #[test]
260    fn stream_trailing_prefix() {
261        use tower_web::util::BufStream;
262        let stream = ReplacingBufStream::new(
263            Bytes::from("foobarba"),
264            Bytes::from("bar"),
265            Bytes::from("blammo"),
266        );
267        let res = stream.collect::<Vec<u8>>().wait().unwrap();
268        assert_eq!(res, "fooblammoba".as_bytes().to_vec())
269    }
270
271    #[test]
272    fn large_file() {
273        let needle = "https://doc.rust-lang.org/nightly";
274        let replacement = "http://127.0.0.1:8080";
275
276        let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
277        path.push("testdata");
278        path.push("struct.Group.html");
279
280        let want = std::fs::read_to_string(&path)
281            .unwrap()
282            .replace(needle, replacement)
283            .as_bytes()
284            .to_owned();
285
286        let stream = FreezingBufStream(tokio_fs::File::from_std(
287            std::fs::File::open(&path).unwrap(),
288        ));
289        let stream = ReplacingBufStream::new(stream, Bytes::from(needle), Bytes::from(replacement));
290        let res = Runtime::new()
291            .unwrap()
292            .block_on(stream.collect::<Vec<u8>>())
293            .unwrap();
294        assert_eq!(want, res)
295    }
296
297    fn make_bufs(strs: &[&str]) -> Vec<std::io::Cursor<BytesMut>> {
298        let mut bufs = Vec::with_capacity(strs.len());
299        for s in strs {
300            bufs.push(std::io::Cursor::new(BytesMut::from(s.as_bytes())));
301        }
302        bufs
303    }
304}