use std::time::Duration;
use tracing::{debug, trace, warn};
use crate::amqp::{
AMQPClass, AMQPContentHeader, AMQPFrame, AMQPValue, AmqpClient, BasicMethod, FieldTable,
};
use crate::error::Result;
use crate::manifest::{BackupHeaderValue, BackupProperties, BackupRecord};
pub struct QueueReader {
pub client: AmqpClient,
pub channel_id: u16,
pub queue_name: String,
pub vhost: String,
}
impl QueueReader {
pub fn new(client: AmqpClient, channel_id: u16, queue_name: String, vhost: String) -> Self {
Self {
client,
channel_id,
queue_name,
vhost,
}
}
pub async fn start_consume(&mut self, prefetch_count: u16) -> Result<String> {
self.client
.basic_qos(self.channel_id, prefetch_count)
.await?;
let consumer_tag = format!("rmq-backup-{}", uuid::Uuid::new_v4());
self.client
.basic_consume(self.channel_id, &self.queue_name, &consumer_tag)
.await?;
debug!(
"Consumer started on {} (tag={}, prefetch={})",
self.queue_name, consumer_tag, prefetch_count
);
Ok(consumer_tag)
}
pub async fn cancel_consume(&mut self, consumer_tag: &str) -> Result<()> {
self.client
.basic_cancel(self.channel_id, consumer_tag)
.await?;
debug!("Consumer cancelled, all unacked messages requeued");
Ok(())
}
pub async fn read_next(
&mut self,
assembler: &mut MessageAssembler,
timeout: Duration,
) -> Result<Option<BackupRecord>> {
let frame = match self.client.read_frame_timeout(timeout).await? {
Some(f) => f,
None => return Ok(None), };
assembler.process_frame(frame, &self.queue_name, &self.vhost)
}
pub async fn close(mut self) {
self.client.close_channel(self.channel_id).await.ok();
self.client.close().await.ok();
}
}
pub struct MessageAssembler {
state: AssemblyState,
body_buf: Vec<u8>,
}
enum AssemblyState {
WaitingForDeliver,
GotDeliver {
delivery_tag: u64,
redelivered: bool,
exchange: String,
routing_key: String,
},
GotHeader {
delivery_tag: u64,
redelivered: bool,
exchange: String,
routing_key: String,
header: Box<AMQPContentHeader>,
body_size: u64,
},
}
impl MessageAssembler {
pub fn new() -> Self {
Self {
state: AssemblyState::WaitingForDeliver,
body_buf: Vec::with_capacity(4096),
}
}
pub fn process_frame(
&mut self,
frame: AMQPFrame,
queue_name: &str,
vhost: &str,
) -> Result<Option<BackupRecord>> {
match frame {
AMQPFrame::Method(_, AMQPClass::Basic(BasicMethod::Deliver(deliver))) => {
trace!(
"Deliver: tag={}, exchange={}, routing_key={}",
deliver.delivery_tag,
deliver.exchange,
deliver.routing_key
);
self.state = AssemblyState::GotDeliver {
delivery_tag: deliver.delivery_tag,
redelivered: deliver.redelivered,
exchange: deliver.exchange.to_string(),
routing_key: deliver.routing_key.to_string(),
};
Ok(None)
}
AMQPFrame::Header(_, header) => {
if let AssemblyState::GotDeliver {
delivery_tag,
redelivered,
exchange,
routing_key,
} = std::mem::replace(&mut self.state, AssemblyState::WaitingForDeliver)
{
let body_size = header.body_size;
if body_size == 0 {
let record = self.build_record(
delivery_tag,
redelivered,
exchange,
routing_key,
&header,
None,
queue_name,
vhost,
);
self.state = AssemblyState::WaitingForDeliver;
return Ok(Some(record));
}
self.body_buf.clear();
self.state = AssemblyState::GotHeader {
delivery_tag,
redelivered,
exchange,
routing_key,
header: Box::new(header),
body_size,
};
} else {
warn!("Received Header frame without preceding Deliver");
}
Ok(None)
}
AMQPFrame::Body(_, data) => {
self.body_buf.extend_from_slice(&data);
if let AssemblyState::GotHeader { body_size, .. } = &self.state {
if self.body_buf.len() as u64 >= *body_size {
if let AssemblyState::GotHeader {
delivery_tag,
redelivered,
exchange,
routing_key,
header,
..
} = std::mem::replace(&mut self.state, AssemblyState::WaitingForDeliver)
{
let body = self.body_buf.clone();
self.body_buf.clear();
let record = self.build_record(
delivery_tag,
redelivered,
exchange,
routing_key,
&header,
Some(body),
queue_name,
vhost,
);
return Ok(Some(record));
}
}
} else {
warn!("Received Body frame without preceding Header");
}
Ok(None)
}
_ => Ok(None),
}
}
#[allow(clippy::too_many_arguments)]
fn build_record(
&self,
delivery_tag: u64,
redelivered: bool,
exchange: String,
routing_key: String,
header: &AMQPContentHeader,
body: Option<Vec<u8>>,
queue_name: &str,
vhost: &str,
) -> BackupRecord {
let properties = convert_properties(header);
let headers = header
.properties
.headers()
.as_ref()
.map(convert_headers)
.unwrap_or_default();
BackupRecord {
body,
properties,
headers,
exchange,
routing_key,
delivery_tag,
redelivered,
backed_up_at: chrono::Utc::now().timestamp_millis(),
source_queue: queue_name.to_string(),
source_vhost: vhost.to_string(),
}
}
}
impl Default for MessageAssembler {
fn default() -> Self {
Self::new()
}
}
pub fn convert_properties(header: &AMQPContentHeader) -> BackupProperties {
let p = &header.properties;
BackupProperties {
content_type: p.content_type().as_ref().map(|s| s.to_string()),
content_encoding: p.content_encoding().as_ref().map(|s| s.to_string()),
delivery_mode: *p.delivery_mode(),
priority: *p.priority(),
correlation_id: p.correlation_id().as_ref().map(|s| s.to_string()),
reply_to: p.reply_to().as_ref().map(|s| s.to_string()),
expiration: p.expiration().as_ref().map(|s| s.to_string()),
message_id: p.message_id().as_ref().map(|s| s.to_string()),
timestamp: p.timestamp().map(|t| t as i64),
type_field: p.kind().as_ref().map(|s| s.to_string()),
user_id: p.user_id().as_ref().map(|s| s.to_string()),
app_id: p.app_id().as_ref().map(|s| s.to_string()),
cluster_id: p.cluster_id().as_ref().map(|s| s.to_string()),
}
}
pub fn convert_headers(table: &FieldTable) -> Vec<(String, BackupHeaderValue)> {
table
.inner()
.iter()
.map(|(k, v)| (k.to_string(), convert_amqp_value(v)))
.collect()
}
pub fn convert_amqp_value(value: &AMQPValue) -> BackupHeaderValue {
match value {
AMQPValue::Boolean(b) => BackupHeaderValue::Bool(*b),
AMQPValue::ShortShortInt(i) => BackupHeaderValue::ShortShortInt(*i),
AMQPValue::ShortShortUInt(u) => BackupHeaderValue::ShortShortUInt(*u),
AMQPValue::ShortInt(i) => BackupHeaderValue::ShortInt(*i),
AMQPValue::ShortUInt(u) => BackupHeaderValue::ShortUInt(*u),
AMQPValue::LongInt(i) => BackupHeaderValue::LongInt(*i),
AMQPValue::LongUInt(u) => BackupHeaderValue::LongUInt(*u),
AMQPValue::LongLongInt(i) => BackupHeaderValue::LongLongInt(*i),
AMQPValue::Float(f) => BackupHeaderValue::Float(*f),
AMQPValue::Double(d) => BackupHeaderValue::Double(*d),
AMQPValue::ShortString(s) => BackupHeaderValue::ShortString(s.to_string()),
AMQPValue::LongString(s) => BackupHeaderValue::LongStringBytes(s.as_bytes().to_vec()),
AMQPValue::Timestamp(t) => BackupHeaderValue::Timestamp(*t as i64),
AMQPValue::FieldTable(t) => BackupHeaderValue::Table(convert_headers(t)),
AMQPValue::FieldArray(a) => {
BackupHeaderValue::Array(a.as_slice().iter().map(convert_amqp_value).collect())
}
AMQPValue::ByteArray(b) => BackupHeaderValue::Bytes(b.as_slice().to_vec()),
AMQPValue::Void => BackupHeaderValue::Void,
AMQPValue::DecimalValue(d) => BackupHeaderValue::Decimal {
scale: d.scale,
value: d.value,
},
}
}
pub fn to_amqp_properties(
props: &BackupProperties,
) -> amq_protocol::protocol::basic::AMQPProperties {
use amq_protocol::types::ShortString;
let mut p = amq_protocol::protocol::basic::AMQPProperties::default();
if let Some(ref v) = props.content_type {
p = p.with_content_type(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.content_encoding {
p = p.with_content_encoding(ShortString::from(v.as_str()));
}
if let Some(v) = props.delivery_mode {
p = p.with_delivery_mode(v);
}
if let Some(v) = props.priority {
p = p.with_priority(v);
}
if let Some(ref v) = props.correlation_id {
p = p.with_correlation_id(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.reply_to {
p = p.with_reply_to(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.expiration {
p = p.with_expiration(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.message_id {
p = p.with_message_id(ShortString::from(v.as_str()));
}
if let Some(v) = props.timestamp {
p = p.with_timestamp(v as u64);
}
if let Some(ref v) = props.type_field {
p = p.with_type(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.user_id {
p = p.with_user_id(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.app_id {
p = p.with_app_id(ShortString::from(v.as_str()));
}
if let Some(ref v) = props.cluster_id {
p = p.with_cluster_id(ShortString::from(v.as_str()));
}
p
}
pub fn to_field_table(headers: &[(String, BackupHeaderValue)]) -> FieldTable {
let mut table = FieldTable::default();
for (key, value) in headers {
table.insert(key.as_str().into(), to_amqp_value(value));
}
table
}
pub fn to_amqp_value(value: &BackupHeaderValue) -> AMQPValue {
match value {
BackupHeaderValue::Bool(b) => AMQPValue::Boolean(*b),
BackupHeaderValue::Short(i) => AMQPValue::ShortInt(*i),
BackupHeaderValue::Long(i) => AMQPValue::LongLongInt(*i),
BackupHeaderValue::ShortShortInt(i) => AMQPValue::ShortShortInt(*i),
BackupHeaderValue::ShortShortUInt(u) => AMQPValue::ShortShortUInt(*u),
BackupHeaderValue::ShortInt(i) => AMQPValue::ShortInt(*i),
BackupHeaderValue::ShortUInt(u) => AMQPValue::ShortUInt(*u),
BackupHeaderValue::LongInt(i) => AMQPValue::LongInt(*i),
BackupHeaderValue::LongUInt(u) => AMQPValue::LongUInt(*u),
BackupHeaderValue::LongLongInt(i) => AMQPValue::LongLongInt(*i),
BackupHeaderValue::Float(f) => AMQPValue::Float(*f),
BackupHeaderValue::Double(d) => AMQPValue::Double(*d),
BackupHeaderValue::ShortString(s) => AMQPValue::ShortString(s.as_str().into()),
BackupHeaderValue::LongString(s) => AMQPValue::LongString(s.as_bytes().to_vec().into()),
BackupHeaderValue::LongStringBytes(b) => AMQPValue::LongString(b.clone().into()),
BackupHeaderValue::Timestamp(t) => AMQPValue::Timestamp(*t as u64),
BackupHeaderValue::Table(entries) => AMQPValue::FieldTable(to_field_table(entries)),
BackupHeaderValue::Array(items) => {
AMQPValue::FieldArray(items.iter().map(to_amqp_value).collect::<Vec<_>>().into())
}
BackupHeaderValue::Bytes(b) => AMQPValue::ByteArray(b.clone().into()),
BackupHeaderValue::Decimal { scale, value } => {
AMQPValue::DecimalValue(amq_protocol::types::DecimalValue {
scale: *scale,
value: *value,
})
}
BackupHeaderValue::Void => AMQPValue::Void,
}
}
#[cfg(test)]
mod tests {
use super::*;
use amq_protocol::protocol::basic::{self as amqp_basic, AMQPProperties};
use amq_protocol::types::ShortString;
#[test]
fn test_convert_properties_all_fields() {
let mut props = AMQPProperties::default();
props = props
.with_content_type(ShortString::from("application/json"))
.with_content_encoding(ShortString::from("utf-8"))
.with_delivery_mode(2)
.with_priority(5)
.with_correlation_id(ShortString::from("corr-123"))
.with_reply_to(ShortString::from("reply-queue"))
.with_expiration(ShortString::from("60000"))
.with_message_id(ShortString::from("msg-456"))
.with_timestamp(1700000000)
.with_type(ShortString::from("order.created"))
.with_app_id(ShortString::from("my-app"));
let header = AMQPContentHeader {
class_id: 60,
body_size: 100,
properties: props,
};
let converted = convert_properties(&header);
assert_eq!(converted.content_type, Some("application/json".to_string()));
assert_eq!(converted.content_encoding, Some("utf-8".to_string()));
assert_eq!(converted.delivery_mode, Some(2));
assert_eq!(converted.priority, Some(5));
assert_eq!(converted.correlation_id, Some("corr-123".to_string()));
assert_eq!(converted.reply_to, Some("reply-queue".to_string()));
assert_eq!(converted.expiration, Some("60000".to_string()));
assert_eq!(converted.message_id, Some("msg-456".to_string()));
assert_eq!(converted.timestamp, Some(1700000000));
assert_eq!(converted.type_field, Some("order.created".to_string()));
assert_eq!(converted.app_id, Some("my-app".to_string()));
}
#[test]
fn test_convert_properties_empty() {
let header = AMQPContentHeader {
class_id: 60,
body_size: 0,
properties: AMQPProperties::default(),
};
let converted = convert_properties(&header);
assert!(converted.content_type.is_none());
assert!(converted.delivery_mode.is_none());
assert!(converted.timestamp.is_none());
}
#[test]
fn test_convert_headers_basic_types() {
let mut table = FieldTable::default();
table.insert(
"string-key".into(),
AMQPValue::LongString("hello".as_bytes().to_vec().into()),
);
table.insert("int-key".into(), AMQPValue::LongInt(42));
table.insert("bool-key".into(), AMQPValue::Boolean(true));
let converted = convert_headers(&table);
assert_eq!(converted.len(), 3);
let find = |name: &str| converted.iter().find(|(k, _)| k == name).map(|(_, v)| v);
assert!(matches!(
find("bool-key"),
Some(BackupHeaderValue::Bool(true))
));
assert!(matches!(
find("int-key"),
Some(BackupHeaderValue::LongInt(42))
));
assert!(matches!(
find("string-key"),
Some(BackupHeaderValue::LongStringBytes(s)) if s == b"hello"
));
}
#[test]
fn test_header_long_string_binary_roundtrip() {
let value = AMQPValue::LongString(vec![0, 159, 146, 150, 255].into());
let backup_value = convert_amqp_value(&value);
assert!(matches!(
&backup_value,
BackupHeaderValue::LongStringBytes(bytes)
if bytes == &[0, 159, 146, 150, 255]
));
match to_amqp_value(&backup_value) {
AMQPValue::LongString(restored) => {
assert_eq!(restored.as_bytes(), &[0, 159, 146, 150, 255]);
}
other => panic!("Expected LongString, got {:?}", other),
}
}
#[test]
fn test_header_decimal_roundtrip() {
let value = AMQPValue::DecimalValue(amq_protocol::types::DecimalValue {
scale: 2,
value: 1234,
});
let backup_value = convert_amqp_value(&value);
assert!(matches!(
&backup_value,
BackupHeaderValue::Decimal {
scale: 2,
value: 1234
}
));
match to_amqp_value(&backup_value) {
AMQPValue::DecimalValue(restored) => {
assert_eq!(restored.scale, 2);
assert_eq!(restored.value, 1234);
}
other => panic!("Expected DecimalValue, got {:?}", other),
}
}
#[test]
fn test_convert_headers_nested_table() {
let mut inner = FieldTable::default();
inner.insert("nested".into(), AMQPValue::Boolean(true));
let mut table = FieldTable::default();
table.insert("outer".into(), AMQPValue::FieldTable(inner));
let converted = convert_headers(&table);
assert_eq!(converted.len(), 1);
match &converted[0].1 {
BackupHeaderValue::Table(inner_vec) => {
assert_eq!(inner_vec.len(), 1);
assert_eq!(inner_vec[0].0, "nested");
assert!(matches!(inner_vec[0].1, BackupHeaderValue::Bool(true)));
}
other => panic!("Expected Table, got {:?}", other),
}
}
#[test]
fn test_message_assembler_basic() {
let mut assembler = MessageAssembler::new();
let deliver = amqp_basic::Deliver {
consumer_tag: "tag".into(),
delivery_tag: 1,
redelivered: false,
exchange: "test-exchange".into(),
routing_key: "test.key".into(),
};
let frame = AMQPFrame::Method(1, AMQPClass::Basic(BasicMethod::Deliver(deliver)));
assert!(assembler
.process_frame(frame, "test-queue", "/")
.unwrap()
.is_none());
let header = AMQPContentHeader {
class_id: 60,
body_size: 5,
properties: AMQPProperties::default()
.with_content_type(ShortString::from("text/plain")),
};
let frame = AMQPFrame::Header(1, header);
assert!(assembler
.process_frame(frame, "test-queue", "/")
.unwrap()
.is_none());
let frame = AMQPFrame::Body(1, b"hello".to_vec());
let record = assembler
.process_frame(frame, "test-queue", "/")
.unwrap()
.expect("Should have assembled a complete record");
assert_eq!(record.body, Some(b"hello".to_vec()));
assert_eq!(record.exchange, "test-exchange");
assert_eq!(record.routing_key, "test.key");
assert_eq!(record.delivery_tag, 1);
assert!(!record.redelivered);
assert_eq!(record.source_queue, "test-queue");
assert_eq!(record.source_vhost, "/");
assert_eq!(
record.properties.content_type,
Some("text/plain".to_string())
);
}
#[test]
fn test_message_assembler_empty_body() {
let mut assembler = MessageAssembler::new();
let deliver = amqp_basic::Deliver {
consumer_tag: "tag".into(),
delivery_tag: 1,
redelivered: false,
exchange: "".into(),
routing_key: "".into(),
};
let frame = AMQPFrame::Method(1, AMQPClass::Basic(BasicMethod::Deliver(deliver)));
assert!(assembler.process_frame(frame, "q", "/").unwrap().is_none());
let header = AMQPContentHeader {
class_id: 60,
body_size: 0,
properties: AMQPProperties::default(),
};
let frame = AMQPFrame::Header(1, header);
let record = assembler
.process_frame(frame, "q", "/")
.unwrap()
.expect("Should emit record for zero-length body");
assert!(record.body.is_none());
}
#[test]
fn test_message_assembler_multi_body() {
let mut assembler = MessageAssembler::new();
let deliver = amqp_basic::Deliver {
consumer_tag: "tag".into(),
delivery_tag: 1,
redelivered: false,
exchange: "".into(),
routing_key: "".into(),
};
let frame = AMQPFrame::Method(1, AMQPClass::Basic(BasicMethod::Deliver(deliver)));
assembler.process_frame(frame, "q", "/").unwrap();
let header = AMQPContentHeader {
class_id: 60,
body_size: 10,
properties: AMQPProperties::default(),
};
let frame = AMQPFrame::Header(1, header);
assembler.process_frame(frame, "q", "/").unwrap();
let frame = AMQPFrame::Body(1, b"hello".to_vec());
assert!(assembler.process_frame(frame, "q", "/").unwrap().is_none());
let frame = AMQPFrame::Body(1, b"world".to_vec());
let record = assembler
.process_frame(frame, "q", "/")
.unwrap()
.expect("Should complete after 10 bytes");
assert_eq!(record.body, Some(b"helloworld".to_vec()));
}
}