sqs_lambda/
event_retriever.rs

1use std::error::Error;
2use std::io::Read;
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use rusoto_s3::{GetObjectRequest, S3};
7use rusoto_sqs::Message as SqsMessage;
8use tokio::prelude::*;
9use tracing::info;
10
11use async_trait::async_trait;
12
13use crate::event_decoder::PayloadDecoder;
14use std::collections::HashMap;
15
16#[async_trait]
17pub trait PayloadRetriever<T> {
18    type Message;
19    async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<T>, Box<dyn Error>>;
20}
21
22#[derive(Clone)]
23pub struct S3PayloadRetriever<S, SInit, D, E>
24where
25    S: S3 + Clone + Send + Sync + 'static,
26    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
27    D: PayloadDecoder<E> + Clone + Send + 'static,
28    E: Send + 'static,
29{
30    s3_init: SInit,
31    s3_clients: HashMap<String, S>,
32    decoder: D,
33    phantom: PhantomData<E>,
34}
35
36impl<S, SInit, D, E> S3PayloadRetriever<S, SInit, D, E>
37where
38    S: S3 + Clone + Send + Sync + 'static,
39    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
40    D: PayloadDecoder<E> + Clone + Send + 'static,
41    E: Send + 'static,
42{
43    pub fn new(s3: SInit, decoder: D) -> Self {
44        Self {
45            s3_init: s3,
46            s3_clients: HashMap::new(),
47            decoder,
48            phantom: PhantomData,
49        }
50    }
51
52    pub fn get_client(&mut self, region: String) -> S {
53        match self.s3_clients.get(&region) {
54            Some(s3) => s3.clone(),
55            None => {
56                let client = (self.s3_init)(region.clone());
57                self.s3_clients.insert(region.to_string(), client.clone());
58                client
59            }
60        }
61    }
62}
63
64#[async_trait]
65impl<S, SInit, D, E> PayloadRetriever<E> for S3PayloadRetriever<S, SInit, D, E>
66where
67    S: S3 + Clone + Send + Sync + 'static,
68    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
69    D: PayloadDecoder<E> + Clone + Send + 'static,
70    E: Send + 'static,
71{
72    type Message = SqsMessage;
73    #[tracing::instrument(skip(self, msg))]
74    async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<E>, Box<dyn Error>> {
75        let body = msg.body.as_ref().unwrap();
76        info!("Got body from message: {}", body);
77        let event: serde_json::Value = serde_json::from_str(body)?;
78
79        if let Some(Some(event_str)) = event.get("Event").map(serde_json::Value::as_str) {
80            if event_str == "s3:TestEvent" {
81                return Ok(None);
82            }
83        }
84        let record = &event["Records"][0]["s3"];
85
86        let bucket = record["bucket"]["name"].as_str().expect("bucket name");
87        let key = record["object"]["key"].as_str().expect("object key");
88
89        let region = &event["Records"][0]["awsRegion"].as_str().expect("region");
90        let s3 = self.get_client(region.to_string());
91        let s3_data = s3.get_object(GetObjectRequest {
92            bucket: bucket.to_string(),
93            key: key.to_string(),
94            ..Default::default()
95        });
96
97        let s3_data = tokio::time::timeout(Duration::from_secs(5), s3_data).await??;
98
99        let object_size = record["object"]["size"].as_u64().unwrap_or_default();
100        let prealloc = if object_size < 1024 {
101            1024
102        } else {
103            object_size as usize
104        };
105
106        info!("Retrieved s3 payload with size : {:?}", prealloc);
107
108        let mut body = Vec::with_capacity(prealloc);
109
110        s3_data
111            .body
112            .expect("Missing S3 body")
113            .into_async_read()
114            .read_to_end(&mut body)
115            .await?;
116
117        info!("Read s3 payload body");
118        self.decoder.decode(body).map(Option::from)
119    }
120}