mqtt_proto/common/
poll.rs

1use std::future::Future;
2use std::io;
3use std::mem::{self, MaybeUninit};
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, ReadBuf};
8
9use crate::Error;
10
11#[derive(Debug, Clone)]
12pub enum GenericPollPacketState<H> {
13    Header(PollHeaderState),
14    Body(GenericPollBodyState<H>),
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct PollHeaderState {
19    pub control_byte: Option<u8>,
20    pub var_idx: u8,
21    pub var_int: u32,
22}
23
24#[derive(Debug, Clone)]
25pub struct GenericPollBodyState<H> {
26    pub header: H,
27    /// Packet total size (include header)
28    pub total: usize,
29    pub idx: usize,
30    pub buf: Vec<MaybeUninit<u8>>,
31}
32
33pub trait PollHeader {
34    type Error;
35    type Packet;
36
37    fn new_with(hd: u8, remaining_len: u32) -> Result<Self, Self::Error>
38    where
39        Self: Sized;
40    /// Packet without body is empty packet
41    fn build_empty_packet(&self) -> Option<Self::Packet>;
42    fn block_decode(self, reader: &mut &[u8]) -> Result<Self::Packet, Self::Error>;
43    fn remaining_len(&self) -> usize;
44    fn is_eof_error(err: &Self::Error) -> bool;
45}
46
47impl<H> Default for GenericPollPacketState<H> {
48    fn default() -> Self {
49        GenericPollPacketState::Header(PollHeaderState::default())
50    }
51}
52
53pub struct GenericPollPacket<'a, T, H> {
54    state: &'a mut GenericPollPacketState<H>,
55    reader: &'a mut T,
56}
57
58impl<'a, T, H> GenericPollPacket<'a, T, H> {
59    pub fn new(state: &'a mut GenericPollPacketState<H>, reader: &'a mut T) -> Self {
60        GenericPollPacket { state, reader }
61    }
62}
63
64impl<'a, T, H> Future for GenericPollPacket<'a, T, H>
65where
66    T: AsyncRead + Unpin,
67    H: PollHeader + Copy + Unpin,
68    H::Error: From<io::Error> + From<Error>,
69{
70    type Output = Result<(usize, Vec<MaybeUninit<u8>>, H::Packet), H::Error>;
71
72    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73        let GenericPollPacket {
74            ref mut state,
75            ref mut reader,
76        } = self.get_mut();
77        loop {
78            match state {
79                GenericPollPacketState::Header(PollHeaderState {
80                    control_byte,
81                    var_idx,
82                    var_int,
83                }) => {
84                    let mut buf = [0u8; 1];
85                    loop {
86                        let mut readbuf = ReadBuf::new(&mut buf);
87                        let _size = match Pin::new(&mut *reader).poll_read(cx, &mut readbuf) {
88                            Poll::Ready(Ok(())) => {
89                                let size = readbuf.filled().len();
90                                if size == 0 {
91                                    return Poll::Ready(Err(Error::IoError(
92                                        io::ErrorKind::UnexpectedEof,
93                                        "eof".to_owned(),
94                                    )
95                                    .into()));
96                                }
97                                size
98                            }
99                            Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
100                            Poll::Pending => return Poll::Pending,
101                        };
102
103                        let byte = readbuf.filled()[0];
104                        if control_byte.is_none() {
105                            *control_byte = Some(byte);
106                        } else {
107                            *var_int |= (u32::from(byte) & 0x7F) << (7 * u32::from(*var_idx));
108                            if byte & 0x80 == 0 {
109                                break;
110                            } else if *var_idx < 3 {
111                                *var_idx += 1;
112                            } else {
113                                return Poll::Ready(Err(Error::InvalidVarByteInt.into()));
114                            }
115                        }
116                    }
117
118                    let header = match H::new_with(control_byte.unwrap(), *var_int) {
119                        Ok(header) => header,
120                        Err(err) => return Poll::Ready(Err(err)),
121                    };
122                    if let Some(empty_packet) = header.build_empty_packet() {
123                        return Poll::Ready(Ok((2, Vec::new(), empty_packet)));
124                    }
125                    if header.remaining_len() == 0 {
126                        return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
127                    }
128                    let mut buf: Vec<MaybeUninit<u8>> = Vec::with_capacity(header.remaining_len());
129                    unsafe {
130                        buf.set_len(header.remaining_len());
131                    }
132                    **state = GenericPollPacketState::Body(GenericPollBodyState {
133                        header,
134                        total: 1 + 1 + *var_idx as usize + header.remaining_len(),
135                        idx: 0,
136                        buf,
137                    });
138                }
139                GenericPollPacketState::Body(GenericPollBodyState {
140                    header,
141                    idx,
142                    buf,
143                    total,
144                }) => loop {
145                    let buf_refmut: &mut [u8] = unsafe { mem::transmute(&mut buf[*idx..]) };
146                    let mut readbuf_refmut = ReadBuf::new(buf_refmut);
147                    let size = match Pin::new(&mut *reader).poll_read(cx, &mut readbuf_refmut) {
148                        Poll::Ready(Ok(())) => {
149                            let size = readbuf_refmut.filled().len();
150                            if size == 0 {
151                                return Poll::Ready(Err(Error::IoError(
152                                    io::ErrorKind::UnexpectedEof,
153                                    "eof".to_owned(),
154                                )
155                                .into()));
156                            }
157                            size
158                        }
159                        Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
160                        Poll::Pending => return Poll::Pending,
161                    };
162
163                    *idx += size;
164                    debug_assert!(*idx <= buf.len());
165
166                    if *idx == buf.len() {
167                        let mut buf_ref: &[u8] = unsafe { mem::transmute(&buf[..]) };
168                        let result = header.block_decode(&mut buf_ref);
169                        if result.is_ok() && !buf_ref.is_empty() {
170                            return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
171                        }
172                        if let Err(err) = &result {
173                            if H::is_eof_error(err) {
174                                return Poll::Ready(Err(Error::InvalidRemainingLength.into()));
175                            }
176                        }
177                        return Poll::Ready(result.map(|packet| (*total, mem::take(buf), packet)));
178                    }
179                },
180            }
181        }
182    }
183}