mqtt_proto/common/
poll.rs1use std::future::Future;
2use std::io;
3use std::mem::{self, MaybeUninit};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::Error;
10
11#[derive(Debug, Clone)]
12pub enum GenericPollPacketState<H> {
13 Header(PollHeaderState),
14 Body(GenericPollBodyState<H>),
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct PollHeaderState {
19 pub control_byte: Option<u8>,
20 pub var_idx: u8,
21 pub var_int: u32,
22}
23
24#[derive(Debug, Clone)]
25pub struct GenericPollBodyState<H> {
26 pub header: H,
27 pub total: usize,
29 pub idx: usize,
30 pub buf: Vec<MaybeUninit<u8>>,
31}
32
33pub trait PollHeader {
34 type Error;
35 type Packet;
36
37 fn new_with(hd: u8, remaining_len: u32) -> Result<Self, Self::Error>
38 where
39 Self: Sized;
40 fn build_empty_packet(&self) -> Option<Self::Packet>;
42 fn block_decode(self, reader: &mut &[u8]) -> Result<Self::Packet, Self::Error>;
43 fn remaining_len(&self) -> usize;
44 fn is_eof_error(err: &Self::Error) -> bool;
45}
46
47impl<H> Default for GenericPollPacketState<H> {
48 fn default() -> Self {
49 GenericPollPacketState::Header(PollHeaderState::default())
50 }
51}
52
53pub struct GenericPollPacket<'a, T, H> {
54 state: &'a mut GenericPollPacketState<H>,
55 reader: &'a mut T,
56}
57
58impl<'a, T, H> GenericPollPacket<'a, T, H> {
59 pub fn new(state: &'a mut GenericPollPacketState<H>, reader: &'a mut T) -> Self {
60 GenericPollPacket { state, reader }
61 }
62}
63
64impl<'a, T, H> Future for GenericPollPacket<'a, T, H>
65where
66 T: AsyncRead + Unpin,
67 H: PollHeader + Copy + Unpin,
68 H::Error: From<io::Error> + From<Error>,
69{
70 type Output = Result<(usize, Vec<MaybeUninit<u8>>, H::Packet), H::Error>;
71
72 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73 let GenericPollPacket {
74 ref mut state,
75 ref mut reader,
76 } = self.get_mut();
77 loop {
78 match state {
79 GenericPollPacketState::Header(PollHeaderState {
80 control_byte,
81 var_idx,
82 var_int,
83 }) => {
84 let mut buf = [0u8; 1];
85 loop {
86 let mut readbuf = ReadBuf::new(&mut buf);
87 let _size = match Pin::new(&mut *reader).poll_read(cx, &mut readbuf) {
88 Poll::Ready(Ok(())) => {
89 let size = readbuf.filled().len();
90 if size == 0 {
91 return Poll::Ready(Err(Error::IoError(
92 io::ErrorKind::UnexpectedEof,
93 "eof".to_owned(),
94 )
95 .into()));
96 }
97 size
98 }
99 Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
100 Poll::Pending => return Poll::Pending,
101 };
102
103 let byte = readbuf.filled()[0];
104 if control_byte.is_none() {
105 *control_byte = Some(byte);
106 } else {
107 *var_int |= (u32::from(byte) & 0x7F) << (7 * u32::from(*var_idx));
108 if byte & 0x80 == 0 {
109 break;
110 } else if *var_idx < 3 {
111 *var_idx += 1;
112 } else {
113 return Poll::Ready(Err(Error::InvalidVarByteInt.into()));
114 }
115 }
116 }
117
118 let header = match H::new_with(control_byte.unwrap(), *var_int) {
119 Ok(header) => header,
120 Err(err) => return Poll::Ready(Err(err)),
121 };
122 if let Some(empty_packet) = header.build_empty_packet() {
123 return Poll::Ready(Ok((2, Vec::new(), empty_packet)));
124 }
125 if header.remaining_len() == 0 {
126 return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
127 }
128 let mut buf: Vec<MaybeUninit<u8>> = Vec::with_capacity(header.remaining_len());
129 unsafe {
130 buf.set_len(header.remaining_len());
131 }
132 **state = GenericPollPacketState::Body(GenericPollBodyState {
133 header,
134 total: 1 + 1 + *var_idx as usize + header.remaining_len(),
135 idx: 0,
136 buf,
137 });
138 }
139 GenericPollPacketState::Body(GenericPollBodyState {
140 header,
141 idx,
142 buf,
143 total,
144 }) => loop {
145 let buf_refmut: &mut [u8] = unsafe { mem::transmute(&mut buf[*idx..]) };
146 let mut readbuf_refmut = ReadBuf::new(buf_refmut);
147 let size = match Pin::new(&mut *reader).poll_read(cx, &mut readbuf_refmut) {
148 Poll::Ready(Ok(())) => {
149 let size = readbuf_refmut.filled().len();
150 if size == 0 {
151 return Poll::Ready(Err(Error::IoError(
152 io::ErrorKind::UnexpectedEof,
153 "eof".to_owned(),
154 )
155 .into()));
156 }
157 size
158 }
159 Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
160 Poll::Pending => return Poll::Pending,
161 };
162
163 *idx += size;
164 debug_assert!(*idx <= buf.len());
165
166 if *idx == buf.len() {
167 let mut buf_ref: &[u8] = unsafe { mem::transmute(&buf[..]) };
168 let result = header.block_decode(&mut buf_ref);
169 if result.is_ok() && !buf_ref.is_empty() {
170 return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
171 }
172 if let Err(err) = &result {
173 if H::is_eof_error(err) {
174 return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
175 }
176 }
177 return Poll::Ready(result.map(|packet| (*total, mem::take(buf), packet)));
178 }
179 },
180 }
181 }
182 }
183}