use std::sync::Arc;
use std::time::{Duration, Instant};
use amq_protocol::frame::{gen_frame, parse_frame, AMQPFrame};
use amq_protocol::protocol::{basic, channel, confirm, connection, AMQPClass};
use amq_protocol::types::{FieldTable, LongString, ShortString};
use amq_protocol::uri::{AMQPScheme, AMQPUri};
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, trace, warn};
use crate::config::TlsConfig;
use crate::error::{Error, Result};
enum ConnectionStream {
Plain(TcpStream),
Tls(Box<tokio_rustls::client::TlsStream<TcpStream>>),
}
impl ConnectionStream {
async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
Self::Plain(s) => s.write_all(buf).await,
Self::Tls(s) => s.write_all(buf).await,
}
}
async fn read_buf(&mut self, buf: &mut BytesMut) -> std::io::Result<usize> {
match self {
Self::Plain(s) => s.read_buf(buf).await,
Self::Tls(s) => s.read_buf(buf).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Plain(s) => AsyncWriteExt::flush(s).await,
Self::Tls(s) => AsyncWriteExt::flush(s).await,
}
}
}
fn serialize_frame(frame: &AMQPFrame) -> Result<Vec<u8>> {
let mut buf = vec![0u8; frame_serialization_capacity(frame)];
let cursor = std::io::Cursor::new(buf.as_mut_slice());
let (cursor, _written) = cookie_factory::gen(gen_frame(frame), cursor)
.map_err(|e| Error::Amqp(format!("Frame serialization failed: {:?}", e)))?;
let pos = cursor.position() as usize;
buf.truncate(pos);
Ok(buf)
}
fn frame_serialization_capacity(frame: &AMQPFrame) -> usize {
match frame {
AMQPFrame::ProtocolHeader(_) => 8,
AMQPFrame::Heartbeat | AMQPFrame::InvalidHeartbeat(_) => 8,
AMQPFrame::Body(_, body) => body.len() + 8,
_ => 131_072,
}
}
fn body_chunk_size(frame_max: u32) -> usize {
let frame_max = if frame_max == 0 {
131_072
} else {
frame_max as usize
};
frame_max.saturating_sub(8).max(1)
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ConfirmStats {
pub nacked: u64,
pub returned: u64,
}
impl ConfirmStats {
pub fn failed(self) -> u64 {
self.nacked + self.returned
}
}
pub struct AmqpClient {
stream: ConnectionStream,
read_buf: BytesMut,
frame_max: u32,
heartbeat_interval: u16,
next_channel_id: u16,
last_frame_received: Instant,
last_heartbeat_sent: Instant,
next_publish_seq: u64,
}
pub fn amqp_url_with_vhost(base_url: &str, vhost: &str) -> Result<String> {
let mut url =
url::Url::parse(base_url).map_err(|e| Error::Config(format!("Invalid AMQP URL: {e}")))?;
let vhost = if vhost.is_empty() { "/" } else { vhost };
url.path_segments_mut()
.map_err(|_| Error::Config(format!("AMQP URL cannot be used as a base: {base_url}")))?
.clear()
.push(vhost);
Ok(url.to_string())
}
impl AmqpClient {
pub async fn connect(amqp_url: &str, tls_config: Option<&TlsConfig>) -> Result<Self> {
let uri: AMQPUri = amqp_url
.parse()
.map_err(|e| Error::Config(format!("Invalid AMQP URL: {:?}", e)))?;
let host = uri.authority.host.clone();
let port = uri.authority.port;
let username = uri.authority.userinfo.username.clone();
let password = uri.authority.userinfo.password.clone();
let vhost = uri.vhost.clone();
info!("Connecting to {}:{} (vhost: {})", host, port, vhost);
let tcp = TcpStream::connect(format!("{}:{}", host, port))
.await
.map_err(|e| Error::Connection(format!("TCP connect failed: {}", e)))?;
tcp.set_nodelay(true)
.map_err(|e| Error::Connection(format!("Failed to set TCP_NODELAY: {}", e)))?;
let stream =
if matches!(uri.scheme, AMQPScheme::AMQPS) || tls_config.is_some_and(|c| c.enabled) {
let tls_cfg = match tls_config {
Some(cfg) => super::tls::build_tls_config(cfg)?,
None => {
let config = TlsConfig {
enabled: true,
ca_cert: None,
client_cert: None,
client_key: None,
};
super::tls::build_tls_config(&config)?
}
};
let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_cfg));
let server_name = host
.clone()
.try_into()
.map_err(|e| Error::Connection(format!("Invalid server name: {}", e)))?;
let tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| Error::Connection(format!("TLS handshake failed: {}", e)))?;
debug!("TLS connection established");
ConnectionStream::Tls(Box::new(tls_stream))
} else {
ConnectionStream::Plain(tcp)
};
let mut client = Self {
stream,
read_buf: BytesMut::with_capacity(16384),
frame_max: 131072, heartbeat_interval: 60,
next_channel_id: 1,
last_frame_received: Instant::now(),
last_heartbeat_sent: Instant::now(),
next_publish_seq: 0,
};
client.handshake(&username, &password, &vhost).await?;
info!(
"AMQP connection established (frame_max={}, heartbeat={}s)",
client.frame_max, client.heartbeat_interval
);
Ok(client)
}
async fn handshake(&mut self, username: &str, password: &str, vhost: &str) -> Result<()> {
debug!("Sending AMQP protocol header");
self.send_frame(&AMQPFrame::ProtocolHeader(
amq_protocol::frame::ProtocolVersion::amqp_0_9_1(),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(0, AMQPClass::Connection(connection::AMQPMethod::Start(start))) => {
debug!(
"Received Connection.Start (v{}.{}, mechanisms: {:?})",
start.version_major,
start.version_minor,
String::from_utf8_lossy(start.mechanisms.as_bytes())
);
}
_ => {
return Err(Error::Amqp(format!(
"Expected Connection.Start, got: {:?}",
frame
)));
}
}
let mut sasl_response = Vec::new();
sasl_response.push(0); sasl_response.extend_from_slice(username.as_bytes());
sasl_response.push(0);
sasl_response.extend_from_slice(password.as_bytes());
let mut client_properties = FieldTable::default();
client_properties.insert(
"product".into(),
amq_protocol::types::AMQPValue::LongString("rabbitmq-backup".into()),
);
client_properties.insert(
"version".into(),
amq_protocol::types::AMQPValue::LongString(
env!("CARGO_PKG_VERSION").as_bytes().to_vec().into(),
),
);
client_properties.insert(
"capabilities".into(),
amq_protocol::types::AMQPValue::FieldTable({
let mut caps = FieldTable::default();
caps.insert(
"consumer_cancel_notify".into(),
amq_protocol::types::AMQPValue::Boolean(true),
);
caps
}),
);
let start_ok = connection::StartOk {
client_properties,
mechanism: ShortString::from("PLAIN"),
response: LongString::from(sasl_response),
locale: ShortString::from("en_US"),
};
debug!("Sending Connection.StartOk (SASL PLAIN)");
self.send_frame(&AMQPFrame::Method(
0,
AMQPClass::Connection(connection::AMQPMethod::StartOk(start_ok)),
))
.await?;
let frame = self.read_frame().await?;
let (channel_max, frame_max, heartbeat) = match &frame {
AMQPFrame::Method(0, AMQPClass::Connection(connection::AMQPMethod::Tune(tune))) => {
debug!(
"Received Connection.Tune (channel_max={}, frame_max={}, heartbeat={})",
tune.channel_max, tune.frame_max, tune.heartbeat
);
(tune.channel_max, tune.frame_max, tune.heartbeat)
}
_ => {
return Err(Error::Amqp(format!(
"Expected Connection.Tune, got: {:?}",
frame
)));
}
};
self.frame_max = if frame_max == 0 { 131072 } else { frame_max };
self.heartbeat_interval = heartbeat;
let tune_ok = connection::TuneOk {
channel_max: if channel_max == 0 { 2047 } else { channel_max },
frame_max: self.frame_max,
heartbeat: self.heartbeat_interval,
};
debug!("Sending Connection.TuneOk");
self.send_frame(&AMQPFrame::Method(
0,
AMQPClass::Connection(connection::AMQPMethod::TuneOk(tune_ok)),
))
.await?;
let open = connection::Open {
virtual_host: ShortString::from(vhost),
};
debug!("Sending Connection.Open (vhost={})", vhost);
self.send_frame(&AMQPFrame::Method(
0,
AMQPClass::Connection(connection::AMQPMethod::Open(open)),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(0, AMQPClass::Connection(connection::AMQPMethod::OpenOk(_))) => {
debug!("Received Connection.OpenOk");
}
_ => {
return Err(Error::Amqp(format!(
"Expected Connection.OpenOk, got: {:?}",
frame
)));
}
}
Ok(())
}
pub async fn open_channel(&mut self) -> Result<u16> {
let channel_id = self.next_channel_id;
self.next_channel_id += 1;
debug!("Opening channel {}", channel_id);
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Channel(channel::AMQPMethod::Open(channel::Open {})),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(ch, AMQPClass::Channel(channel::AMQPMethod::OpenOk(_)))
if *ch == channel_id =>
{
debug!("Channel {} opened", channel_id);
Ok(channel_id)
}
_ => Err(Error::Amqp(format!(
"Expected Channel.OpenOk for channel {}, got: {:?}",
channel_id, frame
))),
}
}
pub async fn send_frame(&mut self, frame: &AMQPFrame) -> Result<()> {
let buf = serialize_frame(frame)?;
self.stream
.write_all(&buf)
.await
.map_err(|e| Error::Connection(format!("Write failed: {}", e)))?;
self.stream
.flush()
.await
.map_err(|e| Error::Connection(format!("Flush failed: {}", e)))?;
self.last_heartbeat_sent = Instant::now();
trace!("Sent frame ({} bytes)", buf.len());
Ok(())
}
pub async fn read_frame(&mut self) -> Result<AMQPFrame> {
let heartbeat_timeout = if self.heartbeat_interval > 0 {
Duration::from_secs(self.heartbeat_interval as u64 / 2)
} else {
Duration::from_secs(30) };
loop {
if !self.read_buf.is_empty() {
match parse_frame(&self.read_buf[..]) {
Ok((remaining, frame)) => {
let consumed = self.read_buf.len() - remaining.len();
let _ = self.read_buf.split_to(consumed);
self.last_frame_received = Instant::now();
match &frame {
AMQPFrame::Heartbeat => {
trace!("Received heartbeat");
continue; }
_ => return Ok(frame),
}
}
Err(e) if format!("{:?}", e).contains("Incomplete") => {
}
Err(e) => {
return Err(Error::Amqp(format!("Frame parse error: {:?}", e)));
}
}
}
match tokio::time::timeout(heartbeat_timeout, self.stream.read_buf(&mut self.read_buf))
.await
{
Ok(Ok(0)) => {
return Err(Error::Connection("Connection closed by peer".to_string()));
}
Ok(Ok(n)) => {
trace!("Read {} bytes from stream", n);
}
Ok(Err(e)) => {
return Err(Error::Connection(format!("Read error: {}", e)));
}
Err(_) => {
if self.heartbeat_interval > 0 {
let since_last_sent = self.last_heartbeat_sent.elapsed();
if since_last_sent
>= Duration::from_secs(self.heartbeat_interval as u64 / 2)
{
trace!("Sending heartbeat (idle for {:?})", since_last_sent);
self.send_frame(&AMQPFrame::Heartbeat).await?;
}
let since_last_received = self.last_frame_received.elapsed();
if since_last_received
> Duration::from_secs(self.heartbeat_interval as u64 * 2)
{
return Err(Error::Connection(format!(
"Heartbeat timeout: no frame received for {:?}",
since_last_received
)));
}
}
}
}
}
}
pub async fn read_frame_timeout(&mut self, timeout: Duration) -> Result<Option<AMQPFrame>> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
if !self.read_buf.is_empty() {
match parse_frame(&self.read_buf[..]) {
Ok((remaining, frame)) => {
let consumed = self.read_buf.len() - remaining.len();
let _ = self.read_buf.split_to(consumed);
self.last_frame_received = Instant::now();
if matches!(&frame, AMQPFrame::Heartbeat) {
continue; }
return Ok(Some(frame));
}
Err(e) if format!("{:?}", e).contains("Incomplete") => {
}
Err(e) => {
return Err(Error::Amqp(format!("Frame parse error: {:?}", e)));
}
}
}
let now = tokio::time::Instant::now();
if now >= deadline {
return Ok(None);
}
match tokio::time::timeout(
deadline.saturating_duration_since(now),
self.stream.read_buf(&mut self.read_buf),
)
.await
{
Ok(Ok(0)) => return Err(Error::Connection("Connection closed".to_string())),
Ok(Ok(_)) => continue,
Ok(Err(e)) => return Err(Error::Connection(format!("Read error: {}", e))),
Err(_) => return Ok(None),
};
}
}
pub async fn close(&mut self) -> Result<()> {
debug!("Closing AMQP connection");
let close = connection::Close {
reply_code: 200,
reply_text: ShortString::from("Normal shutdown"),
class_id: 0,
method_id: 0,
};
self.send_frame(&AMQPFrame::Method(
0,
AMQPClass::Connection(connection::AMQPMethod::Close(close)),
))
.await?;
match tokio::time::timeout(Duration::from_secs(5), self.read_frame()).await {
Ok(Ok(AMQPFrame::Method(
0,
AMQPClass::Connection(connection::AMQPMethod::CloseOk(_)),
))) => {
debug!("Received Connection.CloseOk");
}
_ => {
warn!("Did not receive Connection.CloseOk");
}
}
Ok(())
}
pub async fn close_channel(&mut self, channel_id: u16) -> Result<()> {
debug!("Closing channel {}", channel_id);
let close = channel::Close {
reply_code: 200,
reply_text: ShortString::from("Normal shutdown"),
class_id: 0,
method_id: 0,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Channel(channel::AMQPMethod::Close(close)),
))
.await?;
match tokio::time::timeout(Duration::from_secs(5), self.read_frame()).await {
Ok(Ok(AMQPFrame::Method(ch, AMQPClass::Channel(channel::AMQPMethod::CloseOk(_)))))
if ch == channel_id =>
{
debug!("Channel {} closed", channel_id);
}
_ => {
warn!("Did not receive Channel.CloseOk for channel {}", channel_id);
}
}
Ok(())
}
pub async fn basic_qos(&mut self, channel_id: u16, prefetch_count: u16) -> Result<()> {
let qos = basic::Qos {
prefetch_count,
global: false,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Qos(qos)),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::QosOk(_)))
if *ch == channel_id =>
{
debug!("basic.qos set (prefetch_count={})", prefetch_count);
Ok(())
}
_ => Err(Error::Amqp(format!(
"Expected Basic.QosOk, got: {:?}",
frame
))),
}
}
pub async fn basic_consume(
&mut self,
channel_id: u16,
queue: &str,
consumer_tag: &str,
) -> Result<String> {
let consume = basic::Consume {
queue: ShortString::from(queue),
consumer_tag: ShortString::from(consumer_tag),
no_local: false,
no_ack: false,
exclusive: false,
nowait: false,
arguments: FieldTable::default(),
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Consume(consume)),
))
.await?;
let frame = self.read_frame().await?;
match frame {
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::ConsumeOk(ok)))
if ch == channel_id =>
{
let tag = ok.consumer_tag.to_string();
debug!("Consumer started (tag={})", tag);
Ok(tag)
}
_ => Err(Error::Amqp(format!(
"Expected Basic.ConsumeOk, got: {:?}",
frame
))),
}
}
pub async fn basic_cancel(&mut self, channel_id: u16, consumer_tag: &str) -> Result<()> {
let cancel = basic::Cancel {
consumer_tag: ShortString::from(consumer_tag),
nowait: false,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Cancel(cancel)),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::CancelOk(_)))
if *ch == channel_id =>
{
debug!("Consumer cancelled (tag={})", consumer_tag);
Ok(())
}
_ => Err(Error::Amqp(format!(
"Expected Basic.CancelOk, got: {:?}",
frame
))),
}
}
pub async fn basic_get(&mut self, channel_id: u16, queue: &str) -> Result<()> {
let get = basic::Get {
queue: ShortString::from(queue),
no_ack: false,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Get(get)),
))
.await
}
pub async fn basic_nack(
&mut self,
channel_id: u16,
delivery_tag: u64,
requeue: bool,
) -> Result<()> {
let nack = basic::Nack {
delivery_tag,
multiple: false,
requeue,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Nack(nack)),
))
.await
}
pub async fn confirm_select(&mut self, channel_id: u16) -> Result<()> {
let select = confirm::Select { nowait: false };
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Confirm(confirm::AMQPMethod::Select(select)),
))
.await?;
let frame = self.read_frame().await?;
match &frame {
AMQPFrame::Method(ch, AMQPClass::Confirm(confirm::AMQPMethod::SelectOk(_)))
if *ch == channel_id =>
{
self.next_publish_seq = 1; debug!("Publisher confirms enabled on channel {}", channel_id);
Ok(())
}
_ => Err(Error::Amqp(format!(
"Expected Confirm.SelectOk, got: {:?}",
frame
))),
}
}
pub async fn basic_publish(
&mut self,
channel_id: u16,
exchange: &str,
routing_key: &str,
mandatory: bool,
properties: &basic::AMQPProperties,
body: &[u8],
) -> Result<u64> {
let seq = self.next_publish_seq;
if seq > 0 {
self.next_publish_seq += 1;
}
let publish = basic::Publish {
exchange: ShortString::from(exchange),
routing_key: ShortString::from(routing_key),
mandatory,
immediate: false,
};
self.send_frame(&AMQPFrame::Method(
channel_id,
AMQPClass::Basic(basic::AMQPMethod::Publish(publish)),
))
.await?;
let header = AMQPFrame::Header(
channel_id,
amq_protocol::frame::AMQPContentHeader {
class_id: 60, body_size: body.len() as u64,
properties: properties.clone(),
},
);
self.send_frame(&header).await?;
for chunk in body.chunks(body_chunk_size(self.frame_max)) {
self.send_frame(&AMQPFrame::Body(channel_id, chunk.to_vec()))
.await?;
}
trace!(
"Published message to {}:{} (seq={}, {} bytes)",
exchange,
routing_key,
seq,
body.len()
);
Ok(seq)
}
pub async fn wait_for_confirms(
&mut self,
channel_id: u16,
up_to_tag: u64,
) -> Result<ConfirmStats> {
let mut confirmed_up_to = 0u64;
let mut stats = ConfirmStats::default();
let timeout = Duration::from_secs(30);
while confirmed_up_to < up_to_tag {
let frame = match self.read_frame_timeout(timeout).await? {
Some(f) => f,
None => {
return Err(Error::Amqp(format!(
"Timeout waiting for confirms (confirmed {}/{})",
confirmed_up_to, up_to_tag
)));
}
};
match frame {
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::Ack(ack)))
if ch == channel_id =>
{
if ack.multiple {
confirmed_up_to = ack.delivery_tag;
} else {
confirmed_up_to = confirmed_up_to.max(ack.delivery_tag);
}
trace!(
"Confirm ack: tag={}, multiple={}",
ack.delivery_tag,
ack.multiple
);
}
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::Nack(nack)))
if ch == channel_id =>
{
if nack.multiple {
let range = confirmed_up_to + 1..=nack.delivery_tag;
stats.nacked += range.count() as u64;
confirmed_up_to = nack.delivery_tag;
} else {
stats.nacked += 1;
confirmed_up_to = confirmed_up_to.max(nack.delivery_tag);
}
warn!(
"Confirm nack: tag={}, multiple={}",
nack.delivery_tag, nack.multiple
);
}
AMQPFrame::Method(ch, AMQPClass::Basic(basic::AMQPMethod::Return(ret)))
if ch == channel_id =>
{
stats.returned += 1;
warn!(
"Message returned by broker: code={}, text={}, exchange={}, routing_key={}",
ret.reply_code, ret.reply_text, ret.exchange, ret.routing_key
);
self.discard_returned_content(channel_id, timeout).await?;
}
_ => {
trace!("Ignoring non-confirm frame during wait_for_confirms");
}
}
}
Ok(stats)
}
async fn discard_returned_content(&mut self, channel_id: u16, timeout: Duration) -> Result<()> {
let header = match self.read_frame_timeout(timeout).await? {
Some(AMQPFrame::Header(ch, header)) if ch == channel_id => header,
Some(frame) => {
return Err(Error::Amqp(format!(
"Expected returned message header, got: {:?}",
frame
)));
}
None => {
return Err(Error::Amqp(
"Timeout waiting for returned message header".to_string(),
));
}
};
let mut remaining = header.body_size;
while remaining > 0 {
match self.read_frame_timeout(timeout).await? {
Some(AMQPFrame::Body(ch, body)) if ch == channel_id => {
remaining = remaining.saturating_sub(body.len() as u64);
}
Some(frame) => {
return Err(Error::Amqp(format!(
"Expected returned message body, got: {:?}",
frame
)));
}
None => {
return Err(Error::Amqp(
"Timeout waiting for returned message body".to_string(),
));
}
}
}
Ok(())
}
pub fn heartbeat_interval(&self) -> u16 {
self.heartbeat_interval
}
pub fn frame_max(&self) -> u32 {
self.frame_max
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sasl_plain_response() {
let mut response = Vec::new();
response.push(0);
response.extend_from_slice(b"guest");
response.push(0);
response.extend_from_slice(b"guest");
assert_eq!(response, b"\0guest\0guest");
}
#[test]
fn test_amqp_url_parsing() {
let uri: AMQPUri = "amqp://user:pass@localhost:5672/%2f".parse().unwrap();
assert_eq!(uri.authority.host, "localhost");
assert_eq!(uri.authority.port, 5672);
assert_eq!(uri.authority.userinfo.username, "user");
assert_eq!(uri.authority.userinfo.password, "pass");
assert_eq!(uri.vhost, "/");
}
#[test]
fn test_amqp_url_default_vhost() {
let uri: AMQPUri = "amqp://guest:guest@localhost".parse().unwrap();
assert_eq!(uri.vhost, "/");
assert_eq!(uri.authority.port, 5672);
}
#[test]
fn test_amqp_url_with_vhost() {
let url = amqp_url_with_vhost("amqp://guest:guest@localhost:5672/%2f", "/staging").unwrap();
let uri: AMQPUri = url.parse().unwrap();
assert_eq!(uri.vhost, "/staging");
assert_eq!(uri.authority.host, "localhost");
assert_eq!(uri.authority.port, 5672);
let url = amqp_url_with_vhost("amqp://guest:guest@localhost:5672/%2f", "tenant-a").unwrap();
let uri: AMQPUri = url.parse().unwrap();
assert_eq!(uri.vhost, "tenant-a");
let url = amqp_url_with_vhost("amqp://guest:guest@localhost:5672/tenant-a", "/").unwrap();
let uri: AMQPUri = url.parse().unwrap();
assert_eq!(uri.vhost, "/");
}
#[test]
fn test_frame_serialization_heartbeat() {
let frame = AMQPFrame::Heartbeat;
let buf = serialize_frame(&frame).unwrap();
assert!(!buf.is_empty());
let (remaining, parsed) = parse_frame(&buf[..]).unwrap();
assert!(remaining.is_empty());
assert!(matches!(parsed, AMQPFrame::Heartbeat));
}
#[test]
fn test_frame_serialization_large_body() {
let body = vec![42u8; 128 * 1024];
let frame = AMQPFrame::Body(1, body.clone());
let buf = serialize_frame(&frame).unwrap();
let (remaining, parsed) = parse_frame(&buf[..]).unwrap();
assert!(remaining.is_empty());
match parsed {
AMQPFrame::Body(ch, parsed_body) => {
assert_eq!(ch, 1);
assert_eq!(parsed_body, body);
}
other => panic!("Expected Body frame, got {:?}", other),
}
}
#[test]
fn test_body_chunk_size_respects_frame_overhead() {
assert_eq!(body_chunk_size(131_072), 131_064);
assert_eq!(body_chunk_size(8), 1);
assert_eq!(body_chunk_size(0), 131_064);
}
#[test]
fn test_frame_serialization_protocol_header() {
let frame = AMQPFrame::ProtocolHeader(amq_protocol::frame::ProtocolVersion::amqp_0_9_1());
let buf = serialize_frame(&frame).unwrap();
assert_eq!(buf.len(), 8);
assert_eq!(&buf[0..4], b"AMQP");
}
#[test]
fn test_frame_serialization_channel_open() {
let frame = AMQPFrame::Method(
1,
AMQPClass::Channel(channel::AMQPMethod::Open(channel::Open {})),
);
let buf = serialize_frame(&frame).unwrap();
assert!(!buf.is_empty());
let (remaining, parsed) = parse_frame(&buf[..]).unwrap();
assert!(remaining.is_empty());
match parsed {
AMQPFrame::Method(ch, AMQPClass::Channel(channel::AMQPMethod::Open(_))) => {
assert_eq!(ch, 1);
}
_ => panic!("Expected Channel.Open, got: {:?}", parsed),
}
}
#[test]
fn test_frame_serialization_basic_qos() {
let frame = AMQPFrame::Method(
1,
AMQPClass::Basic(basic::AMQPMethod::Qos(basic::Qos {
prefetch_count: 100,
global: false,
})),
);
let buf = serialize_frame(&frame).unwrap();
let (_, parsed) = parse_frame(&buf[..]).unwrap();
match parsed {
AMQPFrame::Method(1, AMQPClass::Basic(basic::AMQPMethod::Qos(qos))) => {
assert_eq!(qos.prefetch_count, 100);
assert!(!qos.global);
}
_ => panic!("Expected Basic.Qos"),
}
}
#[test]
fn test_frame_serialization_basic_nack() {
let frame = AMQPFrame::Method(
1,
AMQPClass::Basic(basic::AMQPMethod::Nack(basic::Nack {
delivery_tag: 42,
multiple: false,
requeue: true,
})),
);
let buf = serialize_frame(&frame).unwrap();
let (_, parsed) = parse_frame(&buf[..]).unwrap();
match parsed {
AMQPFrame::Method(1, AMQPClass::Basic(basic::AMQPMethod::Nack(nack))) => {
assert_eq!(nack.delivery_tag, 42);
assert!(!nack.multiple);
assert!(nack.requeue);
}
_ => panic!("Expected Basic.Nack"),
}
}
}