1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

//! Python-compatible middleware [http::Request] implementation.

use std::mem;
use std::sync::Arc;

use aws_smithy_http_server::body::Body;
use http::{request::Parts, Request};
use pyo3::{exceptions::PyRuntimeError, prelude::*};
use tokio::sync::Mutex;

use super::{PyHeaderMap, PyMiddlewareError};

/// Python-compatible [Request] object.
#[pyclass(name = "Request")]
#[derive(Debug)]
pub struct PyRequest {
    parts: Option<Parts>,
    headers: PyHeaderMap,
    body: Arc<Mutex<Option<Body>>>,
}

impl PyRequest {
    /// Create a new Python-compatible [Request] structure from the Rust side.
    pub fn new(request: Request<Body>) -> Self {
        let (mut parts, body) = request.into_parts();
        let headers = mem::take(&mut parts.headers);
        Self {
            parts: Some(parts),
            headers: PyHeaderMap::new(headers),
            body: Arc::new(Mutex::new(Some(body))),
        }
    }

    // Consumes self by taking the inner Request.
    // This method would have been `into_inner(self) -> Request<Body>`
    // but we can't do that because we are crossing Python boundary.
    pub fn take_inner(&mut self) -> Option<Request<Body>> {
        let headers = self.headers.take_inner()?;
        let mut parts = self.parts.take()?;
        parts.headers = headers;
        let body = {
            let body = mem::take(&mut self.body);
            let body = Arc::try_unwrap(body).ok()?;
            body.into_inner().take()?
        };
        Some(Request::from_parts(parts, body))
    }
}

#[pymethods]
impl PyRequest {
    /// Return the HTTP method of this request.
    ///
    /// :type str:
    #[getter]
    fn method(&self) -> PyResult<String> {
        self.parts
            .as_ref()
            .map(|parts| parts.method.to_string())
            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
    }

    /// Return the URI of this request.
    ///
    /// :type str:
    #[getter]
    fn uri(&self) -> PyResult<String> {
        self.parts
            .as_ref()
            .map(|parts| parts.uri.to_string())
            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
    }

    /// Return the HTTP version of this request.
    ///
    /// :type str:
    #[getter]
    fn version(&self) -> PyResult<String> {
        self.parts
            .as_ref()
            .map(|parts| format!("{:?}", parts.version))
            .ok_or_else(|| PyMiddlewareError::RequestGone.into())
    }

    /// Return the HTTP headers of this request.
    ///
    /// :type typing.MutableMapping[str, str]:
    #[getter]
    fn headers(&self) -> PyHeaderMap {
        self.headers.clone()
    }

    /// Return the HTTP body of this request.
    /// Note that this is a costly operation because the whole request body is cloned.
    ///
    /// :type typing.Awaitable[bytes]:
    #[getter]
    fn body<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> {
        let body = self.body.clone();
        pyo3_asyncio::tokio::future_into_py(py, async move {
            let body = {
                let mut body_guard = body.lock().await;
                let body = body_guard.take().ok_or(PyMiddlewareError::RequestGone)?;
                let body = hyper::body::to_bytes(body)
                    .await
                    .map_err(|err| PyRuntimeError::new_err(err.to_string()))?;
                let buf = body.clone();
                body_guard.replace(Body::from(body));
                buf
            };
            // TODO(Perf): can we use `PyBytes` here?
            Ok(body.to_vec())
        })
    }

    /// Set the HTTP body of this request.
    #[setter]
    fn set_body(&mut self, buf: &[u8]) {
        self.body = Arc::new(Mutex::new(Some(Body::from(buf.to_owned()))));
    }
}