sqs_lambda/
event_retriever.rs1use std::error::Error;
2use std::io::Read;
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use rusoto_s3::{GetObjectRequest, S3};
7use rusoto_sqs::Message as SqsMessage;
8use tokio::prelude::*;
9use tracing::info;
10
11use async_trait::async_trait;
12
13use crate::event_decoder::PayloadDecoder;
14use std::collections::HashMap;
15
16#[async_trait]
17pub trait PayloadRetriever<T> {
18 type Message;
19 async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<T>, Box<dyn Error>>;
20}
21
22#[derive(Clone)]
23pub struct S3PayloadRetriever<S, SInit, D, E>
24where
25 S: S3 + Clone + Send + Sync + 'static,
26 SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
27 D: PayloadDecoder<E> + Clone + Send + 'static,
28 E: Send + 'static,
29{
30 s3_init: SInit,
31 s3_clients: HashMap<String, S>,
32 decoder: D,
33 phantom: PhantomData<E>,
34}
35
36impl<S, SInit, D, E> S3PayloadRetriever<S, SInit, D, E>
37where
38 S: S3 + Clone + Send + Sync + 'static,
39 SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
40 D: PayloadDecoder<E> + Clone + Send + 'static,
41 E: Send + 'static,
42{
43 pub fn new(s3: SInit, decoder: D) -> Self {
44 Self {
45 s3_init: s3,
46 s3_clients: HashMap::new(),
47 decoder,
48 phantom: PhantomData,
49 }
50 }
51
52 pub fn get_client(&mut self, region: String) -> S {
53 match self.s3_clients.get(®ion) {
54 Some(s3) => s3.clone(),
55 None => {
56 let client = (self.s3_init)(region.clone());
57 self.s3_clients.insert(region.to_string(), client.clone());
58 client
59 }
60 }
61 }
62}
63
64#[async_trait]
65impl<S, SInit, D, E> PayloadRetriever<E> for S3PayloadRetriever<S, SInit, D, E>
66where
67 S: S3 + Clone + Send + Sync + 'static,
68 SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
69 D: PayloadDecoder<E> + Clone + Send + 'static,
70 E: Send + 'static,
71{
72 type Message = SqsMessage;
73 #[tracing::instrument(skip(self, msg))]
74 async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<E>, Box<dyn Error>> {
75 let body = msg.body.as_ref().unwrap();
76 info!("Got body from message: {}", body);
77 let event: serde_json::Value = serde_json::from_str(body)?;
78
79 if let Some(Some(event_str)) = event.get("Event").map(serde_json::Value::as_str) {
80 if event_str == "s3:TestEvent" {
81 return Ok(None);
82 }
83 }
84 let record = &event["Records"][0]["s3"];
85
86 let bucket = record["bucket"]["name"].as_str().expect("bucket name");
87 let key = record["object"]["key"].as_str().expect("object key");
88
89 let region = &event["Records"][0]["awsRegion"].as_str().expect("region");
90 let s3 = self.get_client(region.to_string());
91 let s3_data = s3.get_object(GetObjectRequest {
92 bucket: bucket.to_string(),
93 key: key.to_string(),
94 ..Default::default()
95 });
96
97 let s3_data = tokio::time::timeout(Duration::from_secs(5), s3_data).await??;
98
99 let object_size = record["object"]["size"].as_u64().unwrap_or_default();
100 let prealloc = if object_size < 1024 {
101 1024
102 } else {
103 object_size as usize
104 };
105
106 info!("Retrieved s3 payload with size : {:?}", prealloc);
107
108 let mut body = Vec::with_capacity(prealloc);
109
110 s3_data
111 .body
112 .expect("Missing S3 body")
113 .into_async_read()
114 .read_to_end(&mut body)
115 .await?;
116
117 info!("Read s3 payload body");
118 self.decoder.decode(body).map(Option::from)
119 }
120}