1use bytes::{Buf, Bytes, BytesMut};
2use futures::Async;
3use tower_web::util::BufStream;
4
5pub struct ReplacingBufStream<Stream: BufStream<Item = std::io::Cursor<Bytes>>> {
6 stream: Stream,
7 to_be_replaced: Bytes,
8 replacement: Bytes,
9 bufs: Vec<std::io::Cursor<Bytes>>,
10}
11
12impl<Stream: BufStream<Item = std::io::Cursor<Bytes>>> ReplacingBufStream<Stream> {
13 pub fn new(
14 inner: Stream,
15 to_be_replaced: Bytes,
16 replacement: Bytes,
17 ) -> ReplacingBufStream<Stream> {
18 ReplacingBufStream {
19 stream: inner,
20 to_be_replaced,
21 replacement,
22 bufs: Vec::new(),
23 }
24 }
25}
26
27impl<Stream: BufStream<Item = std::io::Cursor<Bytes>>> BufStream for ReplacingBufStream<Stream>
28where
29 <Stream as tower_web::util::BufStream>::Error: std::fmt::Debug,
30{
31 type Item = std::io::Cursor<bytes::Bytes>;
32 type Error = Stream::Error;
33
34 fn poll(&mut self) -> futures::Poll<Option<Self::Item>, Self::Error> {
35 match non_matching_prefix_length(&self.bufs, &self.to_be_replaced.as_ref()) {
36 Action::Return(size) => {
37 if self.bufs[0].bytes().len() == size {
38 Ok(Async::Ready(Some(self.bufs.remove(0))))
39 } else {
40 let buf = self.bufs.remove(0);
41 let mut ret = buf.into_inner();
42 self.bufs
43 .insert(0, std::io::Cursor::new(ret.split_off(size)));
44 Ok(Async::Ready(Some(std::io::Cursor::new(ret))))
45 }
46 }
47 Action::Replace => {
48 let mut len = 0;
49 let mut buf = None;
50 while len < self.to_be_replaced.len() {
51 let buf2 = self.bufs.remove(0);
52 len += buf2.bytes().len();
53 buf = Some(buf2)
54 }
55 let buf = buf.unwrap();
56 if len > self.to_be_replaced.len() {
57 let mut inner = buf.into_inner();
58 self.bufs.insert(
59 0,
60 std::io::Cursor::new(
61 inner.split_off(inner.len() + self.to_be_replaced.len() - len),
62 ),
63 );
64 }
65 Ok(Async::Ready(Some(std::io::Cursor::new(
66 self.replacement.clone(),
67 ))))
68 }
69 Action::Read => match self.stream.poll() {
70 Ok(Async::Ready(Some(buf))) => {
71 self.bufs.push(buf);
72 self.poll()
73 }
74 Ok(Async::Ready(None)) => {
75 if !self.bufs.is_empty() {
76 return Ok(Async::Ready(Some(self.bufs.remove(0))));
77 }
78 Ok(Async::Ready(None))
79 }
80 Ok(Async::NotReady) => Ok(Async::NotReady),
81 Err(err) => Err(err),
82 },
83 }
84 }
85}
86
87#[derive(Debug, Eq, PartialEq)]
88enum Action {
89 Return(usize),
90 Replace,
91 Read,
92}
93
94pub struct FreezingBufStream<Stream: BufStream<Item = std::io::Cursor<BytesMut>>>(pub Stream);
95
96impl<Err, Stream: BufStream<Item = std::io::Cursor<BytesMut>, Error = Err>> BufStream
97 for FreezingBufStream<Stream>
98{
99 type Item = std::io::Cursor<Bytes>;
100 type Error = Err;
101
102 fn poll(&mut self) -> Result<Async<Option<Self::Item>>, Self::Error> {
103 match self.0.poll() {
104 Ok(Async::Ready(Some(bytes))) => {
105 let position = bytes.position() as usize;
106 let bytes = bytes.into_inner();
107 Ok(Async::Ready(Some(std::io::Cursor::new(
108 bytes.freeze().split_off(position),
109 ))))
110 }
111 Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
112 Ok(Async::NotReady) => Ok(Async::NotReady),
113 Err(err) => Err(err),
114 }
115 }
116}
117
118fn non_matching_prefix_length<B: Buf>(bufs: &[B], prefix: &[u8]) -> Action {
119 if bufs.is_empty() {
120 return Action::Read;
121 }
122 if bufs[0].bytes().len() < prefix.len() {
123 if prefix.starts_with(&bufs[0].bytes()) {
124 return match non_matching_prefix_length(&bufs[1..], &prefix[bufs[0].bytes().len()..]) {
125 Action::Return(_) => Action::Return(bufs[0].bytes().len()),
126 Action::Replace => Action::Replace,
127 Action::Read => Action::Read,
128 };
129 }
130 let mut ret = 0;
131 while !prefix.starts_with(&bufs[0].bytes()[ret..]) && ret < bufs[0].bytes().len() {
132 ret += 1;
133 }
134 return Action::Return(ret);
135 } else {
136 if bufs[0].bytes().starts_with(prefix) {
137 return Action::Replace;
138 }
139 let mut ret = 0;
140 while !bufs[0].bytes()[ret..]
141 .starts_with(&prefix[..std::cmp::min(prefix.len(), bufs[0].bytes().len() - ret)])
142 {
143 ret += 1;
144 }
145 return Action::Return(ret);
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::{non_matching_prefix_length, Action, FreezingBufStream, ReplacingBufStream};
152 use bytes::{Bytes, BytesMut};
153 use futures::Future;
154 use std::path::PathBuf;
155 use tokio::runtime::Runtime;
156 use tower_web::util::BufStream;
157
158 const PREFIX: &'static [u8] = b"matches";
159
160 #[test]
161 fn no_bufs() {
162 let bufs = make_bufs(&[]);
163 assert_eq!(Action::Read, non_matching_prefix_length(&bufs, PREFIX));
164 }
165
166 #[test]
167 fn one_buf_is_prefix() {
168 let bufs = make_bufs(&["matches"]);
169 assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
170 }
171
172 #[test]
173 fn one_buf_is_prefix_start() {
174 let bufs = make_bufs(&["mat"]);
175 assert_eq!(Action::Read, non_matching_prefix_length(&bufs, PREFIX));
176 }
177
178 #[test]
179 fn one_buf_is_not_prefix_shorter() {
180 let bufs = make_bufs(&["nope"]);
181 assert_eq!(Action::Return(4), non_matching_prefix_length(&bufs, PREFIX));
182 }
183
184 #[test]
185 fn one_buf_is_not_prefix_longer() {
186 let bufs = make_bufs(&["nopenopenope"]);
187 assert_eq!(
188 Action::Return(12),
189 non_matching_prefix_length(&bufs, PREFIX)
190 );
191 }
192
193 #[test]
194 fn one_buf_starts_with_prefix() {
195 let bufs = make_bufs(&["matchesyesyesyes"]);
196 assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
197 }
198
199 #[test]
200 fn one_buf_contains_prefix() {
201 let bufs = make_bufs(&["yesmatchesyesyes"]);
202 assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
203 }
204
205 #[test]
206 fn one_buf_same_length_ends_prefix() {
207 let bufs = make_bufs(&["yesmatc"]);
208 assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
209 }
210
211 #[test]
212 fn two_buf_exact_match() {
213 let bufs = make_bufs(&["mat", "ches"]);
214 assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
215 }
216
217 #[test]
218 fn two_buf_with_suffix() {
219 let bufs = make_bufs(&["mat", "chesyes"]);
220 assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
221 }
222
223 #[test]
224 fn two_buf_with_prefix() {
225 let bufs = make_bufs(&["yesmat", "ches"]);
226 assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
227 }
228
229 #[test]
230 fn two_buf_wrong_ending() {
231 let bufs = make_bufs(&["mat", "chnope"]);
232 assert_eq!(Action::Return(3), non_matching_prefix_length(&bufs, PREFIX));
233 }
234
235 #[test]
236 fn two_buf_prefix_and_wrong_ending() {
237 let bufs = make_bufs(&["nomat", "chnope"]);
238 assert_eq!(Action::Return(2), non_matching_prefix_length(&bufs, PREFIX));
239 }
240
241 #[test]
242 fn three_bufs() {
243 let bufs = make_bufs(&["mat", "ch", "es"]);
244 assert_eq!(Action::Replace, non_matching_prefix_length(&bufs, PREFIX));
245 }
246
247 #[test]
248 fn stream() {
249 use tower_web::util::BufStream;
250 let stream = ReplacingBufStream::new(
251 Bytes::from("foobarbaz"),
252 Bytes::from("bar"),
253 Bytes::from("blammo"),
254 );
255 let res = stream.collect::<Vec<u8>>().wait().unwrap();
256 assert_eq!(res, "fooblammobaz".as_bytes().to_vec())
257 }
258
259 #[test]
260 fn stream_trailing_prefix() {
261 use tower_web::util::BufStream;
262 let stream = ReplacingBufStream::new(
263 Bytes::from("foobarba"),
264 Bytes::from("bar"),
265 Bytes::from("blammo"),
266 );
267 let res = stream.collect::<Vec<u8>>().wait().unwrap();
268 assert_eq!(res, "fooblammoba".as_bytes().to_vec())
269 }
270
271 #[test]
272 fn large_file() {
273 let needle = "https://doc.rust-lang.org/nightly";
274 let replacement = "http://127.0.0.1:8080";
275
276 let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
277 path.push("testdata");
278 path.push("struct.Group.html");
279
280 let want = std::fs::read_to_string(&path)
281 .unwrap()
282 .replace(needle, replacement)
283 .as_bytes()
284 .to_owned();
285
286 let stream = FreezingBufStream(tokio_fs::File::from_std(
287 std::fs::File::open(&path).unwrap(),
288 ));
289 let stream = ReplacingBufStream::new(stream, Bytes::from(needle), Bytes::from(replacement));
290 let res = Runtime::new()
291 .unwrap()
292 .block_on(stream.collect::<Vec<u8>>())
293 .unwrap();
294 assert_eq!(want, res)
295 }
296
297 fn make_bufs(strs: &[&str]) -> Vec<std::io::Cursor<BytesMut>> {
298 let mut bufs = Vec::with_capacity(strs.len());
299 for s in strs {
300 bufs.push(std::io::Cursor::new(BytesMut::from(s.as_bytes())));
301 }
302 bufs
303 }
304}