1mod error;
2use std::{
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use error::ChunkIOError;
8use futures::{Sink, SinkExt, Stream, StreamExt};
9use tokio::io::{AsyncRead, AsyncWrite};
10use tokio_util::{
11 bytes::{Buf, BytesMut},
12 codec::{Decoder, Encoder, Framed},
13};
14
15pub struct ChunkIO<T>(Framed<T, ChunkIOProto>);
16
17impl<T> ChunkIO<T> {
18 pub fn new(io: T) -> ChunkIO<T>
19 where
20 T: AsyncRead + AsyncWrite,
21 {
22 ChunkIO(tokio_util::codec::Framed::new(io, Default::default()))
23 }
24}
25
26impl<T> Stream for ChunkIO<T>
27where
28 T: AsyncRead + AsyncWrite + Unpin,
29{
30 type Item = Result<Vec<u8>, ChunkIOError>;
31
32 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
33 self.0.poll_next_unpin(cx)
34 }
35}
36
37impl<T> Sink<Vec<u8>> for ChunkIO<T>
38where
39 T: AsyncRead + AsyncWrite + Unpin,
40{
41 type Error = ChunkIOError;
42
43 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
44 self.0.poll_ready_unpin(cx)
45 }
46
47 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
48 self.0.start_send_unpin(item)
49 }
50
51 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
52 self.0.poll_flush_unpin(cx)
53 }
54
55 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
56 self.0.poll_close_unpin(cx)
57 }
58}
59
60#[derive(Default)]
61struct ChunkIOProto {
62 current_index: (u64, u64), }
64
65impl Decoder for ChunkIOProto {
66 type Item = Vec<u8>;
67
68 type Error = ChunkIOError;
69
70 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
71 if src.len() < 2 {
72 return Ok(None);
73 }
74 let index_pointer = (src[0] >> 4) as u64;
75 let len_pointer = (src[0] & 0xf) as u64;
76 if index_pointer > 8 || len_pointer > 8 || len_pointer == 0 {
77 return Err(ChunkIOError::InvalidChunk);
78 }
79 if (src.len() as u64) < (len_pointer + index_pointer + 1) {
80 return Ok(None);
81 }
82 let index = match index_pointer {
83 0 => 0_u64,
84 1 => u8::from_be_bytes([src[1]]) as u64,
85 2 => u16::from_be_bytes([src[1], src[2]]) as u64,
86 3..=4 => u32::from_be_bytes(
87 src[1..index_pointer as usize]
88 .try_into()
89 .or(Err(ChunkIOError::InvalidChunk))?,
90 ) as u64,
91 5..=8 => u64::from_be_bytes(
92 src[1..index_pointer as usize]
93 .try_into()
94 .or(Err(ChunkIOError::InvalidChunk))?,
95 ),
96 _ => return Err(ChunkIOError::InvalidChunk),
97 };
98 if self.current_index.1 != index {
99 return Err(ChunkIOError::OutOfOrder);
100 }
101
102 let length = match len_pointer {
103 1 => u8::from_be_bytes([src[1 + index_pointer as usize]]) as u64,
104 2 => u16::from_be_bytes([
105 src[1 + index_pointer as usize],
106 src[2 + index_pointer as usize],
107 ]) as u64,
108 3..=4 => u32::from_be_bytes(
109 src[1 + (index_pointer as usize)..((index_pointer + len_pointer) as usize)]
110 .try_into()
111 .or(Err(ChunkIOError::InvalidChunk))?,
112 ) as u64,
113 5..=8 => u64::from_be_bytes(
114 src[1 + index_pointer as usize..(index_pointer + len_pointer) as usize]
115 .try_into()
116 .or(Err(ChunkIOError::InvalidChunk))?,
117 ),
118 _ => {
119 return Err(ChunkIOError::InvalidChunk);
120 }
121 };
122 if src.len() < (index_pointer + len_pointer + length + 1) as usize {
123 Ok(None)
124 } else {
125 src.advance((index_pointer + len_pointer + 1) as usize);
126 self.current_index.1 += length;
127 Ok(Some(src.split_to(length as usize).to_vec()))
128 }
129 }
130}
131
132impl Encoder<Vec<u8>> for ChunkIOProto {
133 type Error = ChunkIOError;
134
135 fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
136 let index = self
137 .current_index
138 .0
139 .to_be_bytes()
140 .into_iter()
141 .skip_while(|x| *x == 0)
142 .collect::<Vec<u8>>();
143 let length = item
144 .len()
145 .to_be_bytes()
146 .into_iter()
147 .skip_while(|x| *x == 0)
148 .collect::<Vec<u8>>();
149 dst.extend_from_slice(&[((index.len() as u8) << 4) | (length.len() as u8)]);
150 dst.extend_from_slice(&index);
151 dst.extend_from_slice(&length);
152 dst.extend_from_slice(&item);
153 self.current_index.0 += item.len() as u64;
154
155 Ok(())
156 }
157}