1use futures::{prelude::*, ready};
22use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
23use std::{io, iter, pin::Pin, task::Context, task::Poll};
24
25#[derive(Debug, Copy, Clone)]
26pub struct DeflateConfig {
27 compression: flate2::Compression,
28}
29
30impl Default for DeflateConfig {
31 fn default() -> Self {
32 DeflateConfig {
33 compression: flate2::Compression::fast(),
34 }
35 }
36}
37
38impl UpgradeInfo for DeflateConfig {
39 type Info = &'static [u8];
40 type InfoIter = iter::Once<Self::Info>;
41
42 fn protocol_info(&self) -> Self::InfoIter {
43 iter::once(b"/deflate/1.0.0")
44 }
45}
46
47impl<C> InboundUpgrade<C> for DeflateConfig
48where
49 C: AsyncRead + AsyncWrite,
50{
51 type Output = DeflateOutput<C>;
52 type Error = io::Error;
53 type Future = future::Ready<Result<Self::Output, Self::Error>>;
54
55 fn upgrade_inbound(self, r: C, _: Self::Info) -> Self::Future {
56 future::ok(DeflateOutput::new(r, self.compression))
57 }
58}
59
60impl<C> OutboundUpgrade<C> for DeflateConfig
61where
62 C: AsyncRead + AsyncWrite,
63{
64 type Output = DeflateOutput<C>;
65 type Error = io::Error;
66 type Future = future::Ready<Result<Self::Output, Self::Error>>;
67
68 fn upgrade_outbound(self, w: C, _: Self::Info) -> Self::Future {
69 future::ok(DeflateOutput::new(w, self.compression))
70 }
71}
72
73#[derive(Debug)]
75pub struct DeflateOutput<S> {
76 inner: S,
78 compress: flate2::Compress,
80 decompress: flate2::Decompress,
82 write_out: Vec<u8>,
85 read_interm: Vec<u8>,
88 inner_read_eof: bool,
91}
92
93impl<S> DeflateOutput<S> {
94 fn new(inner: S, compression: flate2::Compression) -> Self {
95 DeflateOutput {
96 inner,
97 compress: flate2::Compress::new(compression, false),
98 decompress: flate2::Decompress::new(false),
99 write_out: Vec::with_capacity(256),
100 read_interm: Vec::with_capacity(256),
101 inner_read_eof: false,
102 }
103 }
104
105 fn flush_write_out(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>
108 where S: AsyncWrite + Unpin
109 {
110 loop {
111 if self.write_out.is_empty() {
112 return Poll::Ready(Ok(()))
113 }
114
115 match AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &self.write_out) {
116 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
117 Poll::Ready(Ok(n)) => self.write_out = self.write_out.split_off(n),
118 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
119 Poll::Pending => return Poll::Pending,
120 };
121 }
122 }
123}
124
125impl<S> AsyncRead for DeflateOutput<S>
126 where S: AsyncRead + Unpin
127{
128 fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize, io::Error>> {
129 let this = &mut *self;
132
133 loop {
134 if this.read_interm.is_empty() && !this.inner_read_eof {
136 this.read_interm.resize(this.read_interm.capacity() + 256, 0);
137
138 match AsyncRead::poll_read(Pin::new(&mut this.inner), cx, &mut this.read_interm) {
139 Poll::Ready(Ok(0)) => {
140 this.inner_read_eof = true;
141 this.read_interm.clear();
142 }
143 Poll::Ready(Ok(n)) => {
144 this.read_interm.truncate(n)
145 },
146 Poll::Ready(Err(err)) => {
147 this.read_interm.clear();
148 return Poll::Ready(Err(err))
149 },
150 Poll::Pending => {
151 this.read_interm.clear();
152 return Poll::Pending
153 },
154 }
155 }
156 debug_assert!(!this.read_interm.is_empty() || this.inner_read_eof);
157
158 let before_out = this.decompress.total_out();
159 let before_in = this.decompress.total_in();
160 let ret = this.decompress.decompress(&this.read_interm, buf, if this.inner_read_eof { flate2::FlushDecompress::Finish } else { flate2::FlushDecompress::None })?;
161
162 let consumed = (this.decompress.total_in() - before_in) as usize;
164 this.read_interm = this.read_interm.split_off(consumed);
165
166 let read = (this.decompress.total_out() - before_out) as usize;
167 if read != 0 || ret == flate2::Status::StreamEnd {
168 return Poll::Ready(Ok(read))
169 }
170 }
171 }
172}
173
174impl<S> AsyncWrite for DeflateOutput<S>
175 where S: AsyncWrite + Unpin
176{
177 fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
178 -> Poll<Result<usize, io::Error>>
179 {
180 let this = &mut *self;
183
184 ready!(this.flush_write_out(cx))?;
187
188 if buf.is_empty() {
190 return Poll::Ready(Ok(0));
191 }
192
193 loop {
197 let before_in = this.compress.total_in();
198 this.write_out.reserve(256); let ret = this.compress.compress_vec(buf, &mut this.write_out, flate2::FlushCompress::None)?;
200 let written = (this.compress.total_in() - before_in) as usize;
201
202 if written != 0 || ret == flate2::Status::StreamEnd {
203 return Poll::Ready(Ok(written));
204 }
205 }
206 }
207
208 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
209 let this = &mut *self;
212
213 ready!(this.flush_write_out(cx))?;
214 this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Sync)?;
215
216 loop {
217 ready!(this.flush_write_out(cx))?;
218
219 debug_assert!(this.write_out.is_empty());
220 this.write_out.reserve(256); this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::None)?;
223 if this.write_out.is_empty() {
224 break;
225 }
226 }
227
228 AsyncWrite::poll_flush(Pin::new(&mut this.inner), cx)
229 }
230
231 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
232 let this = &mut *self;
235
236 loop {
237 ready!(this.flush_write_out(cx))?;
238
239 debug_assert!(this.write_out.is_empty());
241 this.write_out.reserve(256); this.compress.compress_vec(&[], &mut this.write_out, flate2::FlushCompress::Finish)?;
243 if this.write_out.is_empty() {
244 break;
245 }
246 }
247
248 AsyncWrite::poll_close(Pin::new(&mut this.inner), cx)
249 }
250}