ntex_mqtt/v5/client/
connector.rs

1use std::{num::NonZeroU16, num::NonZeroU32, rc::Rc};
2
3use ntex_bytes::{ByteString, Bytes, PoolId};
4use ntex_io::{DispatcherConfig, IoBoxed};
5use ntex_net::connect::{self, Address, Connect, Connector};
6use ntex_service::{IntoService, Pipeline, Service};
7use ntex_util::time::{timeout_checked, Seconds};
8
9use super::codec::{self, Decoded, Encoded, Packet};
10use super::{connection::Client, error::ClientError, error::ProtocolError};
11use crate::v5::shared::{MqttShared, MqttSinkPool};
12
13/// Mqtt client connector
14pub struct MqttConnector<A, T> {
15    address: A,
16    connector: Pipeline<T>,
17    pkt: codec::Connect,
18    handshake_timeout: Seconds,
19    min_chunk_size: u32,
20    config: DispatcherConfig,
21    pool: Rc<MqttSinkPool>,
22}
23
24impl<A> MqttConnector<A, ()>
25where
26    A: Address + Clone,
27{
28    #[allow(clippy::new_ret_no_self)]
29    /// Create new mqtt connector
30    pub fn new(address: A) -> MqttConnector<A, Connector<A>> {
31        let config = DispatcherConfig::default();
32        config.set_disconnect_timeout(Seconds(3)).set_keepalive_timeout(Seconds(0));
33        MqttConnector {
34            address,
35            config,
36            pkt: codec::Connect::default(),
37            connector: Pipeline::new(Connector::default()),
38            handshake_timeout: Seconds::ZERO,
39            min_chunk_size: 32 * 1024,
40            pool: Rc::new(MqttSinkPool::default()),
41        }
42    }
43}
44
45impl<A, T> MqttConnector<A, T>
46where
47    A: Address + Clone,
48{
49    #[inline]
50    /// Create new client and provide client id
51    pub fn client_id<U>(mut self, client_id: U) -> Self
52    where
53        ByteString: From<U>,
54    {
55        self.pkt.client_id = client_id.into();
56        self
57    }
58
59    #[inline]
60    /// The handling of the Session state.
61    pub fn clean_start(mut self) -> Self {
62        self.pkt.clean_start = true;
63        self
64    }
65
66    #[inline]
67    /// A time interval measured in seconds.
68    ///
69    /// keep-alive is set to 30 seconds by default.
70    pub fn keep_alive(mut self, val: Seconds) -> Self {
71        self.pkt.keep_alive = val.seconds() as u16;
72        self
73    }
74
75    #[inline]
76    /// Will Message be stored on the Server and associated with the Network Connection.
77    ///
78    /// by default last will value is not set
79    pub fn last_will(mut self, val: codec::LastWill) -> Self {
80        self.pkt.last_will = Some(val);
81        self
82    }
83
84    #[inline]
85    /// Set auth-method and auth-data for connect packet.
86    pub fn auth(mut self, method: ByteString, data: Bytes) -> Self {
87        self.pkt.auth_method = Some(method);
88        self.pkt.auth_data = Some(data);
89        self
90    }
91
92    #[inline]
93    /// Username can be used by the Server for authentication and authorization.
94    pub fn username(mut self, val: ByteString) -> Self {
95        self.pkt.username = Some(val);
96        self
97    }
98
99    #[inline]
100    /// Password can be used by the Server for authentication and authorization.
101    pub fn password(mut self, val: Bytes) -> Self {
102        self.pkt.password = Some(val);
103        self
104    }
105
106    #[inline]
107    /// Max incoming packet size.
108    ///
109    /// To disable max size limit set value to 0.
110    pub fn max_packet_size(mut self, val: u32) -> Self {
111        if let Some(val) = NonZeroU32::new(val) {
112            self.pkt.max_packet_size = Some(val);
113        } else {
114            self.pkt.max_packet_size = None;
115        }
116        self
117    }
118
119    /// Set min payload chunk size.
120    ///
121    /// If the minimum size is set to `0`, incoming payload chunks
122    /// will be processed immediately. Otherwise, the codec will
123    /// accumulate chunks until the total size reaches the specified minimum.
124    /// By default min size is set to `0`
125    pub fn min_chunk_size(mut self, size: u32) -> Self {
126        self.min_chunk_size = size;
127        self
128    }
129
130    #[inline]
131    /// Set `receive max`
132    ///
133    /// Number of in-flight incoming publish packets. By default receive max is set to 16 packets.
134    /// To disable in-flight limit set value to 0.
135    pub fn max_receive(mut self, val: u16) -> Self {
136        if let Some(val) = NonZeroU16::new(val) {
137            self.pkt.receive_max = Some(val);
138        } else {
139            self.pkt.receive_max = None;
140        }
141        self
142    }
143
144    #[inline]
145    /// Update connect user properties
146    pub fn properties<F>(mut self, f: F) -> Self
147    where
148        F: FnOnce(&mut codec::UserProperties),
149    {
150        f(&mut self.pkt.user_properties);
151        self
152    }
153
154    #[inline]
155    /// Update connect packet
156    pub fn packet<F>(mut self, f: F) -> Self
157    where
158        F: FnOnce(&mut codec::Connect),
159    {
160        f(&mut self.pkt);
161        self
162    }
163
164    /// Set handshake timeout.
165    ///
166    /// Handshake includes `connect` packet and response `connect-ack`.
167    /// By default handshake timeuot is disabled.
168    pub fn handshake_timeout(mut self, timeout: Seconds) -> Self {
169        self.handshake_timeout = timeout;
170        self
171    }
172
173    /// Set client connection disconnect timeout.
174    ///
175    /// Defines a timeout for disconnect connection. If a disconnect procedure does not complete
176    /// within this time, the connection get dropped.
177    ///
178    /// To disable timeout set value to 0.
179    ///
180    /// By default disconnect timeout is set to 3 seconds.
181    pub fn disconnect_timeout(self, timeout: Seconds) -> Self {
182        self.config.set_disconnect_timeout(timeout);
183        self
184    }
185
186    /// Set memory pool.
187    ///
188    /// Use specified memory pool for memory allocations. By default P5
189    /// memory pool is used.
190    pub fn memory_pool(self, id: PoolId) -> Self {
191        self.pool.pool.set(id.pool_ref());
192        self
193    }
194
195    /// Use custom connector
196    pub fn connector<U, F>(self, connector: F) -> MqttConnector<A, U>
197    where
198        F: IntoService<U, Connect<A>>,
199        U: Service<Connect<A>, Error = connect::ConnectError>,
200        IoBoxed: From<U::Response>,
201    {
202        MqttConnector {
203            connector: Pipeline::new(connector.into_service()),
204            pkt: self.pkt,
205            address: self.address,
206            config: self.config,
207            handshake_timeout: self.handshake_timeout,
208            min_chunk_size: self.min_chunk_size,
209            pool: self.pool,
210        }
211    }
212}
213
214impl<A, T> MqttConnector<A, T>
215where
216    A: Address + Clone,
217    T: Service<Connect<A>, Error = connect::ConnectError>,
218    IoBoxed: From<T::Response>,
219{
220    /// Connect to mqtt server
221    pub async fn connect(&self) -> Result<Client, ClientError<Box<codec::ConnectAck>>> {
222        timeout_checked(self.handshake_timeout, self._connect())
223            .await
224            .map_err(|_| ClientError::HandshakeTimeout)
225            .and_then(|res| res)
226    }
227
228    async fn _connect(&self) -> Result<Client, ClientError<Box<codec::ConnectAck>>> {
229        let io: IoBoxed = self.connector.call(Connect::new(self.address.clone())).await?.into();
230        let pkt = self.pkt.clone();
231        let keep_alive = pkt.keep_alive;
232        let max_packet_size = pkt.max_packet_size.map(|v| v.get()).unwrap_or(0);
233        let max_receive = pkt.receive_max.map(|v| v.get()).unwrap_or(65535);
234        let pool = self.pool.clone();
235        let config = self.config.clone();
236
237        let codec = codec::Codec::new();
238        codec.set_max_inbound_size(max_packet_size);
239        codec.set_min_chunk_size(self.min_chunk_size);
240
241        io.encode(Encoded::Packet(Packet::Connect(Box::new(pkt))), &codec)?;
242
243        let packet = io.recv(&codec).await.map_err(ClientError::from)?.ok_or_else(|| {
244            log::trace!("Mqtt server is disconnected during handshake");
245            ClientError::Disconnected(None)
246        })?;
247
248        let shared = Rc::new(MqttShared::new(io.get_ref(), codec, pool));
249        match packet {
250            Decoded::Packet(Packet::ConnectAck(pkt), ..) => {
251                log::trace!("Connect ack response from server: {:#?}", pkt);
252                if pkt.reason_code == codec::ConnectAckReason::Success {
253                    // set max outbound (encoder) packet size
254                    if let Some(size) = pkt.max_packet_size {
255                        shared.codec.set_max_outbound_size(size);
256                    }
257                    // server keep-alive
258                    let keep_alive = pkt.server_keepalive_sec.unwrap_or(keep_alive);
259
260                    shared.set_cap(pkt.receive_max.get() as usize);
261
262                    Ok(Client::new(io, shared, pkt, max_receive, Seconds(keep_alive), config))
263                } else {
264                    Err(ClientError::Ack(pkt))
265                }
266            }
267            Decoded::Packet(packet, ..) => Err(ProtocolError::unexpected_packet(
268                packet.packet_type(),
269                "CONNACK packet expected from server first [MQTT-3.2.0-1]",
270            )
271            .into()),
272            Decoded::Publish(..) => Err(ProtocolError::unexpected_packet(
273                crate::types::packet_type::PUBLISH_START,
274                "CONNACK packet expected from server first [MQTT-3.2.0-1]",
275            )
276            .into()),
277            Decoded::PayloadChunk(..) => unreachable!(),
278        }
279    }
280}