use std::collections::BTreeMap;
use std::fmt::Debug;
use std::future::Future;
use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use crate::error::{QueueError, QueueResult};
use crate::feature::QueueFeatures;
#[derive(Debug, Clone)]
pub struct QueueMessage {
pub topic: String,
pub payload: Bytes,
pub headers: BTreeMap<String, String>,
pub key: Option<String>,
pub partition: Option<i32>,
pub dead_letter: Option<DeadLetterTarget>,
pub attributes: serde_json::Value,
}
impl QueueMessage {
pub fn new(topic: impl Into<String>, payload: impl Into<Bytes>) -> Self {
Self {
topic: topic.into(),
payload: payload.into(),
headers: BTreeMap::new(),
key: None,
partition: None,
dead_letter: None,
attributes: serde_json::Value::Null,
}
}
pub fn with_key(mut self, key: impl Into<String>) -> Self {
self.key = Some(key.into());
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_partition(mut self, partition: i32) -> Self {
self.partition = Some(partition);
self
}
pub fn with_dead_letter(mut self, target: DeadLetterTarget) -> Self {
self.dead_letter = Some(target);
self
}
pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
self.attributes = attributes;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeadLetterTarget {
pub topic: String,
pub routing_key: Option<String>,
}
impl DeadLetterTarget {
pub fn new(topic: impl Into<String>) -> Self {
Self { topic: topic.into(), routing_key: None }
}
pub fn with_routing_key(mut self, routing_key: impl Into<String>) -> Self {
self.routing_key = Some(routing_key.into());
self
}
}
#[derive(Debug, Clone)]
pub struct QueueReceipt {
pub token: Bytes,
pub attributes: serde_json::Value,
}
impl QueueReceipt {
pub fn new(token: impl Into<Bytes>) -> Self {
Self { token: token.into(), attributes: serde_json::Value::Null }
}
pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
self.attributes = attributes;
self
}
}
#[derive(Debug, Clone)]
pub struct QueueDelivery {
pub message: QueueMessage,
pub receipt: QueueReceipt,
}
impl QueueDelivery {
pub fn new(message: QueueMessage, receipt: QueueReceipt) -> Self {
Self { message, receipt }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NackAction {
Requeue,
DeadLetter,
Drop,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConsumeAction {
Ack,
Nack(NackAction),
}
#[async_trait]
pub trait Queue: Send + Sync + Debug + 'static {
fn name(&self) -> &'static str;
fn features(&self) -> QueueFeatures {
QueueFeatures::default()
}
async fn publish(&self, message: QueueMessage) -> QueueResult<()>;
async fn publish_many(&self, items: Vec<QueueMessage>) -> QueueResult<()> {
default_batch::publish_many(self, items).await
}
async fn receive(&self) -> QueueResult<QueueDelivery>;
async fn receive_many(&self, max: usize) -> QueueResult<Vec<QueueDelivery>> {
default_batch::receive_many(self, max).await
}
async fn ack(&self, receipt: QueueReceipt) -> QueueResult<()>;
async fn ack_many(&self, receipts: Vec<QueueReceipt>) -> QueueResult<()> {
default_batch::ack_many(self, receipts).await
}
async fn nack(&self, _receipt: QueueReceipt, action: NackAction) -> QueueResult<()> {
Err(QueueError::unsupported(self.name(), action.as_str()))
}
async fn nack_many(&self, receipts: Vec<QueueReceipt>, action: NackAction) -> QueueResult<()> {
default_batch::nack_many(self, receipts, action).await
}
async fn ping(&self) -> QueueResult<()> {
Ok(())
}
}
impl NackAction {
pub fn as_str(&self) -> &'static str {
match self {
NackAction::Requeue => "nack_requeue",
NackAction::DeadLetter => "nack_dead_letter",
NackAction::Drop => "nack_drop",
}
}
}
pub async fn consume_once<Q, F, Fut>(queue: &Q, handler: F) -> QueueResult<()>
where
Q: Queue + ?Sized,
F: FnOnce(QueueDelivery) -> Fut + Send,
Fut: Future<Output = QueueResult<ConsumeAction>> + Send,
{
let delivery = queue.receive().await?;
match handler(delivery.clone()).await? {
ConsumeAction::Ack => queue.ack(delivery.receipt).await,
ConsumeAction::Nack(action) => queue.nack(delivery.receipt, action).await,
}
}
pub mod default_batch {
use super::*;
pub async fn publish_many<Q: Queue + ?Sized>(queue: &Q, items: Vec<QueueMessage>) -> QueueResult<()> {
for item in items {
queue.publish(item).await?;
}
Ok(())
}
pub async fn receive_many<Q: Queue + ?Sized>(queue: &Q, max: usize) -> QueueResult<Vec<QueueDelivery>> {
let mut out = Vec::with_capacity(max);
for _ in 0..max {
out.push(queue.receive().await?);
}
Ok(out)
}
pub async fn ack_many<Q: Queue + ?Sized>(queue: &Q, receipts: Vec<QueueReceipt>) -> QueueResult<()> {
for receipt in receipts {
queue.ack(receipt).await?;
}
Ok(())
}
pub async fn nack_many<Q: Queue + ?Sized>(
queue: &Q,
receipts: Vec<QueueReceipt>,
action: NackAction,
) -> QueueResult<()> {
for receipt in receipts {
queue.nack(receipt, action).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::feature::QueueFeature;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Default)]
struct CountingQueue {
publishes: AtomicUsize,
acks: AtomicUsize,
}
#[async_trait]
impl Queue for CountingQueue {
fn name(&self) -> &'static str {
"counting"
}
async fn publish(&self, _: QueueMessage) -> QueueResult<()> {
self.publishes.fetch_add(1, Ordering::Relaxed);
Ok(())
}
async fn receive(&self) -> QueueResult<QueueDelivery> {
Ok(QueueDelivery::new(QueueMessage::new("topic", Bytes::from_static(b"p")), QueueReceipt::new("r")))
}
async fn ack(&self, _: QueueReceipt) -> QueueResult<()> {
self.acks.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
#[tokio::test]
async fn default_batch_publish_and_ack() {
let q = CountingQueue::default();
default_batch::publish_many(&q, vec![QueueMessage::new("t", Bytes::from_static(b"a"))]).await.unwrap();
assert_eq!(q.publishes.load(Ordering::Relaxed), 1);
let delivery = q.receive().await.unwrap();
q.ack(delivery.receipt).await.unwrap();
assert_eq!(q.acks.load(Ordering::Relaxed), 1);
}
#[test]
fn features_default_is_empty() {
assert!(QueueFeatures::default().is_empty());
assert!(!QueueFeatures::new([QueueFeature::Headers]).is_empty());
}
#[test]
fn message_partition_and_dlt_builders() {
let dlt = DeadLetterTarget::new("dlt-topic").with_routing_key("rk");
let msg = QueueMessage::new("topic", Bytes::from_static(b"p"))
.with_partition(2)
.with_dead_letter(dlt.clone());
assert_eq!(msg.partition, Some(2));
assert_eq!(msg.dead_letter.as_ref().map(|d| d.topic.as_str()), Some("dlt-topic"));
assert_eq!(msg.dead_letter.as_ref().and_then(|d| d.routing_key.as_deref()), Some("rk"));
}
#[tokio::test]
async fn consume_once_ack_path() {
#[derive(Debug, Default)]
struct TestQueue {
acks: AtomicUsize,
}
#[async_trait]
impl Queue for TestQueue {
fn name(&self) -> &'static str {
"test"
}
async fn publish(&self, _: QueueMessage) -> QueueResult<()> {
Ok(())
}
async fn receive(&self) -> QueueResult<QueueDelivery> {
Ok(QueueDelivery::new(
QueueMessage::new("topic", Bytes::from_static(b"payload")),
QueueReceipt::new("receipt"),
))
}
async fn ack(&self, _: QueueReceipt) -> QueueResult<()> {
self.acks.fetch_add(1, Ordering::Relaxed);
Ok(())
}
}
let queue = TestQueue::default();
consume_once(&queue, |_delivery| async { Ok(ConsumeAction::Ack) }).await.unwrap();
assert_eq!(queue.acks.load(Ordering::Relaxed), 1);
}
}