1use std::rc::Rc;
2
3use ntex_bytes::{ByteString, Bytes, PoolId};
4use ntex_io::{DispatcherConfig, IoBoxed};
5use ntex_net::connect::{self, Address, Connect, Connector};
6use ntex_service::{IntoService, Pipeline, Service};
7use ntex_util::time::{timeout_checked, Seconds};
8
9use super::{connection::Client, error::ClientError, error::ProtocolError};
10use crate::v3::codec::{self, Decoded, Encoded, Packet};
11use crate::v3::shared::{MqttShared, MqttSinkPool};
12
13pub struct MqttConnector<A, T> {
15 address: A,
16 connector: Pipeline<T>,
17 pkt: codec::Connect,
18 max_size: u32,
19 max_send: usize,
20 max_receive: usize,
21 min_chunk_size: u32,
22 handshake_timeout: Seconds,
23 config: DispatcherConfig,
24 pool: Rc<MqttSinkPool>,
25}
26
27impl<A> MqttConnector<A, ()>
28where
29 A: Address + Clone,
30{
31 #[allow(clippy::new_ret_no_self)]
32 pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
34 let config = DispatcherConfig::default();
35 config.set_disconnect_timeout(Seconds(3)).set_keepalive_timeout(Seconds(0));
36
37 MqttConnector {
38 address,
39 config,
40 pkt: codec::Connect::default(),
41 connector: Pipeline::new(Connector::default()),
42 max_size: 64 * 1024,
43 max_send: 16,
44 max_receive: 16,
45 min_chunk_size: 32 * 1024,
46 handshake_timeout: Seconds::ZERO,
47 pool: Rc::new(MqttSinkPool::default()),
48 }
49 }
50}
51
52impl<A, T> MqttConnector<A, T>
53where
54 A: Address + Clone,
55{
56 #[inline]
57 pub fn client_id<U>(mut self, client_id: U) -> Self
59 where
60 ByteString: From<U>,
61 {
62 self.pkt.client_id = client_id.into();
63 self
64 }
65
66 #[inline]
67 pub fn clean_session(mut self) -> Self {
69 self.pkt.clean_session = true;
70 self
71 }
72
73 #[inline]
74 pub fn keep_alive(mut self, val: Seconds) -> Self {
78 self.pkt.keep_alive = val.seconds() as u16;
79 self
80 }
81
82 #[inline]
83 pub fn last_will(mut self, val: codec::LastWill) -> Self {
87 self.pkt.last_will = Some(val);
88 self
89 }
90
91 #[inline]
92 pub fn username<U>(mut self, val: U) -> Self
94 where
95 ByteString: From<U>,
96 {
97 self.pkt.username = Some(val.into());
98 self
99 }
100
101 #[inline]
102 pub fn password(mut self, val: Bytes) -> Self {
104 self.pkt.password = Some(val);
105 self
106 }
107
108 #[inline]
109 pub fn max_size(mut self, val: u32) -> Self {
113 self.max_size = val;
114 self
115 }
116
117 #[inline]
118 pub fn max_send(mut self, val: u16) -> Self {
123 self.max_send = val as usize;
124 self
125 }
126
127 #[inline]
128 pub fn max_receive(mut self, val: u16) -> Self {
132 self.max_receive = val as usize;
133 self
134 }
135
136 pub fn min_chunk_size(mut self, size: u32) -> Self {
143 self.min_chunk_size = size;
144 self
145 }
146
147 #[inline]
148 pub fn packet<F>(mut self, f: F) -> Self
150 where
151 F: FnOnce(&mut codec::Connect),
152 {
153 f(&mut self.pkt);
154 self
155 }
156
157 pub fn handshake_timeout(mut self, timeout: Seconds) -> Self {
162 self.handshake_timeout = timeout;
163 self
164 }
165
166 pub fn disconnect_timeout(self, timeout: Seconds) -> Self {
175 self.config.set_disconnect_timeout(timeout);
176 self
177 }
178
179 pub fn memory_pool(self, id: PoolId) -> Self {
184 self.pool.pool.set(id.pool_ref());
185 self
186 }
187
188 pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
190 where
191 F: IntoService<U, Connect<A>>,
192 U: Service<Connect<A>, Error = connect::ConnectError>,
193 IoBoxed: From<U::Response>,
194 {
195 MqttConnector {
196 connector: Pipeline::new(connector.into_service()),
197 pkt: self.pkt,
198 address: self.address,
199 config: self.config,
200 max_size: self.max_size,
201 max_send: self.max_send,
202 max_receive: self.max_receive,
203 min_chunk_size: self.min_chunk_size,
204 handshake_timeout: self.handshake_timeout,
205 pool: self.pool,
206 }
207 }
208}
209
210impl<A, T> MqttConnector<A, T>
211where
212 A: Address + Clone,
213 T: Service<Connect<A>, Error = connect::ConnectError>,
214 IoBoxed: From<T::Response>,
215{
216 pub async fn connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
218 timeout_checked(self.handshake_timeout, self._connect())
219 .await
220 .map_err(|_| ClientError::HandshakeTimeout)
221 .and_then(|res| res)
222 }
223
224 async fn _connect(&self) -> Result<Client, ClientError<codec::ConnectAck>> {
225 let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
226 let pkt = self.pkt.clone();
227 let max_send = self.max_send;
228 let max_receive = self.max_receive;
229 let keepalive_timeout = pkt.keep_alive;
230 let config = self.config.clone();
231 let pool = self.pool.clone();
232 let codec = codec::Codec::new();
233 codec.set_max_size(self.max_size);
234 codec.set_min_chunk_size(self.min_chunk_size);
235
236 io.encode(Encoded::Packet(pkt.into()), &codec)?;
237
238 let packet = io.recv(&codec).await.map_err(ClientError::from)?.ok_or_else(|| {
239 log::trace!("Mqtt server is disconnected during handshake");
240 ClientError::Disconnected(None)
241 })?;
242
243 let shared = Rc::new(MqttShared::new(io.get_ref(), codec, true, pool));
244
245 match packet {
246 Decoded::Packet(codec::Packet::ConnectAck(pkt), _) => {
247 log::trace!("Connect ack response from server: session: present: {:?}, return code: {:?}", pkt.session_present, pkt.return_code);
248 if pkt.return_code == codec::ConnectAckReason::ConnectionAccepted {
249 shared.set_cap(max_send);
250 Ok(Client::new(
251 io,
252 shared,
253 pkt.session_present,
254 Seconds(keepalive_timeout),
255 max_receive,
256 config,
257 ))
258 } else {
259 Err(ClientError::Ack(pkt))
260 }
261 }
262 Decoded::Packet(p, _) => Err(ProtocolError::unexpected_packet(
263 p.packet_type(),
264 "Expected CONNACK packet",
265 )
266 .into()),
267 Decoded::Publish(..) => Err(ProtocolError::unexpected_packet(
268 crate::types::packet_type::PUBLISH_START,
269 "CONNACK packet expected from server first [MQTT-3.2.0-1]",
270 )
271 .into()),
272 Decoded::PayloadChunk(..) => unreachable!(),
273 }
274 }
275}