1use anyhow::Result;
4use aws_sdk_sqs::types::SendMessageBatchRequestEntry;
5use derive_more::{Display, From};
6use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
7use serde::{Deserialize, Serialize};
8use std::str::FromStr;
9use thiserror::Error;
10
11pub use aws_sdk_sqs::Client;
14
15const BATCH_SIZE: usize = 10;
16
17pub async fn send_messages_concurrently<Msg: serde::Serialize, St: Stream<Item = Result<Msg>>>(
51 client: &Client,
52 queue_name: &SQSQueueName,
53 concurrency: Option<usize>,
54 msg_stream: St,
55) -> Result<()> {
56 if concurrency == Some(0) {
57 anyhow::bail!("Zero concurrency not allowed.");
58 }
59 let queue_url = client
60 .get_queue_url()
61 .queue_name(queue_name.to_string())
62 .send()
63 .await?
64 .queue_url
65 .ok_or_else(|| anyhow::anyhow!("Failed to get queue URL for {}", queue_name))?;
66 msg_stream
67 .map(|msg| Ok::<_, anyhow::Error>(serde_json::to_string(&msg?)?))
68 .enumerate()
69 .map(|(i, s)| {
70 SendMessageBatchRequestEntry::builder()
71 .message_body(s?)
72 .id(format!("{}", i))
73 .build()
74 .map_err(anyhow::Error::from)
75 })
76 .try_chunks(BATCH_SIZE)
77 .map_err(anyhow::Error::from)
78 .inspect_ok(|entries| tracing::debug!("Sending message batch: {:#?}", entries))
79 .map_ok(|entries| {
80 client
81 .send_message_batch()
82 .queue_url(&queue_url)
83 .set_entries(Some(entries))
84 .send()
85 .map_err(anyhow::Error::from)
86 .map_ok(|_| ())
87 })
88 .try_buffered(concurrency.unwrap_or(1))
89 .try_collect::<()>()
90 .await
91}
92
93#[derive(Clone, Debug, Display, From, Eq, PartialEq, Serialize, Deserialize)]
101pub struct SQSQueueName(String);
102
103impl AsRef<str> for SQSQueueName {
104 fn as_ref(&self) -> &str {
105 &self.0
106 }
107}
108
109#[derive(Error, Debug)]
110pub enum SQSQueueNameError {
111 #[error("Invalid length, expected between 1 and 80 characters, received: {0}")]
112 InvalidLength(usize),
113 #[error("The following characters are accepted: alphanumeric characters, hyphens (-), and underscores (_)")]
114 InvalidCharacters,
115}
116
117const MAX_QUEUE_NAME_LENGTH: usize = 80;
118
119impl FromStr for SQSQueueName {
120 type Err = SQSQueueNameError;
121
122 fn from_str(s: &str) -> Result<Self, Self::Err> {
123 if s.len() > MAX_QUEUE_NAME_LENGTH {
124 Err(SQSQueueNameError::InvalidLength(s.len()))
125 } else if s.is_empty() {
126 Err(SQSQueueNameError::InvalidLength(0))
127 } else if !s
128 .chars()
129 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
130 {
131 Err(SQSQueueNameError::InvalidCharacters)
132 } else {
133 Ok(SQSQueueName(s.to_string()))
134 }
135 }
136}
137
138#[cfg(test)]
139mod test_send_messages {
140 use crate::{config::load_from_env, localstack};
141
142 use super::*;
143 use aws_sdk_sqs::{
144 error::ProvideErrorMetadata, operation::get_queue_url::GetQueueUrlError,
145 types::DeleteMessageBatchRequestEntry,
146 };
147 use futures::stream;
148 use serial_test::serial;
149 use tokio;
150
151 const MAX_MESSAGES: usize = 10;
152
153 async fn localstack_test_client() -> Client {
154 localstack::test_utils::wait_for_localstack().await;
155 let shared_config = load_from_env().await.unwrap();
156 aws_sdk_sqs::Client::new(&shared_config)
157 }
158
159 async fn consume_queue(client: &Client, queue_name: &SQSQueueName) -> Vec<usize> {
160 let queue_url = client
161 .get_queue_url()
162 .queue_name(queue_name.to_string())
163 .send()
164 .await
165 .unwrap()
166 .queue_url
167 .unwrap();
168
169 let mut results: Vec<usize> = vec![];
170 while let Ok(x) = client
171 .receive_message()
172 .max_number_of_messages(MAX_MESSAGES as i32)
173 .wait_time_seconds(1)
174 .queue_url(&queue_url)
175 .send()
176 .await
177 {
178 match x.messages {
179 Some(messages) => {
180 assert!(messages.len() <= MAX_MESSAGES);
181
182 let results_delete: Result<Vec<_>, _> = messages
183 .into_iter()
184 .map(|msg| {
185 results.push(msg.body.unwrap().parse().unwrap());
186 DeleteMessageBatchRequestEntry::builder()
187 .set_receipt_handle(msg.receipt_handle)
188 .set_id(msg.message_id)
189 .build()
190 })
191 .collect();
192 let results_delete =
193 results_delete.expect("Errors not expected building results_delete");
194
195 if results_delete.is_empty() {
196 break;
197 }
198
199 client
200 .delete_message_batch()
201 .queue_url(&queue_url)
202 .set_entries(Some(results_delete))
203 .send()
204 .await
205 .expect("Error deleting message batch");
206 }
207 None => break,
208 }
209 }
210 results.sort_unstable();
211 results
212 }
213
214 #[tokio::test]
215 #[serial]
216 async fn test_non_existent_queue() {
217 let client = localstack_test_client().await;
218
219 let item_stream = stream::iter(vec![Ok::<u32, _>(1), Ok(2), Ok(3)]);
220
221 let queue_name = &SQSQueueName::from_str("non-existent-queue").unwrap();
222
223 let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
224 let e = result.unwrap_err();
225 let e = e
226 .source()
227 .unwrap()
228 .downcast_ref::<GetQueueUrlError>()
229 .unwrap();
230
231 assert!(matches!(e, GetQueueUrlError::QueueDoesNotExist(_)));
232 assert_eq!(e.code(), Some("AWS.SimpleQueueService.NonExistentQueue"));
233 }
234
235 #[tokio::test]
236 #[serial]
237 async fn test_item_stream_error() {
238 let client = localstack_test_client().await;
239
240 let item_stream = stream::iter(vec![
241 Ok::<u32, _>(1),
242 Ok(2),
243 Err(anyhow::anyhow!("some error")),
244 Ok(3),
245 ]);
246
247 let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
248
249 let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
250 let e = result.unwrap_err();
251 assert_eq!(e.to_string(), "some error");
252
253 let values = consume_queue(&client, queue_name).await;
254 assert!(values.is_empty());
255 }
256
257 #[tokio::test]
258 #[serial]
259 async fn test_less_than_batch_size() {
260 let client = localstack_test_client().await;
261
262 let item_stream = stream::iter((0..5).map(Ok));
263
264 let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
265
266 let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
267 result.unwrap();
268
269 let values = consume_queue(&client, queue_name).await;
270 assert_eq!(values, (0..5).collect::<Vec<_>>());
271 }
272
273 #[tokio::test]
274 #[serial]
275 async fn test_more_than_batch_size() {
276 let client = localstack_test_client().await;
277
278 let item_stream = stream::iter((0..25).map(Ok));
279
280 let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
281
282 let result = send_messages_concurrently(&client, queue_name, None, item_stream).await;
283 result.unwrap();
284
285 let values = consume_queue(&client, queue_name).await;
286 assert_eq!(values, (0..25).collect::<Vec<_>>());
287 }
288
289 #[tokio::test]
290 #[serial]
291 async fn test_concurrent() {
292 let client = localstack_test_client().await;
293
294 let item_stream = stream::iter((0..105).map(Ok));
295
296 let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
297
298 let result = send_messages_concurrently(&client, queue_name, Some(5), item_stream).await;
299 result.unwrap();
300
301 let values = consume_queue(&client, queue_name).await;
302 assert_eq!(values, (0..105).collect::<Vec<_>>());
303 }
304
305 #[tokio::test]
306 #[serial]
307 async fn test_zero_concurrency() {
308 let client = localstack_test_client().await;
309
310 let item_stream = stream::iter((0..105).map(Ok));
311
312 let queue_name = &SQSQueueName::from_str("test-queue").unwrap();
313
314 let result = send_messages_concurrently(&client, queue_name, Some(0), item_stream).await;
315 let e = result.unwrap_err();
316
317 assert_eq!(e.to_string(), "Zero concurrency not allowed.");
318
319 let values = consume_queue(&client, queue_name).await;
320 assert!(values.is_empty());
321 }
322}
323
324#[cfg(any(test, feature = "test-utils"))]
332mod test_support {
333 use super::*;
334 use proptest::prelude::*;
335 use proptest::strategy::{BoxedStrategy, Strategy};
336
337 impl Arbitrary for SQSQueueName {
339 type Parameters = ();
340 type Strategy = BoxedStrategy<Self>;
341
342 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
343 let pattern = "[a-zA-Z0-9_-]{1,80}";
344 proptest::string::string_regex(pattern)
345 .expect("Invalid regex pattern for SQSQueueName")
346 .prop_map(|s| SQSQueueName::from_str(&s).unwrap())
347 .boxed()
348 }
349 }
350
351 #[allow(dead_code)]
352 pub fn sqs_name_arbitrary_invalid() -> BoxedStrategy<String> {
353 let too_short = Just("".to_string()); let too_long = "a".repeat(MAX_QUEUE_NAME_LENGTH + 1); let too_long = Just(too_long);
357
358 let invalid_chars = "[*?%!]{1,10}"; let invalid_chars = proptest::string::string_regex(invalid_chars)
360 .expect("Invalid regex pattern for generating invalid SQSQueueName");
361
362 prop_oneof![too_short, too_long, invalid_chars].boxed()
363 }
364}
365
366#[cfg(test)]
367mod prop_tests {
368 use super::test_support::*;
393 use crate::sqs::SQSQueueName;
394 use assert_matches::assert_matches;
395 use proptest::prelude::*;
396 use std::str::FromStr;
397
398 proptest! {
399
400 #[test]
403 fn sqs_name_round_trip_test(queue_name: SQSQueueName) {
404 let serialized = serde_json::to_string(&queue_name).expect("SQS queue name should be valid");
405 let deserialized: SQSQueueName = serde_json::from_str(&serialized).expect("Input json should be valid");
406 prop_assert_eq!(deserialized, queue_name);
407 }
408
409 #[test]
411 fn test_invalid_sqs_name(invalid_str in sqs_name_arbitrary_invalid()) {
412 assert_matches!(SQSQueueName::from_str(&invalid_str), Err(_))
413 }
414
415 }
416}