payload_offloading_for_aws/offload/
sqs.rs

1use std::collections::HashMap;
2
3use super::id_provider::{IdProvider, RandomUuidProvider};
4use super::offloading::{s3_client, try_offload_body_blocking, OFFLOADED_MARKER_ATTRIBUTE};
5use super::{error::OffloadInterceptorError, offloading::deserialize_s3_pointer};
6use aws_config::SdkConfig;
7use aws_sdk_s3::Client as S3Client;
8use aws_sdk_sqs::types::{MessageAttributeValue, SendMessageBatchRequestEntry};
9use aws_sdk_sqs::{
10    config::{Config as SqsConfig, Intercept, RuntimeComponents},
11    operation::{
12        receive_message::ReceiveMessageOutput, send_message::SendMessageInput,
13        send_message_batch::SendMessageBatchInput,
14    },
15    Client as SqsClient,
16};
17use aws_smithy_runtime_api::client::interceptors::context::{self};
18use aws_smithy_types::config_bag::ConfigBag;
19use tracing::error;
20
21#[derive(Debug)]
22pub struct S3OffloadInterceptor<Idp: IdProvider> {
23    s3_client: S3Client,
24    id_provider: Idp,
25    bucket_name: String,
26    max_body_size: usize,
27}
28
29impl<Idp: IdProvider> S3OffloadInterceptor<Idp> {
30    pub fn new(
31        aws_config: &SdkConfig,
32        id_provider: Idp,
33        bucket_name: String,
34        max_body_size: usize,
35    ) -> Self {
36        S3OffloadInterceptor {
37            s3_client: s3_client(aws_config),
38            id_provider,
39            bucket_name,
40            max_body_size,
41        }
42    }
43}
44
45impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> Intercept for S3OffloadInterceptor<Idp> {
46    fn name(&self) -> &'static str {
47        "SQSS3OffloadInterceptor"
48    }
49
50    fn modify_before_serialization(
51        &self,
52        context: &mut context::BeforeSerializationInterceptorContextMut<'_>,
53        _runtime_components: &RuntimeComponents,
54        _cfg: &mut ConfigBag,
55    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
56        tracing::info!("Modifying context before serialization: {:?}", context);
57        let input_mut = context.input_mut();
58        if let Some(original_input) = input_mut.downcast_mut::<SendMessageInput>() {
59            if let Some(modified_input) = try_offload_sqs_send_message_input(
60                original_input,
61                &self.s3_client,
62                &self.id_provider,
63                &self.bucket_name,
64                self.max_body_size,
65            )? {
66                tracing::debug!("Body modified: {:?}", modified_input);
67                *original_input = modified_input;
68            }
69        } else if let Some(original_input) = input_mut.downcast_mut::<SendMessageBatchInput>() {
70            if let Some(modified_input) = try_offload_sqs_send_message_batch_input(
71                original_input,
72                &self.s3_client,
73                &self.id_provider,
74                &self.bucket_name,
75                self.max_body_size,
76            )? {
77                tracing::debug!("Batch Body modified: {:?}", modified_input);
78                *original_input = modified_input;
79            }
80        }
81
82        Ok(())
83    }
84
85    fn modify_before_completion(
86        &self,
87        context: &mut context::FinalizerInterceptorContextMut<'_>,
88        _runtime_components: &RuntimeComponents,
89        _cfg: &mut ConfigBag,
90    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
91        if let Some(Ok(output_mut)) = context.output_or_error_mut() {
92            if let Some(original_output) = output_mut.downcast_mut::<ReceiveMessageOutput>() {
93                if original_output.messages().is_empty() {
94                    return Ok(());
95                }
96
97                for m in original_output.messages.as_mut().ok_or(Box::new(
98                    OffloadInterceptorError::DeserialisationError(
99                        "unable to retrieve messages".to_owned(),
100                    ),
101                ))? {
102                    if let Some(orig_body) = m.body.as_ref() {
103                        let downloaded_body_res = try_downloading_body(&self.s3_client, orig_body);
104                        match downloaded_body_res {
105                            Ok(Some(downloaded_body)) => {
106                                m.body = Some(downloaded_body);
107                            }
108                            Err(e) => {
109                                error!("Error downloading batch entry body \"{}\": {}. Failing back to offloaded body", orig_body, e);
110                            }
111                            _ => {}
112                        }
113                    }
114                }
115            }
116        }
117
118        Ok(())
119    }
120}
121
122fn try_offload_sqs_send_message_input<Idp: IdProvider>(
123    original_input: &SendMessageInput,
124    s3_client: &S3Client,
125    id_provider: &Idp,
126    bucket_name: &str,
127    max_body_size: usize,
128) -> Result<Option<SendMessageInput>, OffloadInterceptorError> {
129    let maybe_modified_body_and_size = original_input.message_body().and_then(|body| {
130        try_offload_body_blocking(body, s3_client, id_provider, bucket_name, max_body_size)
131            .inspect_err(|e| error!("Error offloading content: {}", e))
132            .ok()
133            .flatten()
134            .map(|rr| (rr, body.len()))
135    });
136
137    let maybe_modified_input = maybe_modified_body_and_size.map(|(body, original_length)| {
138        let mut modified_builder = original_input.to_owned();
139        modified_builder.message_body = Some(body);
140
141        modified_builder.message_attributes = Some(put_offloaded_content_marker_attribute(
142            modified_builder.message_attributes,
143            original_length,
144        )?);
145
146        Ok(modified_builder)
147    });
148
149    maybe_modified_input.transpose()
150}
151
152fn try_offload_sqs_send_message_batch_input<Idp: IdProvider>(
153    original_input: &mut SendMessageBatchInput,
154    s3_client: &S3Client,
155    id_provider: &Idp,
156    bucket_name: &str,
157    max_body_size: usize,
158) -> Result<Option<SendMessageBatchInput>, OffloadInterceptorError> {
159    let any_offload_candidates = original_input
160        .entries()
161        .iter()
162        .any(|e| e.message_body().len() > max_body_size);
163
164    if !any_offload_candidates {
165        return Ok(None);
166    }
167
168    let mut modified_builder = SendMessageBatchInput::builder();
169    if let Some(value) = original_input.queue_url() {
170        modified_builder = modified_builder.queue_url(value);
171    }
172
173    modified_builder =
174        modified_builder.set_entries(original_input.entries.as_ref().map(|entries| {
175            entries
176                .iter()
177                .flat_map(|orig_entry| {
178                    let offloaded = try_offload_body_blocking(
179                        orig_entry.message_body(),
180                        s3_client,
181                        id_provider,
182                        bucket_name,
183                        max_body_size,
184                    );
185
186                    let mut modified_entry = orig_entry.to_owned();
187
188                    match offloaded {
189                        Ok(Some(offloaded_body)) => {
190                            modified_entry.message_body = offloaded_body;
191                            modified_entry.message_attributes =
192                                Some(put_offloaded_content_marker_attribute(
193                                    modified_entry.message_attributes,
194                                    orig_entry.message_body().len(),
195                                )?);
196                        }
197                        Err(e) => {
198                            error!(
199                                "Error offloading batch entry body \"{}\": {}. \
200                                Failing back to original body",
201                                orig_entry.message_body(),
202                                e
203                            );
204                        }
205                        _ => {}
206                    }
207
208                    Ok::<SendMessageBatchRequestEntry, OffloadInterceptorError>(modified_entry)
209                })
210                .collect()
211        }));
212
213    modified_builder
214        .build()
215        .map_err(|e| OffloadInterceptorError::FailedToBuildType(e.to_string()))
216        .map(Some)
217}
218
219pub fn try_downloading_body(
220    s3_client: &S3Client,
221    b: &str,
222) -> Result<Option<String>, OffloadInterceptorError> {
223    if !b.contains("PayloadS3Pointer") {
224        return Ok(None);
225    }
226
227    let deserialized_ptr = deserialize_s3_pointer(b)?;
228
229    Ok(Some(crate::offload::offloading::download_from_s3(
230        s3_client,
231        &deserialized_ptr,
232    )?))
233}
234
235pub fn offloading_client(
236    aws_config: &SdkConfig,
237    offloading_bucket: &str,
238    max_non_offloaded_size: usize,
239) -> SqsClient {
240    let s3_offload_interceptor = S3OffloadInterceptor::new(
241        aws_config,
242        RandomUuidProvider::default(),
243        offloading_bucket.to_owned(),
244        max_non_offloaded_size,
245    );
246
247    SqsClient::from_conf(
248        SqsConfig::new(aws_config)
249            .to_builder()
250            .interceptor(s3_offload_interceptor)
251            .build(),
252    )
253}
254
255// Need to set that for extended client libs to pick it up and download the contents automatically.
256//   See: https://github.com/awslabs/amazon-sqs-java-extended-client-lib/blob/07d988c424dea7e4e7d128b217182e1414310560/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java#L24
257fn put_offloaded_content_marker_attribute(
258    message_attributes: Option<HashMap<String, MessageAttributeValue>>,
259    original_body_length: usize,
260) -> Result<HashMap<String, MessageAttributeValue>, OffloadInterceptorError> {
261    let mut modified_attributes = message_attributes.clone().unwrap_or_default();
262    modified_attributes.insert(
263        OFFLOADED_MARKER_ATTRIBUTE.to_owned(),
264        MessageAttributeValue::builder()
265            .set_data_type(Some("Number".to_owned()))
266            .set_string_value(Some(original_body_length.to_string()))
267            .build()
268            .map_err(|e| {
269                OffloadInterceptorError::FailedToRewriteContents(format!(
270                    "Error while building sqs message attributes {e}"
271                ))
272            })?,
273    );
274    Ok(modified_attributes)
275}
276
277#[cfg(test)]
278mod tests {
279    use core::str;
280
281    use aws_sdk_sqs::{
282        config::Credentials as SqsCredentials, types::SendMessageBatchRequestEntry,
283        Client as SqsClient, Config,
284    };
285    use ctor::ctor;
286    use tracing::info;
287    use wiremock::{
288        matchers::{header, method, path},
289        Mock, MockServer, Request, ResponseTemplate,
290    };
291
292    use crate::offload::{
293        id_provider::FixedIdsProvider,
294        sqs::S3OffloadInterceptor,
295        test::{
296            self, expect_no_s3_put_calls, given_s3_put_fails, given_s3_put_succeeds,
297            mock_aws_endpoint_config, s3_config,
298        },
299    };
300
301    const TEST_BUCKET: &str = "my-bucket";
302    const TEST_QUEUE: &str = "my-queue";
303    const TEST_RANDOM_ID: &str = "my-id-123";
304
305    #[ctor]
306    fn before_all() {
307        test::init();
308    }
309
310    // By default test runs on 1 thread and blocks, need to use multi_thread for s3 offloads not to block
311    #[tokio::test(flavor = "multi_thread")]
312    async fn send_batch_no_offload_needed() {
313        const MSG_BODY: &str = "send-batch-no-offload-msg";
314
315        let (mock_server, sqs_client) = setup_sdk_client(MSG_BODY.len()).await;
316
317        expect_no_s3_put_calls(&mock_server).await;
318        expect_final_sqs_batch_body(&mock_server, MSG_BODY).await;
319
320        when_sending_sqs_batch_with_body(&sqs_client, MSG_BODY).await;
321    }
322
323    #[tokio::test(flavor = "multi_thread")]
324    async fn send_batch_if_offload_fails_then_original_sent() {
325        const MSG_BODY: &str = "send-batch-offload_fails";
326
327        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
328
329        given_s3_put_fails(&mock_server).await;
330        expect_final_sqs_batch_body(&mock_server, MSG_BODY).await;
331        when_sending_sqs_batch_with_body(&sns_client, MSG_BODY).await;
332    }
333
334    #[tokio::test(flavor = "multi_thread")]
335    async fn send_batch_offload_succeeds() {
336        const MSG_BODY: &str = "send-batch-offload-succeeds";
337        let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
338
339        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
340
341        given_s3_put_succeeds(&mock_server).await;
342        expect_final_sqs_batch_body(&mock_server, &expected_offload_fragment).await;
343        when_sending_sqs_batch_with_body(&sns_client, MSG_BODY).await;
344    }
345
346    #[tokio::test(flavor = "multi_thread")]
347    async fn send_message_no_offload_needed() {
348        const MSG_BODY: &str = "send-message-no-offload-msg";
349
350        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
351
352        expect_no_s3_put_calls(&mock_server).await;
353        expect_final_sqs_message_body(&mock_server, MSG_BODY).await;
354
355        when_sending_sqs_message_with_body(&sns_client, MSG_BODY).await;
356    }
357
358    #[tokio::test(flavor = "multi_thread")]
359    async fn send_message_offload_succeeds() {
360        const MSG_BODY: &str = "send-message-offload-succeeds";
361        let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
362
363        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
364
365        given_s3_put_succeeds(&mock_server).await;
366        expect_final_sqs_message_body(&mock_server, &expected_offload_fragment).await;
367
368        when_sending_sqs_message_with_body(&sns_client, MSG_BODY).await;
369    }
370
371    fn offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
372        format!(
373            "[\\\"software.amazon.payloadoffloading.PayloadS3Pointer\\\",\
374            {{\\\"s3BucketName\\\": \\\"{bucket_name}\\\",\\\"s3Key\\\": \\\"{bucket_key}\\\"}}]\",\
375            \"MessageAttributes\":{{\"ExtendedPayloadSize\""
376        )
377    }
378
379    async fn setup_sdk_client(max_non_offload_size: usize) -> (MockServer, aws_sdk_sqs::Client) {
380        let mock_server = MockServer::start().await;
381
382        let base_endpoint_url = &mock_server.uri();
383
384        info!("Base endpoint url: {}", base_endpoint_url);
385        let s3_offload_interceptor = S3OffloadInterceptor::new(
386            &s3_config(base_endpoint_url).await,
387            FixedIdsProvider::new(vec![TEST_RANDOM_ID]),
388            TEST_BUCKET.to_owned(),
389            max_non_offload_size,
390        );
391
392        let config = Config::new(
393            &mock_aws_endpoint_config(base_endpoint_url, "sqs")
394                .credentials_provider(SqsCredentials::for_tests())
395                .load()
396                .await,
397        )
398        .to_builder()
399        .interceptor(s3_offload_interceptor)
400        .build();
401
402        let client = aws_sdk_sqs::Client::from_conf(config);
403
404        (mock_server, client)
405    }
406
407    async fn expect_final_sqs_batch_body(mock_server: &MockServer, message_body: &str) {
408        let expected_body_fragment = format!(
409            r#"{{"QueueUrl":"{}","Entries":[{{"Id":"someid","MessageBody":"{}"#,
410            TEST_QUEUE,
411            message_body.to_owned()
412        );
413
414        Mock::given(method("POST"))
415            .and(path("/sqs/"))
416            .and(header("x-amz-target", "AmazonSQS.SendMessageBatch"))
417            .respond_with(move |r: &Request| {
418                info!("SQS Request: {:?}", r);
419
420                let intercepted_body = str::from_utf8(&r.body).unwrap();
421                info!("SQS Body: {:?}", intercepted_body);
422
423                assert!(
424                    intercepted_body.contains(&expected_body_fragment),
425                    "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
426                );
427
428                ResponseTemplate::new(200).set_body_raw("{}", "application/json")
429            })
430            .expect(1)
431            .named("Final SQS batch body")
432            .mount(mock_server)
433            .await;
434    }
435
436    async fn expect_final_sqs_message_body(mock_server: &MockServer, message_body: &str) {
437        let expected_body_fragment = format!(
438            r#"{{"QueueUrl":"{}","MessageBody":"{}"#,
439            TEST_QUEUE,
440            message_body.to_owned()
441        );
442
443        Mock::given(method("POST"))
444            .and(path("/sqs/"))
445            .and(header("x-amz-target", "AmazonSQS.SendMessage"))
446            .respond_with(move |r: &Request| {
447                info!("SQS Request: {:?}", r);
448
449                let intercepted_body = str::from_utf8(&r.body).unwrap();
450                info!("SQS Body: {:?}", intercepted_body);
451
452                assert!(
453                    intercepted_body.contains(&expected_body_fragment),
454                    "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
455                );
456
457                ResponseTemplate::new(200).set_body_raw("{}", "application/json")
458            })
459            .expect(1)
460            .named("Final SQS body")
461            .mount(mock_server)
462            .await;
463    }
464
465    async fn when_sending_sqs_batch_with_body(sqs_client: &SqsClient, message_body: &str) {
466        let _res: aws_sdk_sqs::operation::send_message_batch::SendMessageBatchOutput = sqs_client
467            .send_message_batch()
468            .queue_url(TEST_QUEUE.to_owned())
469            .set_entries(Some(vec![SendMessageBatchRequestEntry::builder()
470                .id("someid")
471                .message_body(message_body)
472                .build()
473                .unwrap()]))
474            .send()
475            .await
476            .unwrap();
477    }
478
479    async fn when_sending_sqs_message_with_body(sqs_client: &SqsClient, message_body: &str) {
480        let _res: aws_sdk_sqs::operation::send_message::SendMessageOutput = sqs_client
481            .send_message()
482            .queue_url(TEST_QUEUE.to_owned())
483            .set_message_body(Some(message_body.to_owned()))
484            .send()
485            .await
486            .unwrap();
487    }
488}