1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
use aws_smithy_http_server::body::{Body, BoxBody};
use http::{Request, Response};
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyFunction};
use pyo3_asyncio::TaskLocals;
use tower::{util::BoxService, BoxError, Service};
use crate::util::func_metadata;
use super::{PyMiddlewareError, PyRequest, PyResponse};
type PyNextInner = BoxService<Request<Body>, Response<BoxBody>, BoxError>;
#[pyo3::pyclass]
struct PyNext(Option<PyNextInner>);
impl PyNext {
fn new(inner: PyNextInner) -> Self {
Self(Some(inner))
}
fn take_inner(&mut self) -> Option<PyNextInner> {
self.0.take()
}
}
#[pyo3::pymethods]
impl PyNext {
fn __call__<'p>(&'p mut self, py: Python<'p>, py_req: Py<PyRequest>) -> PyResult<&'p PyAny> {
let req = py_req
.borrow_mut(py)
.take_inner()
.ok_or(PyMiddlewareError::RequestGone)?;
let mut inner = self
.take_inner()
.ok_or(PyMiddlewareError::NextAlreadyCalled)?;
pyo3_asyncio::tokio::future_into_py(py, async move {
let res = inner
.call(req)
.await
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
Ok(Python::with_gil(|py| PyResponse::new(res).into_py(py)))
})
}
}
#[derive(Debug, Clone)]
pub struct PyMiddlewareHandler {
pub name: String,
pub func: PyObject,
pub is_coroutine: bool,
}
impl PyMiddlewareHandler {
pub fn new(py: Python, func: PyObject) -> PyResult<Self> {
let func_metadata = func_metadata(py, &func)?;
Ok(Self {
name: func_metadata.name,
func,
is_coroutine: func_metadata.is_coroutine,
})
}
pub async fn call(
self,
req: Request<Body>,
next: PyNextInner,
locals: TaskLocals,
) -> PyResult<Response<BoxBody>> {
let py_req = PyRequest::new(req);
let py_next = PyNext::new(next);
let handler = self.func;
let result = if self.is_coroutine {
pyo3_asyncio::tokio::scope(locals, async move {
Python::with_gil(|py| {
let py_handler: &PyFunction = handler.extract(py)?;
let output = py_handler.call1((py_req, py_next))?;
pyo3_asyncio::tokio::into_future(output)
})?
.await
})
.await?
} else {
Python::with_gil(|py| {
let py_handler: &PyFunction = handler.extract(py)?;
let output = py_handler.call1((py_req, py_next))?;
Ok::<_, PyErr>(output.into())
})?
};
let response = Python::with_gil(|py| {
let py_res: Py<PyResponse> = result.extract(py)?;
let mut py_res = py_res.borrow_mut(py);
Ok::<_, PyErr>(py_res.take_inner())
})?;
response.ok_or_else(|| PyMiddlewareError::ResponseGone.into())
}
}