mqtt_proto/common/
poll.rs

1use core::future::Future;
2use core::mem::{self, MaybeUninit};
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use alloc::vec::Vec;
7
8use crate::{from_read_exact_error, AsyncRead, Error, IoErrorKind};
9
10#[derive(Debug, Clone)]
11pub enum GenericPollPacketState<H> {
12    Header(PollHeaderState),
13    Body(GenericPollBodyState<H>),
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct PollHeaderState {
18    pub control_byte: Option<u8>,
19    pub var_idx: u8,
20    pub var_int: u32,
21}
22
23#[derive(Debug, Clone)]
24pub struct GenericPollBodyState<H> {
25    pub header: H,
26    /// Packet total size (include header)
27    pub total: usize,
28    pub idx: usize,
29    pub buf: Vec<MaybeUninit<u8>>,
30}
31
32pub trait PollHeader {
33    type Error;
34    type Packet;
35
36    fn new_with(hd: u8, remaining_len: u32) -> Result<Self, Self::Error>
37    where
38        Self: Sized;
39    /// Packet without body is empty packet
40    fn build_empty_packet(&self) -> Option<Self::Packet>;
41    fn block_decode(self, reader: &mut &[u8]) -> Result<Self::Packet, Self::Error>;
42    fn remaining_len(&self) -> usize;
43    fn is_eof_error(err: &Self::Error) -> bool;
44}
45
46impl<H> Default for GenericPollPacketState<H> {
47    fn default() -> Self {
48        GenericPollPacketState::Header(PollHeaderState::default())
49    }
50}
51
52pub struct GenericPollPacket<'a, T, H> {
53    state: &'a mut GenericPollPacketState<H>,
54    reader: &'a mut T,
55}
56
57impl<'a, T, H> GenericPollPacket<'a, T, H> {
58    pub fn new(state: &'a mut GenericPollPacketState<H>, reader: &'a mut T) -> Self {
59        GenericPollPacket { state, reader }
60    }
61}
62
63impl<'a, T, H> Future for GenericPollPacket<'a, T, H>
64where
65    T: AsyncRead + Unpin,
66    H: PollHeader + Copy + Unpin,
67    H::Error: From<Error>,
68    H::Error: From<T::Error>,
69{
70    type Output = Result<(usize, Vec<MaybeUninit<u8>>, H::Packet), H::Error>;
71
72    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73        let GenericPollPacket {
74            ref mut state,
75            ref mut reader,
76        } = self.get_mut();
77
78        let future = async move {
79            loop {
80                match state {
81                    GenericPollPacketState::Header(PollHeaderState {
82                        control_byte,
83                        var_idx,
84                        var_int,
85                    }) => {
86                        if control_byte.is_none() {
87                            let mut buf = [0u8; 1];
88                            reader
89                                .read_exact(&mut buf)
90                                .await
91                                .map_err(from_read_exact_error)?;
92                            *control_byte = Some(buf[0]);
93                        }
94
95                        loop {
96                            let mut buf = [0u8; 1];
97                            reader
98                                .read_exact(&mut buf)
99                                .await
100                                .map_err(from_read_exact_error)?;
101
102                            let byte = buf[0];
103                            *var_int |= (u32::from(byte) & 0x7F) << (7 * u32::from(*var_idx));
104                            if byte & 0x80 == 0 {
105                                break;
106                            } else if *var_idx < 3 {
107                                *var_idx += 1;
108                            } else {
109                                return Err(Error::InvalidVarByteInt.into());
110                            }
111                        }
112
113                        let header = match H::new_with(control_byte.unwrap(), *var_int) {
114                            Ok(header) => header,
115                            Err(err) => return Err(err),
116                        };
117
118                        if let Some(empty_packet) = header.build_empty_packet() {
119                            return Ok((2, Vec::new(), empty_packet));
120                        }
121
122                        if header.remaining_len() == 0 {
123                            return Err(Error::InvalidRemainingLength.into());
124                        }
125
126                        let mut buf: Vec<MaybeUninit<u8>> =
127                            Vec::with_capacity(header.remaining_len());
128                        unsafe {
129                            buf.set_len(header.remaining_len());
130                        }
131
132                        **state = GenericPollPacketState::Body(GenericPollBodyState {
133                            header,
134                            total: 1 + 1 + *var_idx as usize + header.remaining_len(),
135                            idx: 0,
136                            buf,
137                        });
138                    }
139                    GenericPollPacketState::Body(GenericPollBodyState {
140                        header,
141                        idx,
142                        buf,
143                        total,
144                    }) => {
145                        while *idx < buf.len() {
146                            let remaining = buf.len() - *idx;
147                            let buf_slice: &mut [u8] = unsafe {
148                                core::slice::from_raw_parts_mut(
149                                    buf[*idx..].as_mut_ptr() as *mut u8,
150                                    remaining,
151                                )
152                            };
153
154                            match reader.read(buf_slice).await {
155                                Ok(0) => {
156                                    return Err(Error::IoError(IoErrorKind::UnexpectedEof).into())
157                                }
158                                Ok(n) => *idx += n,
159                                Err(e) => return Err(e.into()),
160                            }
161                        }
162
163                        let mut buf_ref: &[u8] = unsafe { mem::transmute(&buf[..]) };
164                        let result = header.block_decode(&mut buf_ref);
165                        if result.is_ok() && !buf_ref.is_empty() {
166                            return Err(Error::InvalidRemainingLength.into());
167                        }
168                        if let Err(err) = &result {
169                            if H::is_eof_error(err) {
170                                return Err(Error::InvalidRemainingLength.into());
171                            }
172                        }
173                        return result.map(|packet| (*total, mem::take(buf), packet));
174                    }
175                }
176            }
177        };
178
179        futures_lite::pin!(future);
180        future.as_mut().poll(cx)
181    }
182}