1use crate::common::{AdapterError, Message, MessageHandler, Result};
2use aws_sdk_sqs::{
3 types::{MessageAttributeValue, QueueAttributeName},
4 Client as SqsClient,
5};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, error, info, warn};
11
12#[derive(Debug, Clone)]
13pub struct SqsConfig {
14 pub region: String,
15 pub queue_prefix: Option<String>,
16 pub visibility_timeout_seconds: Option<i32>,
17 pub message_retention_seconds: Option<i32>,
18 pub receive_wait_time_seconds: Option<i32>, }
20
21impl Default for SqsConfig {
22 fn default() -> Self {
23 Self {
24 region: "us-east-1".to_string(),
25 queue_prefix: Some("rohas-".to_string()),
26 visibility_timeout_seconds: Some(30),
27 message_retention_seconds: Some(345600), receive_wait_time_seconds: Some(20), }
30 }
31}
32
33pub struct SqsAdapter {
34 client: SqsClient,
35 config: SqsConfig,
36 queue_urls: Arc<RwLock<HashMap<String, String>>>, }
38
39impl SqsAdapter {
40 pub async fn new(config: SqsConfig) -> Result<Self> {
41 let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
42 .region(aws_sdk_sqs::config::Region::new(config.region.clone()))
43 .load()
44 .await;
45
46 let client = SqsClient::new(&aws_config);
47
48 info!(
49 "Initialized SQS adapter for region: {}",
50 config.region
51 );
52
53 Ok(Self {
54 client,
55 config,
56 queue_urls: Arc::new(RwLock::new(HashMap::new())),
57 })
58 }
59
60 async fn get_or_create_queue(&self, topic: &str) -> Result<String> {
61 {
62 let queue_urls = self.queue_urls.read().await;
63 if let Some(url) = queue_urls.get(topic) {
64 return Ok(url.clone());
65 }
66 }
67
68 let queue_name = if let Some(prefix) = &self.config.queue_prefix {
69 format!("{}{}", prefix, topic)
70 } else {
71 topic.to_string()
72 };
73
74 let queue_name = queue_name
75 .chars()
76 .map(|c| {
77 if c.is_alphanumeric() || c == '-' || c == '_' {
78 c
79 } else {
80 '-'
81 }
82 })
83 .collect::<String>();
84
85 let get_queue_result = self
86 .client
87 .get_queue_url()
88 .queue_name(&queue_name)
89 .send()
90 .await;
91
92 let queue_url = match get_queue_result {
93 Ok(response) => {
94 if let Some(url) = response.queue_url() {
95 info!("Found existing queue for topic '{}': {}", topic, url);
96 url.to_string()
97 } else {
98 return Err(AdapterError::QueueNotFound(queue_name));
99 }
100 }
101 Err(_) => {
102 debug!("Queue '{}' not found, creating...", queue_name);
103
104 let mut create_request = self.client.create_queue().queue_name(&queue_name);
105
106 let mut attributes = HashMap::new();
107 if let Some(visibility) = self.config.visibility_timeout_seconds {
108 attributes.insert(
109 QueueAttributeName::VisibilityTimeout,
110 visibility.to_string(),
111 );
112 }
113 if let Some(retention) = self.config.message_retention_seconds {
114 attributes.insert(
115 QueueAttributeName::MessageRetentionPeriod,
116 retention.to_string(),
117 );
118 }
119 if let Some(wait_time) = self.config.receive_wait_time_seconds {
120 attributes.insert(
121 QueueAttributeName::ReceiveMessageWaitTimeSeconds,
122 wait_time.to_string(),
123 );
124 }
125
126 if !attributes.is_empty() {
127 create_request = create_request.set_attributes(Some(attributes));
128 }
129
130 let create_result = create_request.send().await.map_err(|e| {
131 AdapterError::AwsSqs(format!("Failed to create queue '{}': {}", queue_name, e))
132 })?;
133
134 if let Some(url) = create_result.queue_url() {
135 info!("Created queue for topic '{}': {}", topic, url);
136 url.to_string()
137 } else {
138 return Err(AdapterError::AwsSqs(format!(
139 "Queue created but no URL returned for '{}'",
140 queue_name
141 )));
142 }
143 }
144 };
145
146 {
147 let mut queue_urls = self.queue_urls.write().await;
148 queue_urls.insert(topic.to_string(), queue_url.clone());
149 }
150
151 Ok(queue_url)
152 }
153
154 pub async fn publish(
155 &self,
156 topic: impl Into<String>,
157 payload: serde_json::Value,
158 ) -> Result<()> {
159 let topic = topic.into();
160 tracing::info!("SqsAdapter::publish: Starting publish for topic: {}", topic);
161
162 let message = Message::new(topic.clone(), payload);
163
164 let message_body = serde_json::to_string(&message)
165 .map_err(|e| {
166 tracing::error!("SqsAdapter::publish: Serialization error for topic {}: {}", topic, e);
167 AdapterError::Serialization(e)
168 })?;
169
170 tracing::debug!("SqsAdapter::publish: Message serialized, getting/creating queue for topic: {}", topic);
171
172 let queue_url = self.get_or_create_queue(&topic).await.map_err(|e| {
173 tracing::error!("SqsAdapter::publish: Failed to get/create queue for topic {}: {}", topic, e);
174 e
175 })?;
176
177 tracing::info!("SqsAdapter::publish: Queue URL obtained: {} for topic: {}", queue_url, topic);
178
179 let mut attributes = HashMap::new();
180 attributes.insert(
181 "topic".to_string(),
182 MessageAttributeValue::builder()
183 .data_type("String")
184 .string_value(&topic)
185 .build()
186 .map_err(|e| AdapterError::AwsSqs(format!("Failed to build attribute: {}", e)))?,
187 );
188 attributes.insert(
189 "timestamp".to_string(),
190 MessageAttributeValue::builder()
191 .data_type("String")
192 .string_value(&message.timestamp)
193 .build()
194 .map_err(|e| AdapterError::AwsSqs(format!("Failed to build attribute: {}", e)))?,
195 );
196
197 let send_result = self
198 .client
199 .send_message()
200 .queue_url(&queue_url)
201 .message_body(&message_body)
202 .set_message_attributes(Some(attributes))
203 .send()
204 .await;
205
206 match send_result {
207 Ok(response) => {
208 if let Some(message_id) = response.message_id() {
209 info!("Published message to SQS topic: {} (queue: {}, message_id: {})", topic, queue_url, message_id);
210 } else {
211 info!("Published message to SQS topic: {} (queue: {})", topic, queue_url);
212 }
213 Ok(())
214 }
215 Err(e) => {
216 error!("Failed to send message to SQS queue '{}' for topic '{}': {}", queue_url, topic, e);
217 let error_msg = format!(
218 "Failed to send message to queue '{}': {}",
219 queue_url, e
220 );
221 tracing::error!("SqsAdapter::publish: Error details - {}", error_msg);
222 Err(AdapterError::AwsSqs(error_msg))
223 }
224 }
225 }
226
227 pub async fn subscribe<H>(&self, topic: impl Into<String>, handler: Arc<H>) -> Result<()>
228 where
229 H: MessageHandler + 'static,
230 {
231 let topic = topic.into();
232 let queue_url = self.get_or_create_queue(&topic).await?;
233
234 info!("Subscribing to topic: {} (queue: {})", topic, queue_url);
235
236 let client = self.client.clone();
237 let handler = handler.clone();
238 let topic_clone = topic.clone();
239
240 tokio::spawn(async move {
241 info!("SQS subscription polling loop started for topic '{}' (queue: {})", topic_clone, queue_url);
242 let mut poll_count = 0u64;
243 loop {
244 poll_count += 1;
245 if poll_count % 5 == 0 {
246 info!("SQS polling loop still active for topic '{}' (poll #{}), queue: {}", topic_clone, poll_count, queue_url);
247 } else if poll_count <= 3 {
248 info!("SQS polling loop active for topic '{}' (poll #{}), queue: {}", topic_clone, poll_count, queue_url);
249 } else {
250 debug!("Polling SQS queue for topic '{}' (poll #{})...", topic_clone, poll_count);
251 }
252 let receive_result = client
253 .receive_message()
254 .queue_url(&queue_url)
255 .max_number_of_messages(10)
256 .wait_time_seconds(20)
257 .send()
258 .await;
259
260 match receive_result {
261 Ok(response) => {
262 let messages = response.messages();
263 if !messages.is_empty() {
264 info!("Received {} message(s) from SQS queue for topic '{}'", messages.len(), topic_clone);
265 for sqs_message in messages {
266 if let Some(body) = sqs_message.body() {
267 info!("Raw SQS message body for topic '{}': {}", topic_clone, body);
268 match serde_json::from_str::<Message>(body) {
269 Ok(message) => {
270 info!("Successfully parsed SQS message for topic '{}'", topic_clone);
271 info!("Message topic: {}, payload: {:?}", message.topic, message.payload);
272 info!("Calling handler for SQS message...");
273 if let Err(e) = handler.handle(message).await {
274 error!("Handler error for SQS topic '{}': {}", topic_clone, e);
275 } else {
276 info!("Handler completed successfully for SQS topic '{}'", topic_clone);
277 }
278
279 if let Some(receipt_handle) = sqs_message.receipt_handle() {
280 if let Err(e) = client
281 .delete_message()
282 .queue_url(&queue_url)
283 .receipt_handle(receipt_handle)
284 .send()
285 .await
286 {
287 warn!(
288 "Failed to delete message from queue '{}': {}",
289 queue_url, e
290 );
291 }
292 }
293 }
294 Err(e) => {
295 error!(
296 "Failed to deserialize SQS message for topic '{}': {}. Body: {}",
297 topic_clone, e, body
298 );
299 if let Some(receipt_handle) = sqs_message.receipt_handle() {
300 let _ = client
301 .delete_message()
302 .queue_url(&queue_url)
303 .receipt_handle(receipt_handle)
304 .send()
305 .await;
306 }
307 }
308 }
309 }
310 }
311 } else {
312 debug!("No messages received from SQS queue for topic '{}' (this is normal, continuing to poll...)", topic_clone);
313 }
314 }
315 Err(e) => {
316 error!(
317 "Error receiving messages from SQS queue '{}' for topic '{}': {}. Retrying in 5 seconds...",
318 queue_url, topic_clone, e
319 );
320 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
321 }
322 }
323 }
324 });
325
326 Ok(())
327 }
328
329 pub async fn subscribe_fn<F, Fut>(&self, topic: impl Into<String>, handler: F) -> Result<()>
330 where
331 F: Fn(Message) -> Fut + Send + Sync + 'static,
332 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
333 {
334 struct ClosureHandler<F, Fut>
335 where
336 F: Fn(Message) -> Fut + Send + Sync,
337 Fut: std::future::Future<Output = Result<()>> + Send,
338 {
339 func: F,
340 }
341
342 #[async_trait]
343 impl<F, Fut> MessageHandler for ClosureHandler<F, Fut>
344 where
345 F: Fn(Message) -> Fut + Send + Sync,
346 Fut: std::future::Future<Output = Result<()>> + Send,
347 {
348 async fn handle(&self, message: Message) -> Result<()> {
349 (self.func)(message).await
350 }
351 }
352
353 let handler = Arc::new(ClosureHandler { func: handler });
354 self.subscribe(topic, handler).await
355 }
356
357 pub async fn list_topics(&self) -> Vec<String> {
358 let queue_urls = self.queue_urls.read().await;
359 queue_urls.keys().cloned().collect()
360 }
361}
362