diidi-travel-common-queue 0.1.16

A collection of common utilities and types for the DiiDi project.
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use crate::config::QueueConfig;
use crate::decorator::with_logging;
use crate::error::{QueueError, QueueResult};
use crate::provider::QueueProvider;
use crate::queue::Queue;

#[derive(Clone, Default)]
pub struct QueueFactory {
  providers: HashMap<&'static str, Arc<dyn QueueProvider>>,
}

impl QueueFactory {
  pub fn builder() -> QueueFactoryBuilder {
    QueueFactoryBuilder::default()
  }

  pub async fn create(&self, config: &QueueConfig) -> QueueResult<Arc<dyn Queue>> {
    let provider = self
      .providers
      .get(config.provider.as_str())
      .ok_or_else(|| QueueError::ProviderNotRegistered(config.provider.clone()))?;
    provider.build(&config.params).await
  }

  pub async fn create_with_logging(&self, config: &QueueConfig) -> QueueResult<Arc<dyn Queue>> {
    let inner = self.create(config).await?;
    Ok(with_logging(inner))
  }

  pub fn registered(&self) -> Vec<&'static str> {
    let mut names: Vec<_> = self.providers.keys().copied().collect();
    names.sort_unstable();
    names
  }

  pub fn has(&self, name: &str) -> bool {
    self.providers.contains_key(name)
  }
}

impl std::fmt::Debug for QueueFactory {
  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("QueueFactory").field("registered", &self.registered()).finish()
  }
}

#[derive(Default)]
pub struct QueueFactoryBuilder {
  providers: HashMap<&'static str, Arc<dyn QueueProvider>>,
}

impl QueueFactoryBuilder {
  pub fn register<P: QueueProvider>(mut self, provider: P) -> Self {
    self.providers.insert(provider.name(), Arc::new(provider));
    self
  }

  pub fn register_arc(mut self, provider: Arc<dyn QueueProvider>) -> Self {
    self.providers.insert(provider.name(), provider);
    self
  }

  pub fn build(self) -> QueueFactory {
    QueueFactory { providers: self.providers }
  }
}

#[cfg(test)]
mod tests {
  use super::*;
  use async_trait::async_trait;
  use bytes::Bytes;

  #[derive(Debug, Default)]
  struct NoopQueue;

  #[async_trait]
  impl Queue for NoopQueue {
    fn name(&self) -> &'static str {
      "noop"
    }

    async fn publish(&self, _: crate::queue::QueueMessage) -> QueueResult<()> {
      Ok(())
    }

    async fn receive(&self) -> QueueResult<crate::queue::QueueDelivery> {
      Ok(crate::queue::QueueDelivery::new(
        crate::queue::QueueMessage::new("topic", Bytes::from_static(b"x")),
        crate::queue::QueueReceipt::new("r"),
      ))
    }

    async fn ack(&self, _: crate::queue::QueueReceipt) -> QueueResult<()> {
      Ok(())
    }
  }

  #[derive(Default)]
  struct NoopProvider;

  #[async_trait]
  impl QueueProvider for NoopProvider {
    fn name(&self) -> &'static str {
      "noop"
    }

    async fn build(&self, _: &serde_json::Value) -> QueueResult<Arc<dyn Queue>> {
      Ok(Arc::new(NoopQueue))
    }
  }

  #[tokio::test]
  async fn unknown_provider_errors_at_create() {
    let factory = QueueFactory::builder().register(NoopProvider).build();
    let cfg = QueueConfig::new("ghost");
    let err = factory.create(&cfg).await.unwrap_err();
    assert!(matches!(err, QueueError::ProviderNotRegistered(p) if p == "ghost"));
  }

  #[tokio::test]
  async fn registered_provider_resolves() {
    let factory = QueueFactory::builder().register(NoopProvider).build();
    let queue = factory.create(&QueueConfig::new("noop")).await.unwrap();
    assert_eq!(queue.name(), "noop");
  }
}