diidi-travel-common-queue 0.1.16

A collection of common utilities and types for the DiiDi project.
Documentation
use std::sync::Arc;
use std::time::{Duration, Instant};

use async_trait::async_trait;

use crate::queue::{NackAction, Queue, QueueDelivery, QueueMessage, QueueReceipt};
use crate::QueueResult;

pub const LOG_TARGET: &str = "diidi::queue";

#[derive(Clone)]
pub struct LoggingQueue {
  inner: Arc<dyn Queue>,
  target: &'static str,
  slow_threshold: Option<Duration>,
}

impl std::fmt::Debug for LoggingQueue {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("LoggingQueue")
      .field("inner", &self.inner.name())
      .field("target", &self.target)
      .field("slow_threshold", &self.slow_threshold)
      .finish()
  }
}

impl LoggingQueue {
  pub fn new(inner: Arc<dyn Queue>) -> Self {
    Self { inner, target: LOG_TARGET, slow_threshold: None }
  }

  pub fn with_target(mut self, target: &'static str) -> Self {
    self.target = target;
    self
  }

  pub fn with_slow_threshold(mut self, threshold: Duration) -> Self {
    self.slow_threshold = Some(threshold);
    self
  }

  fn provider(&self) -> &'static str {
    self.inner.name()
  }

  fn check_slow(&self, op: &'static str, topic: &str, elapsed: Duration) {
    if let Some(threshold) = self.slow_threshold
      && elapsed >= threshold
    {
      tracing::warn!(
        target: LOG_TARGET,
        provider = self.provider(),
        op,
        topic,
        elapsed_us = elapsed.as_micros() as u64,
        "queue op slow",
      );
    }
  }
}

#[async_trait]
impl Queue for LoggingQueue {
  fn name(&self) -> &'static str {
    self.inner.name()
  }

  fn features(&self) -> crate::feature::QueueFeatures {
    self.inner.features()
  }

  async fn publish(&self, message: QueueMessage) -> QueueResult<()> {
    let topic = message.topic.clone();
    let bytes = message.payload.len();
    let start = Instant::now();
    let res = self.inner.publish(message).await;
    let elapsed = start.elapsed();
    self.check_slow("publish", &topic, elapsed);
    let provider = self.provider();
    match &res {
      Ok(()) => tracing::debug!(
        target: LOG_TARGET,
        provider,
        op = "publish",
        topic,
        bytes,
        elapsed_us = elapsed.as_micros() as u64,
        "queue publish",
      ),
      Err(e) => tracing::warn!(
        target: LOG_TARGET,
        provider,
        op = "publish",
        topic,
        elapsed_us = elapsed.as_micros() as u64,
        error = %e,
        "queue error",
      ),
    }
    res
  }

  async fn receive(&self) -> QueueResult<QueueDelivery> {
    let start = Instant::now();
    let res = self.inner.receive().await;
    let elapsed = start.elapsed();
    let provider = self.provider();
    match &res {
      Ok(delivery) => tracing::debug!(
        target: LOG_TARGET,
        provider,
        op = "receive",
        topic = delivery.message.topic,
        bytes = delivery.message.payload.len(),
        elapsed_us = elapsed.as_micros() as u64,
        "queue receive",
      ),
      Err(e) => tracing::warn!(
        target: LOG_TARGET,
        provider,
        op = "receive",
        elapsed_us = elapsed.as_micros() as u64,
        error = %e,
        "queue error",
      ),
    }
    res
  }

  async fn ack(&self, receipt: QueueReceipt) -> QueueResult<()> {
    let start = Instant::now();
    let res = self.inner.ack(receipt).await;
    let elapsed = start.elapsed();
    let provider = self.provider();
    match &res {
      Ok(()) => tracing::debug!(
        target: LOG_TARGET,
        provider,
        op = "ack",
        elapsed_us = elapsed.as_micros() as u64,
        "queue ack",
      ),
      Err(e) => tracing::warn!(
        target: LOG_TARGET,
        provider,
        op = "ack",
        elapsed_us = elapsed.as_micros() as u64,
        error = %e,
        "queue error",
      ),
    }
    res
  }

  async fn nack(&self, receipt: QueueReceipt, action: NackAction) -> QueueResult<()> {
    let start = Instant::now();
    let res = self.inner.nack(receipt, action).await;
    let elapsed = start.elapsed();
    let provider = self.provider();
    match &res {
      Ok(()) => tracing::debug!(
        target: LOG_TARGET,
        provider,
        op = "nack",
        action = action.as_str(),
        elapsed_us = elapsed.as_micros() as u64,
        "queue nack",
      ),
      Err(e) => tracing::warn!(
        target: LOG_TARGET,
        provider,
        op = "nack",
        action = action.as_str(),
        elapsed_us = elapsed.as_micros() as u64,
        error = %e,
        "queue error",
      ),
    }
    res
  }

  async fn ping(&self) -> QueueResult<()> {
    let start = Instant::now();
    let res = self.inner.ping().await;
    let elapsed = start.elapsed();
    let provider = self.provider();
    match &res {
      Ok(()) => tracing::trace!(
        target: LOG_TARGET,
        provider,
        op = "ping",
        elapsed_us = elapsed.as_micros() as u64,
        "queue ping",
      ),
      Err(e) => tracing::warn!(
        target: LOG_TARGET,
        provider,
        op = "ping",
        error = %e,
        "queue ping failed",
      ),
    }
    res
  }
}

pub fn with_logging(inner: Arc<dyn Queue>) -> Arc<dyn Queue> {
  Arc::new(LoggingQueue::new(inner))
}

#[cfg(test)]
mod tests {
  use super::*;
  use async_trait::async_trait;
  use bytes::Bytes;
  use std::sync::atomic::{AtomicUsize, Ordering};

  #[derive(Debug, Default)]
  struct CountingQueue {
    publishes: AtomicUsize,
  }

  #[async_trait]
  impl Queue for CountingQueue {
    fn name(&self) -> &'static str {
      "counting"
    }

    async fn publish(&self, _: QueueMessage) -> QueueResult<()> {
      self.publishes.fetch_add(1, Ordering::Relaxed);
      Ok(())
    }

    async fn receive(&self) -> QueueResult<QueueDelivery> {
      Ok(QueueDelivery::new(QueueMessage::new("topic", Bytes::from_static(b"x")), QueueReceipt::new("r")))
    }

    async fn ack(&self, _: QueueReceipt) -> QueueResult<()> {
      Ok(())
    }
  }

  #[tokio::test]
  async fn decorator_forwards_to_inner() {
    let inner = Arc::new(CountingQueue::default());
    let queue = with_logging(inner.clone());
    queue.publish(QueueMessage::new("topic", Bytes::from_static(b"v"))).await.unwrap();
    assert_eq!(inner.publishes.load(Ordering::Relaxed), 1);
  }
}