1use super::id_provider::RandomUuidProvider;
2use super::offloading::{s3_client, try_offload_body_blocking, OFFLOADED_MARKER_ATTRIBUTE};
3use super::{error::OffloadInterceptorError, id_provider::IdProvider};
4use aws_config::SdkConfig;
5use aws_sdk_s3::Client as S3Client;
6use aws_sdk_sns::types::{MessageAttributeValue, PublishBatchRequestEntry};
7use aws_sdk_sns::{
8 config::{Config as SnsConfig, Intercept, RuntimeComponents},
9 operation::{publish::PublishInput, publish_batch::PublishBatchInput},
10 Client as SnsClient,
11};
12use aws_smithy_runtime_api::client::interceptors::context::{self};
13use aws_smithy_types::config_bag::ConfigBag;
14use core::str;
15use std::collections::HashMap;
16use tracing::{error, info};
17
18#[derive(Debug)]
19pub struct S3OffloadInterceptor<Idp: IdProvider> {
20 s3_client: S3Client,
21 id_provider: Idp,
22 bucket_name: String,
23 max_body_size: usize,
24}
25
26impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> S3OffloadInterceptor<Idp> {
27 pub fn new(
28 aws_config: &SdkConfig,
29 id_provider: Idp,
30 bucket_name: String,
31 max_body_size: usize,
32 ) -> Self {
33 S3OffloadInterceptor {
34 s3_client: s3_client(aws_config),
35 id_provider,
36 bucket_name,
37 max_body_size,
38 }
39 }
40}
41
42impl<Idp: IdProvider + Sync + Send + std::fmt::Debug> Intercept for S3OffloadInterceptor<Idp> {
43 fn name(&self) -> &'static str {
44 "SNSS3OffloadInterceptor"
45 }
46
47 fn modify_before_serialization(
48 &self,
49 context: &mut context::BeforeSerializationInterceptorContextMut<'_>,
50 _runtime_components: &RuntimeComponents,
51 _cfg: &mut ConfigBag,
52 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
53 info!("Modifying context before serialization: {:?}", context);
54 let input_mut = context.input_mut();
55 if let Some(original_input) = input_mut.downcast_mut::<PublishInput>() {
56 if let Some(modified_input) = try_offload_sns_publish_input(
57 original_input,
58 &self.s3_client,
59 &self.id_provider,
60 &self.bucket_name,
61 self.max_body_size,
62 )? {
63 info!("Sns PublishInput Body modified: {:?}", modified_input);
64 *original_input = modified_input;
65 }
66 } else if let Some(original_input) = input_mut.downcast_mut::<PublishBatchInput>() {
67 if let Some(modified_input) = try_offload_sns_publish_batch_input(
68 original_input,
69 &self.s3_client,
70 &self.id_provider,
71 &self.bucket_name,
72 self.max_body_size,
73 )? {
74 info!("Sns PublishBatchInput Body modified: {:?}", modified_input);
75 *original_input = modified_input;
76 }
77 }
78 Ok(())
79 }
80}
81
82fn try_offload_sns_publish_input<Idp: IdProvider>(
83 original_input: &PublishInput,
84 s3_client: &S3Client,
85 id_provider: &Idp,
86 bucket_name: &str,
87 max_body_size: usize,
88) -> Result<Option<PublishInput>, OffloadInterceptorError> {
89 let maybe_modified_body_and_size = original_input.message().and_then(|body| {
90 try_offload_body_blocking(body, s3_client, id_provider, bucket_name, max_body_size)
91 .inspect_err(|e| error!("Error offloading content: {}", e))
92 .ok()
93 .flatten()
94 .map(|rr| (rr, body.len()))
95 });
96
97 let maybe_modified_input = maybe_modified_body_and_size.map(|(body, original_length)| {
98 let mut modified_builder = original_input.to_owned();
99 modified_builder.message = Some(body);
100
101 modified_builder.message_attributes = Some(put_offloaded_content_marker_attribute(
102 modified_builder.message_attributes,
103 original_length,
104 )?);
105
106 Ok(modified_builder)
107 });
108
109 maybe_modified_input.transpose()
110}
111
112fn try_offload_sns_publish_batch_input<Idp: IdProvider>(
113 original_input: &mut PublishBatchInput,
114 s3_client: &S3Client,
115 id_provider: &Idp,
116 bucket_name: &str,
117 max_body_size: usize,
118) -> Result<Option<PublishBatchInput>, OffloadInterceptorError> {
119 let any_offload_candidates = original_input
120 .publish_batch_request_entries()
121 .iter()
122 .any(|e| e.message().len() > max_body_size);
123
124 if !any_offload_candidates {
125 return Ok(None);
126 }
127
128 let mut modified_builder = PublishBatchInput::builder();
129 if let Some(value) = original_input.topic_arn() {
130 modified_builder = modified_builder.topic_arn(value);
131 }
132
133 modified_builder = modified_builder.set_publish_batch_request_entries(
134 original_input
135 .publish_batch_request_entries
136 .as_ref()
137 .map(|entries| {
138 entries
139 .iter()
140 .flat_map(|orig_entry| {
141 let offloaded = try_offload_body_blocking(
142 orig_entry.message(),
143 s3_client,
144 id_provider,
145 bucket_name,
146 max_body_size,
147 );
148
149 let mut modified_entry = orig_entry.to_owned();
150
151 match offloaded {
152 Ok(Some(offloaded_body)) => {
153 modified_entry.message = offloaded_body;
154
155 modified_entry.message_attributes = Some(put_offloaded_content_marker_attribute(
156 modified_entry.message_attributes,
157 orig_entry.message().len(),
158 )?)
159 },
160 Err(e) => {
161 error!("Error offloading batch entry body \"{}\": {}. Failing back to original body", orig_entry.message(), e);
162 },
163 _ => {}
164 }
165
166 Ok::<PublishBatchRequestEntry, OffloadInterceptorError>(modified_entry)
167 })
168 .collect()
169 }),
170 );
171
172 modified_builder
173 .build()
174 .map_err(|e| OffloadInterceptorError::FailedToBuildType(e.to_string()))
175 .map(Some)
176}
177
178pub fn offloading_client(
179 aws_config: &SdkConfig,
180 offloading_bucket: &str,
181 max_non_offloaded_size: usize,
182) -> SnsClient {
183 let s3_offload_interceptor = S3OffloadInterceptor::new(
184 aws_config,
185 RandomUuidProvider::default(),
186 offloading_bucket.to_owned(),
187 max_non_offloaded_size,
188 );
189
190 SnsClient::from_conf(
191 SnsConfig::new(aws_config)
192 .to_builder()
193 .interceptor(s3_offload_interceptor)
194 .build(),
195 )
196}
197
198fn put_offloaded_content_marker_attribute(
201 message_attributes: Option<HashMap<String, MessageAttributeValue>>,
202 original_body_length: usize,
203) -> Result<HashMap<String, MessageAttributeValue>, OffloadInterceptorError> {
204 let mut modified_attributes = message_attributes.clone().unwrap_or_default();
205 modified_attributes.insert(
206 OFFLOADED_MARKER_ATTRIBUTE.to_owned(),
207 MessageAttributeValue::builder()
208 .set_data_type(Some("Number".to_owned()))
209 .set_string_value(Some(original_body_length.to_string()))
210 .build()
211 .map_err(|e| {
212 OffloadInterceptorError::FailedToRewriteContents(format!(
213 "Error while building sqs message attributes {e}"
214 ))
215 })?,
216 );
217 Ok(modified_attributes)
218}
219
220#[cfg(test)]
221mod tests {
222 use core::str;
223
224 use aws_sdk_sns::{
225 config::Credentials as SnsCredentials, types::PublishBatchRequestEntry,
226 Client as SnsClient, Config as SnsConfig,
227 };
228 use ctor::ctor;
229 use tracing::info;
230 use wiremock::{
231 matchers::{body_string_contains, method, path},
232 Mock, MockServer, Request, ResponseTemplate,
233 };
234
235 use crate::offload::{
236 id_provider::FixedIdsProvider,
237 sns::S3OffloadInterceptor,
238 test::{
239 self, expect_no_s3_put_calls, given_s3_put_fails, given_s3_put_succeeds,
240 mock_aws_endpoint_config, s3_config,
241 },
242 };
243
244 const TEST_BUCKET: &str = "my-bucket";
245 const TEST_TOPIC: &str = "my-topic";
246 const TEST_RANDOM_ID: &str = "my-id-123";
247
248 #[ctor]
249 fn before_all() {
250 test::init();
251 }
252
253 #[tokio::test(flavor = "multi_thread")]
255 async fn publish_batch_no_offload_needed() {
256 const MSG_BODY: &str = "publish-batch-no-offload-msg";
257
258 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
259
260 expect_no_s3_put_calls(&mock_server).await;
261 expect_final_sns_batch_body(&mock_server, MSG_BODY).await;
262
263 when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
264 }
265
266 #[tokio::test(flavor = "multi_thread")]
267 async fn publish_batch_if_offload_fails_then_original_sent() {
268 const MSG_BODY: &str = "publish-batch-offload_fails";
269
270 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
271
272 given_s3_put_fails(&mock_server).await;
273 expect_final_sns_batch_body(&mock_server, MSG_BODY).await;
274
275 when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
276 }
277
278 #[tokio::test(flavor = "multi_thread")]
279 async fn publish_batch_offload_succeeds() {
280 const MSG_BODY: &str = "publish-batch-offload-succeeds";
281 let expected_offload_fragment = batch_offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
282
283 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
284
285 given_s3_put_succeeds(&mock_server).await;
286 expect_final_sns_batch_body(&mock_server, &expected_offload_fragment).await;
287
288 when_calling_sns_publish_batch_with_body(&sns_client, MSG_BODY).await;
289 }
290
291 #[tokio::test(flavor = "multi_thread")]
292 async fn publish_message_no_offload_needed() {
293 const MSG_BODY: &str = "publish-message-no-offload-msg";
294
295 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len()).await;
296
297 expect_no_s3_put_calls(&mock_server).await;
298 expect_final_sns_message_body(&mock_server, MSG_BODY).await;
299
300 when_calling_sns_publish_message_with_body(&sns_client, MSG_BODY).await;
301 }
302
303 #[tokio::test(flavor = "multi_thread")]
304 async fn publish_message_offload_succeeds() {
305 const MSG_BODY: &str = "publish-message-offload-succeeds";
306 let expected_offload_fragment = offloaded_payload(TEST_BUCKET, TEST_RANDOM_ID);
307
308 let (mock_server, sns_client) = setup_sdk_client(MSG_BODY.len() - 1).await;
309
310 given_s3_put_succeeds(&mock_server).await;
311 expect_final_sns_message_body(&mock_server, &expected_offload_fragment).await;
312
313 when_calling_sns_publish_message_with_body(&sns_client, MSG_BODY).await;
314 }
315
316 async fn setup_sdk_client(max_non_offload_size: usize) -> (MockServer, aws_sdk_sns::Client) {
317 let mock_server = MockServer::start().await;
318
319 let base_endpoint_url = &mock_server.uri();
320
321 info!("Base endpoint url: {}", base_endpoint_url);
322 let s3_offload_interceptor = S3OffloadInterceptor::new(
323 &s3_config(base_endpoint_url).await,
324 FixedIdsProvider::new(vec![TEST_RANDOM_ID]),
325 TEST_BUCKET.to_owned(),
326 max_non_offload_size,
327 );
328
329 let sns_config = SnsConfig::new(
330 &mock_aws_endpoint_config(base_endpoint_url, "sns")
331 .credentials_provider(SnsCredentials::for_tests())
332 .load()
333 .await,
334 )
335 .to_builder()
336 .interceptor(s3_offload_interceptor)
337 .build();
338
339 let sns_client = aws_sdk_sns::Client::from_conf(sns_config);
340
341 (mock_server, sns_client)
342 }
343
344 async fn expect_final_sns_batch_body(mock_server: &MockServer, message_body: &str) {
345 let expected_body_fragment = format!(
346 "PublishBatchRequestEntries.member.1.Message={}",
347 message_body.to_owned()
348 );
349
350 Mock::given(method("POST"))
351 .and(path("/sns/"))
352 .and(body_string_contains("Action=PublishBatch"))
353 .respond_with(move |r: &Request| {
354 info!("SNS Request: {:?}", r);
355
356 let intercepted_body = str::from_utf8(&r.body).unwrap();
357 info!("SNS Body: {:?}", intercepted_body);
358
359
360 assert!(intercepted_body.contains(&expected_body_fragment),
361 "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\"");
362
363 ResponseTemplate::new(200).set_body_raw("<PublishBatchResponse><PublishBatchResult></PublishBatchResult></PublishBatchResponse>", "text/xml")
364 }).expect(1)
365 .named("Final SNS batch body")
366 .mount(mock_server)
367 .await;
368 }
369
370 async fn expect_final_sns_message_body(mock_server: &MockServer, message_body: &str) {
371 let expected_body_fragment = format!("Message={}", message_body.to_owned());
372
373 Mock::given(method("POST"))
374 .respond_with(move |r: &Request| {
377 info!("SNS Request: {:?}", r);
378
379 let intercepted_body = str::from_utf8(&r.body).unwrap();
380 info!("SNS Body: {:?}", intercepted_body);
381
382 assert!(
383 intercepted_body.contains(&expected_body_fragment),
384 "Body does not contain expected fragment: \nbody=\"{intercepted_body}\", \nexpected=\"{expected_body_fragment}\""
385 );
386
387 ResponseTemplate::new(200).set_body_raw(
388 "<PublishResponse><PublishResult></PublishResult></PublishResponse>",
389 "text/xml",
390 )
391 })
392 .expect(1)
393 .named("Final SNS batch body")
394 .mount(mock_server)
395 .await;
396 }
397
398 async fn when_calling_sns_publish_batch_with_body(sns_client: &SnsClient, message_body: &str) {
399 let _res = sns_client
400 .publish_batch()
401 .topic_arn(TEST_TOPIC.to_owned())
402 .set_publish_batch_request_entries(Some(vec![PublishBatchRequestEntry::builder()
403 .id("someid")
404 .message(message_body)
405 .build()
406 .unwrap()]))
407 .send()
408 .await
409 .unwrap();
410 }
411
412 async fn when_calling_sns_publish_message_with_body(
413 sns_client: &SnsClient,
414 message_body: &str,
415 ) {
416 let _res = sns_client
417 .publish()
418 .topic_arn(TEST_TOPIC.to_owned())
419 .set_message(Some(message_body.to_owned()))
420 .send()
421 .await
422 .unwrap();
423 }
424
425 fn offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
426 let msg = format!(
427 r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{{"s3BucketName": "{bucket_name}","s3Key": "{bucket_key}"}}]"#
428 );
429 format!(
430 "{}{}",
431 urlencoding::encode(&msg).into_owned(),
432 "&MessageAttributes.entry.1.Name=ExtendedPayloadSize"
433 )
434 }
435
436 fn batch_offloaded_payload(bucket_name: &str, bucket_key: &str) -> String {
437 let msg = format!(
438 r#"["software.amazon.payloadoffloading.PayloadS3Pointer",{{"s3BucketName": "{bucket_name}","s3Key": "{bucket_key}"}}]"#
439 );
440 format!(
441 "{}{}",
442 urlencoding::encode(&msg).into_owned(),
443 "&PublishBatchRequestEntries.member.1.MessageAttributes\
444 .entry.1.Name=ExtendedPayloadSize"
445 )
446 }
447}