rocket_lamb/
handler.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
use crate::config::*;
use crate::error::RocketLambError;
use crate::request_ext::RequestExt as _;
use lambda_http::{Body, Handler, Request, RequestExt, Response};
use lambda_runtime::{error::HandlerError, Context};
use rocket::http::{uri::Uri, Header};
use rocket::local::{Client, LocalRequest, LocalResponse};
use rocket::{Rocket, Route};
use std::mem;

/// A Lambda handler for API Gateway events that processes requests using a [Rocket](rocket::Rocket) instance.
pub struct RocketHandler {
    pub(super) client: LazyClient,
    pub(super) config: Config,
}

pub(super) enum LazyClient {
    Placeholder,
    Uninitialized(Rocket),
    Ready(Client),
}

impl Handler<Response<Body>> for RocketHandler {
    fn run(&mut self, req: Request, _ctx: Context) -> Result<Response<Body>, HandlerError> {
        self.ensure_client_ready(&req);
        self.process_request(req)
            .map_err(failure::Error::from)
            .map_err(failure::Error::into)
    }
}

impl RocketHandler {
    fn ensure_client_ready(&mut self, req: &Request) {
        match self.client {
            ref mut lazy_client @ LazyClient::Uninitialized(_) => {
                let uninitialized_client = mem::replace(lazy_client, LazyClient::Placeholder);
                let mut rocket = match uninitialized_client {
                    LazyClient::Uninitialized(rocket) => rocket,
                    _ => unreachable!("LazyClient must be uninitialized at this point."),
                };
                if self.config.base_path_behaviour == BasePathBehaviour::RemountAndInclude {
                    let base_path = req.base_path();
                    if !base_path.is_empty() {
                        let routes: Vec<Route> = rocket.routes().cloned().collect();
                        rocket = rocket.mount(&base_path, routes);
                    }
                }
                let client = Client::untracked(rocket).unwrap();
                self.client = LazyClient::Ready(client);
            }
            LazyClient::Ready(_) => {}
            LazyClient::Placeholder => panic!("LazyClient has previously begun initialiation."),
        }
    }

    fn client(&self) -> &Client {
        match &self.client {
            LazyClient::Ready(client) => client,
            _ => panic!("Rocket client wasn't ready. ensure_client_ready should have been called!"),
        }
    }

    fn process_request(&self, req: Request) -> Result<Response<Body>, RocketLambError> {
        let local_req = self.create_rocket_request(req)?;
        let local_res = local_req.dispatch();
        self.create_lambda_response(local_res)
    }

    fn create_rocket_request(&self, req: Request) -> Result<LocalRequest, RocketLambError> {
        let method = to_rocket_method(req.method())?;
        let uri = self.get_path_and_query(&req);
        let mut local_req = self.client().req(method, uri);
        for (name, value) in req.headers() {
            match value.to_str() {
                Ok(v) => local_req.add_header(Header::new(name.to_string(), v.to_string())),
                Err(_) => return Err(invalid_request!("invalid value for header '{}'", name)),
            }
        }
        local_req.set_body(req.into_body());
        Ok(local_req)
    }

    fn create_lambda_response(
        &self,
        mut local_res: LocalResponse,
    ) -> Result<Response<Body>, RocketLambError> {
        let mut builder = Response::builder();
        builder.status(local_res.status().code);
        for h in local_res.headers().iter() {
            builder.header(&h.name.to_string(), &h.value.to_string());
        }

        let response_type = local_res
            .headers()
            .get_one("content-type")
            .unwrap_or_default()
            .split(';')
            .next()
            .and_then(|ct| self.config.response_types.get(&ct.to_lowercase()))
            .copied()
            .unwrap_or(self.config.default_response_type);
        let body = match (local_res.body(), response_type) {
            (Some(b), ResponseType::Auto) => {
                let bytes = b
                    .into_bytes()
                    .ok_or_else(|| invalid_response!("failed to read response body"))?;
                match String::from_utf8(bytes) {
                    Ok(s) => Body::Text(s),
                    Err(e) => Body::Binary(e.into_bytes()),
                }
            }
            (Some(b), ResponseType::Text) => Body::Text(
                b.into_string()
                    .ok_or_else(|| invalid_response!("failed to read response body as UTF-8"))?,
            ),
            (Some(b), ResponseType::Binary) => Body::Binary(
                b.into_bytes()
                    .ok_or_else(|| invalid_response!("failed to read response body"))?,
            ),
            (None, _) => Body::Empty,
        };

        builder.body(body).map_err(|e| invalid_response!("{}", e))
    }

    fn get_path_and_query(&self, req: &Request) -> String {
        let mut uri = match self.config.base_path_behaviour {
            BasePathBehaviour::Include | BasePathBehaviour::RemountAndInclude => req.full_path(),
            BasePathBehaviour::Exclude => req.api_path().to_owned(),
        };
        let query = req.query_string_parameters();

        let mut separator = '?';
        for (key, _) in query.iter() {
            for value in query.get_all(key).unwrap() {
                uri.push_str(&format!(
                    "{}{}={}",
                    separator,
                    Uri::percent_encode(key),
                    Uri::percent_encode(value)
                ));
                separator = '&';
            }
        }
        uri
    }
}

fn to_rocket_method(method: &http::Method) -> Result<rocket::http::Method, RocketLambError> {
    use http::Method as H;
    use rocket::http::Method::*;
    Ok(match *method {
        H::GET => Get,
        H::PUT => Put,
        H::POST => Post,
        H::DELETE => Delete,
        H::OPTIONS => Options,
        H::HEAD => Head,
        H::TRACE => Trace,
        H::CONNECT => Connect,
        H::PATCH => Patch,
        _ => return Err(invalid_request!("unknown method '{}'", method)),
    })
}