lib/
lib.rs

1//! -*- mode: rust; -*-
2//!
3//! This file is part of privatemail crate.
4//! Copyright (c) 2022 Nyah Check
5//! See LICENSE for licensing information.
6//!
7//! A rust library for handling SNS requests to Lambda.
8//!
9//! Authors:
10//! - Nyah Check <hello@nyah.dev>
11//!
12//! Example:
13//!
14//! ```
15//! use crate::lib::config::PrivatEmailConfig;
16//! use serde::{Deserialize, Serialize};
17//!
18//! async fn privatemail_handler() {
19//!     // Initialize PrivatEmailConfig object.
20//!     let email_config = PrivatEmailConfig::default();
21//!
22//! }
23//! ```
24
25#![forbid(unsafe_code)]
26#![allow(clippy::derive_partial_eq_without_eq)]
27
28pub mod config;
29
30use config::PrivatEmailConfig;
31use lambda_runtime::{Error, LambdaEvent};
32use mailparse::parse_mail;
33use rusoto_core::Region;
34use rusoto_ses::{
35    Body, Content, Destination, Message, SendEmailRequest, Ses, SesClient,
36};
37use serde::{Deserialize, Serialize};
38use serde_json::Value;
39use std::{collections::HashMap, env, fmt::Debug};
40use tracing::{error, trace};
41
42/// LambdaResponse: The Outgoing response being passed by the Lambda
43#[derive(Debug, Default, Clone, Serialize)]
44#[serde(default, rename_all = "camelCase")]
45pub struct LambdaResponse {
46    /// is_base_64_encoded response field
47    is_base_64_encoded: bool,
48
49    /// status_code for lambda response
50    status_code: u32,
51
52    /// response headers for lambda response
53    headers: HashMap<String, String>,
54
55    /// response body for LambdaResponse struct
56    body: String,
57}
58
59impl LambdaResponse {
60    /// Given a status_code and response body a new LambdaResponse
61    /// is returned to the calling function
62    pub fn new(status_code: u32, body: &str) -> Self {
63        let mut header = HashMap::new();
64        header.insert("content-type".to_owned(), "application/json".to_owned());
65        LambdaResponse {
66            is_base_64_encoded: false,
67            status_code,
68            headers: header,
69            body: serde_json::to_string(&body).unwrap(),
70        }
71    }
72}
73
74impl std::fmt::Display for LambdaResponse {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        write!(
77            f,
78            "LambdaResponse: status_code: {}, body: {}",
79            self.status_code, self.body
80        )
81    }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Default)]
85pub struct EmailReceiptNotification {
86    #[serde(rename = "notificationType")]
87    notification_type: String,
88    mail: Mail,
89    receipt: Receipt,
90    content: String,
91    // #[serde(flatten)]
92    // other: HashMap<String, Value>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96pub struct Mail {
97    timestamp: String,
98    source: String,
99    #[serde(rename = "messageId")]
100    message_id: String,
101    destination: Vec<String>,
102
103    #[serde(rename = "commonHeaders")]
104    common_headers: CommonHeaders,
105
106    #[serde(flatten)]
107    other: HashMap<String, Value>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, Default)]
111pub struct CommonHeaders {
112    // replyTo: Vec<String>,
113    subject: String,
114    #[serde(rename = "returnPath")]
115    return_path: String,
116    #[serde(flatten)]
117    other: HashMap<String, Value>,
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, Default)]
121pub struct Receipt {
122    #[serde(rename = "spamVerdict")]
123    spam_verdict: Verdict,
124    #[serde(rename = "virusVerdict")]
125    virus_verdict: Verdict,
126    #[serde(flatten)]
127    other: HashMap<String, Value>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize, Default)]
131pub struct Verdict {
132    status: String,
133}
134
135/// PrivatEmail_Handler: processes incoming messages from SNS
136/// and forwards to the appropriate recipient email
137pub async fn privatemail_handler(
138    lambda_event: LambdaEvent<Value>,
139) -> Result<LambdaResponse, Error> {
140    let (event, ctx) = lambda_event.into_parts();
141
142    // install global collector configured based on RUST_LOG env var
143    let xray_trace_id = &ctx.xray_trace_id.as_ref().unwrap();
144    env::set_var("_X_AMZN_TRACE_ID", xray_trace_id);
145
146    // Enable Cloudwatch error logging at runtime
147    trace!("Event: {:#?}, Context: {:#?}", event, ctx);
148
149    // create ses client
150    let ses_client = SesClient::new(Region::default());
151
152    // Initialize the PrivatEmailConfig object
153    let email_config = PrivatEmailConfig::new_from_env();
154
155    // fetch sns payload
156    let sns_payload = event["Records"][0]["Sns"]
157        .as_object()
158        .unwrap_or_else(|| panic!("Missing sns payload"));
159    tracing::info!("Raw Email Info: {:?}", sns_payload);
160
161    // Fetch request payload
162    let sns_payload = event["Records"][0]["Sns"]
163        .as_object()
164        .unwrap_or_else(|| panic!("Missing sns payload"));
165    tracing::info!("Raw Email Info: {:?}", sns_payload);
166
167    // Fetch ses request payload from sns message
168    let ses_mail: EmailReceiptNotification = serde_json::from_str(
169        sns_payload["Message"]
170            .as_str()
171            .unwrap_or_else(|| panic!("Missing Message field")),
172    )?;
173
174    // skip spam messages
175    let ses_receipt = &ses_mail.receipt;
176    if ses_receipt.spam_verdict.status == "FAIL"
177        || ses_receipt.virus_verdict.status == "FAIL"
178    {
179        let err_msg = "Message contains spam or virus, skipping!";
180        error!(err_msg);
181        return Ok(LambdaResponse::new(200, err_msg));
182    }
183
184    // Rewrite Email From header to contain sender's name with forwarder's email address
185    let original_sender: String =
186        ses_mail.mail.common_headers.return_path.to_string();
187    let subject: String = ses_mail.mail.common_headers.subject.to_string();
188
189    // parse email content
190    let mail = parse_mail(ses_mail.content.as_bytes()).unwrap();
191    let content = mail.subparts[1].get_body_raw().unwrap();
192    let msg_body = charset::decode_latin1(&content).to_string();
193    trace!("HTML content: {:#?}", content);
194
195    // Skip mail if it's from blacklisted email
196    for email in
197        email_config.black_list.unwrap_or_else(|| panic!("Missing black list"))
198    {
199        if !email.is_empty() && original_sender.contains(email.as_str()) {
200            let mut err_msg: String =
201                "Message is from blacklisted email: ".to_owned();
202            err_msg.push_str(email.as_str());
203            trace!("`{}`, skipping!", err_msg.as_str());
204            return Ok(LambdaResponse::new(200, err_msg.as_str()));
205        }
206    }
207
208    let ses_email_message = SendEmailRequest {
209        configuration_set_name: Default::default(),
210        destination: Destination {
211            bcc_addresses: Default::default(),
212            cc_addresses: Default::default(),
213            to_addresses: Some(vec![email_config.to_email.to_string()]),
214        },
215        message: Message {
216            body: Body {
217                html: Some(Content {
218                    charset: Default::default(),
219                    data: msg_body,
220                }),
221                text: Default::default(),
222            },
223            subject: Content { charset: Default::default(), data: subject },
224        },
225        reply_to_addresses: Some(vec![original_sender]),
226        return_path: Default::default(),
227        return_path_arn: Default::default(),
228        source: email_config.from_email.to_string(),
229        source_arn: Default::default(),
230        tags: Default::default(),
231    };
232
233    match ses_client.send_email(ses_email_message).await {
234        Ok(email_response) => {
235            trace!("Email forward success: {:?}", email_response);
236            Ok(LambdaResponse::new(200, &email_response.message_id))
237        }
238        Err(error) => {
239            tracing::error!("Error forwarding email: {:?}", error);
240            Err(Box::new(error))
241        }
242    }
243}
244
245/// Test module for privatemail package
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use lambda_runtime::Context;
250    use std::fs;
251    use std::path::PathBuf;
252
253    fn read_test_event(file_name: String) -> Value {
254        // Open the file in read-only mode with buffer.
255
256        let mut srcdir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
257        let mut file_dir: String = "tests/payload/".to_owned();
258        file_dir.push_str(file_name.as_str());
259        srcdir.push(file_dir.as_str());
260        println!("Cur Dir: {}", srcdir.display());
261
262        // Read the JSON contents of the file as an instance of `String`.
263        let input_str = fs::read_to_string(srcdir.as_path()).unwrap();
264        trace!("Input str: {}", input_str);
265
266        // Return the `Value`.
267        return serde_json::from_str(input_str.as_str()).unwrap();
268    }
269
270    #[tokio::test]
271    #[ignore = "skipping integration because of IAM requirements"]
272    async fn handler_with_success() {
273        env::set_var("TO_EMAIL", "nyah@hey.com");
274        env::set_var("FROM_EMAIL", "test@nyah.dev");
275        let test_event = read_test_event(String::from("test_event.json"));
276
277        assert_eq!(
278            privatemail_handler(LambdaEvent {
279                payload: test_event,
280                context: Context::default()
281            })
282            .await
283            .expect("expected Ok(_) response")
284            .status_code,
285            200
286        )
287    }
288
289    #[tokio::test]
290    #[ignore = "skipping integration because of IAM requirements"]
291    async fn handler_with_black_listed_email() {
292        env::set_var("TO_EMAIL", "test@nyah.dev");
293        env::set_var("FROM_EMAIL", "fufu@achu.soup");
294        env::set_var("BLACK_LIST", "achu.soup");
295        let test_event = read_test_event(String::from("test_event.json"));
296
297        assert_eq!(
298            privatemail_handler(LambdaEvent {
299                payload: test_event,
300                context: Context::default()
301            })
302            .await
303            .expect("expected Ok(_) response")
304            .status_code,
305            200
306        )
307    }
308}