1use crate::{framed::Network, Transport};
2use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
3use crate::{MqttOptions, Outgoing};
4
5use crate::framed::AsyncReadWrite;
6use crate::mqttbytes::v4::*;
7use flume::{bounded, Receiver, Sender};
8use tokio::net::{lookup_host, TcpSocket, TcpStream};
9use tokio::select;
10use tokio::time::{self, Instant, Sleep};
11
12use std::collections::VecDeque;
13use std::io;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::time::Duration;
17
18#[cfg(unix)]
19use {std::path::Path, tokio::net::UnixStream};
20
21#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
22use crate::tls;
23
24#[cfg(feature = "websocket")]
25use {
26 crate::websockets::{split_url, validate_response_headers, UrlError},
27 async_tungstenite::tungstenite::client::IntoClientRequest,
28 ws_stream_tungstenite::WsStream,
29};
30
31#[cfg(feature = "proxy")]
32use crate::proxy::ProxyError;
33
34#[derive(Debug, thiserror::Error)]
36pub enum ConnectionError {
37 #[error("Mqtt state: {0}")]
38 MqttState(#[from] StateError),
39 #[error("Network timeout")]
40 NetworkTimeout,
41 #[error("Flush timeout")]
42 FlushTimeout,
43 #[cfg(feature = "websocket")]
44 #[error("Websocket: {0}")]
45 Websocket(#[from] async_tungstenite::tungstenite::error::Error),
46 #[cfg(feature = "websocket")]
47 #[error("Websocket Connect: {0}")]
48 WsConnect(#[from] http::Error),
49 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
50 #[error("TLS: {0}")]
51 Tls(#[from] tls::Error),
52 #[error("I/O: {0}")]
53 Io(#[from] io::Error),
54 #[error("Connection refused, return code: `{0:?}`")]
55 ConnectionRefused(ConnectReturnCode),
56 #[error("Expected ConnAck packet, received: {0:?}")]
57 NotConnAck(Packet),
58 #[error("Requests done")]
59 RequestsDone,
60 #[cfg(feature = "websocket")]
61 #[error("Invalid Url: {0}")]
62 InvalidUrl(#[from] UrlError),
63 #[cfg(feature = "proxy")]
64 #[error("Proxy Connect: {0}")]
65 Proxy(#[from] ProxyError),
66 #[cfg(feature = "websocket")]
67 #[error("Websocket response validation error: ")]
68 ResponseValidation(#[from] crate::websockets::ValidationError),
69}
70
71pub struct EventLoop {
73 pub mqtt_options: MqttOptions,
75 pub state: MqttState,
77 requests_rx: Receiver<Request>,
79 pub(crate) requests_tx: Sender<Request>,
81 pub pending: VecDeque<Request>,
83 pub network: Option<Network>,
85 keepalive_timeout: Option<Pin<Box<Sleep>>>,
87 pub network_options: NetworkOptions,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub enum Event {
93 Incoming(Incoming),
94 Outgoing(Outgoing),
95}
96
97impl EventLoop {
98 pub fn new(mqtt_options: MqttOptions, cap: usize) -> EventLoop {
103 let (requests_tx, requests_rx) = bounded(cap);
104 let pending = VecDeque::new();
105 let max_inflight = mqtt_options.inflight;
106 let manual_acks = mqtt_options.manual_acks;
107
108 EventLoop {
109 mqtt_options,
110 state: MqttState::new(max_inflight, manual_acks),
111 requests_tx,
112 requests_rx,
113 pending,
114 network: None,
115 keepalive_timeout: None,
116 network_options: NetworkOptions::new(),
117 }
118 }
119
120 pub fn clean(&mut self) {
128 self.network = None;
129 self.keepalive_timeout = None;
130 self.pending.extend(self.state.clean());
131
132 let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();
134
135 requests_in_channel.retain(|request| {
136 match request {
137 Request::PubAck(..) => false, _ => true,
139 }
140 });
141
142 self.pending.extend(requests_in_channel);
143 }
144
145 pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
150 if self.network.is_none() {
151 let (network, connack) = match time::timeout(
152 Duration::from_secs(self.network_options.connection_timeout()),
153 connect(&self.mqtt_options, self.network_options.clone()),
154 )
155 .await
156 {
157 Ok(inner) => inner?,
158 Err(_) => return Err(ConnectionError::NetworkTimeout),
159 };
160 if !connack.session_present {
162 self.pending.clear();
163 }
164 self.network = Some(network);
165
166 if self.keepalive_timeout.is_none() && !self.mqtt_options.keep_alive.is_zero() {
167 self.keepalive_timeout = Some(Box::pin(time::sleep(self.mqtt_options.keep_alive)));
168 }
169
170 return Ok(Event::Incoming(Packet::ConnAck(connack)));
171 }
172
173 match self.select().await {
174 Ok(v) => Ok(v),
175 Err(e) => {
176 self.clean();
179 Err(e)
180 }
181 }
182 }
183
184 async fn select(&mut self) -> Result<Event, ConnectionError> {
186 let network = self.network.as_mut().unwrap();
187 let inflight_full = self.state.inflight >= self.mqtt_options.inflight;
189 let collision = self.state.collision.is_some();
190 let network_timeout = Duration::from_secs(self.network_options.connection_timeout());
191
192 if let Some(event) = self.state.events.pop_front() {
194 return Ok(event);
195 }
196
197 let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
198 select! {
201 o = network.readb(&mut self.state) => {
203 o?;
204 match time::timeout(network_timeout, network.flush()).await {
206 Ok(inner) => inner?,
207 Err(_)=> return Err(ConnectionError::FlushTimeout),
208 };
209 Ok(self.state.events.pop_front().unwrap())
210 },
211 o = Self::next_request(
240 &mut self.pending,
241 &self.requests_rx,
242 self.mqtt_options.pending_throttle
243 ), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
244 Ok(request) => {
245 if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
246 network.write(outgoing).await?;
247 }
248 match time::timeout(network_timeout, network.flush()).await {
249 Ok(inner) => inner?,
250 Err(_)=> return Err(ConnectionError::FlushTimeout),
251 };
252 Ok(self.state.events.pop_front().unwrap())
253 }
254 Err(_) => Err(ConnectionError::RequestsDone),
255 },
256 _ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
259 if self.keepalive_timeout.is_some() && !self.mqtt_options.keep_alive.is_zero() => {
260 let timeout = self.keepalive_timeout.as_mut().unwrap();
261 timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive);
262
263 if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? {
264 network.write(outgoing).await?;
265 }
266 match time::timeout(network_timeout, network.flush()).await {
267 Ok(inner) => inner?,
268 Err(_)=> return Err(ConnectionError::FlushTimeout),
269 };
270 Ok(self.state.events.pop_front().unwrap())
271 }
272 }
273 }
274
275 pub fn network_options(&self) -> NetworkOptions {
276 self.network_options.clone()
277 }
278
279 pub fn set_network_options(&mut self, network_options: NetworkOptions) -> &mut Self {
280 self.network_options = network_options;
281 self
282 }
283
284 async fn next_request(
285 pending: &mut VecDeque<Request>,
286 rx: &Receiver<Request>,
287 pending_throttle: Duration,
288 ) -> Result<Request, ConnectionError> {
289 if !pending.is_empty() {
290 time::sleep(pending_throttle).await;
291 Ok(pending.pop_front().unwrap())
294 } else {
295 match rx.recv_async().await {
296 Ok(r) => Ok(r),
297 Err(_) => Err(ConnectionError::RequestsDone),
298 }
299 }
300 }
301}
302
303async fn connect(
309 mqtt_options: &MqttOptions,
310 network_options: NetworkOptions,
311) -> Result<(Network, ConnAck), ConnectionError> {
312 let mut network = network_connect(mqtt_options, network_options).await?;
314
315 let connack = mqtt_connect(mqtt_options, &mut network).await?;
317
318 Ok((network, connack))
319}
320
321pub(crate) async fn socket_connect(
322 host: String,
323 network_options: NetworkOptions,
324) -> io::Result<TcpStream> {
325 let addrs = lookup_host(host).await?;
326 let mut last_err = None;
327
328 for addr in addrs {
329 let socket = match addr {
330 SocketAddr::V4(_) => TcpSocket::new_v4()?,
331 SocketAddr::V6(_) => TcpSocket::new_v6()?,
332 };
333
334 socket.set_nodelay(network_options.tcp_nodelay)?;
335
336 if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
337 socket.set_send_buffer_size(send_buff_size).unwrap();
338 }
339 if let Some(recv_buffer_size) = network_options.tcp_recv_buffer_size {
340 socket.set_recv_buffer_size(recv_buffer_size).unwrap();
341 }
342
343 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
344 {
345 if let Some(bind_device) = &network_options.bind_device {
346 socket.bind_device(Some(bind_device.as_bytes()))?;
350 }
351 }
352
353 match socket.connect(addr).await {
354 Ok(s) => return Ok(s),
355 Err(e) => {
356 last_err = Some(e);
357 }
358 };
359 }
360
361 Err(last_err.unwrap_or_else(|| {
362 io::Error::new(
363 io::ErrorKind::InvalidInput,
364 "could not resolve to any address",
365 )
366 }))
367}
368
369async fn network_connect(
370 options: &MqttOptions,
371 network_options: NetworkOptions,
372) -> Result<Network, ConnectionError> {
373 #[cfg(unix)]
375 if matches!(options.transport(), Transport::Unix) {
376 let file = options.broker_addr.as_str();
377 let socket = UnixStream::connect(Path::new(file)).await?;
378 let network = Network::new(
379 socket,
380 options.max_incoming_packet_size,
381 options.max_outgoing_packet_size,
382 );
383 return Ok(network);
384 }
385
386 let (domain, port) = match options.transport() {
388 #[cfg(feature = "websocket")]
389 Transport::Ws => split_url(&options.broker_addr)?,
390 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
391 Transport::Wss(_) => split_url(&options.broker_addr)?,
392 _ => options.broker_address(),
393 };
394
395 let tcp_stream: Box<dyn AsyncReadWrite> = {
396 #[cfg(feature = "proxy")]
397 match options.proxy() {
398 Some(proxy) => proxy.connect(&domain, port, network_options).await?,
399 None => {
400 let addr = format!("{domain}:{port}");
401 let tcp = socket_connect(addr, network_options).await?;
402 Box::new(tcp)
403 }
404 }
405 #[cfg(not(feature = "proxy"))]
406 {
407 let addr = format!("{domain}:{port}");
408 let tcp = socket_connect(addr, network_options).await?;
409 Box::new(tcp)
410 }
411 };
412
413 let network = match options.transport() {
414 Transport::Tcp => Network::new(
415 tcp_stream,
416 options.max_incoming_packet_size,
417 options.max_outgoing_packet_size,
418 ),
419 #[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
420 Transport::Tls(tls_config) => {
421 let socket =
422 tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
423 .await?;
424 Network::new(
425 socket,
426 options.max_incoming_packet_size,
427 options.max_outgoing_packet_size,
428 )
429 }
430 #[cfg(unix)]
431 Transport::Unix => unreachable!(),
432 #[cfg(feature = "websocket")]
433 Transport::Ws => {
434 let mut request = options.broker_addr.as_str().into_client_request()?;
435 request
436 .headers_mut()
437 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
438
439 if let Some(request_modifier) = options.request_modifier() {
440 request = request_modifier(request).await;
441 }
442
443 let (socket, response) =
444 async_tungstenite::tokio::client_async(request, tcp_stream).await?;
445 validate_response_headers(response)?;
446
447 Network::new(
448 WsStream::new(socket),
449 options.max_incoming_packet_size,
450 options.max_outgoing_packet_size,
451 )
452 }
453 #[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
454 Transport::Wss(tls_config) => {
455 let mut request = options.broker_addr.as_str().into_client_request()?;
456 request
457 .headers_mut()
458 .insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
459
460 if let Some(request_modifier) = options.request_modifier() {
461 request = request_modifier(request).await;
462 }
463
464 let connector = tls::rustls_connector(&tls_config).await?;
465
466 let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
467 request,
468 tcp_stream,
469 Some(connector),
470 )
471 .await?;
472 validate_response_headers(response)?;
473
474 Network::new(
475 WsStream::new(socket),
476 options.max_incoming_packet_size,
477 options.max_outgoing_packet_size,
478 )
479 }
480 };
481
482 Ok(network)
483}
484
485async fn mqtt_connect(
486 options: &MqttOptions,
487 network: &mut Network,
488) -> Result<ConnAck, ConnectionError> {
489 let mut connect = Connect::new(options.client_id());
490 connect.keep_alive = options.keep_alive().as_secs() as u16;
491 connect.clean_session = options.clean_session();
492 connect.last_will = options.last_will();
493 connect.login = options.credentials();
494
495 network.write(Packet::Connect(connect)).await?;
497 network.flush().await?;
498
499 match network.read().await? {
501 Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => Ok(connack),
502 Incoming::ConnAck(connack) => Err(ConnectionError::ConnectionRefused(connack.code)),
503 packet => Err(ConnectionError::NotConnAck(packet)),
504 }
505}