#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::codec::decode::Cursor;
use crate::codec::encode::{self, encode_fixed_header, encode_string, encode_u16};
use crate::codec::properties::Properties;
use crate::codec::types::{ConnAckPacket, ConnectPacket, PacketType};
use crate::error::Result;
const PROTOCOL_NAME: &str = "MQTT";
const PROTOCOL_VERSION_5: u8 = 5;
impl ConnectPacket {
pub fn encode(&self) -> Result<Vec<u8>> {
let mut body = Vec::new();
encode_string(&mut body, PROTOCOL_NAME)?;
body.push(self.protocol_version);
let mut flags: u8 = 0;
if self.clean_start {
flags |= 0x02;
}
if self.password.is_some() {
flags |= 0x40;
}
if self.username.is_some() {
flags |= 0x80;
}
body.push(flags);
encode_u16(&mut body, self.keep_alive);
if self.protocol_version >= PROTOCOL_VERSION_5 {
self.properties.encode(&mut body)?;
}
encode_string(&mut body, &self.client_id)?;
if let Some(ref username) = self.username {
encode_string(&mut body, username)?;
}
if let Some(ref password) = self.password {
encode::encode_binary(&mut body, password)?;
}
let mut packet = Vec::new();
encode_fixed_header(&mut packet, PacketType::Connect, 0, body.len() as u32)?;
packet.extend_from_slice(&body);
Ok(packet)
}
}
impl ConnAckPacket {
pub fn decode(body: &[u8]) -> Result<Self> {
let mut cur = Cursor::new(body);
let ack_flags = cur.read_u8()?;
let session_present = (ack_flags & 0x01) != 0;
let reason_code = cur.read_u8()?;
let properties = if cur.remaining() > 0 {
Properties::decode(&mut cur)?
} else {
Properties::new()
};
Ok(ConnAckPacket {
session_present,
reason_code,
properties,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::decode::decode_fixed_header;
#[test]
fn connect_encode_minimal() {
let pkt = ConnectPacket {
protocol_version: 5,
clean_start: true,
keep_alive: 60,
client_id: String::new(),
username: None,
password: None,
properties: Properties::new(),
};
let bytes = pkt.encode().unwrap();
let (header, hdr_len) = decode_fixed_header(&bytes).unwrap();
assert_eq!(header.packet_type, PacketType::Connect);
assert_eq!(header.flags, 0);
let body = &bytes[hdr_len..];
let mut cur = Cursor::new(body);
assert_eq!(cur.read_string().unwrap(), "MQTT");
assert_eq!(cur.read_u8().unwrap(), 5); assert_eq!(cur.read_u8().unwrap(), 0x02); assert_eq!(cur.read_u16().unwrap(), 60); }
#[test]
fn connect_with_credentials() {
let pkt = ConnectPacket {
protocol_version: 5,
clean_start: true,
keep_alive: 30,
client_id: String::from("test-client"),
username: Some(String::from("user")),
password: Some(b"pass".to_vec()),
properties: Properties::new(),
};
let bytes = pkt.encode().unwrap();
let (header, _) = decode_fixed_header(&bytes).unwrap();
assert_eq!(header.packet_type, PacketType::Connect);
}
#[test]
fn connack_decode_success() {
let body = [0x00, 0x00, 0x00];
let pkt = ConnAckPacket::decode(&body).unwrap();
assert!(!pkt.session_present);
assert_eq!(pkt.reason_code, 0x00);
}
#[test]
fn connack_decode_refused() {
let body = [0x00, 0x87, 0x00]; let pkt = ConnAckPacket::decode(&body).unwrap();
assert_eq!(pkt.reason_code, 0x87);
}
}