use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::error::Error;
use crate::protocol::http::response::Response;
fn percent_decode_topic(encoded: &str) -> String {
let mut decoded = Vec::with_capacity(encoded.len());
let bytes = encoded.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let (Some(hi), Some(lo)) = (hex_nibble(bytes[i + 1]), hex_nibble(bytes[i + 2])) {
decoded.push(hi << 4 | lo);
i += 3;
continue;
}
}
decoded.push(bytes[i]);
i += 1;
}
String::from_utf8(decoded)
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
}
const fn hex_nibble(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum QoS {
#[default]
AtMostOnce = 0,
AtLeastOnce = 1,
ExactlyOnce = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Connect = 1,
Connack = 2,
Publish = 3,
Puback = 4,
Pubrec = 5,
Pubrel = 6,
Pubcomp = 7,
Subscribe = 8,
Suback = 9,
Disconnect = 14,
}
fn encode_remaining_length(mut len: usize) -> Vec<u8> {
let mut bytes = Vec::new();
loop {
#[allow(clippy::cast_possible_truncation)]
let mut byte = (len % 128) as u8;
len /= 128;
if len > 0 {
byte |= 0x80;
}
bytes.push(byte);
if len == 0 {
break;
}
}
bytes
}
fn decode_remaining_length(data: &[u8]) -> Result<(usize, usize), Error> {
let mut multiplier: usize = 1;
let mut value: usize = 0;
for (i, &byte) in data.iter().enumerate() {
value += usize::from(byte & 0x7F) * multiplier;
if multiplier > 128 * 128 * 128 {
return Err(Error::Http("MQTT remaining length overflow".to_string()));
}
if byte & 0x80 == 0 {
return Ok((value, i + 1));
}
multiplier *= 128;
}
Err(Error::Http("MQTT remaining length incomplete".to_string()))
}
fn build_connect_packet(client_id: &str) -> Vec<u8> {
const MQTT_CLIENTID_LEN: usize = 12;
let mut client_id_padded = [0u8; MQTT_CLIENTID_LEN];
let src = client_id.as_bytes();
let copy_len = src.len().min(MQTT_CLIENTID_LEN);
client_id_padded[..copy_len].copy_from_slice(&src[..copy_len]);
let mut variable_header = Vec::new();
variable_header.extend_from_slice(&[0x00, 0x04]);
variable_header.extend_from_slice(b"MQTT");
variable_header.push(0x04);
variable_header.push(0x02);
variable_header.extend_from_slice(&60_u16.to_be_bytes());
let mut payload = Vec::new();
#[allow(clippy::cast_possible_truncation)]
let client_id_len = MQTT_CLIENTID_LEN as u16;
payload.extend_from_slice(&client_id_len.to_be_bytes());
payload.extend_from_slice(&client_id_padded);
let remaining_len = variable_header.len() + payload.len();
let mut packet = vec![0x10]; packet.extend_from_slice(&encode_remaining_length(remaining_len));
packet.extend_from_slice(&variable_header);
packet.extend_from_slice(&payload);
packet
}
fn build_publish_packet(topic: &str, payload: &[u8], qos: QoS, packet_id: u16) -> Vec<u8> {
let topic_bytes = topic.as_bytes();
#[allow(clippy::cast_possible_truncation)]
let topic_len = topic_bytes.len() as u16;
let id_len = if qos == QoS::AtMostOnce { 0 } else { 2 };
let remaining_len = 2 + topic_bytes.len() + id_len + payload.len();
let flags = (qos as u8) << 1;
let mut packet = vec![0x30 | flags];
packet.extend_from_slice(&encode_remaining_length(remaining_len));
packet.extend_from_slice(&topic_len.to_be_bytes());
packet.extend_from_slice(topic_bytes);
if qos != QoS::AtMostOnce {
packet.extend_from_slice(&packet_id.to_be_bytes());
}
packet.extend_from_slice(payload);
packet
}
fn build_puback_packet(packet_id: u16) -> Vec<u8> {
let mut packet = vec![0x40, 0x02]; packet.extend_from_slice(&packet_id.to_be_bytes());
packet
}
fn build_pubrec_packet(packet_id: u16) -> Vec<u8> {
let mut packet = vec![0x50, 0x02]; packet.extend_from_slice(&packet_id.to_be_bytes());
packet
}
fn build_pubrel_packet(packet_id: u16) -> Vec<u8> {
let mut packet = vec![0x62, 0x02]; packet.extend_from_slice(&packet_id.to_be_bytes());
packet
}
fn build_pubcomp_packet(packet_id: u16) -> Vec<u8> {
let mut packet = vec![0x70, 0x02]; packet.extend_from_slice(&packet_id.to_be_bytes());
packet
}
fn build_subscribe_packet(topic: &str, packet_id: u16, qos: QoS) -> Vec<u8> {
let topic_bytes = topic.as_bytes();
#[allow(clippy::cast_possible_truncation)]
let topic_len = topic_bytes.len() as u16;
let remaining_len = 2 + 2 + topic_bytes.len() + 1;
let mut packet = vec![0x82]; packet.extend_from_slice(&encode_remaining_length(remaining_len));
packet.extend_from_slice(&packet_id.to_be_bytes());
packet.extend_from_slice(&topic_len.to_be_bytes());
packet.extend_from_slice(topic_bytes);
packet.push(qos as u8);
packet
}
fn build_disconnect_packet() -> Vec<u8> {
vec![0xE0, 0x00]
}
async fn read_packet<S: AsyncRead + Unpin>(stream: &mut S) -> Result<(u8, u8, Vec<u8>), Error> {
let mut header = [0u8; 1];
let _n = stream
.read_exact(&mut header)
.await
.map_err(|e| Error::Http(format!("MQTT read header error: {e}")))?;
let packet_type = header[0] >> 4;
let flags = header[0] & 0x0F;
let mut len_bytes = Vec::new();
loop {
let mut byte = [0u8; 1];
let _n = stream
.read_exact(&mut byte)
.await
.map_err(|e| Error::Http(format!("MQTT read length error: {e}")))?;
len_bytes.push(byte[0]);
if byte[0] & 0x80 == 0 || len_bytes.len() >= 4 {
break;
}
}
let (remaining_len, _) = decode_remaining_length(&len_bytes)?;
let mut payload = vec![0u8; remaining_len];
if remaining_len > 0 {
let _ = stream.read_exact(&mut payload).await.map_err(|_| {
Error::Http("partial MQTT data received".to_string())
})?;
}
Ok((packet_type, flags, payload))
}
async fn read_connack<S: AsyncRead + Unpin>(stream: &mut S) -> Result<(), Error> {
let mut buf = [0u8; 4];
let _ = stream
.read_exact(&mut buf)
.await
.map_err(|e| Error::Http(format!("MQTT read CONNACK error: {e}")))?;
let ptype = buf[0] >> 4;
if ptype != PacketType::Connack as u8 {
return Err(Error::Http("Weird server reply".to_string()));
}
if buf[1] != 2 {
return Err(Error::Http("Weird server reply".to_string()));
}
if buf[3] != 0 {
return Err(Error::Http("Weird server reply".to_string()));
}
Ok(())
}
pub async fn publish(url: &crate::url::Url, payload: &[u8]) -> Result<Response, Error> {
publish_qos(url, payload, QoS::AtMostOnce).await
}
pub async fn publish_qos(
url: &crate::url::Url,
payload: &[u8],
qos: QoS,
) -> Result<Response, Error> {
let (host, port) = url.host_and_port()?;
let raw_topic = url.path().trim_start_matches('/');
let topic = percent_decode_topic(raw_topic);
let addr = format!("{host}:{port}");
let mut tcp = tokio::net::TcpStream::connect(&addr).await.map_err(Error::Connect)?;
let connect = build_connect_packet("curl");
tcp.write_all(&connect)
.await
.map_err(|e| Error::Http(format!("MQTT connect write error: {e}")))?;
read_connack(&mut tcp).await?;
if topic.is_empty() {
return Err(Error::UrlParse("No MQTT topic found. Forgot to URL encode it?".to_string()));
}
let packet_id: u16 = 1;
let publish_pkt = build_publish_packet(&topic, payload, qos, packet_id);
tcp.write_all(&publish_pkt)
.await
.map_err(|e| Error::Http(format!("MQTT publish write error: {e}")))?;
match qos {
QoS::AtMostOnce => {} QoS::AtLeastOnce => {
let (ptype, _, ack_payload) = read_packet(&mut tcp).await?;
if ptype != PacketType::Puback as u8 {
return Err(Error::Http(format!("MQTT expected PUBACK, got type {ptype}")));
}
if ack_payload.len() >= 2 {
let ack_id = u16::from_be_bytes([ack_payload[0], ack_payload[1]]);
if ack_id != packet_id {
return Err(Error::Http(format!(
"MQTT PUBACK packet ID mismatch: expected {packet_id}, got {ack_id}"
)));
}
}
}
QoS::ExactlyOnce => {
let (ptype, _, rec_payload) = read_packet(&mut tcp).await?;
if ptype != PacketType::Pubrec as u8 {
return Err(Error::Http(format!("MQTT expected PUBREC, got type {ptype}")));
}
if rec_payload.len() >= 2 {
let rec_id = u16::from_be_bytes([rec_payload[0], rec_payload[1]]);
if rec_id != packet_id {
return Err(Error::Http(format!(
"MQTT PUBREC packet ID mismatch: expected {packet_id}, got {rec_id}"
)));
}
}
let pubrel = build_pubrel_packet(packet_id);
tcp.write_all(&pubrel)
.await
.map_err(|e| Error::Http(format!("MQTT PUBREL write error: {e}")))?;
let (ptype, _, comp_payload) = read_packet(&mut tcp).await?;
if ptype != PacketType::Pubcomp as u8 {
return Err(Error::Http(format!("MQTT expected PUBCOMP, got type {ptype}")));
}
if comp_payload.len() >= 2 {
let comp_id = u16::from_be_bytes([comp_payload[0], comp_payload[1]]);
if comp_id != packet_id {
return Err(Error::Http(format!(
"MQTT PUBCOMP packet ID mismatch: expected {packet_id}, got {comp_id}"
)));
}
}
}
}
tcp.write_all(&build_disconnect_packet())
.await
.map_err(|e| Error::Http(format!("MQTT disconnect write error: {e}")))?;
let headers = std::collections::HashMap::new();
Ok(Response::new(200, headers, Vec::new(), url.as_str().to_string()))
}
pub async fn subscribe(url: &crate::url::Url) -> Result<Response, Error> {
subscribe_qos(url, QoS::AtMostOnce).await
}
#[allow(clippy::too_many_lines)] pub async fn subscribe_qos(url: &crate::url::Url, qos: QoS) -> Result<Response, Error> {
let (host, port) = url.host_and_port()?;
let raw_topic = url.path().trim_start_matches('/');
let topic = percent_decode_topic(raw_topic);
let addr = format!("{host}:{port}");
let mut tcp = tokio::net::TcpStream::connect(&addr).await.map_err(Error::Connect)?;
let connect = build_connect_packet("curl");
tcp.write_all(&connect)
.await
.map_err(|e| Error::Http(format!("MQTT connect write error: {e}")))?;
read_connack(&mut tcp).await?;
if topic.is_empty() {
return Err(Error::UrlParse("No MQTT topic found. Forgot to URL encode it?".to_string()));
}
let subscribe_pkt = build_subscribe_packet(&topic, 1, qos);
tcp.write_all(&subscribe_pkt)
.await
.map_err(|e| Error::Http(format!("MQTT subscribe write error: {e}")))?;
let mut got_suback = false;
let mut publish_data: Option<(u8, Vec<u8>)> = None;
for _ in 0..2 {
let (ptype, flags, pkt_payload) = read_packet(&mut tcp).await.map_err(|e| {
let msg = e.to_string();
if msg.contains("read")
|| msg.contains("eof")
|| msg.contains("Eof")
|| msg.contains("unexpected end")
{
Error::Http("partial MQTT data received".to_string())
} else {
e
}
})?;
if ptype == PacketType::Suback as u8 {
got_suback = true;
if publish_data.is_some() {
break; }
} else if ptype == PacketType::Publish as u8 {
publish_data = Some((flags, pkt_payload));
if got_suback {
break; }
} else {
return Err(Error::Http(format!(
"MQTT unexpected packet type {ptype} during subscribe"
)));
}
}
let (flags, payload) = publish_data
.ok_or_else(|| Error::Http("MQTT did not receive PUBLISH message".to_string()))?;
let recv_qos = (flags >> 1) & 0x03;
if payload.len() < 2 {
return Err(Error::Http("partial MQTT data received".to_string()));
}
let topic_len = u16::from_be_bytes([payload[0], payload[1]]) as usize;
let (message_start, recv_packet_id) = if recv_qos > 0 {
let id_offset = 2 + topic_len;
if payload.len() < id_offset + 2 {
return Err(Error::Http("partial MQTT data received".to_string()));
}
let pid = u16::from_be_bytes([payload[id_offset], payload[id_offset + 1]]);
(id_offset + 2, Some(pid))
} else {
(2 + topic_len, None)
};
let message =
if message_start <= payload.len() { payload[message_start..].to_vec() } else { Vec::new() };
if let Some(pid) = recv_packet_id {
if recv_qos == 1 {
let puback = build_puback_packet(pid);
tcp.write_all(&puback)
.await
.map_err(|e| Error::Http(format!("MQTT PUBACK write error: {e}")))?;
} else if recv_qos == 2 {
let pubrec = build_pubrec_packet(pid);
tcp.write_all(&pubrec)
.await
.map_err(|e| Error::Http(format!("MQTT PUBREC write error: {e}")))?;
let (ptype, _, _) = read_packet(&mut tcp).await?;
if ptype != PacketType::Pubrel as u8 {
return Err(Error::Http(format!("MQTT expected PUBREL, got type {ptype}")));
}
let pubcomp = build_pubcomp_packet(pid);
tcp.write_all(&pubcomp)
.await
.map_err(|e| Error::Http(format!("MQTT PUBCOMP write error: {e}")))?;
}
}
let _ = tcp.write_all(&build_disconnect_packet()).await;
let mut headers = std::collections::HashMap::new();
let _old = headers.insert("content-length".to_string(), message.len().to_string());
Ok(Response::new(200, headers, message, url.as_str().to_string()))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn encode_remaining_length_small() {
assert_eq!(encode_remaining_length(0), vec![0]);
assert_eq!(encode_remaining_length(127), vec![127]);
}
#[test]
fn encode_remaining_length_two_bytes() {
assert_eq!(encode_remaining_length(128), vec![0x80, 0x01]);
}
#[test]
fn encode_remaining_length_large() {
assert_eq!(encode_remaining_length(16383), vec![0xFF, 0x7F]);
}
#[test]
fn decode_remaining_length_small() {
let (len, consumed) = decode_remaining_length(&[64]).unwrap();
assert_eq!(len, 64);
assert_eq!(consumed, 1);
}
#[test]
fn decode_remaining_length_two_bytes() {
let (len, consumed) = decode_remaining_length(&[0x80, 0x01]).unwrap();
assert_eq!(len, 128);
assert_eq!(consumed, 2);
}
#[test]
fn roundtrip_remaining_length() {
for &value in &[0, 1, 127, 128, 255, 16383, 16384, 2_097_151] {
let encoded = encode_remaining_length(value);
let (decoded, _) = decode_remaining_length(&encoded).unwrap();
assert_eq!(decoded, value, "roundtrip failed for {value}");
}
}
#[test]
fn connect_packet_structure() {
let packet = build_connect_packet("test");
assert_eq!(packet[0], 0x10); let rl_end = 2; assert_eq!(&packet[rl_end..rl_end + 2], &[0x00, 0x04]); assert_eq!(&packet[rl_end + 2..rl_end + 6], b"MQTT");
assert_eq!(packet[rl_end + 6], 0x04); }
#[test]
fn publish_packet_qos0() {
let packet = build_publish_packet("test/topic", b"hello", QoS::AtMostOnce, 0);
assert_eq!(packet[0], 0x30); }
#[test]
fn publish_packet_qos1() {
let packet = build_publish_packet("test/topic", b"hello", QoS::AtLeastOnce, 42);
assert_eq!(packet[0] & 0x06, 0x02); let topic_len = u16::from_be_bytes([packet[2], packet[3]]) as usize;
let id_offset = 2 + 2 + topic_len; let packet_id = u16::from_be_bytes([packet[id_offset], packet[id_offset + 1]]);
assert_eq!(packet_id, 42);
}
#[test]
fn publish_packet_qos2() {
let packet = build_publish_packet("t", b"x", QoS::ExactlyOnce, 1);
assert_eq!(packet[0] & 0x06, 0x04); }
#[test]
fn puback_packet_structure() {
let packet = build_puback_packet(42);
assert_eq!(packet[0], 0x40); assert_eq!(packet[1], 0x02); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 42);
}
#[test]
fn pubrec_packet_structure() {
let packet = build_pubrec_packet(7);
assert_eq!(packet[0], 0x50); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 7);
}
#[test]
fn pubrel_packet_structure() {
let packet = build_pubrel_packet(99);
assert_eq!(packet[0], 0x62); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 99);
}
#[test]
fn pubcomp_packet_structure() {
let packet = build_pubcomp_packet(100);
assert_eq!(packet[0], 0x70); assert_eq!(u16::from_be_bytes([packet[2], packet[3]]), 100);
}
#[test]
fn subscribe_packet_qos0() {
let packet = build_subscribe_packet("test/#", 1, QoS::AtMostOnce);
assert_eq!(packet[0], 0x82); assert_eq!(*packet.last().unwrap(), 0x00);
}
#[test]
fn subscribe_packet_qos2() {
let packet = build_subscribe_packet("test/#", 1, QoS::ExactlyOnce);
assert_eq!(*packet.last().unwrap(), 0x02);
}
#[test]
fn disconnect_packet() {
let packet = build_disconnect_packet();
assert_eq!(packet, vec![0xE0, 0x00]);
}
#[test]
fn qos_default_is_at_most_once() {
assert_eq!(QoS::default(), QoS::AtMostOnce);
}
#[tokio::test]
async fn read_connack_packet() {
let data = vec![0x20, 0x02, 0x00, 0x00];
let mut cursor = std::io::Cursor::new(data);
let (ptype, flags, payload) = read_packet(&mut cursor).await.unwrap();
assert_eq!(ptype, PacketType::Connack as u8);
assert_eq!(flags, 0);
assert_eq!(payload, vec![0x00, 0x00]);
}
#[tokio::test]
async fn read_puback_packet() {
let data = vec![0x40, 0x02, 0x00, 0x01];
let mut cursor = std::io::Cursor::new(data);
let (ptype, _, payload) = read_packet(&mut cursor).await.unwrap();
assert_eq!(ptype, PacketType::Puback as u8);
assert_eq!(u16::from_be_bytes([payload[0], payload[1]]), 1);
}
#[tokio::test]
async fn read_pubrec_packet() {
let data = vec![0x50, 0x02, 0x00, 0x05];
let mut cursor = std::io::Cursor::new(data);
let (ptype, _, payload) = read_packet(&mut cursor).await.unwrap();
assert_eq!(ptype, PacketType::Pubrec as u8);
assert_eq!(u16::from_be_bytes([payload[0], payload[1]]), 5);
}
#[tokio::test]
async fn read_pubcomp_packet() {
let data = vec![0x70, 0x02, 0x00, 0x0A];
let mut cursor = std::io::Cursor::new(data);
let (ptype, _, payload) = read_packet(&mut cursor).await.unwrap();
assert_eq!(ptype, PacketType::Pubcomp as u8);
assert_eq!(u16::from_be_bytes([payload[0], payload[1]]), 10);
}
}