Skip to main content

streamling_e2e/resources/
sqs.rs

1//! SQS resource manager for creating and cleaning up isolated SQS queues per test.
2//!
3//! Uses ElasticMQ (or any SQS-compatible endpoint) as the SQS backend for testing.
4
5use crate::{E2eError, Result};
6use aws_config::BehaviorVersion;
7use aws_sdk_sqs::Client as SqsClient;
8use std::time::Duration;
9use tracing::info;
10
11/// SQS resource manager
12pub struct SqsResource {
13    /// SQS-compatible endpoint URL (e.g. ElasticMQ)
14    pub endpoint_url: String,
15    /// AWS region (default: us-east-1)
16    pub region: String,
17    /// Name of the isolated SQS queue
18    pub queue_name: String,
19    /// URL of the created SQS queue
20    pub queue_url: String,
21    /// SQS client
22    client: SqsClient,
23}
24
25impl SqsResource {
26    /// Create a new SQS resource with an isolated queue
27    pub async fn new(endpoint_url: &str, queue_name: &str) -> Result<Self> {
28        let region = "us-east-1".to_string();
29
30        let client = Self::create_client(endpoint_url, &region).await;
31
32        // Create the queue
33        let create_result = client
34            .create_queue()
35            .queue_name(queue_name)
36            .send()
37            .await
38            .map_err(|e| {
39                E2eError::Sqs(format!(
40                    "Failed to create SQS queue '{}': {}",
41                    queue_name, e
42                ))
43            })?;
44
45        let mut queue_url = create_result
46            .queue_url()
47            .ok_or_else(|| E2eError::Sqs("Queue URL not returned after creation".to_string()))?
48            .to_string();
49
50        // ElasticMQ may return URLs with its internal port (9324); normalize to use
51        // the endpoint_url so clients connecting via NodePort get the correct URL.
52        if let Some(path_start) = queue_url
53            .find("://")
54            .and_then(|i| queue_url[i + 3..].find('/').map(|j| i + 3 + j))
55        {
56            let path = &queue_url[path_start..];
57            let base = endpoint_url.trim_end_matches('/');
58            queue_url = format!("{}{}", base, path);
59        }
60        info!("Created SQS queue: {} (url: {})", queue_name, queue_url);
61
62        Ok(Self {
63            endpoint_url: endpoint_url.to_string(),
64            region,
65            queue_name: queue_name.to_string(),
66            queue_url,
67            client,
68        })
69    }
70
71    /// Create an SQS client for the SQS-compatible endpoint
72    async fn create_client(endpoint_url: &str, region: &str) -> SqsClient {
73        let sdk_config = aws_config::defaults(BehaviorVersion::latest())
74            .endpoint_url(endpoint_url)
75            .region(aws_types::region::Region::new(region.to_string()))
76            .load()
77            .await;
78
79        SqsClient::new(&sdk_config)
80    }
81
82    /// Send a single message to the queue
83    pub async fn send_message(&self, body: &str) -> Result<()> {
84        self.client
85            .send_message()
86            .queue_url(&self.queue_url)
87            .message_body(body)
88            .send()
89            .await
90            .map_err(|e| E2eError::Sqs(format!("Failed to send message: {}", e)))?;
91
92        Ok(())
93    }
94
95    /// Receive messages from the queue
96    ///
97    /// Returns a list of message bodies. Uses long polling with a wait time of up to 5 seconds.
98    pub async fn receive_messages(&self, max_messages: i32) -> Result<Vec<String>> {
99        let result = self
100            .client
101            .receive_message()
102            .queue_url(&self.queue_url)
103            .max_number_of_messages(max_messages.min(10)) // SQS max is 10 per call
104            .wait_time_seconds(5)
105            .send()
106            .await
107            .map_err(|e| E2eError::Sqs(format!("Failed to receive messages: {}", e)))?;
108
109        let messages = result
110            .messages()
111            .iter()
112            .filter_map(|msg| msg.body().map(|b| b.to_string()))
113            .collect();
114
115        Ok(messages)
116    }
117
118    /// Receive all available messages from the queue, polling until no more messages arrive.
119    ///
120    /// This is useful for verification in e2e tests where we want to read all messages
121    /// that have been written to the queue.
122    pub async fn receive_all_messages(
123        &self,
124        max_messages: usize,
125        max_wait: Duration,
126    ) -> Result<Vec<String>> {
127        let mut all_messages = Vec::new();
128        let start = std::time::Instant::now();
129
130        while all_messages.len() < max_messages && start.elapsed() < max_wait {
131            let batch_size = (max_messages - all_messages.len()).min(10) as i32;
132            let result = self
133                .client
134                .receive_message()
135                .queue_url(&self.queue_url)
136                .max_number_of_messages(batch_size)
137                .wait_time_seconds(2)
138                .send()
139                .await
140                .map_err(|e| E2eError::Sqs(format!("Failed to receive messages: {}", e)))?;
141
142            let messages: Vec<String> = result
143                .messages()
144                .iter()
145                .filter_map(|msg| msg.body().map(|b| b.to_string()))
146                .collect();
147
148            if messages.is_empty() {
149                // No more messages available
150                break;
151            }
152
153            // Delete received messages so they don't show up again
154            for msg in result.messages() {
155                if let Some(receipt_handle) = msg.receipt_handle() {
156                    let _ = self
157                        .client
158                        .delete_message()
159                        .queue_url(&self.queue_url)
160                        .receipt_handle(receipt_handle)
161                        .send()
162                        .await;
163                }
164            }
165
166            all_messages.extend(messages);
167        }
168
169        info!(
170            "Received {} messages from SQS queue {}",
171            all_messages.len(),
172            self.queue_name
173        );
174
175        Ok(all_messages)
176    }
177
178    /// Get the approximate number of messages in the queue
179    pub async fn get_message_count(&self) -> Result<i64> {
180        let result = self
181            .client
182            .get_queue_attributes()
183            .queue_url(&self.queue_url)
184            .attribute_names(aws_sdk_sqs::types::QueueAttributeName::ApproximateNumberOfMessages)
185            .send()
186            .await
187            .map_err(|e| E2eError::Sqs(format!("Failed to get queue attributes: {}", e)))?;
188
189        let count = result
190            .attributes()
191            .and_then(|attrs| {
192                attrs
193                    .get(&aws_sdk_sqs::types::QueueAttributeName::ApproximateNumberOfMessages)
194                    .and_then(|v| v.parse::<i64>().ok())
195            })
196            .unwrap_or(0);
197
198        Ok(count)
199    }
200
201    /// Delete the queue (can be called explicitly if needed)
202    #[allow(dead_code)]
203    pub async fn cleanup(&self) -> Result<()> {
204        self.client
205            .delete_queue()
206            .queue_url(&self.queue_url)
207            .send()
208            .await
209            .map_err(|e| E2eError::Sqs(format!("Failed to delete SQS queue: {}", e)))?;
210
211        info!("Deleted SQS queue: {}", self.queue_name);
212        Ok(())
213    }
214}
215
216impl Drop for SqsResource {
217    fn drop(&mut self) {
218        // Best-effort cleanup
219        if let Ok(handle) = tokio::runtime::Handle::try_current() {
220            let queue_url = self.queue_url.clone();
221            let endpoint_url = self.endpoint_url.clone();
222            let region = self.region.clone();
223            let queue_name = self.queue_name.clone();
224
225            handle.spawn(async move {
226                let client = Self::create_client(&endpoint_url, &region).await;
227
228                if let Err(e) = client.delete_queue().queue_url(&queue_url).send().await {
229                    tracing::warn!("Failed to delete SQS queue {}: {}", queue_name, e);
230                } else {
231                    info!("Deleted SQS queue: {}", queue_name);
232                }
233            });
234        }
235    }
236}