use super::framed::Network;
use super::mqttbytes::v5::*;
use super::{Incoming, MqttOptions, MqttState, Outgoing, Request, StateError, Transport};
use crate::eventloop::socket_connect;
use crate::framed::AsyncReadWrite;
use flume::{bounded, Receiver, Sender};
use tokio::select;
use tokio::time::{self, error::Elapsed, Instant, Sleep};
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::time::Duration;
use super::mqttbytes::v5::ConnectReturnCode;
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
use crate::tls;
#[cfg(unix)]
use {std::path::Path, tokio::net::UnixStream};
#[cfg(feature = "websocket")]
use {
crate::websockets::{split_url, validate_response_headers, UrlError},
async_tungstenite::tungstenite::client::IntoClientRequest,
ws_stream_tungstenite::WsStream,
};
#[cfg(feature = "proxy")]
use crate::proxy::ProxyError;
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Mqtt state: {0}")]
MqttState(#[from] StateError),
#[error("Timeout")]
Timeout(#[from] Elapsed),
#[cfg(feature = "websocket")]
#[error("Websocket: {0}")]
Websocket(#[from] async_tungstenite::tungstenite::error::Error),
#[cfg(feature = "websocket")]
#[error("Websocket Connect: {0}")]
WsConnect(#[from] http::Error),
#[cfg(any(feature = "use-rustls-no-provider", feature = "use-native-tls"))]
#[error("TLS: {0}")]
Tls(#[from] tls::Error),
#[error("I/O: {0}")]
Io(#[from] io::Error),
#[error("Connection refused, return code: `{0:?}`")]
ConnectionRefused(ConnectReturnCode),
#[error("Expected ConnAck packet, received: {0:?}")]
NotConnAck(Box<Packet>),
#[error("Requests done")]
RequestsDone,
#[error("Auth processing error")]
AuthProcessingError,
#[cfg(feature = "websocket")]
#[error("Invalid Url: {0}")]
InvalidUrl(#[from] UrlError),
#[cfg(feature = "proxy")]
#[error("Proxy Connect: {0}")]
Proxy(#[from] ProxyError),
#[cfg(feature = "websocket")]
#[error("Websocket response validation error: ")]
ResponseValidation(#[from] crate::websockets::ValidationError),
}
pub struct EventLoop {
pub options: MqttOptions,
pub state: MqttState,
requests_rx: Receiver<Request>,
_requests_tx: Option<Sender<Request>>,
pub pending: VecDeque<Request>,
network: Option<Network>,
keepalive_timeout: Option<Pin<Box<Sleep>>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[allow(clippy::large_enum_variant)]
pub enum Event {
Incoming(Incoming),
Outgoing(Outgoing),
}
impl EventLoop {
pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
let (requests_tx, requests_rx) = bounded(cap);
Self::with_channel(options, requests_rx, Some(requests_tx))
}
pub(crate) fn new_for_async_client(
options: MqttOptions,
cap: usize,
) -> (EventLoop, Sender<Request>) {
let (requests_tx, requests_rx) = bounded(cap);
let eventloop = Self::with_channel(options, requests_rx, None);
(eventloop, requests_tx)
}
fn with_channel(
options: MqttOptions,
requests_rx: Receiver<Request>,
requests_tx: Option<Sender<Request>>,
) -> EventLoop {
let pending = VecDeque::new();
let inflight_limit = options.outgoing_inflight_upper_limit.unwrap_or(u16::MAX);
let manual_acks = options.manual_acks;
let auth_manager = options.auth_manager();
EventLoop {
options,
state: MqttState::new(inflight_limit, manual_acks, auth_manager),
requests_rx,
_requests_tx: requests_tx,
pending,
network: None,
keepalive_timeout: None,
}
}
pub fn clean(&mut self) {
self.network = None;
self.keepalive_timeout = None;
self.pending.extend(self.state.clean());
let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();
requests_in_channel.retain(|request| {
match request {
Request::PubAck(_) => false, _ => true,
}
});
self.pending.extend(requests_in_channel);
}
pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
if self.network.is_none() {
let (network, connack) = time::timeout(
Duration::from_secs(self.options.connection_timeout()),
connect(&mut self.options, &mut self.state),
)
.await??;
if !connack.session_present {
self.pending.clear();
self.state.clear_collision();
}
self.network = Some(network);
if self.keepalive_timeout.is_none() && !self.options.keep_alive.is_zero() {
self.keepalive_timeout = Some(Box::pin(time::sleep(self.options.keep_alive)));
}
self.state
.handle_incoming_packet(Incoming::ConnAck(connack))?;
}
match self.select().await {
Ok(v) => Ok(v),
Err(e) => {
self.clean();
Err(e)
}
}
}
async fn select(&mut self) -> Result<Event, ConnectionError> {
let network = self.network.as_mut().unwrap();
let inflight_full = self.state.inflight >= self.state.max_outgoing_inflight;
let collision = self.state.collision.is_some();
if let Some(event) = self.state.events.pop_front() {
return Ok(event);
}
let mut no_sleep = Box::pin(time::sleep(Duration::ZERO));
select! {
o = Self::next_request(
&mut self.pending,
&self.requests_rx,
self.options.pending_throttle
), if !self.pending.is_empty() || (!inflight_full && !collision) => match o {
Ok(request) => {
if let Some(outgoing) = self.state.handle_outgoing_packet(request)? {
network.write(outgoing).await?;
}
network.flush().await?;
Ok(self.state.events.pop_front().unwrap())
}
Err(_) => Err(ConnectionError::RequestsDone),
},
o = network.readb(&mut self.state) => {
o?;
network.flush().await?;
Ok(self.state.events.pop_front().unwrap())
},
_ = self.keepalive_timeout.as_mut().unwrap_or(&mut no_sleep),
if self.keepalive_timeout.is_some() && !self.options.keep_alive.is_zero() => {
let timeout = self.keepalive_timeout.as_mut().unwrap();
timeout.as_mut().reset(Instant::now() + self.options.keep_alive);
if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? {
network.write(outgoing).await?;
}
network.flush().await?;
Ok(self.state.events.pop_front().unwrap())
}
}
}
async fn next_request(
pending: &mut VecDeque<Request>,
rx: &Receiver<Request>,
pending_throttle: Duration,
) -> Result<Request, ConnectionError> {
if !pending.is_empty() {
time::sleep(pending_throttle).await;
Ok(pending.pop_front().unwrap())
} else {
match rx.recv_async().await {
Ok(r) => Ok(r),
Err(_) => Err(ConnectionError::RequestsDone),
}
}
}
}
async fn connect(
options: &mut MqttOptions,
state: &mut MqttState,
) -> Result<(Network, ConnAck), ConnectionError> {
let mut network = network_connect(options).await?;
let connack = mqtt_connect(options, &mut network, state).await?;
Ok((network, connack))
}
async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
let max_incoming_pkt_size = options.max_incoming_packet_size();
#[cfg(unix)]
if matches!(options.transport(), Transport::Unix) {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
let network = Network::new(socket, max_incoming_pkt_size);
return Ok(network);
}
let (domain, port) = match options.transport() {
#[cfg(feature = "websocket")]
Transport::Ws => split_url(&options.broker_addr)?,
#[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
Transport::Wss(_) => split_url(&options.broker_addr)?,
_ => options.broker_address(),
};
let tcp_stream: Box<dyn AsyncReadWrite> = {
#[cfg(feature = "proxy")]
match options.proxy() {
Some(proxy) => {
proxy
.connect(&domain, port, options.network_options())
.await?
}
None => {
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, options.network_options()).await?;
Box::new(tcp)
}
}
#[cfg(not(feature = "proxy"))]
{
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, options.network_options()).await?;
Box::new(tcp)
}
};
let network = match options.transport() {
Transport::Tcp => Network::new(tcp_stream, max_incoming_pkt_size),
#[cfg(any(feature = "use-native-tls", feature = "use-rustls-no-provider"))]
Transport::Tls(tls_config) => {
let socket =
tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
.await?;
Network::new(socket, max_incoming_pkt_size)
}
#[cfg(unix)]
Transport::Unix => unreachable!(),
#[cfg(feature = "websocket")]
Transport::Ws => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
.headers_mut()
.insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
if let Some(request_modifier) = options.request_modifier() {
request = request_modifier(request).await;
}
let (socket, response) =
async_tungstenite::tokio::client_async(request, tcp_stream).await?;
validate_response_headers(response)?;
Network::new(WsStream::new(socket), max_incoming_pkt_size)
}
#[cfg(all(feature = "use-rustls-no-provider", feature = "websocket"))]
Transport::Wss(tls_config) => {
let mut request = options.broker_addr.as_str().into_client_request()?;
request
.headers_mut()
.insert("Sec-WebSocket-Protocol", "mqtt".parse().unwrap());
if let Some(request_modifier) = options.request_modifier() {
request = request_modifier(request).await;
}
let connector = tls::rustls_connector(&tls_config)?;
let (socket, response) = async_tungstenite::tokio::client_async_tls_with_connector(
request,
tcp_stream,
Some(connector),
)
.await?;
validate_response_headers(response)?;
Network::new(WsStream::new(socket), max_incoming_pkt_size)
}
};
Ok(network)
}
async fn mqtt_connect(
options: &mut MqttOptions,
network: &mut Network,
state: &mut MqttState,
) -> Result<ConnAck, ConnectionError> {
let packet = Packet::Connect(
Connect {
client_id: options.client_id(),
keep_alive: u16::try_from(options.keep_alive().as_secs()).unwrap_or(u16::MAX),
clean_start: options.clean_start(),
properties: options.connect_properties(),
},
options.last_will(),
options.credentials(),
);
network.write(packet).await?;
network.flush().await?;
loop {
match network.read().await? {
Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => {
if let Some(props) = &connack.properties {
if let Some(keep_alive) = props.server_keep_alive {
options.keep_alive = Duration::from_secs(u64::from(keep_alive));
}
network.set_max_outgoing_size(props.max_packet_size);
if props.session_expiry_interval.is_some() {
options.set_session_expiry_interval(props.session_expiry_interval);
}
}
return Ok(connack);
}
Incoming::ConnAck(connack) => {
return Err(ConnectionError::ConnectionRefused(connack.code))
}
Incoming::Auth(auth) => {
if let Some(outgoing) = state.handle_incoming_packet(Incoming::Auth(auth))? {
network.write(outgoing).await?;
network.flush().await?;
} else {
return Err(ConnectionError::AuthProcessingError);
}
}
packet => return Err(ConnectionError::NotConnAck(Box::new(packet))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use flume::TryRecvError;
#[test]
fn eventloop_new_keeps_internal_sender_alive() {
let options = MqttOptions::new("test-client", "localhost", 1883);
let eventloop = EventLoop::new(options, 1);
assert!(matches!(eventloop.requests_rx.try_recv(), Err(TryRecvError::Empty)));
}
#[test]
fn async_client_constructor_path_allows_channel_shutdown() {
let options = MqttOptions::new("test-client", "localhost", 1883);
let (eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
drop(request_tx);
assert!(matches!(
eventloop.requests_rx.try_recv(),
Err(TryRecvError::Disconnected)
));
}
#[tokio::test]
async fn async_client_path_reports_requests_done_after_pending_drain() {
let options = MqttOptions::new("test-client", "localhost", 1883);
let (mut eventloop, request_tx) = EventLoop::new_for_async_client(options, 1);
eventloop.pending.push_back(Request::PingReq);
drop(request_tx);
let request = EventLoop::next_request(
&mut eventloop.pending,
&eventloop.requests_rx,
Duration::ZERO,
)
.await
.unwrap();
assert!(matches!(request, Request::PingReq));
let err = EventLoop::next_request(
&mut eventloop.pending,
&eventloop.requests_rx,
Duration::ZERO,
)
.await
.unwrap_err();
assert!(matches!(err, ConnectionError::RequestsDone));
}
}