use std::io::{self, Read, Write};
use std::net::TcpStream;
use purecrypto::rng::{OsRng, RngCore};
use crate::error::{Error, Result};
use crate::url::Url;
const PKT_CONNECT: u8 = 1;
const PKT_CONNACK: u8 = 2;
const PKT_PUBLISH: u8 = 3;
const PKT_PUBACK: u8 = 4;
const PKT_SUBSCRIBE: u8 = 8;
const PKT_SUBACK: u8 = 9;
const PKT_PINGRESP: u8 = 13;
const PKT_DISCONNECT: u8 = 14;
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
let topic = url.path.strip_prefix('/').unwrap_or(&url.path);
if topic.is_empty() {
return Err(Error::InvalidUrl(format!(
"mqtt: no topic in URL path ({:?})",
url.path
)));
}
let (user, pass) = split_userinfo(url.userinfo.as_deref());
let addr = format!("{}:{}", url.host, url.port);
let tcp = TcpStream::connect(&addr)?;
if url.is_tls() {
let mut stream = crate::tls::connect_over(tcp, &url.host)?;
run_session(&mut stream, topic, user, pass)
} else {
let mut stream = tcp;
run_session(&mut stream, topic, user, pass)
}
}
pub fn publish(url: &Url, payload: &[u8], qos: u8) -> Result<()> {
let topic = url.path.strip_prefix('/').unwrap_or(&url.path);
if topic.is_empty() {
return Err(Error::InvalidUrl(format!(
"mqtt: no topic in URL path ({:?})",
url.path
)));
}
validate_publish_topic(topic)?;
if qos > 1 {
return Err(Error::BadResponse(format!(
"mqtt: unsupported publish QoS {qos} (only 0 and 1)"
)));
}
let (user, pass) = split_userinfo(url.userinfo.as_deref());
let addr = format!("{}:{}", url.host, url.port);
let tcp = TcpStream::connect(&addr)?;
if url.is_tls() {
let mut stream = crate::tls::connect_over(tcp, &url.host)?;
run_publish(&mut stream, topic, payload, qos, user, pass)
} else {
let mut stream = tcp;
run_publish(&mut stream, topic, payload, qos, user, pass)
}
}
fn connect_handshake<S: Read + Write>(
stream: &mut S,
user: Option<&str>,
pass: Option<&str>,
) -> Result<()> {
let client_id = random_client_id();
let connect = build_connect(&client_id, user, pass, 60);
stream.write_all(&connect)?;
stream.flush()?;
let (ctype, body) = read_packet(stream)?;
if ctype != PKT_CONNACK {
return Err(Error::BadResponse(format!(
"mqtt: expected CONNACK, got packet type {ctype}"
)));
}
if body.len() < 2 {
return Err(Error::BadResponse("mqtt: short CONNACK".into()));
}
let rc = body[1];
if rc != 0 {
return Err(Error::BadResponse(format!("mqtt: connack {rc}")));
}
Ok(())
}
fn run_publish<S: Read + Write>(
stream: &mut S,
topic: &str,
payload: &[u8],
qos: u8,
user: Option<&str>,
pass: Option<&str>,
) -> Result<()> {
connect_handshake(stream, user, pass)?;
let packet_id = 1u16;
let publish = build_publish(topic, payload, qos, packet_id);
stream.write_all(&publish)?;
stream.flush()?;
if qos == 1 {
loop {
let (ctype, body) = read_packet(stream)?;
match ctype {
PKT_PUBACK => {
let acked = parse_puback(&body)?;
if acked != packet_id {
return Err(Error::BadResponse(format!(
"mqtt: PUBACK packet id {acked} != sent {packet_id}"
)));
}
break;
}
PKT_PINGRESP => continue,
other => {
return Err(Error::BadResponse(format!(
"mqtt: unexpected packet type {other} while awaiting PUBACK"
)));
}
}
}
}
let _ = stream.write_all(&[PKT_DISCONNECT << 4, 0x00]);
let _ = stream.flush();
Ok(())
}
fn run_session<S: Read + Write>(
stream: &mut S,
topic: &str,
user: Option<&str>,
pass: Option<&str>,
) -> Result<Vec<u8>> {
connect_handshake(stream, user, pass)?;
let subscribe = build_subscribe(1, topic);
stream.write_all(&subscribe)?;
stream.flush()?;
let (ctype, body) = read_packet(stream)?;
if ctype != PKT_SUBACK {
return Err(Error::BadResponse(format!(
"mqtt: expected SUBACK, got packet type {ctype}"
)));
}
if body.len() < 3 {
return Err(Error::BadResponse("mqtt: short SUBACK".into()));
}
let sub_rc = body[2];
if sub_rc == 0x80 {
return Err(Error::BadResponse("mqtt: suback failure (0x80)".into()));
}
let payload = loop {
let (ctype, body) = read_packet(stream)?;
match ctype {
PKT_PUBLISH => break extract_publish_payload(&body)?,
PKT_PINGRESP => continue,
other => {
return Err(Error::BadResponse(format!(
"mqtt: unexpected packet type {other} before PUBLISH"
)));
}
}
};
let _ = stream.write_all(&[PKT_DISCONNECT << 4, 0x00]);
let _ = stream.flush();
Ok(payload)
}
fn split_userinfo(ui: Option<&str>) -> (Option<&str>, Option<&str>) {
match ui {
None => (None, None),
Some(s) => match s.split_once(':') {
Some((u, p)) => (Some(u), Some(p)),
None => (Some(s), None),
},
}
}
fn random_client_id() -> String {
let mut buf = [0u8; 6];
OsRng.fill_bytes(&mut buf);
let mut s = String::with_capacity(7 + 12);
s.push_str("rsurl-");
for b in buf {
s.push(hex_nibble(b >> 4));
s.push(hex_nibble(b & 0x0F));
}
s
}
fn hex_nibble(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
10..=15 => (b'a' + n - 10) as char,
_ => unreachable!(),
}
}
fn push_str(out: &mut Vec<u8>, s: &str) {
let bytes = s.as_bytes();
let len = bytes.len().min(u16::MAX as usize) as u16;
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(&bytes[..len as usize]);
}
pub(crate) fn build_connect(
client_id: &str,
user: Option<&str>,
pass: Option<&str>,
keep_alive_secs: u16,
) -> Vec<u8> {
let mut vh = Vec::new();
push_str(&mut vh, "MQTT");
vh.push(4); let mut flags: u8 = 0x02; if user.is_some() {
flags |= 0x80;
}
if pass.is_some() {
flags |= 0x40;
}
vh.push(flags);
vh.extend_from_slice(&keep_alive_secs.to_be_bytes());
let mut pl = Vec::new();
push_str(&mut pl, client_id);
if let Some(u) = user {
push_str(&mut pl, u);
}
if let Some(p) = pass {
push_str(&mut pl, p);
}
let mut out = Vec::with_capacity(2 + vh.len() + pl.len());
out.push(PKT_CONNECT << 4); write_remaining_length(&mut out, vh.len() + pl.len());
out.extend_from_slice(&vh);
out.extend_from_slice(&pl);
out
}
pub(crate) fn build_subscribe(packet_id: u16, topic: &str) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&packet_id.to_be_bytes());
push_str(&mut body, topic);
body.push(0x00);
let mut out = Vec::with_capacity(2 + body.len());
out.push((PKT_SUBSCRIBE << 4) | 0x02); write_remaining_length(&mut out, body.len());
out.extend_from_slice(&body);
out
}
fn validate_publish_topic(topic: &str) -> Result<()> {
for ch in topic.chars() {
match ch {
'+' | '#' => {
return Err(Error::InvalidUrl(format!(
"mqtt: publish topic must not contain wildcard {ch:?}"
)));
}
'\0' => {
return Err(Error::InvalidUrl(
"mqtt: publish topic must not contain NUL".into(),
));
}
c if c.is_control() => {
return Err(Error::InvalidUrl(format!(
"mqtt: publish topic must not contain control char {:?}",
c
)));
}
_ => {}
}
}
Ok(())
}
pub(crate) fn build_publish(topic: &str, payload: &[u8], qos: u8, packet_id: u16) -> Vec<u8> {
let mut body = Vec::new();
push_str(&mut body, topic);
if qos > 0 {
body.extend_from_slice(&packet_id.to_be_bytes());
}
body.extend_from_slice(payload);
let mut out = Vec::with_capacity(2 + body.len());
out.push((PKT_PUBLISH << 4) | ((qos & 0x03) << 1));
write_remaining_length(&mut out, body.len());
out.extend_from_slice(&body);
out
}
fn parse_puback(body: &[u8]) -> Result<u16> {
if body.len() < 2 {
return Err(Error::BadResponse("mqtt: short PUBACK".into()));
}
Ok(u16::from_be_bytes([body[0], body[1]]))
}
fn extract_publish_payload(body: &[u8]) -> Result<Vec<u8>> {
if body.len() < 2 {
return Err(Error::BadResponse("mqtt: short PUBLISH".into()));
}
let topic_len = u16::from_be_bytes([body[0], body[1]]) as usize;
let after_topic = 2 + topic_len;
if body.len() < after_topic {
return Err(Error::BadResponse(
"mqtt: PUBLISH topic length exceeds packet".into(),
));
}
Ok(body[after_topic..].to_vec())
}
fn read_packet<R: Read>(r: &mut R) -> Result<(u8, Vec<u8>)> {
let mut hdr = [0u8; 1];
read_exact_or_eof(r, &mut hdr)?;
let ctype = hdr[0] >> 4;
let rem = read_remaining_length(r)?;
let mut body = vec![0u8; rem];
if rem > 0 {
read_exact_or_eof(r, &mut body)?;
}
Ok((ctype, body))
}
fn read_exact_or_eof<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
match r.read_exact(buf) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Err(Error::UnexpectedEof),
Err(e) => Err(Error::Io(e)),
}
}
pub(crate) fn read_remaining_length<R: Read>(r: &mut R) -> io::Result<usize> {
let mut value: usize = 0;
let mut multiplier: usize = 1;
for i in 0..4 {
let mut b = [0u8; 1];
r.read_exact(&mut b)?;
let byte = b[0];
value += (byte & 0x7F) as usize * multiplier;
if byte & 0x80 == 0 {
return Ok(value);
}
if i == 3 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"mqtt: malformed remaining length (5th byte)",
));
}
multiplier *= 128;
}
unreachable!("loop returns or errors before exit")
}
pub(crate) fn write_remaining_length(out: &mut Vec<u8>, len: usize) {
let mut x = len.min(268_435_455);
loop {
let mut byte = (x & 0x7F) as u8;
x >>= 7;
if x > 0 {
byte |= 0x80;
out.push(byte);
} else {
out.push(byte);
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockStream {
to_read: std::io::Cursor<Vec<u8>>,
written: Vec<u8>,
}
impl MockStream {
fn new(to_read: Vec<u8>) -> Self {
MockStream {
to_read: std::io::Cursor::new(to_read),
written: Vec::new(),
}
}
fn written(&self) -> &[u8] {
&self.written
}
}
impl Read for MockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.to_read.read(buf)
}
}
impl Write for MockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.written.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
const RL_CASES: &[(usize, &[u8])] = &[
(0, &[0x00]),
(127, &[0x7F]),
(128, &[0x80, 0x01]),
(16_383, &[0xFF, 0x7F]),
(16_384, &[0x80, 0x80, 0x01]),
(2_097_151, &[0xFF, 0xFF, 0x7F]),
(2_097_152, &[0x80, 0x80, 0x80, 0x01]),
(268_435_455, &[0xFF, 0xFF, 0xFF, 0x7F]),
];
#[test]
fn write_remaining_length_matches_spec_bytes() {
for (value, expected) in RL_CASES {
let mut buf = Vec::new();
write_remaining_length(&mut buf, *value);
assert_eq!(
buf.as_slice(),
*expected,
"encoding of {value} (got {:02X?}, want {:02X?})",
buf,
expected
);
}
}
#[test]
fn read_remaining_length_round_trips() {
for (value, expected) in RL_CASES {
let mut buf = Vec::new();
write_remaining_length(&mut buf, *value);
let mut cur = std::io::Cursor::new(&buf);
let got = read_remaining_length(&mut cur).expect("decode");
assert_eq!(got, *value, "round trip for {value}");
let mut cur2 = std::io::Cursor::new(*expected);
let got2 = read_remaining_length(&mut cur2).expect("decode spec bytes");
assert_eq!(got2, *value, "spec-bytes decode for {value}");
}
}
#[test]
fn read_remaining_length_rejects_5_byte_varint() {
let bad = [0xFF, 0xFF, 0xFF, 0xFF];
let mut cur = std::io::Cursor::new(&bad[..]);
assert!(read_remaining_length(&mut cur).is_err());
}
#[test]
fn build_connect_exact_bytes_for_known_input() {
let got = build_connect("abc", None, None, 60);
let expected: Vec<u8> = vec![
0x10, 0x0F, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x04, 0x02, 0x00, 0x3C, 0x00, 0x03, b'a', b'b', b'c', ];
assert_eq!(got, expected);
}
#[test]
fn build_connect_sets_user_and_password_flags() {
let got = build_connect("id", Some("u"), Some("p"), 30);
let expected: Vec<u8> = vec![
0x10, 0x14, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x04, 0xC2, 0x00, 0x1E, 0x00, 0x02,
b'i', b'd', 0x00, 0x01, b'u', 0x00, 0x01, b'p',
];
assert_eq!(got, expected);
}
#[test]
fn build_subscribe_exact_bytes() {
let got = build_subscribe(1, "a/b");
let expected: Vec<u8> = vec![0x82, 0x08, 0x00, 0x01, 0x00, 0x03, b'a', b'/', b'b', 0x00];
assert_eq!(got, expected);
}
#[test]
fn build_publish_qos0_has_no_packet_id() {
let got = build_publish("top", b"PAY", 0, 1);
let expected: Vec<u8> = vec![0x30, 0x08, 0x00, 0x03, b't', b'o', b'p', b'P', b'A', b'Y'];
assert_eq!(got, expected);
assert_eq!(build_publish("top", b"PAY", 0, 9999), expected);
}
#[test]
fn build_publish_qos1_has_flag_and_packet_id() {
let got = build_publish("top", b"PAY", 1, 7);
let expected: Vec<u8> = vec![
0x32, 0x0A, 0x00, 0x03, b't', b'o', b'p', 0x00, 0x07, b'P', b'A', b'Y',
];
assert_eq!(got, expected);
}
#[test]
fn build_publish_large_payload_round_trips() {
let payload = vec![0xABu8; 5000];
let pkt = build_publish("t", &payload, 0, 1);
let mut cur = std::io::Cursor::new(&pkt);
let (ctype, body) = read_packet(&mut cur).expect("read PUBLISH");
assert_eq!(ctype, PKT_PUBLISH);
let got = extract_publish_payload(&body).expect("payload");
assert_eq!(got, payload);
}
#[test]
fn parse_puback_reads_packet_id() {
assert_eq!(parse_puback(&[0x00, 0x07]).unwrap(), 7);
assert_eq!(parse_puback(&[0x12, 0x34]).unwrap(), 0x1234);
assert!(parse_puback(&[0x00]).is_err());
}
#[test]
fn qos1_publish_waits_for_matching_puback() {
let mut from_broker = Vec::new();
from_broker.extend_from_slice(&[(PKT_CONNACK << 4), 0x02, 0x00, 0x00]);
from_broker.extend_from_slice(&[(PKT_PUBACK << 4), 0x02, 0x00, 0x01]);
let mut peer = MockStream::new(from_broker);
run_publish(&mut peer, "a/b", b"hello", 1, None, None).expect("publish ok");
let written = peer.written();
let mut cur = std::io::Cursor::new(&written);
let (c0, _) = read_packet(&mut cur).expect("connect");
assert_eq!(c0, PKT_CONNECT);
let publish_start = cur.position() as usize;
let expected_publish = build_publish("a/b", b"hello", 1, 1);
assert_eq!(
&written[publish_start..publish_start + expected_publish.len()],
expected_publish.as_slice()
);
let disc = &written[publish_start + expected_publish.len()..];
assert_eq!(disc, &[PKT_DISCONNECT << 4, 0x00]);
}
#[test]
fn qos1_publish_rejects_mismatched_puback() {
let mut from_broker = Vec::new();
from_broker.extend_from_slice(&[(PKT_CONNACK << 4), 0x02, 0x00, 0x00]);
from_broker.extend_from_slice(&[(PKT_PUBACK << 4), 0x02, 0x00, 0x02]);
let mut peer = MockStream::new(from_broker);
let err = run_publish(&mut peer, "a/b", b"hi", 1, None, None).unwrap_err();
match err {
Error::BadResponse(m) => assert!(m.contains("PUBACK"), "got {m}"),
other => panic!("expected BadResponse, got {other:?}"),
}
}
#[test]
fn validate_publish_topic_rejects_wildcards_and_control() {
assert!(validate_publish_topic("sensor/+/temp").is_err());
assert!(validate_publish_topic("sensor/#").is_err());
assert!(validate_publish_topic("a\0b").is_err());
assert!(validate_publish_topic("a\nb").is_err());
assert!(validate_publish_topic("a/b/c").is_ok());
assert!(validate_publish_topic("home/kitchen/temp").is_ok());
}
#[test]
fn extract_publish_payload_strips_topic() {
let body = b"\x00\x03topPAY";
let payload = extract_publish_payload(body).unwrap();
assert_eq!(payload, b"PAY");
}
#[test]
fn split_userinfo_variants() {
assert_eq!(split_userinfo(None), (None, None));
assert_eq!(split_userinfo(Some("alice")), (Some("alice"), None));
assert_eq!(
split_userinfo(Some("alice:secret")),
(Some("alice"), Some("secret"))
);
assert_eq!(
split_userinfo(Some("alice:s:p")),
(Some("alice"), Some("s:p"))
);
}
#[test]
fn random_client_id_format() {
let id = random_client_id();
assert!(id.starts_with("rsurl-"), "got {id}");
let suffix = &id["rsurl-".len()..];
assert_eq!(suffix.len(), 12);
assert!(suffix
.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
assert_ne!(random_client_id(), random_client_id());
}
}