use std::io::{Read, Write};
use std::net::TcpStream;
use std::string::{String, ToString};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use std::vec::Vec;
use crate::codec::{PublishPacket, decode_publish, encode_publish};
use crate::control_packets::{
ConnectBody, DisconnectBody, SubscribeBody, Subscription as MqttSubscription, connect_flags,
encode_connect_body, encode_disconnect_body, encode_subscribe_body,
};
use crate::packet::{ControlPacketType, FixedHeader};
use crate::vbi::{decode_vbi, encode_vbi};
use super::config::DaemonConfig;
#[cfg(feature = "daemon")]
use rustls::{ClientConfig, ClientConnection, StreamOwned};
#[derive(Debug)]
pub enum ClientError {
Io(String),
Codec(String),
ConnAck {
reason: u8,
},
}
impl core::fmt::Display for ClientError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Io(m) => write!(f, "io: {m}"),
Self::Codec(m) => write!(f, "codec: {m}"),
Self::ConnAck { reason } => write!(f, "connack reject: 0x{reason:02x}"),
}
}
}
impl std::error::Error for ClientError {}
#[derive(Debug, Clone)]
pub enum InboundEvent {
Publish {
topic: String,
payload: Vec<u8>,
qos: u8,
},
Disconnected(String),
}
#[cfg(feature = "daemon")]
pub(crate) enum MqttStream {
Plain(TcpStream),
Tls(Box<StreamOwned<ClientConnection, TcpStream>>),
}
#[cfg(feature = "daemon")]
impl MqttStream {
fn set_read_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.set_read_timeout(dur),
Self::Tls(s) => s.sock.set_read_timeout(dur),
}
}
fn set_write_timeout(&self, dur: Option<Duration>) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.set_write_timeout(dur),
Self::Tls(s) => s.sock.set_write_timeout(dur),
}
}
fn shutdown_both(&mut self) {
match self {
Self::Plain(s) => {
let _ = s.shutdown(std::net::Shutdown::Both);
}
Self::Tls(s) => {
let _ = s.sock.shutdown(std::net::Shutdown::Both);
}
}
}
}
#[cfg(feature = "daemon")]
impl Read for MqttStream {
fn read(&mut self, b: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.read(b),
Self::Tls(s) => s.read(b),
}
}
}
#[cfg(feature = "daemon")]
impl Write for MqttStream {
fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.write(b),
Self::Tls(s) => s.write(b),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.flush(),
Self::Tls(s) => s.flush(),
}
}
}
pub struct MqttClient {
#[cfg(feature = "daemon")]
stream: MqttStream,
#[cfg(not(feature = "daemon"))]
stream: TcpStream,
next_packet_id: u16,
}
impl MqttClient {
#[cfg(feature = "daemon")]
pub fn connect_secure(
host: &str,
port: u16,
cfg: &DaemonConfig,
tls_client_cfg: Option<Arc<ClientConfig>>,
) -> Result<Self, ClientError> {
let addr = format!("{host}:{port}");
let tcp = match addr.parse::<std::net::SocketAddr>() {
Ok(sa) => TcpStream::connect_timeout(&sa, Duration::from_secs(10)),
Err(_) => TcpStream::connect(&addr),
}
.map_err(|e| ClientError::Io(format!("connect: {e}")))?;
tcp.set_read_timeout(Some(Duration::from_secs(5)))
.map_err(|e| ClientError::Io(format!("set timeout: {e}")))?;
tcp.set_write_timeout(Some(Duration::from_secs(5)))
.map_err(|e| ClientError::Io(format!("set timeout: {e}")))?;
let stream = match tls_client_cfg {
Some(client_cfg) => {
let server_name_str = if cfg.broker_tls_server_name.is_empty() {
host.to_string()
} else {
cfg.broker_tls_server_name.clone()
};
let server_name = rustls::pki_types::ServerName::try_from(server_name_str.clone())
.map_err(|e| {
ClientError::Io(format!("server name '{server_name_str}': {e}"))
})?;
let conn = ClientConnection::new(client_cfg, server_name)
.map_err(|e| ClientError::Io(format!("rustls client conn: {e}")))?;
MqttStream::Tls(Box::new(StreamOwned::new(conn, tcp)))
}
None => MqttStream::Plain(tcp),
};
let _ = stream.set_read_timeout(Some(Duration::from_secs(5)));
let _ = stream.set_write_timeout(Some(Duration::from_secs(5)));
let mut me = Self {
stream,
next_packet_id: 1,
};
me.send_connect(cfg)?;
me.wait_connack()?;
Ok(me)
}
pub fn connect(host: &str, port: u16, cfg: &DaemonConfig) -> Result<Self, ClientError> {
#[cfg(feature = "daemon")]
{
Self::connect_secure(host, port, cfg, None)
}
#[cfg(not(feature = "daemon"))]
{
let addr = format!("{host}:{port}");
let stream = TcpStream::connect_timeout(
&addr
.parse()
.map_err(|e| ClientError::Io(format!("addr: {e}")))?,
Duration::from_secs(10),
)
.map_err(|e| ClientError::Io(format!("connect: {e}")))?;
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.map_err(|e| ClientError::Io(format!("set timeout: {e}")))?;
let mut me = Self {
stream,
next_packet_id: 1,
};
me.send_connect(cfg)?;
me.wait_connack()?;
Ok(me)
}
}
fn send_connect(&mut self, cfg: &DaemonConfig) -> Result<(), ClientError> {
#[cfg(feature = "daemon")]
let (user, pass) = {
let (u, p) = super::security::outbound_credentials(cfg);
let u = u.or_else(|| cfg.username.clone());
let p = p.or_else(|| cfg.password.as_ref().map(|s| s.as_bytes().to_vec()));
(u, p)
};
#[cfg(not(feature = "daemon"))]
let (user, pass) = (
cfg.username.clone(),
cfg.password.as_ref().map(|s| s.as_bytes().to_vec()),
);
let mut flags: u8 = 0;
if cfg.clean_start {
flags |= connect_flags::CLEAN_START;
}
if user.is_some() {
flags |= connect_flags::USER_NAME;
}
if pass.is_some() {
flags |= connect_flags::PASSWORD;
}
let body = ConnectBody {
protocol_name: "MQTT".to_string(),
protocol_version: 5,
connect_flags: flags,
keep_alive: cfg.keep_alive_secs,
properties: Vec::new(),
client_id: cfg.client_id.clone(),
will_properties: Vec::new(),
will_topic: None,
will_payload: Vec::new(),
user_name: user,
password: pass.unwrap_or_default(),
};
let body_bytes =
encode_connect_body(&body).map_err(|e| ClientError::Codec(format!("{e:?}")))?;
let frame = wrap_packet(ControlPacketType::Connect, 0, &body_bytes)
.map_err(|e| ClientError::Codec(format!("{e:?}")))?;
self.stream
.write_all(&frame)
.map_err(|e| ClientError::Io(format!("write connect: {e}")))?;
Ok(())
}
fn wait_connack(&mut self) -> Result<(), ClientError> {
let (header, body) = self.read_packet()?;
if header.packet_type != ControlPacketType::ConnAck {
return Err(ClientError::Codec(format!(
"expected CONNACK got {:?}",
header.packet_type
)));
}
if body.len() < 2 {
return Err(ClientError::Codec("connack too short".to_string()));
}
let reason = body[1];
if reason >= 0x80 {
return Err(ClientError::ConnAck { reason });
}
Ok(())
}
pub fn subscribe(&mut self, filters: &[(String, u8)]) -> Result<(), ClientError> {
if filters.is_empty() {
return Ok(());
}
let pid = self.next_pid();
let body = SubscribeBody {
packet_id: pid,
properties: Vec::new(),
subscriptions: filters
.iter()
.map(|(filter, qos)| MqttSubscription {
topic_filter: filter.clone(),
options: *qos & 0x03,
})
.collect(),
};
let body_bytes =
encode_subscribe_body(&body).map_err(|e| ClientError::Codec(format!("{e:?}")))?;
let frame = wrap_packet(ControlPacketType::Subscribe, 0b0010, &body_bytes)
.map_err(|e| ClientError::Codec(format!("{e:?}")))?;
self.stream
.write_all(&frame)
.map_err(|e| ClientError::Io(format!("write sub: {e}")))?;
Ok(())
}
pub fn publish(
&mut self,
topic: &str,
payload: &[u8],
qos: u8,
retain: bool,
) -> Result<(), ClientError> {
let pid = if qos > 0 { Some(self.next_pid()) } else { None };
let pkt = PublishPacket {
dup: false,
qos,
retain,
topic: topic.to_string(),
packet_id: pid,
properties: Vec::new(),
payload: payload.to_vec(),
};
let bytes = encode_publish(&pkt).map_err(|e| ClientError::Codec(format!("{e:?}")))?;
self.stream
.write_all(&bytes)
.map_err(|e| ClientError::Io(format!("write pub: {e}")))?;
Ok(())
}
pub fn next_event(&mut self) -> Result<Option<InboundEvent>, ClientError> {
let (header, body) = match self.read_packet_nonblocking() {
Ok(p) => p,
Err(ClientError::Io(m)) if m.contains("WouldBlock") || m.contains("timed out") => {
return Ok(None);
}
Err(ClientError::Io(m)) => {
return Ok(Some(InboundEvent::Disconnected(m)));
}
Err(other) => return Err(other),
};
match header.packet_type {
ControlPacketType::Publish => {
let mut full = Vec::with_capacity(2 + body.len());
let byte0 = (ControlPacketType::Publish.to_bits() << 4) | (header.flags & 0x0F);
full.push(byte0);
let len_u32 =
u32::try_from(body.len()).map_err(|_| ClientError::Codec("len".to_string()))?;
full.extend_from_slice(
&encode_vbi(len_u32).ok_or_else(|| ClientError::Codec("vbi".to_string()))?,
);
full.extend_from_slice(&body);
let (_, pkt) =
decode_publish(&full).map_err(|e| ClientError::Codec(format!("{e:?}")))?;
Ok(Some(InboundEvent::Publish {
topic: pkt.topic,
payload: pkt.payload,
qos: pkt.qos,
}))
}
ControlPacketType::SubAck
| ControlPacketType::PubAck
| ControlPacketType::PubRec
| ControlPacketType::PubRel
| ControlPacketType::PubComp
| ControlPacketType::PingResp => {
Ok(None)
}
ControlPacketType::Disconnect => Ok(Some(InboundEvent::Disconnected(
"broker disconnect".to_string(),
))),
_ => Ok(None),
}
}
pub fn graceful_disconnect(mut self) {
let body = DisconnectBody {
reason_code: 0,
properties: Vec::new(),
};
if let Ok(body_bytes) = encode_disconnect_body(&body) {
if let Ok(frame) = wrap_packet(ControlPacketType::Disconnect, 0, &body_bytes) {
let _ = self.stream.write_all(&frame);
}
}
#[cfg(feature = "daemon")]
{
self.stream.shutdown_both();
}
#[cfg(not(feature = "daemon"))]
{
let _ = self.stream.shutdown(std::net::Shutdown::Both);
}
}
fn next_pid(&mut self) -> u16 {
let p = self.next_packet_id;
self.next_packet_id = self.next_packet_id.wrapping_add(1);
if self.next_packet_id == 0 {
self.next_packet_id = 1;
}
p
}
fn read_packet(&mut self) -> Result<(FixedHeader, Vec<u8>), ClientError> {
self.read_packet_inner()
}
fn read_packet_nonblocking(&mut self) -> Result<(FixedHeader, Vec<u8>), ClientError> {
self.read_packet_inner()
}
fn read_packet_inner(&mut self) -> Result<(FixedHeader, Vec<u8>), ClientError> {
let mut hdr = [0u8; 1];
match self.stream.read_exact(&mut hdr) {
Ok(()) => {}
Err(e) => return Err(ClientError::Io(format!("read header: {e:?}"))),
}
let mut vbi_buf: Vec<u8> = Vec::new();
loop {
let mut byte = [0u8; 1];
self.stream
.read_exact(&mut byte)
.map_err(|e| ClientError::Io(format!("read vbi: {e:?}")))?;
vbi_buf.push(byte[0]);
if byte[0] & 0x80 == 0 {
break;
}
if vbi_buf.len() >= 4 {
return Err(ClientError::Codec("vbi too long".to_string()));
}
}
let (remaining, _) =
decode_vbi(&vbi_buf).map_err(|e| ClientError::Codec(format!("{e:?}")))?;
let mut body = vec![0u8; remaining as usize];
if !body.is_empty() {
self.stream
.read_exact(&mut body)
.map_err(|e| ClientError::Io(format!("read body: {e:?}")))?;
}
let byte0 = hdr[0];
let pt_bits = (byte0 >> 4) & 0x0F;
let packet_type = ControlPacketType::from_bits(pt_bits)
.ok_or_else(|| ClientError::Codec(format!("unknown packet type {pt_bits}")))?;
let flags = byte0 & 0x0F;
Ok((
FixedHeader {
packet_type,
flags,
remaining_length: remaining,
},
body,
))
}
}
fn wrap_packet(
packet_type: ControlPacketType,
flags: u8,
body: &[u8],
) -> Result<Vec<u8>, crate::codec::CodecError> {
let mut out = Vec::with_capacity(5 + body.len());
let byte0 = (packet_type.to_bits() << 4) | (flags & 0x0F);
out.push(byte0);
let len_u32 = u32::try_from(body.len())
.map_err(|_| crate::codec::CodecError::Vbi(crate::vbi::VbiError::Malformed))?;
let vbi = encode_vbi(len_u32).ok_or(crate::codec::CodecError::Vbi(
crate::vbi::VbiError::Malformed,
))?;
out.extend_from_slice(&vbi);
out.extend_from_slice(body);
Ok(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackoffConfig {
pub initial_ms: u64,
pub max_ms: u64,
pub multiplier: u64,
pub max_attempts: u32,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
initial_ms: 100,
max_ms: 30_000,
multiplier: 2,
max_attempts: u32::MAX,
}
}
}
impl BackoffConfig {
#[must_use]
pub fn delay_for(&self, attempt: u32) -> Duration {
let mut d = self.initial_ms;
for _ in 0..attempt {
d = d.saturating_mul(self.multiplier);
if d >= self.max_ms {
d = self.max_ms;
break;
}
}
Duration::from_millis(d)
}
}
pub fn connect_with_backoff(
host: &str,
port: u16,
cfg: &DaemonConfig,
backoff: BackoffConfig,
stop: &AtomicBool,
) -> Result<MqttClient, ClientError> {
let mut last_err = ClientError::Io("no attempts".to_string());
for attempt in 0..backoff.max_attempts {
if stop.load(Ordering::SeqCst) {
return Err(ClientError::Io("stop signaled".to_string()));
}
match MqttClient::connect(host, port, cfg) {
Ok(c) => return Ok(c),
Err(e) => {
last_err = e;
let d = backoff.delay_for(attempt);
std::thread::sleep(d);
}
}
}
Err(last_err)
}
#[cfg(feature = "daemon")]
pub fn connect_secure_with_backoff(
host: &str,
port: u16,
cfg: &DaemonConfig,
tls_client_cfg: Option<Arc<ClientConfig>>,
backoff: BackoffConfig,
stop: &AtomicBool,
) -> Result<MqttClient, ClientError> {
let mut last_err = ClientError::Io("no attempts".to_string());
for attempt in 0..backoff.max_attempts {
if stop.load(Ordering::SeqCst) {
return Err(ClientError::Io("stop signaled".to_string()));
}
match MqttClient::connect_secure(host, port, cfg, tls_client_cfg.clone()) {
Ok(c) => return Ok(c),
Err(e) => {
last_err = e;
let d = backoff.delay_for(attempt);
std::thread::sleep(d);
}
}
}
Err(last_err)
}
pub fn run_inbound_loop<F>(mut client: MqttClient, stop: Arc<AtomicBool>, mut on_event: F)
where
F: FnMut(InboundEvent),
{
while !stop.load(Ordering::SeqCst) {
match client.next_event() {
Ok(Some(InboundEvent::Disconnected(reason))) => {
on_event(InboundEvent::Disconnected(reason));
break;
}
Ok(Some(ev)) => on_event(ev),
Ok(None) => continue,
Err(e) => {
on_event(InboundEvent::Disconnected(format!("client error: {e}")));
break;
}
}
}
client.graceful_disconnect();
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn backoff_config_default_increments_exponentially() {
let b = BackoffConfig::default();
assert_eq!(b.delay_for(0), Duration::from_millis(100));
assert_eq!(b.delay_for(1), Duration::from_millis(200));
assert_eq!(b.delay_for(2), Duration::from_millis(400));
assert_eq!(b.delay_for(3), Duration::from_millis(800));
}
#[test]
fn backoff_config_caps_at_max() {
let b = BackoffConfig {
initial_ms: 100,
max_ms: 1_000,
multiplier: 2,
max_attempts: 100,
};
assert_eq!(b.delay_for(4), Duration::from_millis(1_000));
assert_eq!(b.delay_for(20), Duration::from_millis(1_000));
}
#[test]
fn backoff_connect_aborts_when_stop_set() {
let stop = AtomicBool::new(true);
let cfg = DaemonConfig::default_for_dev();
let b = BackoffConfig::default();
let r = connect_with_backoff("127.0.0.1", 1, &cfg, b, &stop);
assert!(r.is_err());
}
#[test]
fn wrap_packet_publish() {
let body = b"\x00\x03foo".to_vec();
let f = wrap_packet(ControlPacketType::Publish, 0, &body).unwrap();
assert_eq!(f[0], 0x30);
assert_eq!(f[1], 5);
assert_eq!(&f[2..], &body[..]);
}
#[test]
fn wrap_packet_subscribe_has_reserved_bits() {
let f = wrap_packet(ControlPacketType::Subscribe, 0b0010, b"x").unwrap();
assert_eq!(f[0] & 0x0F, 0b0010);
}
}