mqtt_proto/common/
poll.rs1use core::future::Future;
2use core::mem::{self, MaybeUninit};
3use core::pin::Pin;
4use core::task::{Context, Poll};
5
6use alloc::vec::Vec;
7
8use crate::{from_read_exact_error, AsyncRead, Error, IoErrorKind};
9
10#[derive(Debug, Clone)]
11pub enum GenericPollPacketState<H> {
12 Header(PollHeaderState),
13 Body(GenericPollBodyState<H>),
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct PollHeaderState {
18 pub control_byte: Option<u8>,
19 pub var_idx: u8,
20 pub var_int: u32,
21}
22
23#[derive(Debug, Clone)]
24pub struct GenericPollBodyState<H> {
25 pub header: H,
26 pub total: usize,
28 pub idx: usize,
29 pub buf: Vec<MaybeUninit<u8>>,
30}
31
32pub trait PollHeader {
33 type Error;
34 type Packet;
35
36 fn new_with(hd: u8, remaining_len: u32) -> Result<Self, Self::Error>
37 where
38 Self: Sized;
39 fn build_empty_packet(&self) -> Option<Self::Packet>;
41 fn block_decode(self, reader: &mut &[u8]) -> Result<Self::Packet, Self::Error>;
42 fn remaining_len(&self) -> usize;
43 fn is_eof_error(err: &Self::Error) -> bool;
44}
45
46impl<H> Default for GenericPollPacketState<H> {
47 fn default() -> Self {
48 GenericPollPacketState::Header(PollHeaderState::default())
49 }
50}
51
52pub struct GenericPollPacket<'a, T, H> {
53 state: &'a mut GenericPollPacketState<H>,
54 reader: &'a mut T,
55}
56
57impl<'a, T, H> GenericPollPacket<'a, T, H> {
58 pub fn new(state: &'a mut GenericPollPacketState<H>, reader: &'a mut T) -> Self {
59 GenericPollPacket { state, reader }
60 }
61}
62
63impl<'a, T, H> Future for GenericPollPacket<'a, T, H>
64where
65 T: AsyncRead + Unpin,
66 H: PollHeader + Copy + Unpin,
67 H::Error: From<Error>,
68 H::Error: From<T::Error>,
69{
70 type Output = Result<(usize, Vec<MaybeUninit<u8>>, H::Packet), H::Error>;
71
72 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
73 let GenericPollPacket {
74 ref mut state,
75 ref mut reader,
76 } = self.get_mut();
77
78 let future = async move {
79 loop {
80 match state {
81 GenericPollPacketState::Header(PollHeaderState {
82 control_byte,
83 var_idx,
84 var_int,
85 }) => {
86 if control_byte.is_none() {
87 let mut buf = [0u8; 1];
88 reader
89 .read_exact(&mut buf)
90 .await
91 .map_err(from_read_exact_error)?;
92 *control_byte = Some(buf[0]);
93 }
94
95 loop {
96 let mut buf = [0u8; 1];
97 reader
98 .read_exact(&mut buf)
99 .await
100 .map_err(from_read_exact_error)?;
101
102 let byte = buf[0];
103 *var_int |= (u32::from(byte) & 0x7F) << (7 * u32::from(*var_idx));
104 if byte & 0x80 == 0 {
105 break;
106 } else if *var_idx < 3 {
107 *var_idx += 1;
108 } else {
109 return Err(Error::InvalidVarByteInt.into());
110 }
111 }
112
113 let header = match H::new_with(control_byte.unwrap(), *var_int) {
114 Ok(header) => header,
115 Err(err) => return Err(err),
116 };
117
118 if let Some(empty_packet) = header.build_empty_packet() {
119 return Ok((2, Vec::new(), empty_packet));
120 }
121
122 if header.remaining_len() == 0 {
123 return Err(Error::InvalidRemainingLength.into());
124 }
125
126 let mut buf: Vec<MaybeUninit<u8>> =
127 Vec::with_capacity(header.remaining_len());
128 unsafe {
129 buf.set_len(header.remaining_len());
130 }
131
132 **state = GenericPollPacketState::Body(GenericPollBodyState {
133 header,
134 total: 1 + 1 + *var_idx as usize + header.remaining_len(),
135 idx: 0,
136 buf,
137 });
138 }
139 GenericPollPacketState::Body(GenericPollBodyState {
140 header,
141 idx,
142 buf,
143 total,
144 }) => {
145 while *idx < buf.len() {
146 let remaining = buf.len() - *idx;
147 let buf_slice: &mut [u8] = unsafe {
148 core::slice::from_raw_parts_mut(
149 buf[*idx..].as_mut_ptr() as *mut u8,
150 remaining,
151 )
152 };
153
154 match reader.read(buf_slice).await {
155 Ok(0) => {
156 return Err(Error::IoError(IoErrorKind::UnexpectedEof).into())
157 }
158 Ok(n) => *idx += n,
159 Err(e) => return Err(e.into()),
160 }
161 }
162
163 let mut buf_ref: &[u8] = unsafe { mem::transmute(&buf[..]) };
164 let result = header.block_decode(&mut buf_ref);
165 if result.is_ok() && !buf_ref.is_empty() {
166 return Err(Error::InvalidRemainingLength.into());
167 }
168 if let Err(err) = &result {
169 if H::is_eof_error(err) {
170 return Err(Error::InvalidRemainingLength.into());
171 }
172 }
173 return result.map(|packet| (*total, mem::take(buf), packet));
174 }
175 }
176 }
177 };
178
179 futures_lite::pin!(future);
180 future.as_mut().poll(cx)
181 }
182}