diidi-travel-common-queue 0.1.16

A collection of common utilities and types for the DiiDi project.
Documentation
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::future::Future;

use async_trait::async_trait;
use bytes::Bytes;
use serde::{Deserialize, Serialize};

use crate::error::{QueueError, QueueResult};
use crate::feature::QueueFeatures;

#[derive(Debug, Clone)]
pub struct QueueMessage {
  pub topic: String,
  pub payload: Bytes,
  pub headers: BTreeMap<String, String>,
  pub key: Option<String>,
  pub partition: Option<i32>,
  pub dead_letter: Option<DeadLetterTarget>,
  pub attributes: serde_json::Value,
}

impl QueueMessage {
  pub fn new(topic: impl Into<String>, payload: impl Into<Bytes>) -> Self {
    Self {
      topic: topic.into(),
      payload: payload.into(),
      headers: BTreeMap::new(),
      key: None,
      partition: None,
      dead_letter: None,
      attributes: serde_json::Value::Null,
    }
  }

  pub fn with_key(mut self, key: impl Into<String>) -> Self {
    self.key = Some(key.into());
    self
  }

  pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
    self.headers.insert(key.into(), value.into());
    self
  }

  pub fn with_partition(mut self, partition: i32) -> Self {
    self.partition = Some(partition);
    self
  }

  pub fn with_dead_letter(mut self, target: DeadLetterTarget) -> Self {
    self.dead_letter = Some(target);
    self
  }

  pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
    self.attributes = attributes;
    self
  }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeadLetterTarget {
  pub topic: String,
  pub routing_key: Option<String>,
}

impl DeadLetterTarget {
  pub fn new(topic: impl Into<String>) -> Self {
    Self { topic: topic.into(), routing_key: None }
  }

  pub fn with_routing_key(mut self, routing_key: impl Into<String>) -> Self {
    self.routing_key = Some(routing_key.into());
    self
  }
}

#[derive(Debug, Clone)]
pub struct QueueReceipt {
  pub token: Bytes,
  pub attributes: serde_json::Value,
}

impl QueueReceipt {
  pub fn new(token: impl Into<Bytes>) -> Self {
    Self { token: token.into(), attributes: serde_json::Value::Null }
  }

  pub fn with_attributes(mut self, attributes: serde_json::Value) -> Self {
    self.attributes = attributes;
    self
  }
}

#[derive(Debug, Clone)]
pub struct QueueDelivery {
  pub message: QueueMessage,
  pub receipt: QueueReceipt,
}

impl QueueDelivery {
  pub fn new(message: QueueMessage, receipt: QueueReceipt) -> Self {
    Self { message, receipt }
  }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NackAction {
  Requeue,
  DeadLetter,
  Drop,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConsumeAction {
  Ack,
  Nack(NackAction),
}

#[async_trait]
pub trait Queue: Send + Sync + Debug + 'static {
  fn name(&self) -> &'static str;

  fn features(&self) -> QueueFeatures {
    QueueFeatures::default()
  }

  async fn publish(&self, message: QueueMessage) -> QueueResult<()>;

  async fn publish_many(&self, items: Vec<QueueMessage>) -> QueueResult<()> {
    default_batch::publish_many(self, items).await
  }

  async fn receive(&self) -> QueueResult<QueueDelivery>;

  async fn receive_many(&self, max: usize) -> QueueResult<Vec<QueueDelivery>> {
    default_batch::receive_many(self, max).await
  }

  async fn ack(&self, receipt: QueueReceipt) -> QueueResult<()>;

  async fn ack_many(&self, receipts: Vec<QueueReceipt>) -> QueueResult<()> {
    default_batch::ack_many(self, receipts).await
  }

  async fn nack(&self, _receipt: QueueReceipt, action: NackAction) -> QueueResult<()> {
    Err(QueueError::unsupported(self.name(), action.as_str()))
  }

  async fn nack_many(&self, receipts: Vec<QueueReceipt>, action: NackAction) -> QueueResult<()> {
    default_batch::nack_many(self, receipts, action).await
  }

  async fn ping(&self) -> QueueResult<()> {
    Ok(())
  }
}

impl NackAction {
  pub fn as_str(&self) -> &'static str {
    match self {
      NackAction::Requeue => "nack_requeue",
      NackAction::DeadLetter => "nack_dead_letter",
      NackAction::Drop => "nack_drop",
    }
  }
}

pub async fn consume_once<Q, F, Fut>(queue: &Q, handler: F) -> QueueResult<()>
where
  Q: Queue + ?Sized,
  F: FnOnce(QueueDelivery) -> Fut + Send,
  Fut: Future<Output = QueueResult<ConsumeAction>> + Send,
{
  let delivery = queue.receive().await?;
  match handler(delivery.clone()).await? {
    ConsumeAction::Ack => queue.ack(delivery.receipt).await,
    ConsumeAction::Nack(action) => queue.nack(delivery.receipt, action).await,
  }
}

