use std::io::{self, Read, Write};
use std::net::TcpStream;
use purecrypto::rng::{OsRng, RngCore};
use crate::error::{Error, Result};
use crate::url::Url;
const PKT_CONNECT: u8 = 1;
const PKT_CONNACK: u8 = 2;
const PKT_PUBLISH: u8 = 3;
const PKT_SUBSCRIBE: u8 = 8;
const PKT_SUBACK: u8 = 9;
const PKT_PINGRESP: u8 = 13;
const PKT_DISCONNECT: u8 = 14;
pub fn fetch(url: &Url) -> Result<Vec<u8>> {
let topic = url.path.strip_prefix('/').unwrap_or(&url.path);
if topic.is_empty() {
return Err(Error::InvalidUrl(format!(
"mqtt: no topic in URL path ({:?})",
url.path
)));
}
let (user, pass) = split_userinfo(url.userinfo.as_deref());
let addr = format!("{}:{}", url.host, url.port);
let tcp = TcpStream::connect(&addr)?;
if url.is_tls() {
let mut stream = crate::tls::connect_over(tcp, &url.host)?;
run_session(&mut stream, topic, user, pass)
} else {
let mut stream = tcp;
run_session(&mut stream, topic, user, pass)
}
}
fn run_session<S: Read + Write>(
stream: &mut S,
topic: &str,
user: Option<&str>,
pass: Option<&str>,
) -> Result<Vec<u8>> {
let client_id = random_client_id();
let connect = build_connect(&client_id, user, pass, 60);
stream.write_all(&connect)?;
stream.flush()?;
let (ctype, body) = read_packet(stream)?;
if ctype != PKT_CONNACK {
return Err(Error::BadResponse(format!(
"mqtt: expected CONNACK, got packet type {ctype}"
)));
}
if body.len() < 2 {
return Err(Error::BadResponse("mqtt: short CONNACK".into()));
}
let rc = body[1];
if rc != 0 {
return Err(Error::BadResponse(format!("mqtt: connack {rc}")));
}
let subscribe = build_subscribe(1, topic);
stream.write_all(&subscribe)?;
stream.flush()?;
let (ctype, body) = read_packet(stream)?;
if ctype != PKT_SUBACK {
return Err(Error::BadResponse(format!(
"mqtt: expected SUBACK, got packet type {ctype}"
)));
}
if body.len() < 3 {
return Err(Error::BadResponse("mqtt: short SUBACK".into()));
}
let sub_rc = body[2];
if sub_rc == 0x80 {
return Err(Error::BadResponse("mqtt: suback failure (0x80)".into()));
}
let payload = loop {
let (ctype, body) = read_packet(stream)?;
match ctype {
PKT_PUBLISH => break extract_publish_payload(&body)?,
PKT_PINGRESP => continue,
other => {
return Err(Error::BadResponse(format!(
"mqtt: unexpected packet type {other} before PUBLISH"
)));
}
}
};
let _ = stream.write_all(&[PKT_DISCONNECT << 4, 0x00]);
let _ = stream.flush();
Ok(payload)
}
fn split_userinfo(ui: Option<&str>) -> (Option<&str>, Option<&str>) {
match ui {
None => (None, None),
Some(s) => match s.split_once(':') {
Some((u, p)) => (Some(u), Some(p)),
None => (Some(s), None),
},
}
}
fn random_client_id() -> String {
let mut buf = [0u8; 6];
OsRng.fill_bytes(&mut buf);
let mut s = String::with_capacity(7 + 12);
s.push_str("rsurl-");
for b in buf {
s.push(hex_nibble(b >> 4));
s.push(hex_nibble(b & 0x0F));
}
s
}
fn hex_nibble(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
10..=15 => (b'a' + n - 10) as char,
_ => unreachable!(),
}
}
fn push_str(out: &mut Vec<u8>, s: &str) {
let bytes = s.as_bytes();
let len = bytes.len().min(u16::MAX as usize) as u16;
out.extend_from_slice(&len.to_be_bytes());
out.extend_from_slice(&bytes[..len as usize]);
}
pub(crate) fn build_connect(
client_id: &str,
user: Option<&str>,
pass: Option<&str>,
keep_alive_secs: u16,
) -> Vec<u8> {
let mut vh = Vec::new();
push_str(&mut vh, "MQTT");
vh.push(4); let mut flags: u8 = 0x02; if user.is_some() {
flags |= 0x80;
}
if pass.is_some() {
flags |= 0x40;
}
vh.push(flags);
vh.extend_from_slice(&keep_alive_secs.to_be_bytes());
let mut pl = Vec::new();
push_str(&mut pl, client_id);
if let Some(u) = user {
push_str(&mut pl, u);
}
if let Some(p) = pass {
push_str(&mut pl, p);
}
let mut out = Vec::with_capacity(2 + vh.len() + pl.len());
out.push(PKT_CONNECT << 4); write_remaining_length(&mut out, vh.len() + pl.len());
out.extend_from_slice(&vh);
out.extend_from_slice(&pl);
out
}
pub(crate) fn build_subscribe(packet_id: u16, topic: &str) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&packet_id.to_be_bytes());
push_str(&mut body, topic);
body.push(0x00);
let mut out = Vec::with_capacity(2 + body.len());
out.push((PKT_SUBSCRIBE << 4) | 0x02); write_remaining_length(&mut out, body.len());
out.extend_from_slice(&body);
out
}
fn extract_publish_payload(body: &[u8]) -> Result<Vec<u8>> {
if body.len() < 2 {
return Err(Error::BadResponse("mqtt: short PUBLISH".into()));
}
let topic_len = u16::from_be_bytes([body[0], body[1]]) as usize;
let after_topic = 2 + topic_len;
if body.len() < after_topic {
return Err(Error::BadResponse(
"mqtt: PUBLISH topic length exceeds packet".into(),
));
}
Ok(body[after_topic..].to_vec())
}
fn read_packet<R: Read>(r: &mut R) -> Result<(u8, Vec<u8>)> {
let mut hdr = [0u8; 1];
read_exact_or_eof(r, &mut hdr)?;
let ctype = hdr[0] >> 4;
let rem = read_remaining_length(r)?;
let mut body = vec![0u8; rem];
if rem > 0 {
read_exact_or_eof(r, &mut body)?;
}
Ok((ctype, body))
}
fn read_exact_or_eof<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<()> {
match r.read_exact(buf) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Err(Error::UnexpectedEof),
Err(e) => Err(Error::Io(e)),
}
}
pub(crate) fn read_remaining_length<R: Read>(r: &mut R) -> io::Result<usize> {
let mut value: usize = 0;
let mut multiplier: usize = 1;
for i in 0..4 {
let mut b = [0u8; 1];
r.read_exact(&mut b)?;
let byte = b[0];
value += (byte & 0x7F) as usize * multiplier;
if byte & 0x80 == 0 {
return Ok(value);
}
if i == 3 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"mqtt: malformed remaining length (5th byte)",
));
}
multiplier *= 128;
}
unreachable!("loop returns or errors before exit")
}
pub(crate) fn write_remaining_length(out: &mut Vec<u8>, len: usize) {
let mut x = len.min(268_435_455);
loop {
let mut byte = (x & 0x7F) as u8;
x >>= 7;
if x > 0 {
byte |= 0x80;
out.push(byte);
} else {
out.push(byte);
return;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const RL_CASES: &[(usize, &[u8])] = &[
(0, &[0x00]),
(127, &[0x7F]),
(128, &[0x80, 0x01]),
(16_383, &[0xFF, 0x7F]),
(16_384, &[0x80, 0x80, 0x01]),
(2_097_151, &[0xFF, 0xFF, 0x7F]),
(2_097_152, &[0x80, 0x80, 0x80, 0x01]),
(268_435_455, &[0xFF, 0xFF, 0xFF, 0x7F]),
];
#[test]
fn write_remaining_length_matches_spec_bytes() {
for (value, expected) in RL_CASES {
let mut buf = Vec::new();
write_remaining_length(&mut buf, *value);
assert_eq!(
buf.as_slice(),
*expected,
"encoding of {value} (got {:02X?}, want {:02X?})",
buf,
expected
);
}
}
#[test]
fn read_remaining_length_round_trips() {
for (value, expected) in RL_CASES {
let mut buf = Vec::new();
write_remaining_length(&mut buf, *value);
let mut cur = std::io::Cursor::new(&buf);
let got = read_remaining_length(&mut cur).expect("decode");
assert_eq!(got, *value, "round trip for {value}");
let mut cur2 = std::io::Cursor::new(*expected);
let got2 = read_remaining_length(&mut cur2).expect("decode spec bytes");
assert_eq!(got2, *value, "spec-bytes decode for {value}");
}
}
#[test]
fn read_remaining_length_rejects_5_byte_varint() {
let bad = [0xFF, 0xFF, 0xFF, 0xFF];
let mut cur = std::io::Cursor::new(&bad[..]);
assert!(read_remaining_length(&mut cur).is_err());
}
#[test]
fn build_connect_exact_bytes_for_known_input() {
let got = build_connect("abc", None, None, 60);
let expected: Vec<u8> = vec![
0x10, 0x0F, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x04, 0x02, 0x00, 0x3C, 0x00, 0x03, b'a', b'b', b'c', ];
assert_eq!(got, expected);
}
#[test]
fn build_connect_sets_user_and_password_flags() {
let got = build_connect("id", Some("u"), Some("p"), 30);
let expected: Vec<u8> = vec![
0x10, 0x14, 0x00, 0x04, b'M', b'Q', b'T', b'T', 0x04, 0xC2, 0x00, 0x1E, 0x00, 0x02,
b'i', b'd', 0x00, 0x01, b'u', 0x00, 0x01, b'p',
];
assert_eq!(got, expected);
}
#[test]
fn build_subscribe_exact_bytes() {
let got = build_subscribe(1, "a/b");
let expected: Vec<u8> = vec![0x82, 0x08, 0x00, 0x01, 0x00, 0x03, b'a', b'/', b'b', 0x00];
assert_eq!(got, expected);
}
#[test]
fn extract_publish_payload_strips_topic() {
let body = b"\x00\x03topPAY";
let payload = extract_publish_payload(body).unwrap();
assert_eq!(payload, b"PAY");
}
#[test]
fn split_userinfo_variants() {
assert_eq!(split_userinfo(None), (None, None));
assert_eq!(split_userinfo(Some("alice")), (Some("alice"), None));
assert_eq!(
split_userinfo(Some("alice:secret")),
(Some("alice"), Some("secret"))
);
assert_eq!(
split_userinfo(Some("alice:s:p")),
(Some("alice"), Some("s:p"))
);
}
#[test]
fn random_client_id_format() {
let id = random_client_id();
assert!(id.starts_with("rsurl-"), "got {id}");
let suffix = &id["rsurl-".len()..];
assert_eq!(suffix.len(), 12);
assert!(suffix
.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
assert_ne!(random_client_id(), random_client_id());
}
}