payload_offloading_for_aws/offload/
sns.rs

1use super::id_provider::RandomUuidProvider;
2use super::offloading::{s3_client, try_offload_body_blocking, OFFLOADED_MARKER_ATTRIBUTE};
3use super::{error::OffloadInterceptorError, id_provider::IdProvider};
4use aws_config::SdkConfig;
5use aws_sdk_s3::Client as S3Client;
6use aws_sdk_sns::types::{MessageAttributeValue, PublishBatchRequestEntry};
7use aws_sdk_sns::{
8    config::{Config as SnsConfig, Intercept, RuntimeComponents},
9    operation::{publish::PublishInput, publish_batch::PublishBatchInput},
10    Client as SnsClient,
11};
12use aws_smithy_runtime_api::client::interceptors::context::{self};
13use aws_smithy_types::config_bag::ConfigBag;
14use core::str;
15use std::collections::HashMap;
16use tracing::{error, info};
17
18#[derive(Debug)]
19pub struct S3OffloadInterceptor<Idp: IdProvider> {
20    s3_client: S3Client,
21    id_provider: Idp,
22    bucket_name: String,
23    max_body_size: usize,
24}
25
26impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> S3OffloadInterceptor<Idp> {
27    pub fn new(
28        aws_config: &SdkConfig,
29        id_provider: Idp,
30        bucket_name: String,
31        max_body_size: usize,
32    ) -> Self {
33        S3OffloadInterceptor {
34            s3_client: s3_client(aws_config),
35            id_provider,
36            bucket_name,
37            max_body_size,
38        }
39    }
40}
41
42impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> Intercept for S3OffloadInterceptor<Idp> {
43    fn name(&self) -> &'static str {
44        "SNSS3OffloadInterceptor"
45    }
46
47    fn modify_before_serialization(
48        &self,
49        context: &mut context::BeforeSerializationInterceptorContextMut<'_>,
50        _runtime_components: &RuntimeComponents,
51        _cfg: &mut ConfigBag,
52    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
53        info!("Modifying context before serialization: {:?}", context);
54        let input_mut = context.input_mut();
55        if let Some(original_input) = input_mut.downcast_mut::<PublishInput>() {
56            if let Some(modified_input) = try_offload_sns_publish_input(
57                original_input,
58                &self.s3_client,
59                &self.id_provider,
60                &self.bucket_name,
61                self.max_body_size,
62            )? {
63                info!("Sns PublishInput Body modified: {:?}", modified_input);
64                *original_input = modified_input;
65            }
66        } else if let Some(original_input) = input_mut.downcast_mut::<PublishBatchInput>() {
67            if let Some(modified_input) = try_offload_sns_publish_batch_input(
68                original_input,
69                &self.s3_client,
70                &self.id_provider,
71                &self.bucket_name,
72                self.max_body_size,
73            )? {
74                info!("Sns PublishBatchInput Body modified: {:?}", modified_input);
75                *original_input = modified_input;
76            }
77        }
78        Ok(())
79    }
80}
81
82fn try_offload_sns_publish_input<Idp: IdProvider>(
83    original_input: &PublishInput,
84    s3_client: &S3Client,
85    id_provider: &Idp,
86    bucket_name: &str,
87    max_body_size: usize,
88) -> Result<Option<PublishInput>, OffloadInterceptorError> {
89    let maybe_modified_body_and_size = original_input.message().and_then(|body| {
90        try_offload_body_blocking(body, s3_client, id_provider, bucket_name, max_body_size)
91            .inspect_err(|e| error!("Error offloading content: {}", e))
92            .ok()
93            .flatten()
94            .map(|rr| (rr, body.len()))
95    });
96
97    let maybe_modified_input = maybe_modified_body_and_size.map(|(body, original_length)| {
98        let mut modified_builder = original_input.to_owned();
99        modified_builder.message = Some(body);
100
101        modified_builder.message_attributes = Some(put_offloaded_content_marker_attribute(
102            modified_builder.message_attributes,
103            original_length,
104        )?);
105
106        Ok(modified_builder)
107    });
108
109    maybe_modified_input.transpose()
110}
111
112fn try_offload_sns_publish_batch_input<Idp: IdProvider>(
113    original_input: &mut PublishBatchInput,
114    s3_client: &S3Client,
115    id_provider: &Idp,
116    bucket_name: &str,
117    max_body_size: usize,
118) -> Result<Option<PublishBatchInput>, OffloadInterceptorError> {
119    let any_offload_candidates = original_input
120        .publish_batch_request_entries()
121        .iter()
122        .any(|e| e.message().len() > max_body_size);
123
124    if !any_offload_candidates {
125        return Ok(None);
126    }
127
128    let mut modified_builder = PublishBatchInput::builder();
129    if let Some(value) = original_input.topic_arn() {
130        modified_builder = modified_builder.topic_arn(value);
131    }
132
133    modified_builder = modified_builder.set_publish_batch_request_entries(
134        original_input
135            .publish_batch_request_entries
136            .as_ref()
137            .map(|entries| {
138                entries
139                    .iter()
140                    .flat_map(|orig_entry| {
141                        let offloaded = try_offload_body_blocking(
142                            orig_entry.message(),
143                            s3_client,
144                            id_provider,
145                            bucket_name,
146                            max_body_size,
147                        );
148
149                        let mut modified_entry = orig_entry.to_owned();
150
151                        match offloaded {
152                            Ok(Some(offloaded_body)) => {
153                                modified_entry.message = offloaded_body;
154
155                                modified_entry.message_attributes = Some(put_offloaded_content_marker_attribute(
156                                    modified_entry.message_attributes,
157                                    orig_entry.message().len(),
158                                )?)
159                            },
160                            Err(e) => {
161                                error!("Error offloading batch entry body \"{}\": {}. Failing back to original body", orig_entry.message(), e);
162                            },
163                            _ => {}
164                        }
165
166                        Ok::<PublishBatchRequestEntry, OffloadInterceptorError>(modified_entry)
167                    })
168                    .collect()
169            }),
170    );
171
172    modified_builder
173        .build()
174        .map_err(|e| OffloadInterceptorError::FailedToBuildType(e.to_string()))
175        .map(Some)
176}
177
178pub fn offloading_client(
179    aws_config: &SdkConfig,
180    offloading_bucket: &str,
181    max_non_offloaded_size: usize,
182) -> SnsClient {
183    let s3_offload_interceptor = S3OffloadInterceptor::new(
184        aws_config,
185        RandomUuidProvider::default(),
186        offloading_bucket.to_owned(),
187        max_non_offloaded_size,
188    );
189
190    SnsClient::from_conf(
191        SnsConfig::new(aws_config)
192            .to_builder()
193            .interceptor(s3_offload_interceptor)
194            .build(),
195    )
196}
197
198// Need to set that for extended client libs to pick it up and download the contents automatically.
199//   See: https://github.com/awslabs/amazon-sqs-java-extended-client-lib/blob/07d988c424dea7e4e7d128b217182e1414310560/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java#L24
200fn put_offloaded_content_marker_attribute(
201    message_attributes: Option<HashMap<String, MessageAttributeValue>>,
202    original_body_length: usize,
203) -> Result<HashMap<String, MessageAttributeValue>, OffloadInterceptorError> {
204    let mut modified_attributes = message_attributes.clone().unwrap_or_default();
205    modified_attributes.insert(
206        OFFLOADED_MARKER_ATTRIBUTE.to_owned(),
207        MessageAttributeValue::builder()
208            .set_data_type(Some("Number".to_owned()))
209            .set_string_value(Some(original_body_length.to_string()))
210            .build()
211            .map_err(|e| {
212                OffloadInterceptorError::FailedToRewriteContents(format!(
213                    "Error while building sqs message attributes {e}"
214                ))
215            })?,
216    );
217    Ok(modified_attributes)
218}
219
220#[cfg(test)]
221mod tests {
222    use core::str;
223
224    use aws_sdk_sns::{
225        config::Credentials as SnsCredentials, types::PublishBatchRequestEntry,
226        Client as SnsClient, Config as SnsConfig,
227    };
228    use ctor::ctor;
229    use tracing::info;
230    use wiremock::{
231        matchers::{body_string_contains, method, path},
232        Mock, MockServer, Request, ResponseTemplate,
233    };
234
235    use crate::offload::{
236        id_provider::FixedIdsProvider,
237        sns::S3OffloadInterceptor,
238        test::{
239            self, expect_no_s3_put_calls, given_s3_put_fails, given_s3_put_succeeds,
240            mock_aws_endpoint_config, s3_config,
241        },
242    };
243
244    const TEST_BUCKET: &str = "my-bucket";
245    const TEST_TOPIC: &str = "my-topic";
246    const TEST_RANDOM_ID: &str = "my-id-123";
247
248    #[ctor]
249    fn before_all() {
250        test::init();
251    }
252
253    // By default test runs on 1 thread and blocks, need to use multi_thread for s3 offloads not to block
254    #[tokio::test(flavor = "multi_thread")]
255    async fn publish_batch_no_offload_needed() {
256        const MSG_BODY: &str = "publish-batch-no-offload-msg";
257
258        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
259
260        expect_no_s3_put_calls(&mock_server).await;
261        expect_final_sns_batch_body(&mock_server, MSG_BODY).await;
262
263        when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
264    }
265
266    #[tokio::test(flavor = "multi_thread")]
267    async fn publish_batch_if_offload_fails_then_original_sent() {
268        const MSG_BODY: &str = "publish-batch-offload_fails";
269
270        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
271
272        given_s3_put_fails(&mock_server).await;
273        expect_final_sns_batch_body(&mock_server, MSG_BODY).await;
274
275        when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
276    }
277
278    #[tokio::test(flavor = "multi_thread")]
279    async fn publish_batch_offload_succeeds() {
280        const MSG_BODY: &str = "publish-batch-offload-succeeds";
281        let expected_offload_fragment = batch_offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
282
283        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
284
285        given_s3_put_succeeds(&mock_server).await;
286        expect_final_sns_batch_body(&mock_server, &expected_offload_fragment).await;
287
288        when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
289    }
290
291    #[tokio::test(flavor = "multi_thread")]
292    async fn publish_message_no_offload_needed() {
293        const MSG_BODY: &str = "publish-message-no-offload-msg";
294
295        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
296
297        expect_no_s3_put_calls(&mock_server).await;
298        expect_final_sns_message_body(&mock_server, MSG_BODY).await;
299
300        when_calling_sns_publish_message_with_body(&sns_client, MSG_BODY).await;
301    }
302
303    #[tokio::test(flavor = "multi_thread")]
304    async fn publish_message_offload_succeeds() {
305        const MSG_BODY: &str = "publish-message-offload-succeeds";
306        let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
307
308        let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
309
310        given_s3_put_succeeds(&mock_server).await;
311        expect_final_sns_message_body(&mock_server, &expected_offload_fragment).await;
312
313        when_calling_sns_publish_message_with_body(&sns_client, MSG_BODY).await;
314    }
315
316    async fn setup_sdk_client(max_non_offload_size: usize) -> (MockServer, aws_sdk_sns::Client) {
317        let mock_server = MockServer::start().await;
318
319        let base_endpoint_url = &mock_server.uri();
320
321        info!("Base endpoint url: {}", base_endpoint_url);
322        let s3_offload_interceptor = S3OffloadInterceptor::new(
323            &s3_config(base_endpoint_url).await,
324            FixedIdsProvider::new(vec![TEST_RANDOM_ID]),
325            TEST_BUCKET.to_owned(),
326            max_non_offload_size,
327        );
328
329        let sns_config = SnsConfig::new(
330            &mock_aws_endpoint_config(base_endpoint_url, "sns")
331                .credentials_provider(SnsCredentials::for_tests())
332                .load()
333                .await,
334        )
335        .to_builder()
336        .interceptor(s3_offload_interceptor)
337        .build();
338
339        let sns_client = aws_sdk_sns::Client::from_conf(sns_config);
340
341        (mock_server, sns_client)
342    }
343
344    async fn expect_final_sns_batch_body(mock_server: &MockServer, message_body: &str) {
345        let expected_body_fragment = format!(
346            "PublishBatchRequestEntries.member.1.Message={}",
347            message_body.to_owned()
348        );
349
350        Mock::given(method("POST"))
351            .and(path("/sns/"))
352            .and(body_string_contains("Action=PublishBatch"))
353            .respond_with(move |r: &Request| {
354                info!("SNS Request: {:?}", r);
355
356                let intercepted_body = str::from_utf8(&r.body).unwrap();
357                info!("SNS Body: {:?}", intercepted_body);
358
359
360                assert!(intercepted_body.contains(&expected_body_fragment),
361                        "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\"");
362
363                ResponseTemplate::new(200).set_body_raw("<PublishBatchResponse><PublishBatchResult></PublishBatchResult></PublishBatchResponse>", "text/xml")
364            }).expect(1)
365            .named("Final SNS batch body")
366            .mount(mock_server)
367            .await;
368    }
369
370    async fn expect_final_sns_message_body(mock_server: &MockServer, message_body: &str) {
371        let expected_body_fragment = format!("Message={}", message_body.to_owned());
372
373        Mock::given(method("POST"))
374            //.and(path("/sns/"))
375            //.and(body_string_contains("Action=Publish"))
376            .respond_with(move |r: &Request| {
377                info!("SNS Request: {:?}", r);
378
379                let intercepted_body = str::from_utf8(&r.body).unwrap();
380                info!("SNS Body: {:?}", intercepted_body);
381
382                assert!(
383                    intercepted_body.contains(&expected_body_fragment),
384                    "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
385                );
386
387                ResponseTemplate::new(200).set_body_raw(
388                    "<PublishResponse><PublishResult></PublishResult></PublishResponse>",
389                    "text/xml",
390                )
391            })
392            .expect(1)
393            .named("Final SNS batch body")
394            .mount(mock_server)
395            .await;
396    }
397
398    async fn when_calling_sns_publish_batch_with_body(sns_client: &SnsClient, message_body: &str) {
399        let _res = sns_client
400            .publish_batch()
401            .topic_arn(TEST_TOPIC.to_owned())
402            .set_publish_batch_request_entries(Some(vec![PublishBatchRequestEntry::builder()
403                .id("someid")
404                .message(message_body)
405                .build()
406                .unwrap()]))
407            .send()
408            .await
409            .unwrap();
410    }
411
412    async fn when_calling_sns_publish_message_with_body(
413        sns_client: &SnsClient,
414        message_body: &str,
415    ) {
416        let _res = sns_client
417            .publish()
418            .topic_arn(TEST_TOPIC.to_owned())
419            .set_message(Some(message_body.to_owned()))
420            .send()
421            .await
422            .unwrap();
423    }
424
425    fn offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
426        let msg = format!(
427            r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{{"s3BucketName": "{bucket_name}","s3Key": "{bucket_key}"}}]"#
428        );
429        format!(
430            "{}{}",
431            urlencoding::encode(&msg).into_owned(),
432            "&MessageAttributes.entry.1.Name=ExtendedPayloadSize"
433        )
434    }
435
436    fn batch_offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
437        let msg = format!(
438            r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{{"s3BucketName": "{bucket_name}","s3Key": "{bucket_key}"}}]"#
439        );
440        format!(
441            "{}{}",
442            urlencoding::encode(&msg).into_owned(),
443            "&PublishBatchRequestEntries.member.1.MessageAttributes\
444            .entry.1.Name=ExtendedPayloadSize"
445        )
446    }
447}