cobalt_aws/
sqs.rs

1//! A collection of wrappers around the [aws_sdk_sqs](https://docs.rs/aws-sdk-sqs/latest/aws_sdk_sqs/) crate.
2
3use anyhow::Result;
4use aws_sdk_sqs::types::SendMessageBatchRequestEntry;
5use derive_more::{Display, From};
6use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
7use serde::{Deserialize, Serialize};
8use std::str::FromStr;
9use thiserror::Error;
10
11/// Re-export of [aws_sdk_sqs::client::Client](https://docs.rs/aws-sdk-sqs/latest/aws_sdk_sqs/client/struct.Client.html).
12///
13pub use aws_sdk_sqs::Client;
14
15const BATCH_SIZE: usize = 10;
16
17/// Send message to a queue from a stream with concurrent invocations of `SendMessageBatch`.
18///
19/// This function retrieves items from the stream until it has exhausted. Any errors
20/// in the stream, or while processing the stream, are returned as soon as they are encountered.
21///
22/// If `concurrency` is `None` then message batches are sent sequentially. A `concurrency` value of
23/// zero, `Some(0)`, is not allowed and will result in an error.
24///
25/// # Example
26///
27/// ```no_run
28/// use aws_config;
29/// use futures::stream;
30/// use std::str::FromStr;
31/// use cobalt_aws::sqs::{Client, send_messages_concurrently, SQSQueueName};
32/// use cobalt_aws::config::load_from_env;
33///
34/// # tokio_test::block_on(async {
35/// let shared_config = load_from_env().await.unwrap();
36/// let client = Client::new(&shared_config);
37///
38/// let messages = stream::iter(vec![Ok("Hello"), Ok("world")]);
39/// let queue_name = &SQSQueueName::from_str("MyQueue").unwrap();
40///
41/// // Send up to 4 concurrent API requests at once.
42/// send_messages_concurrently(&client, queue_name, Some(4), messages).await.unwrap();
43/// # })
44/// ```
45///
46/// # Implementation details
47///
48/// This function uses the [SendMessageBatch](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessageBatch.html)
49/// API to send message in batches of 10, which is the maximum allowed batch size.
50pub async fn send_messages_concurrently<Msg: serde::Serialize, St: Stream<Item = Result<Msg>>>(
51    client: &Client,
52    queue_name: &SQSQueueName,
53    concurrency: Option<usize>,
54    msg_stream: St,
55) -> Result<()> {
56    if concurrency == Some(0) {
57        anyhow::bail!("Zero concurrency not allowed.");
58    }
59    let queue_url = client
60        .get_queue_url()
61        .queue_name(queue_name.to_string())
62        .send()
63        .await?
64        .queue_url
65        .ok_or_else(|| anyhow::anyhow!("Failed to get queue URL for {}", queue_name))?;
66    msg_stream
67        .map(|msg| Ok::<_, anyhow::Error>(serde_json::to_string(&msg?)?))
68        .enumerate()
69        .map(|(i, s)| {
70            SendMessageBatchRequestEntry::builder()
71                .message_body(s?)
72                .id(format!("{}", i))
73                .build()
74                .map_err(anyhow::Error::from)
75        })
76        .try_chunks(BATCH_SIZE)
77        .map_err(anyhow::Error::from)
78        .inspect_ok(|entries| tracing::debug!("Sending message batch: {:#?}", entries))
79        .map_ok(|entries| {
80            client
81                .send_message_batch()
82                .queue_url(&queue_url)
83                .set_entries(Some(entries))
84                .send()
85                .map_err(anyhow::Error::from)
86                .map_ok(|_| ())
87        })
88        .try_buffered(concurrency.unwrap_or(1))
89        .try_collect::<()>()
90        .await
91}
92
93/// The name of an AWS SQS queue.
94///
95/// The `FromStr`` implementation of this type ensures the value is a valid AWS
96/// SQS name. This means it is between 1 and 80 characters, and only contains
97/// alphanumberic characters, hyphens (-), and underscores (_).
98///
99/// Ref: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-queues.html
100#[derive(Clone, Debug, Display, From, Eq, PartialEq, Serialize, Deserialize)]
101pub struct SQSQueueName(String);
102
103impl AsRef<str> for SQSQueueName {
104    fn as_ref(&self) -> &str {
105        &self.0
106    }
107}
108
109#[derive(Error, Debug)]
110pub enum SQSQueueNameError {
111    #[error("Invalid length, expected between 1 and 80 characters, received: {0}")]
112    InvalidLength(usize),
113    #[error("The following characters are accepted: alphanumeric characters, hyphens (-), and underscores (_)")]
114    InvalidCharacters,
115}
116
117const MAX_QUEUE_NAME_LENGTH: usize = 80;
118
119impl FromStr for SQSQueueName {
120    type Err = SQSQueueNameError;
121
122    fn from_str(s: &str) -> Result<Self, Self::Err> {
123        if s.len() > MAX_QUEUE_NAME_LENGTH {
124            Err(SQSQueueNameError::InvalidLength(s.len()))
125        } else if s.is_empty() {
126            Err(SQSQueueNameError::InvalidLength(0))
127        } else if !s
128            .chars()
129            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
130        {
131            Err(SQSQueueNameError::InvalidCharacters)
132        } else {
133            Ok(SQSQueueName(s.to_string()))
134        }
135    }
136}
137
138#[cfg(test)]
139mod test_send_messages {
140    use crate::{config::load_from_env, localstack};
141
142    use super::*;
143    use aws_sdk_sqs::{
144        error::ProvideErrorMetadata, operation::get_queue_url::GetQueueUrlError,
145        types::DeleteMessageBatchRequestEntry,
146    };
147    use futures::stream;
148    use serial_test::serial;
149    use tokio;
150
151    const MAX_MESSAGES: usize = 10;
152
153    async fn localstack_test_client() -> Client {
154        localstack::test_utils::wait_for_localstack().await;
155        let shared_config = load_from_env().await.unwrap();
156        aws_sdk_sqs::Client::new(&shared_config)
157    }
158
159    async fn consume_queue(client: &Client, queue_name: &SQSQueueName) -> Vec<usize> {
160        let queue_url = client
161            .get_queue_url()
162            .queue_name(queue_name.to_string())
163            .send()
164            .await
165            .unwrap()
166            .queue_url
167            .unwrap();
168
169        let mut results: Vec<usize> = vec![];
170        while let Ok(x) = client
171            .receive_message()
172            .max_number_of_messages(MAX_MESSAGES as i32)
173            .wait_time_seconds(1)
174            .queue_url(&queue_url)
175            .send()
176            .await
177        {
178            match x.messages {
179                Some(messages) => {
180                    assert!(messages.len() <= MAX_MESSAGES);
181
182                    let results_delete: Result<Vec<_>, _> = messages
183                        .into_iter()
184                        .map(|msg| {
185                            results.push(msg.body.unwrap().parse().unwrap());
186                            DeleteMessageBatchRequestEntry::builder()
187                                .set_receipt_handle(msg.receipt_handle)
188                                .set_id(msg.message_id)
189                                .build()
190                        })
191                        .collect();
192                    let results_delete =
193                        results_delete.expect("Errors not expected building results_delete");
194
195                    if results_delete.is_empty() {
196                        break;
197                    }
198
199                    client
200                        .delete_message_batch()
201                        .queue_url(&queue_url)
202                        .set_entries(Some(results_delete))
203                        .send()
204                        .await
205                        .expect("Error deleting message batch");
206                }
207                None => break,
208            }
209        }
210        results.sort_unstable();
211        results
212    }
213
214    #[tokio::test]
215    #[serial]
216    async fn test_non_existent_queue() {
217        let client = localstack_test_client().await;
218
219        let item_stream = stream::iter(vec![Ok::<u32, _>(1), Ok(2), Ok(3)]);
220
221        let queue_name = &SQSQueueName::from_str("non-existent-queue").unwrap();
222
223        let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
224        let e = result.unwrap_err();
225        let e = e
226            .source()
227            .unwrap()
228            .downcast_ref::<GetQueueUrlError>()
229            .unwrap();
230
231        assert!(matches!(e, GetQueueUrlError::QueueDoesNotExist(_)));
232        assert_eq!(e.code(), Some("AWS.SimpleQueueService.NonExistentQueue"));
233    }
234
235    #[tokio::test]
236    #[serial]
237    async fn test_item_stream_error() {
238        let client = localstack_test_client().await;
239
240        let item_stream = stream::iter(vec![
241            Ok::<u32, _>(1),
242            Ok(2),
243            Err(anyhow::anyhow!("some error")),
244            Ok(3),
245        ]);
246
247        let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
248
249        let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
250        let e = result.unwrap_err();
251        assert_eq!(e.to_string(), "some error");
252
253        let values = consume_queue(&client, queue_name).await;
254        assert!(values.is_empty());
255    }
256
257    #[tokio::test]
258    #[serial]
259    async fn test_less_than_batch_size() {
260        let client = localstack_test_client().await;
261
262        let item_stream = stream::iter((0..5).map(Ok));
263
264        let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
265
266        let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
267        result.unwrap();
268
269        let values = consume_queue(&client, queue_name).await;
270        assert_eq!(values, (0..5).collect::<Vec<_>>());
271    }
272
273    #[tokio::test]
274    #[serial]
275    async fn test_more_than_batch_size() {
276        let client = localstack_test_client().await;
277
278        let item_stream = stream::iter((0..25).map(Ok));
279
280        let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
281
282        let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
283        result.unwrap();
284
285        let values = consume_queue(&client, queue_name).await;
286        assert_eq!(values, (0..25).collect::<Vec<_>>());
287    }
288
289    #[tokio::test]
290    #[serial]
291    async fn test_concurrent() {
292        let client = localstack_test_client().await;
293
294        let item_stream = stream::iter((0..105).map(Ok));
295
296        let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
297
298        let result = send_messages_concurrently(&client, queue_name, Some(5), item_stream).await;
299        result.unwrap();
300
301        let values = consume_queue(&client, queue_name).await;
302        assert_eq!(values, (0..105).collect::<Vec<_>>());
303    }
304
305    #[tokio::test]
306    #[serial]
307    async fn test_zero_concurrency() {
308        let client = localstack_test_client().await;
309
310        let item_stream = stream::iter((0..105).map(Ok));
311
312        let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
313
314        let result = send_messages_concurrently(&client, queue_name, Some(0), item_stream).await;
315        let e = result.unwrap_err();
316
317        assert_eq!(e.to_string(), "Zero concurrency not allowed.");
318
319        let values = consume_queue(&client, queue_name).await;
320        assert!(values.is_empty());
321    }
322}
323
324// This module provides generators for property-based testing.
325// This module provides generators for property-based testing. It is made publicly available
326// under the `test-support` feature flag and during test compilation. This approach ensures
327// that external modules and other crates can optionally include and utilize these generators
328// for their testing purposes when the `test-utils` feature is enabled or during the crate's
329// own test runs. This feature-guarded accessibility helps maintain clean separation between
330// test utilities and production code while enabling code reuse in testing contexts.
331#[cfg(any(test, feature = "test-utils"))]
332mod test_support {
333    use super::*;
334    use proptest::prelude::*;
335    use proptest::strategy::{BoxedStrategy, Strategy};
336
337    // Arbitrary implementation for SQSQueueName for testing
338    impl Arbitrary for SQSQueueName {
339        type Parameters = ();
340        type Strategy = BoxedStrategy<Self>;
341
342        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
343            let pattern = "[a-zA-Z0-9_-]{1,80}";
344            proptest::string::string_regex(pattern)
345                .expect("Invalid regex pattern for SQSQueueName")
346                .prop_map(|s| SQSQueueName::from_str(&s).unwrap())
347                .boxed()
348        }
349    }
350
351    #[allow(dead_code)]
352    pub fn sqs_name_arbitrary_invalid() -> BoxedStrategy<String> {
353        let too_short = Just("".to_string()); // Too short
354
355        let too_long = "a".repeat(MAX_QUEUE_NAME_LENGTH + 1); // Too long
356        let too_long = Just(too_long);
357
358        let invalid_chars = "[*?%!]{1,10}"; // Contains invalid characters
359        let invalid_chars = proptest::string::string_regex(invalid_chars)
360            .expect("Invalid regex pattern for generating invalid SQSQueueName");
361
362        prop_oneof![too_short, too_long, invalid_chars].boxed()
363    }
364}
365
366#[cfg(test)]
367mod prop_tests {
368    /**
369    Module containing property-based tests for the paraent module.
370
371    In these tests, `prop_assert!` is used extensively for a few key reasons:
372
373    1. **Integration with `proptest`**: Unlike `assert!` or `assert_eq!`, `prop_assert!`
374       and its variants (e.g., `prop_assert_eq!`) are designed to work seamlessly
375       within the `proptest` framework. They handle failure reporting in a way that
376       integrates with `proptest`'s test case reduction mechanisms, making it easier
377       to diagnose and understand failures.
378
379    2. **Test Case Reduction**: When a `prop_assert!` fails, `proptest` attempts to
380       "shrink" the input data to the smallest case that still causes the assertion
381       to fail. This simplification process is crucial for debugging and is a major
382       advantage of property-based testing. `prop_assert!` ensures that shrinking
383       behavior works correctly.
384
385    3. **Custom Failure Messages**: Like `assert!`, `prop_assert!` allows for custom
386       failure messages. This feature is particularly useful in complex tests where
387       the default error message may not provide enough context about the failure.
388
389    Using `prop_assert!` correctly is essential for leveraging the full power of
390    property-based testing with `proptest`.
391    */
392    use super::test_support::*;
393    use crate::sqs::SQSQueueName;
394    use assert_matches::assert_matches;
395    use proptest::prelude::*;
396    use std::str::FromStr;
397
398    proptest! {
399
400        // Tests that serialization and deserialization of `SQSQueueName` is symmetric,
401        // ensuring that any `SQSQueueName` can be round-tripped to JSON and back without loss.
402        #[test]
403        fn sqs_name_round_trip_test(queue_name: SQSQueueName) {
404            let serialized = serde_json::to_string(&queue_name).expect("SQS queue name should be valid");
405            let deserialized:  SQSQueueName = serde_json::from_str(&serialized).expect("Input json should be valid");
406            prop_assert_eq!(deserialized, queue_name);
407        }
408
409        //Invalid SQSNames should fail
410        #[test]
411        fn test_invalid_sqs_name(invalid_str in sqs_name_arbitrary_invalid()) {
412            assert_matches!(SQSQueueName::from_str(&invalid_str), Err(_))
413        }
414
415    }
416}