1use std::{num::NonZeroU16, num::NonZeroU32, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes, PoolId};
4use ntex_io::{DispatcherConfig, IoBoxed};
5use ntex_net::connect::{self, Address, Connect, Connector};
6use ntex_service::{IntoService, Pipeline, Service};
7use ntex_util::time::{timeout_checked, Seconds};
8
9use super::codec::{self, Decoded, Encoded, Packet};
10use super::{connection::Client, error::ClientError, error::ProtocolError};
11use crate::v5::shared::{MqttShared, MqttSinkPool};
12
13pub struct MqttConnector<A, T> {
15 address: A,
16 connector: Pipeline<T>,
17 pkt: codec::Connect,
18 handshake_timeout: Seconds,
19 min_chunk_size: u32,
20 config: DispatcherConfig,
21 pool: Rc<MqttSinkPool>,
22}
23
24impl<A> MqttConnector<A, ()>
25where
26 A: Address + Clone,
27{
28 #[allow(clippy::new_ret_no_self)]
29 pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
31 let config = DispatcherConfig::default();
32 config.set_disconnect_timeout(Seconds(3)).set_keepalive_timeout(Seconds(0));
33 MqttConnector {
34 address,
35 config,
36 pkt: codec::Connect::default(),
37 connector: Pipeline::new(Connector::default()),
38 handshake_timeout: Seconds::ZERO,
39 min_chunk_size: 32 * 1024,
40 pool: Rc::new(MqttSinkPool::default()),
41 }
42 }
43}
44
45impl<A, T> MqttConnector<A, T>
46where
47 A: Address + Clone,
48{
49 #[inline]
50 pub fn client_id<U>(mut self, client_id: U) -> Self
52 where
53 ByteString: From<U>,
54 {
55 self.pkt.client_id = client_id.into();
56 self
57 }
58
59 #[inline]
60 pub fn clean_start(mut self) -> Self {
62 self.pkt.clean_start = true;
63 self
64 }
65
66 #[inline]
67 pub fn keep_alive(mut self, val: Seconds) -> Self {
71 self.pkt.keep_alive = val.seconds() as u16;
72 self
73 }
74
75 #[inline]
76 pub fn last_will(mut self, val: codec::LastWill) -> Self {
80 self.pkt.last_will = Some(val);
81 self
82 }
83
84 #[inline]
85 pub fn auth(mut self, method: ByteString, data: Bytes) -> Self {
87 self.pkt.auth_method = Some(method);
88 self.pkt.auth_data = Some(data);
89 self
90 }
91
92 #[inline]
93 pub fn username(mut self, val: ByteString) -> Self {
95 self.pkt.username = Some(val);
96 self
97 }
98
99 #[inline]
100 pub fn password(mut self, val: Bytes) -> Self {
102 self.pkt.password = Some(val);
103 self
104 }
105
106 #[inline]
107 pub fn max_packet_size(mut self, val: u32) -> Self {
111 if let Some(val) = NonZeroU32::new(val) {
112 self.pkt.max_packet_size = Some(val);
113 } else {
114 self.pkt.max_packet_size = None;
115 }
116 self
117 }
118
119 pub fn min_chunk_size(mut self, size: u32) -> Self {
126 self.min_chunk_size = size;
127 self
128 }
129
130 #[inline]
131 pub fn max_receive(mut self, val: u16) -> Self {
136 if let Some(val) = NonZeroU16::new(val) {
137 self.pkt.receive_max = Some(val);
138 } else {
139 self.pkt.receive_max = None;
140 }
141 self
142 }
143
144 #[inline]
145 pub fn properties<F>(mut self, f: F) -> Self
147 where
148 F: FnOnce(&mut codec::UserProperties),
149 {
150 f(&mut self.pkt.user_properties);
151 self
152 }
153
154 #[inline]
155 pub fn packet<F>(mut self, f: F) -> Self
157 where
158 F: FnOnce(&mut codec::Connect),
159 {
160 f(&mut self.pkt);
161 self
162 }
163
164 pub fn handshake_timeout(mut self, timeout: Seconds) -> Self {
169 self.handshake_timeout = timeout;
170 self
171 }
172
173 pub fn disconnect_timeout(self, timeout: Seconds) -> Self {
182 self.config.set_disconnect_timeout(timeout);
183 self
184 }
185
186 pub fn memory_pool(self, id: PoolId) -> Self {
191 self.pool.pool.set(id.pool_ref());
192 self
193 }
194
195 pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
197 where
198 F: IntoService<U, Connect<A>>,
199 U: Service<Connect<A>, Error = connect::ConnectError>,
200 IoBoxed: From<U::Response>,
201 {
202 MqttConnector {
203 connector: Pipeline::new(connector.into_service()),
204 pkt: self.pkt,
205 address: self.address,
206 config: self.config,
207 handshake_timeout: self.handshake_timeout,
208 min_chunk_size: self.min_chunk_size,
209 pool: self.pool,
210 }
211 }
212}
213
214impl<A, T> MqttConnector<A, T>
215where
216 A: Address + Clone,
217 T: Service<Connect<A>, Error = connect::ConnectError>,
218 IoBoxed: From<T::Response>,
219{
220 pub async fn connect(&self) -> Result<Client, ClientError<Box<codec::ConnectAck>>> {
222 timeout_checked(self.handshake_timeout, self._connect())
223 .await
224 .map_err(|_| ClientError::HandshakeTimeout)
225 .and_then(|res| res)
226 }
227
228 async fn _connect(&self) -> Result<Client, ClientError<Box<codec::ConnectAck>>> {
229 let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
230 let pkt = self.pkt.clone();
231 let keep_alive = pkt.keep_alive;
232 let max_packet_size = pkt.max_packet_size.map(|v| v.get()).unwrap_or(0);
233 let max_receive = pkt.receive_max.map(|v| v.get()).unwrap_or(65535);
234 let pool = self.pool.clone();
235 let config = self.config.clone();
236
237 let codec = codec::Codec::new();
238 codec.set_max_inbound_size(max_packet_size);
239 codec.set_min_chunk_size(self.min_chunk_size);
240
241 io.encode(Encoded::Packet(Packet::Connect(Box::new(pkt))), &codec)?;
242
243 let packet = io.recv(&codec).await.map_err(ClientError::from)?.ok_or_else(|| {
244 log::trace!("Mqtt server is disconnected during handshake");
245 ClientError::Disconnected(None)
246 })?;
247
248 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, pool));
249 match packet {
250 Decoded::Packet(Packet::ConnectAck(pkt), ..) => {
251 log::trace!("Connect ack response from server: {:#?}", pkt);
252 if pkt.reason_code == codec::ConnectAckReason::Success {
253 if let Some(size) = pkt.max_packet_size {
255 shared.codec.set_max_outbound_size(size);
256 }
257 let keep_alive = pkt.server_keepalive_sec.unwrap_or(keep_alive);
259
260 shared.set_cap(pkt.receive_max.get() as usize);
261
262 Ok(Client::new(io, shared, pkt, max_receive, Seconds(keep_alive), config))
263 } else {
264 Err(ClientError::Ack(pkt))
265 }
266 }
267 Decoded::Packet(packet, ..) => Err(ProtocolError::unexpected_packet(
268 packet.packet_type(),
269 "CONNACK packet expected from server first [MQTT-3.2.0-1]",
270 )
271 .into()),
272 Decoded::Publish(..) => Err(ProtocolError::unexpected_packet(
273 crate::types::packet_type::PUBLISH_START,
274 "CONNACK packet expected from server first [MQTT-3.2.0-1]",
275 )
276 .into()),
277 Decoded::PayloadChunk(..) => unreachable!(),
278 }
279 }
280}