use crate::{framed::Network, Transport};
use crate::{tls, Incoming, MqttState, Packet, Request, StateError};
use crate::{MqttOptions, Outgoing};
use async_channel::{bounded, Receiver, Sender};
#[cfg(feature = "websocket")]
use async_tungstenite::async_std::{connect_async, connect_async_with_tls_connector};
use mqttbytes::v4::*;
use async_std::net::TcpStream;
use async_std::future::TimeoutError;
use futures::select;
use futures::future::FutureExt;
use async_io::Timer;
#[cfg(feature = "websocket")]
use ws_stream_tungstenite::WsStream;
use std::io;
use std::pin::Pin;
use std::time::Duration;
use std::vec::IntoIter;
use crate::cond_fut::cond_fut;
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
#[error("Mqtt state: {0}")]
MqttState(#[from] StateError),
#[error("Timeout")]
Timeout(#[from] TimeoutError),
#[error("Packet parsing error: {0}")]
Mqtt4Bytes(mqttbytes::Error),
#[error("Network: {0}")]
Network(#[from] tls::Error),
#[error("I/O: {0}")]
Io(#[from] io::Error),
#[error("Stream done")]
StreamDone,
#[error("Requests done")]
RequestsDone,
#[error("Cancel request by the user")]
Cancel,
}
pub struct EventLoop {
pub options: MqttOptions,
pub state: MqttState,
pub requests_rx: Receiver<Request>,
pub requests_tx: Sender<Request>,
pub pending: IntoIter<Request>,
pub(crate) network: Option<Network>,
pub(crate) keepalive_timeout: Option<Pin<Box<Timer>>>,
pub(crate) cancel_rx: Receiver<()>,
pub(crate) cancel_tx: Sender<()>,
}
#[derive(Debug, PartialEq, Clone)]
pub enum Event {
Incoming(Incoming),
Outgoing(Outgoing),
}
impl EventLoop {
pub fn new(options: MqttOptions, cap: usize) -> EventLoop {
let (cancel_tx, cancel_rx) = bounded(5);
let (requests_tx, requests_rx) = bounded(cap);
let pending = Vec::new();
let pending = pending.into_iter();
let max_inflight = options.inflight;
EventLoop {
options,
state: MqttState::new(max_inflight),
requests_tx,
requests_rx,
pending,
network: None,
keepalive_timeout: None,
cancel_rx,
cancel_tx,
}
}
pub fn handle(&self) -> Sender<Request> {
self.requests_tx.clone()
}
pub(crate) fn cancel_handle(&mut self) -> Sender<()> {
self.cancel_tx.clone()
}
fn clean(&mut self) {
self.network = None;
self.keepalive_timeout = None;
let pending = self.state.clean();
self.pending = pending.into_iter();
}
#[must_use = "Eventloop should be iterated over a loop to make progress"]
pub async fn poll(&mut self) -> Result<Event, ConnectionError> {
if self.network.is_none() {
let (network, connack) = connect_or_cancel(&self.options, &self.cancel_rx).await?;
self.network = Some(network);
if self.keepalive_timeout.is_none() {
self.keepalive_timeout = Some(Box::pin(Timer::after(self.options.keep_alive)));
}
return Ok(Event::Incoming(connack));
}
match self.select().await {
Ok(v) => Ok(v),
Err(e) => {
if let ConnectionError::MqttState(StateError::Collision(pkid)) = e {
if self.options.collision_safety() {
return Err(ConnectionError::MqttState(StateError::Collision(pkid)));
}
}
self.clean();
Err(e)
}
}
}
async fn select(&mut self) -> Result<Event, ConnectionError> {
let network = self.network.as_mut().unwrap();
let inflight_full = self.state.inflight >= self.options.inflight;
let throttle = self.options.pending_throttle;
let pending = self.pending.len() > 0;
let collision = self.state.collision.is_some();
if let Some(event) = self.state.events.pop_front() {
return Ok(event);
}
select! {
o = network.readb(&mut self.state).fuse() => {
o?;
network.flush(&mut self.state.write).await?;
Ok(self.state.events.pop_front().unwrap())
},
o = cond_fut(self.requests_rx.recv().fuse(), !inflight_full && !pending && !collision) => match o {
Ok(request) => {
self.state.handle_outgoing_packet(request)?;
network.flush(&mut self.state.write).await?;
Ok(self.state.events.pop_front().unwrap())
}
Err(_) => Err(ConnectionError::RequestsDone),
},
request = cond_fut(next_pending(throttle, &mut self.pending).fuse(), pending) => {
if let Some(request) = request {
self.state.handle_outgoing_packet(request)?;
network.flush(&mut self.state.write).await?;
Ok(self.state.events.pop_front().unwrap())
} else {
unreachable!()
}
},
_ = self.keepalive_timeout.as_mut().unwrap().fuse() => {
let timeout = self.keepalive_timeout.as_mut().unwrap();
timeout.as_mut().set_after(self.options.keep_alive);
self.state.handle_outgoing_packet(Request::PingReq)?;
network.flush(&mut self.state.write).await?;
Ok(self.state.events.pop_front().unwrap())
}
_ = self.cancel_rx.recv().fuse() => {
Err(ConnectionError::Cancel)
}
}
}
}
async fn connect_or_cancel(options: &MqttOptions, cancel_rx: &Receiver<()>) -> Result<(Network, Incoming), ConnectionError> {
select! {
o = connect(options).fuse() => o,
_ = cancel_rx.recv().fuse() => {
Err(ConnectionError::Cancel)
}
}
}
async fn connect(options: &MqttOptions) -> Result<(Network, Incoming), ConnectionError> {
let mut network = match network_connect(options).await {
Ok(network) => network,
Err(e) => {
return Err(e);
}
};
let packet = match mqtt_connect(options, &mut network).await {
Ok(p) => p,
Err(e) => return Err(e),
};
Ok((network, packet))
}
async fn network_connect(options: &MqttOptions) -> Result<Network, ConnectionError> {
let network = match options.transport() {
Transport::Tcp => {
let addr = options.broker_addr.as_str();
let port = options.port;
let socket = TcpStream::connect((addr, port)).await?;
Network::new(socket, options.max_incoming_packet_size)
}
Transport::Tls(tls_config) => {
let socket = tls::tls_connect(&options, &tls_config).await?;
Network::new(socket, options.max_incoming_packet_size)
}
#[cfg(feature = "websocket")]
Transport::Ws => {
let request = http::Request::builder()
.method(http::Method::GET)
.uri(options.broker_addr.as_str())
.header("Sec-WebSocket-Protocol", "mqttv3.1")
.body(())
.unwrap();
let (socket, _) = connect_async(request).await.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
Network::new(WsStream::new(socket), options.max_incoming_packet_size)
}
#[cfg(feature = "websocket")]
Transport::Wss(tls_config) => {
let request = http::Request::builder()
.method(http::Method::GET)
.uri(options.broker_addr.as_str())
.header("Sec-WebSocket-Protocol", "mqttv3.1")
.body(())
.unwrap();
let connector = tls::tls_connector(&tls_config).await?;
let (socket, _) = connect_async_with_tls_connector(request, Some(connector))
.await
.map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
Network::new(WsStream::new(socket), options.max_incoming_packet_size)
}
};
Ok(network)
}
async fn mqtt_connect(options: &MqttOptions, network: &mut Network) -> Result<Incoming, ConnectionError> {
let keep_alive = options.keep_alive().as_secs() as u16;
let clean_session = options.clean_session();
let last_will = options.last_will();
let mut connect = Connect::new(options.client_id());
connect.keep_alive = keep_alive;
connect.clean_session = clean_session;
connect.last_will = last_will;
if let Some((username, password)) = options.credentials() {
let login = Login::new(username, password);
connect.login = Some(login);
}
async_std::future::timeout(Duration::from_secs(options.connection_timeout()), async {
network.connect(connect).await?;
Ok::<_, ConnectionError>(())
})
.await??;
let packet = async_std::future::timeout(Duration::from_secs(options.connection_timeout()), async {
let packet = match network.read().await? {
Incoming::ConnAck(connack) if connack.code == ConnectReturnCode::Success => Packet::ConnAck(connack),
Incoming::ConnAck(connack) => {
let error = format!("Broker rejected. Reason = {:?}", connack.code);
return Err(io::Error::new(io::ErrorKind::InvalidData, error));
}
packet => {
let error = format!("Expecting connack. Received = {:?}", packet);
return Err(io::Error::new(io::ErrorKind::InvalidData, error));
}
};
Ok::<_, io::Error>(packet)
})
.await??;
Ok(packet)
}
pub(crate) async fn next_pending(delay: Duration, pending: &mut IntoIter<Request>) -> Option<Request> {
Timer::after(delay).await;
pending.next()
}