1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Execute Python middleware handlers.
use aws_smithy_http_server::{body::Body, body::BoxBody, response::IntoResponse};
use http::Request;
use pyo3::prelude::*;

use pyo3_asyncio::TaskLocals;

use crate::{PyMiddlewareException, PyRequest, PyResponse};

use super::PyFuture;

#[derive(Debug, Clone, Copy)]
pub enum PyMiddlewareType {
    Request,
    Response,
}

/// A Python middleware handler function representation.
///
/// The Python business logic implementation needs to carry some information
/// to be executed properly like if it is a coroutine.
#[derive(Debug, Clone)]
pub struct PyMiddlewareHandler {
    pub name: String,
    pub func: PyObject,
    pub is_coroutine: bool,
    pub _type: PyMiddlewareType,
}

/// Structure holding the list of Python middlewares that will be executed by this server.
///
/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer.
#[derive(Debug, Clone)]
pub struct PyMiddlewares {
    handlers: Vec<PyMiddlewareHandler>,
    into_response: fn(PyMiddlewareException) -> http::Response<BoxBody>,
}

impl PyMiddlewares {
    /// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers.
    pub fn new<P>(handlers: Vec<PyMiddlewareHandler>) -> Self
    where
        PyMiddlewareException: IntoResponse<P>,
    {
        Self {
            handlers,
            into_response: PyMiddlewareException::into_response,
        }
    }

    /// Add a new handler to the list.
    pub fn push(&mut self, handler: PyMiddlewareHandler) {
        self.handlers.push(handler);
    }

    /// Execute a single middleware handler.
    ///
    /// The handler is scheduled on the Python interpreter syncronously or asynchronously,
    /// dependening on the handler signature.
    async fn execute_middleware(
        request: PyRequest,
        handler: PyMiddlewareHandler,
    ) -> Result<(Option<PyRequest>, Option<PyResponse>), PyMiddlewareException> {
        let handle: PyResult<pyo3::Py<pyo3::PyAny>> = if handler.is_coroutine {
            tracing::debug!("Executing Python middleware coroutine `{}`", handler.name);
            let result = pyo3::Python::with_gil(|py| {
                let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?;
                let coroutine = pyhandler.call1((request,))?;
                pyo3_asyncio::tokio::into_future(coroutine)
            })?;
            let output = result.await?;
            Ok(output)
        } else {
            tracing::debug!("Executing Python middleware function `{}`", handler.name);
            pyo3::Python::with_gil(|py| {
                let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?;
                let output = pyhandler.call1((request,))?;
                Ok(output.into_py(py))
            })
        };
        Python::with_gil(|py| match handle {
            Ok(result) => {
                if let Ok(request) = result.extract::<PyRequest>(py) {
                    return Ok((Some(request), None));
                }
                if let Ok(response) = result.extract::<PyResponse>(py) {
                    return Ok((None, Some(response)));
                }
                Ok((None, None))
            }
            Err(e) => pyo3::Python::with_gil(|py| {
                let traceback = match e.traceback(py) {
                    Some(t) => t.format().unwrap_or_else(|e| e.to_string()),
                    None => "Unknown traceback\n".to_string(),
                };
                tracing::error!("{}{}", traceback, e);
                let variant = e.value(py);
                if let Ok(v) = variant.extract::<PyMiddlewareException>() {
                    Err(v)
                } else {
                    Err(e.into())
                }
            }),
        })
    }

    /// Execute all the available Python middlewares in order of registration.
    ///
    /// Once the response is returned by the Python interpreter, different scenarios can happen:
    /// * Middleware not returning will let the execution continue to the next middleware without
    ///   changing the original request.
    /// * Middleware returning a modified [PyRequest] will update the original request before
    ///   continuing the execution of the next middleware.
    /// * Middleware returning a [PyResponse] will immediately terminate the request handling and
    ///   return the response constructed from Python.
    /// * Middleware raising [PyMiddlewareException] will immediately terminate the request handling
    ///   and return a protocol specific error, with the option of setting the HTTP return code.
    /// * Middleware raising any other exception will immediately terminate the request handling and
    ///   return a protocol specific error, with HTTP status code 500.
    pub fn run(&mut self, mut request: Request<Body>, locals: TaskLocals) -> PyFuture {
        let handlers = self.handlers.clone();
        let into_response = self.into_response;
        // Run all Python handlers in a loop.
        Box::pin(async move {
            tracing::debug!("Executing Python middleware stack");
            for handler in handlers {
                let name = handler.name.clone();
                let pyrequest = PyRequest::new(&request);
                let loop_locals = locals.clone();
                let result = pyo3_asyncio::tokio::scope(
                    loop_locals,
                    Self::execute_middleware(pyrequest, handler),
                )
                .await;
                match result {
                    Ok((pyrequest, pyresponse)) => {
                        if let Some(pyrequest) = pyrequest {
                            if let Ok(headers) = (&pyrequest.headers).try_into() {
                                tracing::debug!("Python middleware `{name}` returned an HTTP request, override headers with middleware's one");
                                *request.headers_mut() = headers;
                            }
                        }
                        if let Some(pyresponse) = pyresponse {
                            tracing::debug!(
                            "Python middleware `{name}` returned a HTTP response, exit middleware loop"
                        );
                            return Err(pyresponse.into());
                        }
                    }
                    Err(e) => {
                        tracing::debug!(
                            "Middleware `{name}` returned an error, exit middleware loop"
                        );
                        return Err((into_response)(e));
                    }
                }
            }
            tracing::debug!(
                "Python middleware execution finised, returning the request to operation handler"
            );
            Ok(request)
        })
    }
}

#[cfg(test)]
mod tests {
    use aws_smithy_http_server::proto::rest_json_1::RestJson1;
    use http::HeaderValue;
    use hyper::body::to_bytes;
    use pretty_assertions::assert_eq;

