Skip to main content

adapter_aws/
sqs.rs

1use crate::common::{AdapterError, Message, MessageHandler, Result};
2use aws_sdk_sqs::{
3    types::{MessageAttributeValue, QueueAttributeName},
4    Client as SqsClient,
5};
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, error, info, warn};
11
12#[derive(Debug, Clone)]
13pub struct SqsConfig {
14    pub region: String,
15    pub queue_prefix: Option<String>,
16    pub visibility_timeout_seconds: Option<i32>,
17    pub message_retention_seconds: Option<i32>,
18    pub receive_wait_time_seconds: Option<i32>, // Long polling wait time
19}
20
21impl Default for SqsConfig {
22    fn default() -> Self {
23        Self {
24            region: "us-east-1".to_string(),
25            queue_prefix: Some("rohas-".to_string()),
26            visibility_timeout_seconds: Some(30),
27            message_retention_seconds: Some(345600), // 4 days
28            receive_wait_time_seconds: Some(20),      // Long polling
29        }
30    }
31}
32
33pub struct SqsAdapter {
34    client: SqsClient,
35    config: SqsConfig,
36    queue_urls: Arc<RwLock<HashMap<String, String>>>, // topic -> queue_url
37}
38
39impl SqsAdapter {
40    pub async fn new(config: SqsConfig) -> Result<Self> {
41        let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
42            .region(aws_sdk_sqs::config::Region::new(config.region.clone()))
43            .load()
44            .await;
45
46        let client = SqsClient::new(&aws_config);
47
48        info!(
49            "Initialized SQS adapter for region: {}",
50            config.region
51        );
52
53        Ok(Self {
54            client,
55            config,
56            queue_urls: Arc::new(RwLock::new(HashMap::new())),
57        })
58    }
59
60    async fn get_or_create_queue(&self, topic: &str) -> Result<String> {
61        {
62            let queue_urls = self.queue_urls.read().await;
63            if let Some(url) = queue_urls.get(topic) {
64                return Ok(url.clone());
65            }
66        }
67
68        let queue_name = if let Some(prefix) = &self.config.queue_prefix {
69            format!("{}{}", prefix, topic)
70        } else {
71            topic.to_string()
72        };
73
74        let queue_name = queue_name
75            .chars()
76            .map(|c| {
77                if c.is_alphanumeric() || c == '-' || c == '_' {
78                    c
79                } else {
80                    '-'
81                }
82            })
83            .collect::<String>();
84
85        let get_queue_result = self
86            .client
87            .get_queue_url()
88            .queue_name(&queue_name)
89            .send()
90            .await;
91
92        let queue_url = match get_queue_result {
93            Ok(response) => {
94                if let Some(url) = response.queue_url() {
95                    info!("Found existing queue for topic '{}': {}", topic, url);
96                    url.to_string()
97                } else {
98                    return Err(AdapterError::QueueNotFound(queue_name));
99                }
100            }
101            Err(_) => {
102                debug!("Queue '{}' not found, creating...", queue_name);
103
104                let mut create_request = self.client.create_queue().queue_name(&queue_name);
105
106                let mut attributes = HashMap::new();
107                if let Some(visibility) = self.config.visibility_timeout_seconds {
108                    attributes.insert(
109                        QueueAttributeName::VisibilityTimeout,
110                        visibility.to_string(),
111                    );
112                }
113                if let Some(retention) = self.config.message_retention_seconds {
114                    attributes.insert(
115                        QueueAttributeName::MessageRetentionPeriod,
116                        retention.to_string(),
117                    );
118                }
119                if let Some(wait_time) = self.config.receive_wait_time_seconds {
120                    attributes.insert(
121                        QueueAttributeName::ReceiveMessageWaitTimeSeconds,
122                        wait_time.to_string(),
123                    );
124                }
125
126                if !attributes.is_empty() {
127                    create_request = create_request.set_attributes(Some(attributes));
128                }
129
130                let create_result = create_request.send().await.map_err(|e| {
131                    AdapterError::AwsSqs(format!("Failed to create queue '{}': {}", queue_name, e))
132                })?;
133
134                if let Some(url) = create_result.queue_url() {
135                    info!("Created queue for topic '{}': {}", topic, url);
136                    url.to_string()
137                } else {
138                    return Err(AdapterError::AwsSqs(format!(
139                        "Queue created but no URL returned for '{}'",
140                        queue_name
141                    )));
142                }
143            }
144        };
145
146        {
147            let mut queue_urls = self.queue_urls.write().await;
148            queue_urls.insert(topic.to_string(), queue_url.clone());
149        }
150
151        Ok(queue_url)
152    }
153
154    pub async fn publish(
155        &self,
156        topic: impl Into<String>,
157        payload: serde_json::Value,
158    ) -> Result<()> {
159        let topic = topic.into();
160        tracing::info!("SqsAdapter::publish: Starting publish for topic: {}", topic);
161        
162        let message = Message::new(topic.clone(), payload);
163
164        let message_body = serde_json::to_string(&message)
165            .map_err(|e| {
166                tracing::error!("SqsAdapter::publish: Serialization error for topic {}: {}", topic, e);
167                AdapterError::Serialization(e)
168            })?;
169
170        tracing::debug!("SqsAdapter::publish: Message serialized, getting/creating queue for topic: {}", topic);
171        
172        let queue_url = self.get_or_create_queue(&topic).await.map_err(|e| {
173            tracing::error!("SqsAdapter::publish: Failed to get/create queue for topic {}: {}", topic, e);
174            e
175        })?;
176        
177        tracing::info!("SqsAdapter::publish: Queue URL obtained: {} for topic: {}", queue_url, topic);
178
179        let mut attributes = HashMap::new();
180        attributes.insert(
181            "topic".to_string(),
182            MessageAttributeValue::builder()
183                .data_type("String")
184                .string_value(&topic)
185                .build()
186                .map_err(|e| AdapterError::AwsSqs(format!("Failed to build attribute: {}", e)))?,
187        );
188        attributes.insert(
189            "timestamp".to_string(),
190            MessageAttributeValue::builder()
191                .data_type("String")
192                .string_value(&message.timestamp)
193                .build()
194                .map_err(|e| AdapterError::AwsSqs(format!("Failed to build attribute: {}", e)))?,
195        );
196
197        let send_result = self
198            .client
199            .send_message()
200            .queue_url(&queue_url)
201            .message_body(&message_body)
202            .set_message_attributes(Some(attributes))
203            .send()
204            .await;
205
206        match send_result {
207            Ok(response) => {
208                if let Some(message_id) = response.message_id() {
209                    info!("Published message to SQS topic: {} (queue: {}, message_id: {})", topic, queue_url, message_id);
210                } else {
211                    info!("Published message to SQS topic: {} (queue: {})", topic, queue_url);
212                }
213                Ok(())
214            }
215            Err(e) => {
216                error!("Failed to send message to SQS queue '{}' for topic '{}': {}", queue_url, topic, e);
217                let error_msg = format!(
218                    "Failed to send message to queue '{}': {}",
219                    queue_url, e
220                );
221                tracing::error!("SqsAdapter::publish: Error details - {}", error_msg);
222                Err(AdapterError::AwsSqs(error_msg))
223            }
224        }
225    }
226
227    pub async fn subscribe<H>(&self, topic: impl Into<String>, handler: Arc<H>) -> Result<()>
228    where
229        H: MessageHandler + 'static,
230    {
231        let topic = topic.into();
232        let queue_url = self.get_or_create_queue(&topic).await?;
233
234        info!("Subscribing to topic: {} (queue: {})", topic, queue_url);
235
236        let client = self.client.clone();
237        let handler = handler.clone();
238        let topic_clone = topic.clone();
239
240        tokio::spawn(async move {
241            info!("SQS subscription polling loop started for topic '{}' (queue: {})", topic_clone, queue_url);
242            let mut poll_count = 0u64;
243            loop {
244                poll_count += 1;
245                if poll_count % 5 == 0 {
246                    info!("SQS polling loop still active for topic '{}' (poll #{}), queue: {}", topic_clone, poll_count, queue_url);
247                } else if poll_count <= 3 {
248                    info!("SQS polling loop active for topic '{}' (poll #{}), queue: {}", topic_clone, poll_count, queue_url);
249                } else {
250                    debug!("Polling SQS queue for topic '{}' (poll #{})...", topic_clone, poll_count);
251                }
252                let receive_result = client
253                    .receive_message()
254                    .queue_url(&queue_url)
255                    .max_number_of_messages(10)
256                    .wait_time_seconds(20)
257                    .send()
258                    .await;
259
260                match receive_result {
261                    Ok(response) => {
262                        let messages = response.messages();
263                        if !messages.is_empty() {
264                            info!("Received {} message(s) from SQS queue for topic '{}'", messages.len(), topic_clone);
265                            for sqs_message in messages {
266                                if let Some(body) = sqs_message.body() {
267                                    info!("Raw SQS message body for topic '{}': {}", topic_clone, body);
268                                    match serde_json::from_str::<Message>(body) {
269                                        Ok(message) => {
270                                            info!("Successfully parsed SQS message for topic '{}'", topic_clone);
271                                            info!("Message topic: {}, payload: {:?}", message.topic, message.payload);
272                                            info!("Calling handler for SQS message...");
273                                            if let Err(e) = handler.handle(message).await {
274                                                error!("Handler error for SQS topic '{}': {}", topic_clone, e);
275                                            } else {
276                                                info!("Handler completed successfully for SQS topic '{}'", topic_clone);
277                                            }
278
279                                            if let Some(receipt_handle) = sqs_message.receipt_handle() {
280                                                if let Err(e) = client
281                                                    .delete_message()
282                                                    .queue_url(&queue_url)
283                                                    .receipt_handle(receipt_handle)
284                                                    .send()
285                                                    .await
286                                                {
287                                                    warn!(
288                                                        "Failed to delete message from queue '{}': {}",
289                                                        queue_url, e
290                                                    );
291                                                }
292                                            }
293                                        }
294                                        Err(e) => {
295                                            error!(
296                                                "Failed to deserialize SQS message for topic '{}': {}. Body: {}",
297                                                topic_clone, e, body
298                                            );
299                                            if let Some(receipt_handle) = sqs_message.receipt_handle() {
300                                                let _ = client
301                                                    .delete_message()
302                                                    .queue_url(&queue_url)
303                                                    .receipt_handle(receipt_handle)
304                                                    .send()
305                                                    .await;
306                                            }
307                                        }
308                                    }
309                                }
310                            }
311                        } else {
312                            debug!("No messages received from SQS queue for topic '{}' (this is normal, continuing to poll...)", topic_clone);
313                        }
314                    }
315                    Err(e) => {
316                        error!(
317                            "Error receiving messages from SQS queue '{}' for topic '{}': {}. Retrying in 5 seconds...",
318                            queue_url, topic_clone, e
319                        );
320                        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
321                    }
322                }
323            }
324        });
325
326        Ok(())
327    }
328
329    pub async fn subscribe_fn<F, Fut>(&self, topic: impl Into<String>, handler: F) -> Result<()>
330    where
331        F: Fn(Message) -> Fut + Send + Sync + 'static,
332        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
333    {
334        struct ClosureHandler<F, Fut>
335        where
336            F: Fn(Message) -> Fut + Send + Sync,
337            Fut: std::future::Future<Output = Result<()>> + Send,
338        {
339            func: F,
340        }
341
342        #[async_trait]
343        impl<F, Fut> MessageHandler for ClosureHandler<F, Fut>
344        where
345            F: Fn(Message) -> Fut + Send + Sync,
346            Fut: std::future::Future<Output = Result<()>> + Send,
347        {
348            async fn handle(&self, message: Message) -> Result<()> {
349                (self.func)(message).await
350            }
351        }
352
353        let handler = Arc::new(ClosureHandler { func: handler });
354        self.subscribe(topic, handler).await
355    }
356
357    pub async fn list_topics(&self) -> Vec<String> {
358        let queue_urls = self.queue_urls.read().await;
359        queue_urls.keys().cloned().collect()
360    }
361}
362