chunkio/
lib.rs

1mod error;
2use std::{
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use error::ChunkIOError;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_util::{
11    bytes::{Buf, BytesMut},
12    codec::{Decoder, Encoder, Framed},
13};
14
15pub struct ChunkIO<T>(Framed<T, ChunkIOProto>);
16
17impl<T> ChunkIO<T> {
18    pub fn new(io: T) -> ChunkIO<T>
19    where
20        T: AsyncRead + AsyncWrite,
21    {
22        ChunkIO(tokio_util::codec::Framed::new(io, Default::default()))
23    }
24}
25
26impl<T> Stream for ChunkIO<T>
27where
28    T: AsyncRead + AsyncWrite + Unpin,
29{
30    type Item = Result<Vec<u8>, ChunkIOError>;
31
32    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
33        self.0.poll_next_unpin(cx)
34    }
35}
36
37impl<T> Sink<Vec<u8>> for ChunkIO<T>
38where
39    T: AsyncRead + AsyncWrite + Unpin,
40{
41    type Error = ChunkIOError;
42
43    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44        self.0.poll_ready_unpin(cx)
45    }
46
47    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
48        self.0.start_send_unpin(item)
49    }
50
51    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
52        self.0.poll_flush_unpin(cx)
53    }
54
55    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
56        self.0.poll_close_unpin(cx)
57    }
58}
59
60#[derive(Default)]
61struct ChunkIOProto {
62    current_index: (u64, u64), // send index, receive index
63}
64
65impl Decoder for ChunkIOProto {
66    type Item = Vec<u8>;
67
68    type Error = ChunkIOError;
69
70    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
71        if src.len() < 2 {
72            return Ok(None);
73        }
74        let index_pointer = (src[0] >> 4) as u64;
75        let len_pointer = (src[0] & 0xf) as u64;
76        if index_pointer > 8 || len_pointer > 8 || len_pointer == 0 {
77            return Err(ChunkIOError::InvalidChunk);
78        }
79        if (src.len() as u64) < (len_pointer + index_pointer + 1) {
80            return Ok(None);
81        }
82        let index = match index_pointer {
83            0 => 0_u64,
84            1 => u8::from_be_bytes([src[1]]) as u64,
85            2 => u16::from_be_bytes([src[1], src[2]]) as u64,
86            3..=4 => u32::from_be_bytes(
87                src[1..index_pointer as usize]
88                    .try_into()
89                    .or(Err(ChunkIOError::InvalidChunk))?,
90            ) as u64,
91            5..=8 => u64::from_be_bytes(
92                src[1..index_pointer as usize]
93                    .try_into()
94                    .or(Err(ChunkIOError::InvalidChunk))?,
95            ),
96            _ => return Err(ChunkIOError::InvalidChunk),
97        };
98        if self.current_index.1 != index {
99            return Err(ChunkIOError::OutOfOrder);
100        }
101
102        let length = match len_pointer {
103            1 => u8::from_be_bytes([src[1 + index_pointer as usize]]) as u64,
104            2 => u16::from_be_bytes([
105                src[1 + index_pointer as usize],
106                src[2 + index_pointer as usize],
107            ]) as u64,
108            3..=4 => u32::from_be_bytes(
109                src[1 + (index_pointer as usize)..((index_pointer + len_pointer) as usize)]
110                    .try_into()
111                    .or(Err(ChunkIOError::InvalidChunk))?,
112            ) as u64,
113            5..=8 => u64::from_be_bytes(
114                src[1 + index_pointer as usize..(index_pointer + len_pointer) as usize]
115                    .try_into()
116                    .or(Err(ChunkIOError::InvalidChunk))?,
117            ),
118            _ => {
119                return Err(ChunkIOError::InvalidChunk);
120            }
121        };
122        if src.len() < (index_pointer + len_pointer + length + 1) as usize {
123            Ok(None)
124        } else {
125            src.advance((index_pointer + len_pointer + 1) as usize);
126            self.current_index.1 += length;
127            Ok(Some(src.split_to(length as usize).to_vec()))
128        }
129    }
130}
131
132impl Encoder<Vec<u8>> for ChunkIOProto {
133    type Error = ChunkIOError;
134
135    fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
136        let index = self
137            .current_index
138            .0
139            .to_be_bytes()
140            .into_iter()
141            .skip_while(|x| *x == 0)
142            .collect::<Vec<u8>>();
143        let length = item
144            .len()
145            .to_be_bytes()
146            .into_iter()
147            .skip_while(|x| *x == 0)
148            .collect::<Vec<u8>>();
149        dst.extend_from_slice(&[((index.len() as u8) << 4) | (length.len() as u8)]);
150        dst.extend_from_slice(&index);
151        dst.extend_from_slice(&length);
152        dst.extend_from_slice(&item);
153        self.current_index.0 += item.len() as u64;
154
155        Ok(())
156    }
157}