    use super::*;

    #[tokio::test]
    async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> {
        let locals = crate::tests::initialize();
        let mut middlewares = PyMiddlewares::new::<RestJson1>(vec![]);

        Python::with_gil(|py| {
            let middleware = PyModule::new(py, "middleware").unwrap();
            middleware.add_class::<PyRequest>().unwrap();
            middleware.add_class::<PyMiddlewareException>().unwrap();
            let pycode = r#"
def first_middleware(request: Request):
    request.set_header("x-amzn-answer", "42")
    return request

def second_middleware(request: Request):
    if request.get_header("x-amzn-answer") != "42":
        raise MiddlewareException("wrong answer", 401)
"#;
            py.run(pycode, Some(middleware.dict()), None)?;
            let all = middleware.index()?;
            let first_middleware = PyMiddlewareHandler {
                func: middleware.getattr("first_middleware")?.into_py(py),
                is_coroutine: false,
                name: "first".to_string(),
                _type: PyMiddlewareType::Request,
            };
            all.append("first_middleware")?;
            middlewares.push(first_middleware);
            let second_middleware = PyMiddlewareHandler {
                func: middleware.getattr("second_middleware")?.into_py(py),
                is_coroutine: false,
                name: "second".to_string(),
                _type: PyMiddlewareType::Request,
            };
            all.append("second_middleware")?;
            middlewares.push(second_middleware);
            Ok::<(), PyErr>(())
        })?;

        let result = middlewares
            .run(Request::builder().body(Body::from("")).unwrap(), locals)
            .await
            .unwrap();
        assert_eq!(
            result.headers().get("x-amzn-answer"),
            Some(&HeaderValue::from_static("42"))
        );
        Ok(())
    }

    #[tokio::test]
    async fn request_middleware_return_response() -> PyResult<()> {
        let locals = crate::tests::initialize();
        let mut middlewares = PyMiddlewares::new::<RestJson1>(vec![]);

        Python::with_gil(|py| {
            let middleware = PyModule::new(py, "middleware").unwrap();
            middleware.add_class::<PyRequest>().unwrap();
            middleware.add_class::<PyResponse>().unwrap();
            let pycode = r#"
def middleware(request: Request):
    return Response(200, {}, b"something")"#;
            py.run(pycode, Some(middleware.dict()), None)?;
            let all = middleware.index()?;
            let middleware = PyMiddlewareHandler {
                func: middleware.getattr("middleware")?.into_py(py),
                is_coroutine: false,
                name: "middleware".to_string(),
                _type: PyMiddlewareType::Request,
            };
            all.append("middleware")?;
            middlewares.push(middleware);
            Ok::<(), PyErr>(())
        })?;

        let result = middlewares
            .run(Request::builder().body(Body::from("")).unwrap(), locals)
            .await
            .unwrap_err();
        assert_eq!(result.status(), 200);
        let body = to_bytes(result.into_body()).await.unwrap();
        assert_eq!(body, "something".as_bytes());
        Ok(())
    }

    #[tokio::test]
    async fn request_middleware_raise_middleware_exception() -> PyResult<()> {
        let locals = crate::tests::initialize();
        let mut middlewares = PyMiddlewares::new::<RestJson1>(vec![]);

        Python::with_gil(|py| {
            let middleware = PyModule::new(py, "middleware").unwrap();
            middleware.add_class::<PyRequest>().unwrap();
            middleware.add_class::<PyMiddlewareException>().unwrap();
            let pycode = r#"
def middleware(request: Request):
    raise MiddlewareException("error", 503)"#;
            py.run(pycode, Some(middleware.dict()), None)?;
            let all = middleware.index()?;
            let middleware = PyMiddlewareHandler {
                func: middleware.getattr("middleware")?.into_py(py),
                is_coroutine: false,
                name: "middleware".to_string(),
                _type: PyMiddlewareType::Request,
            };
            all.append("middleware")?;
            middlewares.push(middleware);
            Ok::<(), PyErr>(())
        })?;

        let result = middlewares
            .run(Request::builder().body(Body::from("")).unwrap(), locals)
            .await
            .unwrap_err();
        assert_eq!(result.status(), 503);
        assert_eq!(
            result.headers().get("X-Amzn-Errortype"),
            Some(&HeaderValue::from_static("MiddlewareException"))
        );
        let body = to_bytes(result.into_body()).await.unwrap();
        assert_eq!(body, r#"{"message":"error"}"#.as_bytes());
        Ok(())
    }

    #[tokio::test]
    async fn request_middleware_raise_python_exception() -> PyResult<()> {
        let locals = crate::tests::initialize();
        let mut middlewares = PyMiddlewares::new::<RestJson1>(vec![]);

        Python::with_gil(|py| {
            let middleware = PyModule::from_code(
                py,
                r#"
def middleware(request):
    raise ValueError("error")"#,
                "",
                "",
            )?;
            let middleware = PyMiddlewareHandler {
                func: middleware.getattr("middleware")?.into_py(py),
                is_coroutine: false,
                name: "middleware".to_string(),
                _type: PyMiddlewareType::Request,
            };
            middlewares.push(middleware);
            Ok::<(), PyErr>(())
        })?;

        let result = middlewares
            .run(Request::builder().body(Body::from("")).unwrap(), locals)
            .await
            .unwrap_err();
        assert_eq!(result.status(), 500);
        assert_eq!(
            result.headers().get("X-Amzn-Errortype"),
            Some(&HeaderValue::from_static("MiddlewareException"))
        );
        let body = to_bytes(result.into_body()).await.unwrap();
        assert_eq!(body, r#"{"message":"ValueError: error"}"#.as_bytes());
        Ok(())
    }
}