use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use shove::sns::*;
use shove::*;
use testcontainers::ImageExt;
use testcontainers::runners::AsyncRunner;
use testcontainers_modules::localstack::LocalStack;
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TaskEvent {
task_id: String,
payload: String,
}
define_topic!(
WorkQueue,
TaskEvent,
TopologyBuilder::new("sqs-work-queue").dlq().build()
);
#[derive(Clone)]
struct TaskHandler;
impl MessageHandler<WorkQueue> for TaskHandler {
type Context = ();
async fn handle(&self, msg: TaskEvent, metadata: MessageMetadata, _: &()) -> Outcome {
println!(
"[worker] task={} attempt={}",
msg.task_id,
metadata.retry_count + 1,
);
tokio::time::sleep(Duration::from_millis(200)).await;
Outcome::Ack
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "shove=debug,sqs_consumer_groups=debug".parse().unwrap()),
)
.init();
let auth_token = match std::env::var("LOCALSTACK_AUTH_TOKEN") {
Ok(t) => t,
Err(_) => {
eprintln!(
"LOCALSTACK_AUTH_TOKEN is not set. This example requires a LocalStack Pro auth \
token:\n\n export LOCALSTACK_AUTH_TOKEN=...\n"
);
std::process::exit(1);
}
};
unsafe {
std::env::set_var("AWS_ACCESS_KEY_ID", "test");
std::env::set_var("AWS_SECRET_ACCESS_KEY", "test");
std::env::set_var("AWS_REGION", "us-east-1");
}
let container = LocalStack::default()
.with_env_var("LOCALSTACK_AUTH_TOKEN", auth_token)
.start()
.await?;
let port = container.get_host_port_ipv4(4566).await?;
let endpoint = format!("http://localhost:{port}");
let config = SnsConfig {
region: Some("us-east-1".into()),
endpoint_url: Some(endpoint),
};
let client = SnsClient::new(&config).await?;
let declarer = SnsTopologyDeclarer::new(client.clone());
declarer.declare(WorkQueue::topology()).await?;
let publisher = SnsPublisher::new(client.clone(), client.topic_registry().clone());
let burst_size = 50;
for i in 0..burst_size {
let event = TaskEvent {
task_id: format!("TASK-{i:03}"),
payload: format!("work item {i}"),
};
publisher.publish::<WorkQueue>(&event).await?;
}
println!("published {burst_size} tasks\n");
let mut registry = SqsConsumerGroupRegistry::new(client.clone());
registry
.register::<WorkQueue, TaskHandler>(
SqsConsumerGroupConfig::new(1..=5) .with_prefetch_count(10) .with_max_retries(3),
|| TaskHandler, (), )
.await?;
registry.start_all();
println!("consumer group started (min_consumers=1)\n");
let registry = Arc::new(Mutex::new(registry));
let stats_provider =
SqsQueueStatsProvider::new(client.clone(), client.queue_registry().clone());
println!("monitoring queue depth — watching backlog drain\n");
for _ in 0..15 {
tokio::time::sleep(Duration::from_secs(2)).await;
match stats_provider
.get_queue_stats(WorkQueue::topology().queue())
.await
{
Ok(stats) => println!(
"[monitor] messages_ready={} in_flight={}",
stats.messages_ready, stats.messages_not_visible,
),
Err(e) => eprintln!("[monitor] failed to fetch stats: {e}"),
}
}
println!("\nshutting down...");
registry.lock().await.shutdown_all().await;
client.shutdown().await;
println!("done");
drop(container);
Ok(())
}