1use std::collections::HashMap;
2
3use super::id_provider::{IdProvider, RandomUuidProvider};
4use super::offloading::{s3_client, try_offload_body_blocking, OFFLOADED_MARKER_ATTRIBUTE};
5use super::{error::OffloadInterceptorError, offloading::deserialize_s3_pointer};
6use aws_config::SdkConfig;
7use aws_sdk_s3::Client as S3Client;
8use aws_sdk_sqs::types::{MessageAttributeValue, SendMessageBatchRequestEntry};
9use aws_sdk_sqs::{
10 config::{Config as SqsConfig, Intercept, RuntimeComponents},
11 operation::{
12 receive_message::ReceiveMessageOutput, send_message::SendMessageInput,
13 send_message_batch::SendMessageBatchInput,
14 },
15 Client as SqsClient,
16};
17use aws_smithy_runtime_api::client::interceptors::context::{self};
18use aws_smithy_types::config_bag::ConfigBag;
19use tracing::error;
20
21#[derive(Debug)]
22pub struct S3OffloadInterceptor<Idp: IdProvider> {
23 s3_client: S3Client,
24 id_provider: Idp,
25 bucket_name: String,
26 max_body_size: usize,
27}
28
29impl<Idp: IdProvider> S3OffloadInterceptor<Idp> {
30 pub fn new(
31 aws_config: &SdkConfig,
32 id_provider: Idp,
33 bucket_name: String,
34 max_body_size: usize,
35 ) -> Self {
36 S3OffloadInterceptor {
37 s3_client: s3_client(aws_config),
38 id_provider,
39 bucket_name,
40 max_body_size,
41 }
42 }
43}
44
45impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> Intercept for S3OffloadInterceptor<Idp> {
46 fn name(&self) -> &'static str {
47 "SQSS3OffloadInterceptor"
48 }
49
50 fn modify_before_serialization(
51 &self,
52 context: &mut context::BeforeSerializationInterceptorContextMut<'_>,
53 _runtime_components: &RuntimeComponents,
54 _cfg: &mut ConfigBag,
55 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
56 tracing::info!("Modifying context before serialization: {:?}", context);
57 let input_mut = context.input_mut();
58 if let Some(original_input) = input_mut.downcast_mut::<SendMessageInput>() {
59 if let Some(modified_input) = try_offload_sqs_send_message_input(
60 original_input,
61 &self.s3_client,
62 &self.id_provider,
63 &self.bucket_name,
64 self.max_body_size,
65 )? {
66 tracing::debug!("Body modified: {:?}", modified_input);
67 *original_input = modified_input;
68 }
69 } else if let Some(original_input) = input_mut.downcast_mut::<SendMessageBatchInput>() {
70 if let Some(modified_input) = try_offload_sqs_send_message_batch_input(
71 original_input,
72 &self.s3_client,
73 &self.id_provider,
74 &self.bucket_name,
75 self.max_body_size,
76 )? {
77 tracing::debug!("Batch Body modified: {:?}", modified_input);
78 *original_input = modified_input;
79 }
80 }
81
82 Ok(())
83 }
84
85 fn modify_before_completion(
86 &self,
87 context: &mut context::FinalizerInterceptorContextMut<'_>,
88 _runtime_components: &RuntimeComponents,
89 _cfg: &mut ConfigBag,
90 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
91 if let Some(Ok(output_mut)) = context.output_or_error_mut() {
92 if let Some(original_output) = output_mut.downcast_mut::<ReceiveMessageOutput>() {
93 if original_output.messages().is_empty() {
94 return Ok(());
95 }
96
97 for m in original_output.messages.as_mut().ok_or(Box::new(
98 OffloadInterceptorError::DeserialisationError(
99 "unable to retrieve messages".to_owned(),
100 ),
101 ))? {
102 if let Some(orig_body) = m.body.as_ref() {
103 let downloaded_body_res = try_downloading_body(&self.s3_client, orig_body);
104 match downloaded_body_res {
105 Ok(Some(downloaded_body)) => {
106 m.body = Some(downloaded_body);
107 }
108 Err(e) => {
109 error!("Error downloading batch entry body \"{}\": {}. Failing back to offloaded body", orig_body, e);
110 }
111 _ => {}
112 }
113 }
114 }
115 }
116 }
117
118 Ok(())
119 }
120}
121
122fn try_offload_sqs_send_message_input<Idp: IdProvider>(
123 original_input: &SendMessageInput,
124 s3_client: &S3Client,
125 id_provider: &Idp,
126 bucket_name: &str,
127 max_body_size: usize,
128) -> Result<Option<SendMessageInput>, OffloadInterceptorError> {
129 let maybe_modified_body_and_size = original_input.message_body().and_then(|body| {
130 try_offload_body_blocking(body, s3_client, id_provider, bucket_name, max_body_size)
131 .inspect_err(|e| error!("Error offloading content: {}", e))
132 .ok()
133 .flatten()
134 .map(|rr| (rr, body.len()))
135 });
136
137 let maybe_modified_input = maybe_modified_body_and_size.map(|(body, original_length)| {
138 let mut modified_builder = original_input.to_owned();
139 modified_builder.message_body = Some(body);
140
141 modified_builder.message_attributes = Some(put_offloaded_content_marker_attribute(
142 modified_builder.message_attributes,
143 original_length,
144 )?);
145
146 Ok(modified_builder)
147 });
148
149 maybe_modified_input.transpose()
150}
151
152fn try_offload_sqs_send_message_batch_input<Idp: IdProvider>(
153 original_input: &mut SendMessageBatchInput,
154 s3_client: &S3Client,
155 id_provider: &Idp,
156 bucket_name: &str,
157 max_body_size: usize,
158) -> Result<Option<SendMessageBatchInput>, OffloadInterceptorError> {
159 let any_offload_candidates = original_input
160 .entries()
161 .iter()
162 .any(|e| e.message_body().len() > max_body_size);
163
164 if !any_offload_candidates {
165 return Ok(None);
166 }
167
168 let mut modified_builder = SendMessageBatchInput::builder();
169 if let Some(value) = original_input.queue_url() {
170 modified_builder = modified_builder.queue_url(value);
171 }
172
173 modified_builder =
174 modified_builder.set_entries(original_input.entries.as_ref().map(|entries| {
175 entries
176 .iter()
177 .flat_map(|orig_entry| {
178 let offloaded = try_offload_body_blocking(
179 orig_entry.message_body(),
180 s3_client,
181 id_provider,
182 bucket_name,
183 max_body_size,
184 );
185
186 let mut modified_entry = orig_entry.to_owned();
187
188 match offloaded {
189 Ok(Some(offloaded_body)) => {
190 modified_entry.message_body = offloaded_body;
191 modified_entry.message_attributes =
192 Some(put_offloaded_content_marker_attribute(
193 modified_entry.message_attributes,
194 orig_entry.message_body().len(),
195 )?);
196 }
197 Err(e) => {
198 error!(
199 "Error offloading batch entry body \"{}\": {}. \
200 Failing back to original body",
201 orig_entry.message_body(),
202 e
203 );
204 }
205 _ => {}
206 }
207
208 Ok::<SendMessageBatchRequestEntry, OffloadInterceptorError>(modified_entry)
209 })
210 .collect()
211 }));
212
213 modified_builder
214 .build()
215 .map_err(|e| OffloadInterceptorError::FailedToBuildType(e.to_string()))
216 .map(Some)
217}
218
219pub fn try_downloading_body(
220 s3_client: &S3Client,
221 b: &str,
222) -> Result<Option<String>, OffloadInterceptorError> {
223 if !b.contains("PayloadS3Pointer") {
224 return Ok(None);
225 }
226
227 let deserialized_ptr = deserialize_s3_pointer(b)?;
228
229 Ok(Some(crate::offload::offloading::download_from_s3(
230 s3_client,
231 &deserialized_ptr,
232 )?))
233}
234
235pub fn offloading_client(
236 aws_config: &SdkConfig,
237 offloading_bucket: &str,
238 max_non_offloaded_size: usize,
239) -> SqsClient {
240 let s3_offload_interceptor = S3OffloadInterceptor::new(
241 aws_config,
242 RandomUuidProvider::default(),
243 offloading_bucket.to_owned(),
244 max_non_offloaded_size,
245 );
246
247 SqsClient::from_conf(
248 SqsConfig::new(aws_config)
249 .to_builder()
250 .interceptor(s3_offload_interceptor)
251 .build(),
252 )
253}
254
255fn put_offloaded_content_marker_attribute(
258 message_attributes: Option<HashMap<String, MessageAttributeValue>>,
259 original_body_length: usize,
260) -> Result<HashMap<String, MessageAttributeValue>, OffloadInterceptorError> {
261 let mut modified_attributes = message_attributes.clone().unwrap_or_default();
262 modified_attributes.insert(
263 OFFLOADED_MARKER_ATTRIBUTE.to_owned(),
264 MessageAttributeValue::builder()
265 .set_data_type(Some("Number".to_owned()))
266 .set_string_value(Some(original_body_length.to_string()))
267 .build()
268 .map_err(|e| {
269 OffloadInterceptorError::FailedToRewriteContents(format!(
270 "Error while building sqs message attributes {e}"
271 ))
272 })?,
273 );
274 Ok(modified_attributes)
275}
276
277#[cfg(test)]
278mod tests {
279 use core::str;
280
281 use aws_sdk_sqs::{
282 config::Credentials as SqsCredentials, types::SendMessageBatchRequestEntry,
283 Client as SqsClient, Config,
284 };
285 use ctor::ctor;
286 use tracing::info;
287 use wiremock::{
288 matchers::{header, method, path},
289 Mock, MockServer, Request, ResponseTemplate,
290 };
291
292 use crate::offload::{
293 id_provider::FixedIdsProvider,
294 sqs::S3OffloadInterceptor,
295 test::{
296 self, expect_no_s3_put_calls, given_s3_put_fails, given_s3_put_succeeds,
297 mock_aws_endpoint_config, s3_config,
298 },
299 };
300
301 const TEST_BUCKET: &str = "my-bucket";
302 const TEST_QUEUE: &str = "my-queue";
303 const TEST_RANDOM_ID: &str = "my-id-123";
304
305 #[ctor]
306 fn before_all() {
307 test::init();
308 }
309
310 #[tokio::test(flavor = "multi_thread")]
312 async fn send_batch_no_offload_needed() {
313 const MSG_BODY: &str = "send-batch-no-offload-msg";
314
315 let (mock_server, sqs_client) = setup_sdk_client(MSG_BODY.len()).await;
316
317 expect_no_s3_put_calls(&mock_server).await;
318 expect_final_sqs_batch_body(&mock_server, MSG_BODY).await;
319
320 when_sending_sqs_batch_with_body(&sqs_client, MSG_BODY).await;
321 }
322
323 #[tokio::test(flavor = "multi_thread")]
324 async fn send_batch_if_offload_fails_then_original_sent() {
325 const MSG_BODY: &str = "send-batch-offload_fails";
326
327 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
328
329 given_s3_put_fails(&mock_server).await;
330 expect_final_sqs_batch_body(&mock_server, MSG_BODY).await;
331 when_sending_sqs_batch_with_body(&sns_client, MSG_BODY).await;
332 }
333
334 #[tokio::test(flavor = "multi_thread")]
335 async fn send_batch_offload_succeeds() {
336 const MSG_BODY: &str = "send-batch-offload-succeeds";
337 let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
338
339 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
340
341 given_s3_put_succeeds(&mock_server).await;
342 expect_final_sqs_batch_body(&mock_server, &expected_offload_fragment).await;
343 when_sending_sqs_batch_with_body(&sns_client, MSG_BODY).await;
344 }
345
346 #[tokio::test(flavor = "multi_thread")]
347 async fn send_message_no_offload_needed() {
348 const MSG_BODY: &str = "send-message-no-offload-msg";
349
350 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
351
352 expect_no_s3_put_calls(&mock_server).await;
353 expect_final_sqs_message_body(&mock_server, MSG_BODY).await;
354
355 when_sending_sqs_message_with_body(&sns_client, MSG_BODY).await;
356 }
357
358 #[tokio::test(flavor = "multi_thread")]
359 async fn send_message_offload_succeeds() {
360 const MSG_BODY: &str = "send-message-offload-succeeds";
361 let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
362
363 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
364
365 given_s3_put_succeeds(&mock_server).await;
366 expect_final_sqs_message_body(&mock_server, &expected_offload_fragment).await;
367
368 when_sending_sqs_message_with_body(&sns_client, MSG_BODY).await;
369 }
370
371 fn offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
372 format!(
373 "[\\\"software.amazon.payloadoffloading.PayloadS3Pointer\\\",\
374 {{\\\"s3BucketName\\\": \\\"{bucket_name}\\\",\\\"s3Key\\\": \\\"{bucket_key}\\\"}}]\",\
375 \"MessageAttributes\":{{\"ExtendedPayloadSize\""
376 )
377 }
378
379 async fn setup_sdk_client(max_non_offload_size: usize) -> (MockServer, aws_sdk_sqs::Client) {
380 let mock_server = MockServer::start().await;
381
382 let base_endpoint_url = &mock_server.uri();
383
384 info!("Base endpoint url: {}", base_endpoint_url);
385 let s3_offload_interceptor = S3OffloadInterceptor::new(
386 &s3_config(base_endpoint_url).await,
387 FixedIdsProvider::new(vec![TEST_RANDOM_ID]),
388 TEST_BUCKET.to_owned(),
389 max_non_offload_size,
390 );
391
392 let config = Config::new(
393 &mock_aws_endpoint_config(base_endpoint_url, "sqs")
394 .credentials_provider(SqsCredentials::for_tests())
395 .load()
396 .await,
397 )
398 .to_builder()
399 .interceptor(s3_offload_interceptor)
400 .build();
401
402 let client = aws_sdk_sqs::Client::from_conf(config);
403
404 (mock_server, client)
405 }
406
407 async fn expect_final_sqs_batch_body(mock_server: &MockServer, message_body: &str) {
408 let expected_body_fragment = format!(
409 r#"{{"QueueUrl":"{}","Entries":[{{"Id":"someid","MessageBody":"{}"#,
410 TEST_QUEUE,
411 message_body.to_owned()
412 );
413
414 Mock::given(method("POST"))
415 .and(path("/sqs/"))
416 .and(header("x-amz-target", "AmazonSQS.SendMessageBatch"))
417 .respond_with(move |r: &Request| {
418 info!("SQS Request: {:?}", r);
419
420 let intercepted_body = str::from_utf8(&r.body).unwrap();
421 info!("SQS Body: {:?}", intercepted_body);
422
423 assert!(
424 intercepted_body.contains(&expected_body_fragment),
425 "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
426 );
427
428 ResponseTemplate::new(200).set_body_raw("{}", "application/json")
429 })
430 .expect(1)
431 .named("Final SQS batch body")
432 .mount(mock_server)
433 .await;
434 }
435
436 async fn expect_final_sqs_message_body(mock_server: &MockServer, message_body: &str) {
437 let expected_body_fragment = format!(
438 r#"{{"QueueUrl":"{}","MessageBody":"{}"#,
439 TEST_QUEUE,
440 message_body.to_owned()
441 );
442
443 Mock::given(method("POST"))
444 .and(path("/sqs/"))
445 .and(header("x-amz-target", "AmazonSQS.SendMessage"))
446 .respond_with(move |r: &Request| {
447 info!("SQS Request: {:?}", r);
448
449 let intercepted_body = str::from_utf8(&r.body).unwrap();
450 info!("SQS Body: {:?}", intercepted_body);
451
452 assert!(
453 intercepted_body.contains(&expected_body_fragment),
454 "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
455 );
456
457 ResponseTemplate::new(200).set_body_raw("{}", "application/json")
458 })
459 .expect(1)
460 .named("Final SQS body")
461 .mount(mock_server)
462 .await;
463 }
464
465 async fn when_sending_sqs_batch_with_body(sqs_client: &SqsClient, message_body: &str) {
466 let _res: aws_sdk_sqs::operation::send_message_batch::SendMessageBatchOutput = sqs_client
467 .send_message_batch()
468 .queue_url(TEST_QUEUE.to_owned())
469 .set_entries(Some(vec![SendMessageBatchRequestEntry::builder()
470 .id("someid")
471 .message_body(message_body)
472 .build()
473 .unwrap()]))
474 .send()
475 .await
476 .unwrap();
477 }
478
479 async fn when_sending_sqs_message_with_body(sqs_client: &SqsClient, message_body: &str) {
480 let _res: aws_sdk_sqs::operation::send_message::SendMessageOutput = sqs_client
481 .send_message()
482 .queue_url(TEST_QUEUE.to_owned())
483 .set_message_body(Some(message_body.to_owned()))
484 .send()
485 .await
486 .unwrap();
487 }
488}