1use crate::{ByteCount, Decode, Encode, Eos, Error, ErrorKind, Result};
3#[cfg(feature = "tokio-async")]
4use pin_project::pin_project;
5use std::cmp;
6use std::io::{self, Read, Write};
7
8pub trait IoDecodeExt: Decode {
10 fn decode_from_read_buf<B>(&mut self, buf: &mut ReadBuf<B>) -> Result<()>
12 where
13 B: AsRef<[u8]>,
14 {
15 let eos = Eos::new(buf.stream_state.is_eos());
16 let size = track!(self.decode(&buf.inner.as_ref()[buf.head..buf.tail], eos))?;
17 buf.head += size;
18 if buf.head == buf.tail {
19 buf.head = 0;
20 buf.tail = 0;
21 }
22 Ok(())
23 }
24
25 fn decode_exact<R: Read>(&mut self, mut reader: R) -> Result<Self::Item> {
31 let mut buf = [0; 1024];
32 loop {
33 let mut size = match self.requiring_bytes() {
34 ByteCount::Finite(n) => cmp::min(n, buf.len() as u64) as usize,
35 ByteCount::Infinite => buf.len(),
36 ByteCount::Unknown => 1,
37 };
38 let eos = if size != 0 {
39 size = track!(reader.read(&mut buf[..size]).map_err(Error::from))?;
40 Eos::new(size == 0)
41 } else {
42 Eos::new(false)
43 };
44
45 let consumed = track!(self.decode(&buf[..size], eos))?;
46 track_assert_eq!(consumed, size, ErrorKind::InconsistentState; self.is_idle(), eos);
47 if self.is_idle() {
48 let item = track!(self.finish_decoding())?;
49 return Ok(item);
50 }
51 }
52 }
53}
54impl<T: Decode> IoDecodeExt for T {}
55
56pub trait IoEncodeExt: Encode {
58 fn encode_to_write_buf<B>(&mut self, buf: &mut WriteBuf<B>) -> Result<()>
61 where
62 B: AsMut<[u8]>,
63 {
64 let eos = Eos::new(buf.stream_state.is_eos());
65 let size = track!(self.encode(&mut buf.inner.as_mut()[buf.tail..], eos))?;
66 buf.tail += size;
67 Ok(())
68 }
69
70 #[cfg(feature = "tokio-async")]
76 fn encode_to_write_buf_async<B>(
77 &mut self,
78 buf: &mut WriteBuf<B>,
79 cx: &mut std::task::Context,
80 ) -> Result<()>
81 where
82 B: AsMut<[u8]>,
83 {
84 let eos = Eos::new(buf.stream_state.is_eos());
85 let size = track!(self.encode(&mut buf.inner.as_mut()[buf.tail..], eos))?;
86 buf.tail += size;
87 buf.waker = Some(cx.waker().clone());
88 Ok(())
89 }
90
91 fn encode_all<W: Write>(&mut self, mut writer: W) -> Result<()> {
96 let mut buf = [0; 1024];
97 while !self.is_idle() {
98 let size = track!(self.encode(&mut buf[..], Eos::new(false)))?;
99 track!(writer.write_all(&buf[..size]).map_err(Error::from))?;
100 if !self.is_idle() {
101 track_assert_ne!(size, 0, ErrorKind::Other);
102 }
103 }
104 Ok(())
105 }
106}
107impl<T: Encode> IoEncodeExt for T {}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111#[allow(missing_docs)]
112pub enum StreamState {
113 Normal,
114 Eos,
115 WouldBlock,
116 Error,
117}
118impl StreamState {
119 pub fn is_normal(self) -> bool {
121 self == StreamState::Normal
122 }
123
124 pub fn is_error(self) -> bool {
126 self == StreamState::Error
127 }
128
129 pub fn is_eos(self) -> bool {
131 self == StreamState::Eos
132 }
133
134 pub fn would_block(self) -> bool {
136 self == StreamState::WouldBlock
137 }
138}
139
140#[derive(Debug)]
142pub struct ReadBuf<B> {
143 pub(crate) inner: B,
144 pub(crate) head: usize,
145 pub(crate) tail: usize,
146 pub(crate) stream_state: StreamState,
147}
148impl<B: AsRef<[u8]> + AsMut<[u8]>> ReadBuf<B> {
149 pub fn new(inner: B) -> Self {
151 ReadBuf {
152 inner,
153 head: 0,
154 tail: 0,
155 stream_state: StreamState::Normal,
156 }
157 }
158
159 pub fn len(&self) -> usize {
161 self.tail - self.head
162 }
163
164 pub fn room(&self) -> usize {
168 self.inner.as_ref().len() - self.tail
169 }
170
171 pub fn capacity(&self) -> usize {
173 self.inner.as_ref().len()
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.tail == 0
179 }
180
181 pub fn is_full(&self) -> bool {
183 self.tail == self.inner.as_ref().len()
184 }
185
186 pub fn stream_state(&self) -> StreamState {
188 self.stream_state
189 }
190
191 pub fn stream_state_mut(&mut self) -> &mut StreamState {
193 &mut self.stream_state
194 }
195
196 pub fn fill<R: Read>(&mut self, mut reader: R) -> Result<()> {
203 while !self.is_full() {
204 match reader.read(&mut self.inner.as_mut()[self.tail..]) {
205 Err(e) => {
206 if e.kind() == io::ErrorKind::WouldBlock {
207 self.stream_state = StreamState::WouldBlock;
208 break;
209 } else {
210 self.stream_state = StreamState::Error;
211 return Err(track!(Error::from(e)));
212 }
213 }
214 Ok(0) => {
215 self.stream_state = StreamState::Eos;
216 break;
217 }
218 Ok(size) => {
219 self.stream_state = StreamState::Normal;
220 self.tail += size;
221 }
222 }
223 }
224 Ok(())
225 }
226
227 pub fn inner_ref(&self) -> &B {
229 &self.inner
230 }
231
232 pub fn inner_mut(&mut self) -> &mut B {
234 &mut self.inner
235 }
236
237 pub fn into_inner(self) -> B {
239 self.inner
240 }
241}
242impl<B: AsRef<[u8]> + AsMut<[u8]>> Read for ReadBuf<B> {
243 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
244 let size = cmp::min(buf.len(), self.len());
245 buf[..size].copy_from_slice(&self.inner.as_ref()[self.head..][..size]);
246 self.head += size;
247 if self.head == self.tail {
248 self.head = 0;
249 self.tail = 0;
250 }
251 Ok(size)
252 }
253}
254
255#[derive(Debug)]
257pub struct WriteBuf<B> {
258 pub(crate) inner: B,
259 pub(crate) head: usize,
260 pub(crate) tail: usize,
261 pub(crate) stream_state: StreamState,
262 #[cfg(feature = "tokio-async")]
263 pub(crate) waker: Option<std::task::Waker>,
264}
265impl<B: AsRef<[u8]> + AsMut<[u8]>> WriteBuf<B> {
266 pub fn new(inner: B) -> Self {
268 WriteBuf {
269 inner,
270 head: 0,
271 tail: 0,
272 stream_state: StreamState::Normal,
273 #[cfg(feature = "tokio-async")]
274 waker: None,
275 }
276 }
277
278 pub fn len(&self) -> usize {
280 self.tail - self.head
281 }
282
283 pub fn room(&self) -> usize {
287 self.inner.as_ref().len() - self.tail
288 }
289
290 pub fn capacity(&self) -> usize {
292 self.inner.as_ref().len()
293 }
294
295 pub fn is_empty(&self) -> bool {
297 self.tail == 0
298 }
299
300 pub fn is_full(&self) -> bool {
302 self.tail == self.inner.as_ref().len()
303 }
304
305 pub fn stream_state(&self) -> StreamState {
307 self.stream_state
308 }
309
310 pub fn stream_state_mut(&mut self) -> &mut StreamState {
312 &mut self.stream_state
313 }
314
315 pub fn flush<W: Write>(&mut self, mut writer: W) -> Result<()> {
324 while !self.is_empty() {
325 match writer.write(&self.inner.as_ref()[self.head..self.tail]) {
326 Err(e) => {
327 if e.kind() == io::ErrorKind::WouldBlock {
328 self.stream_state = StreamState::WouldBlock;
329 break;
330 } else {
331 self.stream_state = StreamState::Error;
332 return Err(track!(Error::from(e)));
333 }
334 }
335 Ok(0) => {
336 self.stream_state = StreamState::Eos;
337 break;
338 }
339 Ok(size) => {
340 self.stream_state = StreamState::Normal;
341 self.head += size;
342 if self.head == self.tail {
343 self.head = 0;
344 self.tail = 0;
345 }
346 }
347 }
348 }
349 #[cfg(feature = "tokio-async")]
350 if !self.is_full() {
351 if let Some(ref waker) = self.waker {
352 waker.wake_by_ref();
353 }
354 }
355 Ok(())
356 }
357
358 pub fn inner_ref(&self) -> &B {
360 &self.inner
361 }
362
363 pub fn inner_mut(&mut self) -> &mut B {
365 &mut self.inner
366 }
367
368 pub fn into_inner(self) -> B {
370 self.inner
371 }
372}
373impl<B: AsRef<[u8]> + AsMut<[u8]>> Write for WriteBuf<B> {
374 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
375 let size = cmp::min(buf.len(), self.room());
376 self.inner.as_mut()[self.tail..][..size].copy_from_slice(&buf[..size]);
377 self.tail += size;
378 Ok(size)
379 }
380
381 fn flush(&mut self) -> io::Result<()> {
382 Ok(())
383 }
384}
385
386#[cfg_attr(feature = "tokio-async", pin_project)]
388#[derive(Debug)]
389pub struct BufferedIo<T> {
390 #[cfg_attr(feature = "tokio-async", pin)]
391 pub(crate) stream: T,
392 pub(crate) rbuf: ReadBuf<Vec<u8>>,
393 pub(crate) wbuf: WriteBuf<Vec<u8>>,
394}
395impl<T: Read + Write> BufferedIo<T> {
396 pub fn execute_io(&mut self) -> Result<()> {
400 track!(self.rbuf.fill(&mut self.stream))?;
401 track!(self.wbuf.flush(&mut self.stream))?;
402 Ok(())
403 }
404}
405
406impl<T> BufferedIo<T> {
407 pub fn new(stream: T, read_buf_size: usize, write_buf_size: usize) -> Self {
409 BufferedIo {
410 stream,
411 rbuf: ReadBuf::new(vec![0; read_buf_size]),
412 wbuf: WriteBuf::new(vec![0; write_buf_size]),
413 }
414 }
415
416 pub fn is_eos(&self) -> bool {
418 self.rbuf.stream_state().is_eos() || self.wbuf.stream_state().is_eos()
419 }
420
421 pub fn would_block(&self) -> bool {
423 self.rbuf.stream_state().would_block()
424 && (self.wbuf.is_empty() || self.wbuf.stream_state().would_block())
425 }
426
427 pub fn read_buf_ref(&self) -> &ReadBuf<Vec<u8>> {
429 &self.rbuf
430 }
431
432 pub fn read_buf_mut(&mut self) -> &mut ReadBuf<Vec<u8>> {
434 &mut self.rbuf
435 }
436
437 pub fn write_buf_ref(&self) -> &WriteBuf<Vec<u8>> {
439 &self.wbuf
440 }
441
442 pub fn write_buf_mut(&mut self) -> &mut WriteBuf<Vec<u8>> {
444 &mut self.wbuf
445 }
446
447 pub fn stream_ref(&self) -> &T {
449 &self.stream
450 }
451
452 pub fn stream_mut(&mut self) -> &mut T {
454 &mut self.stream
455 }
456
457 pub fn into_stream(self) -> T {
459 self.stream
460 }
461}
462
463#[cfg(test)]
464mod test {
465 use super::*;
466 use crate::bytes::{Utf8Decoder, Utf8Encoder};
467 use crate::EncodeExt;
468 use std::io::{Read, Write};
469
470 #[test]
471 fn decode_from_read_buf_works() {
472 let mut buf = ReadBuf::new(vec![0; 1024]);
473 track_try_unwrap!(buf.fill(b"foo".as_ref()));
474 assert_eq!(buf.len(), 3);
475 assert_eq!(buf.stream_state(), StreamState::Eos);
476
477 let mut decoder = Utf8Decoder::new();
478 track_try_unwrap!(decoder.decode_from_read_buf(&mut buf));
479 assert_eq!(track_try_unwrap!(decoder.finish_decoding()), "foo");
480 }
481
482 #[test]
483 fn read_from_read_buf_works() {
484 let mut rbuf = ReadBuf::new(vec![0; 1024]);
485 track_try_unwrap!(rbuf.fill(b"foo".as_ref()));
486 assert_eq!(rbuf.len(), 3);
487 assert_eq!(rbuf.stream_state(), StreamState::Eos);
488
489 let mut buf = Vec::new();
490 rbuf.read_to_end(&mut buf).unwrap();
491 assert_eq!(buf, b"foo");
492 assert_eq!(rbuf.len(), 0);
493 }
494
495 #[test]
496 fn encode_to_write_buf_works() {
497 let mut encoder = track_try_unwrap!(Utf8Encoder::with_item("foo"));
498
499 let mut buf = WriteBuf::new(vec![0; 1024]);
500 track_try_unwrap!(encoder.encode_to_write_buf(&mut buf));
501 assert_eq!(buf.len(), 3);
502
503 let mut v = Vec::new();
504 track_try_unwrap!(buf.flush(&mut v));
505 assert_eq!(buf.len(), 0);
506 assert_eq!(buf.stream_state(), StreamState::Normal);
507 assert_eq!(v, b"foo");
508 }
509
510 #[test]
511 fn write_to_write_buf_works() {
512 let mut buf = WriteBuf::new(vec![0; 1024]);
513 buf.write_all(b"foo").unwrap();
514 assert_eq!(buf.len(), 3);
515
516 let mut v = Vec::new();
517 track_try_unwrap!(buf.flush(&mut v));
518 assert_eq!(buf.len(), 0);
519 assert_eq!(buf.stream_state(), StreamState::Normal);
520 assert_eq!(v, b"foo");
521 }
522}