payload_offloading_for_aws/offload/
offloading.rs

1use super::{error::OffloadInterceptorError, id_provider::IdProvider};
2use aws_config::SdkConfig;
3use aws_sdk_s3::Client as S3Client;
4use core::str;
5use futures::{executor, TryFutureExt};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9// This attribute should be set on SQS messages for extended java client to pick it up
10//  and download the offloaded content automatically.
11pub const OFFLOADED_MARKER_ATTRIBUTE: &str = "ExtendedPayloadSize";
12
13/// Where original_body > max_body_size this function will offload to s3.
14pub(crate) fn try_offload_body_blocking<Idp: IdProvider>(
15    original_body: &str,
16    s3_client: &S3Client,
17    id_provider: &Idp,
18    bucket_name: &str,
19    max_body_size: usize,
20) -> Result<Option<String>, OffloadInterceptorError> {
21    let original_body_len = original_body.len();
22    if original_body_len > max_body_size {
23        let bucket_key = id_provider
24            .generate()
25            .map_err(OffloadInterceptorError::FailedToGenerateS3Key)?;
26        let put_obj_output = executor::block_on(async {
27            let bucket_body =
28                aws_sdk_s3::primitives::ByteStream::from(original_body.as_bytes().to_vec());
29
30            s3_client
31                .put_object()
32                .bucket(bucket_name)
33                .key(&bucket_key)
34                .body(bucket_body)
35                .send()
36                .await
37                .map_err(|e| OffloadInterceptorError::FailedToPutS3Object(e.to_string()))
38        })?;
39
40        tracing::info!("Got s3 put result: {:?}", put_obj_output);
41
42        let offloaded_body = format!(
43            "[\
44            \"software.amazon.payloadoffloading.PayloadS3Pointer\",\
45            {{\
46                \"s3BucketName\": \"{bucket_name}\",\
47                \"s3Key\": \"{bucket_key}\"\
48            }}\
49        ]"
50        );
51        Ok(Some(offloaded_body))
52    } else {
53        Ok(None)
54    }
55}
56
57pub fn deserialize_s3_pointer(payload: &str) -> Result<PayloadS3Pointer, OffloadInterceptorError> {
58    let parsed: Value = serde_json::from_str(payload)
59        .map_err(|e| OffloadInterceptorError::DeserialisationError(e.to_string()))?;
60
61    if let Some(array) = parsed.as_array() {
62        if array.len() < 2 {
63            return Err(deserializer_error());
64        }
65
66        let is_s3_pointer_payload = array
67            .first()
68            .and_then(|v| v.as_str())
69            .map(|v| v == "software.amazon.payloadoffloading.PayloadS3Pointer")
70            .unwrap_or(false);
71        if !is_s3_pointer_payload {
72            return Err(deserializer_error());
73        }
74
75        let s3_pointer: PayloadS3Pointer = serde_json::from_value(array[1].clone())
76            .map_err(|e| OffloadInterceptorError::DeserialisationError(e.to_string()))?;
77        return Ok(s3_pointer);
78    }
79
80    Err(deserializer_error())
81}
82
83fn deserializer_error() -> OffloadInterceptorError {
84    OffloadInterceptorError::DeserialisationError("Invalid Format".to_string())
85}
86
87pub fn s3_client(aws_config: &SdkConfig) -> S3Client {
88    S3Client::from_conf(
89        aws_sdk_s3::Config::from(aws_config)
90            .to_builder()
91            .force_path_style(true)
92            .build(),
93    )
94}
95
96pub fn download_from_s3(
97    s3_client: &aws_sdk_s3::Client,
98    s3_pointer: &PayloadS3Pointer,
99) -> Result<String, OffloadInterceptorError> {
100    let downloaded = executor::block_on(
101        async {
102            s3_client
103                .get_object()
104                .bucket(&s3_pointer.s3_bucket_name)
105                .key(&s3_pointer.s3_key)
106                .send()
107                .await
108                .map_err(|e| OffloadInterceptorError::FailedToLoadFromS3(e.to_string()))
109        }
110        .and_then(|r| {
111            r.body
112                .collect()
113                .map_err(|e| OffloadInterceptorError::ByteStreamError(e.to_string()))
114        }),
115    )?;
116
117    let data_string = String::from_utf8(downloaded.into_bytes().to_vec())
118        .map_err(|e| OffloadInterceptorError::ByteStreamError(e.to_string()))?;
119
120    Ok(data_string)
121}
122
123#[derive(Serialize, Deserialize, Debug, PartialEq)]
124#[serde(rename_all = "camelCase")]
125pub struct PayloadS3Pointer {
126    pub s3_bucket_name: String,
127    pub s3_key: String,
128}
129
130#[cfg(test)]
131mod tests {
132    use crate::offload::offloading::{deserialize_s3_pointer, PayloadS3Pointer};
133
134    #[test]
135    fn deserializes_s3_pointer() {
136        let offloaded_payload = r#"[
137            "software.amazon.payloadoffloading.PayloadS3Pointer",
138            {"s3BucketName": "offload-test", "s3Key": "42ced2b1-b2f7-4b59-b1cc-c1a4b5349edf"}
139        ]"#;
140
141        let deserialized_ptr = deserialize_s3_pointer(offloaded_payload).unwrap();
142
143        assert_eq!(
144            PayloadS3Pointer {
145                s3_bucket_name: "offload-test".to_owned(),
146                s3_key: "42ced2b1-b2f7-4b59-b1cc-c1a4b5349edf".to_owned(),
147            },
148            deserialized_ptr
149        )
150    }
151}