pub mod default_batch {
  use super::*;

  pub async fn publish_many<Q: Queue + ?Sized>(queue: &Q, items: Vec<QueueMessage>) -> QueueResult<()> {
    for item in items {
      queue.publish(item).await?;
    }
    Ok(())
  }

  pub async fn receive_many<Q: Queue + ?Sized>(queue: &Q, max: usize) -> QueueResult<Vec<QueueDelivery>> {
    let mut out = Vec::with_capacity(max);
    for _ in 0..max {
      out.push(queue.receive().await?);
    }
    Ok(out)
  }

  pub async fn ack_many<Q: Queue + ?Sized>(queue: &Q, receipts: Vec<QueueReceipt>) -> QueueResult<()> {
    for receipt in receipts {
      queue.ack(receipt).await?;
    }
    Ok(())
  }

  pub async fn nack_many<Q: Queue + ?Sized>(
    queue: &Q,
    receipts: Vec<QueueReceipt>,
    action: NackAction,
  ) -> QueueResult<()> {
    for receipt in receipts {
      queue.nack(receipt, action).await?;
    }
    Ok(())
  }
}

#[cfg(test)]
mod tests {
  use super::*;
  use crate::feature::QueueFeature;
  use std::sync::atomic::{AtomicUsize, Ordering};

  #[derive(Debug, Default)]
  struct CountingQueue {
    publishes: AtomicUsize,
    acks: AtomicUsize,
  }

  #[async_trait]
  impl Queue for CountingQueue {
    fn name(&self) -> &'static str {
      "counting"
    }

    async fn publish(&self, _: QueueMessage) -> QueueResult<()> {
      self.publishes.fetch_add(1, Ordering::Relaxed);
      Ok(())
    }

    async fn receive(&self) -> QueueResult<QueueDelivery> {
      Ok(QueueDelivery::new(QueueMessage::new("topic", Bytes::from_static(b"p")), QueueReceipt::new("r")))
    }

    async fn ack(&self, _: QueueReceipt) -> QueueResult<()> {
      self.acks.fetch_add(1, Ordering::Relaxed);
      Ok(())
    }
  }

  #[tokio::test]
  async fn default_batch_publish_and_ack() {
    let q = CountingQueue::default();
    default_batch::publish_many(&q, vec![QueueMessage::new("t", Bytes::from_static(b"a"))]).await.unwrap();
    assert_eq!(q.publishes.load(Ordering::Relaxed), 1);
    let delivery = q.receive().await.unwrap();
    q.ack(delivery.receipt).await.unwrap();
    assert_eq!(q.acks.load(Ordering::Relaxed), 1);
  }

  #[test]
  fn features_default_is_empty() {
    assert!(QueueFeatures::default().is_empty());
    assert!(!QueueFeatures::new([QueueFeature::Headers]).is_empty());
  }

  #[test]
  fn message_partition_and_dlt_builders() {
    let dlt = DeadLetterTarget::new("dlt-topic").with_routing_key("rk");
    let msg = QueueMessage::new("topic", Bytes::from_static(b"p"))
      .with_partition(2)
      .with_dead_letter(dlt.clone());
    assert_eq!(msg.partition, Some(2));
    assert_eq!(msg.dead_letter.as_ref().map(|d| d.topic.as_str()), Some("dlt-topic"));
    assert_eq!(msg.dead_letter.as_ref().and_then(|d| d.routing_key.as_deref()), Some("rk"));
  }

  #[tokio::test]
  async fn consume_once_ack_path() {
    #[derive(Debug, Default)]
    struct TestQueue {
      acks: AtomicUsize,
    }

    #[async_trait]
    impl Queue for TestQueue {
      fn name(&self) -> &'static str {
        "test"
      }

      async fn publish(&self, _: QueueMessage) -> QueueResult<()> {
        Ok(())
      }

      async fn receive(&self) -> QueueResult<QueueDelivery> {
        Ok(QueueDelivery::new(
          QueueMessage::new("topic", Bytes::from_static(b"payload")),
          QueueReceipt::new("receipt"),
        ))
      }

      async fn ack(&self, _: QueueReceipt) -> QueueResult<()> {
        self.acks.fetch_add(1, Ordering::Relaxed);
        Ok(())
      }
    }

    let queue = TestQueue::default();
    consume_once(&queue, |_delivery| async { Ok(ConsumeAction::Ack) }).await.unwrap();
    assert_eq!(queue.acks.load(Ordering::Relaxed), 1);
  }
}