use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
#[cfg(feature = "aws-sns-sqs")]
use crate::backends::sns::topology::QueueRegistry;
use crate::backends::sns::topology::TopicRegistry;
use crate::error::{Result, ShoveError};
#[derive(Clone)]
pub struct SnsConfig {
pub region: Option<String>,
pub endpoint_url: Option<String>,
}
impl Debug for SnsConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SnsConfig")
.field("region", &self.region)
.field("endpoint_url", &self.endpoint_url)
.finish()
}
}
#[derive(Clone)]
pub struct SnsClient {
sns_client: aws_sdk_sns::Client,
#[cfg(feature = "aws-sns-sqs")]
sqs_client: aws_sdk_sqs::Client,
topic_registry: Arc<TopicRegistry>,
#[cfg(feature = "aws-sns-sqs")]
queue_registry: Arc<QueueRegistry>,
shutdown_token: CancellationToken,
}
impl SnsClient {
pub async fn new(config: &SnsConfig) -> Result<Self> {
let mut aws_config = aws_config::from_env();
if let Some(region) = &config.region {
aws_config = aws_config.region(aws_config::Region::new(region.clone()));
}
if let Some(endpoint) = &config.endpoint_url {
aws_config = aws_config.endpoint_url(endpoint);
}
let aws_config = aws_config.load().await;
let sns_client = aws_sdk_sns::Client::new(&aws_config);
#[cfg(feature = "aws-sns-sqs")]
let sqs_client = aws_sdk_sqs::Client::new(&aws_config);
Ok(Self {
sns_client,
#[cfg(feature = "aws-sns-sqs")]
sqs_client,
topic_registry: Arc::new(TopicRegistry::new()),
#[cfg(feature = "aws-sns-sqs")]
queue_registry: Arc::new(QueueRegistry::new()),
shutdown_token: CancellationToken::new(),
})
}
#[cfg(test)]
pub(crate) fn mock() -> Self {
let behavior_version = aws_config::BehaviorVersion::latest();
#[allow(clippy::absolute_paths)]
let sns_conf = aws_sdk_sns::config::Config::builder()
.behavior_version(behavior_version)
.region(aws_config::Region::new("us-east-1"))
.build();
let sns_client = aws_sdk_sns::Client::from_conf(sns_conf);
#[cfg(feature = "aws-sns-sqs")]
let sqs_client = {
#[allow(clippy::absolute_paths)]
let sqs_conf = aws_sdk_sqs::config::Config::builder()
.behavior_version(behavior_version)
.region(aws_config::Region::new("us-east-1"))
.build();
aws_sdk_sqs::Client::from_conf(sqs_conf)
};
Self {
sns_client,
#[cfg(feature = "aws-sns-sqs")]
sqs_client,
topic_registry: Arc::new(TopicRegistry::new()),
#[cfg(feature = "aws-sns-sqs")]
queue_registry: Arc::new(QueueRegistry::new()),
shutdown_token: CancellationToken::new(),
}
}
pub(crate) fn inner(&self) -> &aws_sdk_sns::Client {
&self.sns_client
}
#[cfg(feature = "aws-sns-sqs")]
pub(crate) fn sqs(&self) -> &aws_sdk_sqs::Client {
&self.sqs_client
}
pub fn topic_registry(&self) -> &Arc<TopicRegistry> {
&self.topic_registry
}
#[cfg(feature = "aws-sns-sqs")]
pub fn queue_registry(&self) -> &Arc<QueueRegistry> {
&self.queue_registry
}
pub fn shutdown_token(&self) -> CancellationToken {
self.shutdown_token.clone()
}
#[cfg(feature = "aws-sns-sqs")]
pub(super) async fn ping(&self, timeout: std::time::Duration) -> Result<()> {
if self.shutdown_token.is_cancelled() {
return Err(ShoveError::Connection("client is shut down".into()));
}
let fut = self.sqs_client.list_queues().max_results(1).send();
tokio::time::timeout(timeout, fut)
.await
.map_err(|_| ShoveError::Connection(format!("sqs ping timed out after {timeout:?}")))?
.map_err(|e| ShoveError::Connection(format!("sqs ping failed: {e}")))?;
Ok(())
}
pub async fn shutdown(&self) {
self.shutdown_token.cancel();
tokio::time::sleep(crate::SHUTDOWN_GRACE).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_debug_shows_fields() {
let config = SnsConfig {
region: Some("us-east-1".into()),
endpoint_url: Some("http://localhost:4566".into()),
};
let debug = format!("{config:?}");
assert!(debug.contains("us-east-1"));
assert!(debug.contains("localhost:4566"));
}
#[test]
fn config_debug_none_fields() {
let config = SnsConfig {
region: None,
endpoint_url: None,
};
let debug = format!("{config:?}");
assert!(debug.contains("None"));
}
#[test]
fn config_clone() {
let config = SnsConfig {
region: Some("eu-west-1".into()),
endpoint_url: None,
};
let cloned = config.clone();
assert_eq!(cloned.region, config.region);
assert_eq!(cloned.endpoint_url, config.endpoint_url);
}
}