aws_smithy_http_server_python/context/
lambda.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Support for injecting [PyLambdaContext] to [super::PyContext].
7
8use std::collections::HashSet;
9
10use http::Extensions;
11use lambda_http::Context as LambdaContext;
12use pyo3::{types::PyDict, PyObject, PyResult, Python};
13
14use crate::{lambda::PyLambdaContext, rich_py_err, util::is_optional_of};
15
16#[derive(Clone)]
17pub struct PyContextLambda {
18    fields: HashSet<String>,
19}
20
21impl PyContextLambda {
22    pub fn new(ctx: PyObject) -> PyResult<Self> {
23        let fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &ctx))?;
24        Ok(Self { fields })
25    }
26
27    pub fn populate_from_extensions(&self, ctx: PyObject, ext: &Extensions) {
28        if self.fields.is_empty() {
29            // Return early without acquiring GIL
30            return;
31        }
32
33        let lambda_ctx = ext
34            .get::<LambdaContext>()
35            .cloned()
36            .map(PyLambdaContext::new);
37
38        Python::with_gil(|py| {
39            for field in self.fields.iter() {
40                if let Err(err) = ctx.setattr(py, field.as_str(), lambda_ctx.clone()) {
41                    tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context")
42                }
43            }
44        });
45    }
46}
47
48// Inspects the given `PyObject` to detect fields that type-hinted `PyLambdaContext`.
49fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult<HashSet<String>> {
50    let typing = py.import("typing")?;
51    let hints = match typing
52        .call_method1("get_type_hints", (ctx,))
53        .and_then(|res| res.extract::<&PyDict>())
54    {
55        Ok(hints) => hints,
56        Err(_) => {
57            // `get_type_hints` could fail if `ctx` is `None`, which is the default value
58            // for the context if user does not provide a custom class.
59            // In that case, this is not really an error and we should just return an empty set.
60            return Ok(HashSet::new());
61        }
62    };
63
64    let mut fields = HashSet::new();
65    for (key, value) in hints {
66        if is_optional_of::<PyLambdaContext>(py, value)? {
67            fields.insert(key.to_string());
68        }
69    }
70    Ok(fields)
71}
72
73#[cfg(test)]
74mod tests {
75    use http::Extensions;
76    use lambda_http::Context as LambdaContext;
77    use pyo3::{prelude::*, py_run};
78
79    use crate::context::testing::{get_context, lambda_ctx};
80
81    #[test]
82    fn py_context_with_lambda_context() -> PyResult<()> {
83        pyo3::prepare_freethreaded_python();
84
85        let ctx = get_context(
86            r#"
87class Context:
88    foo: int = 0
89    bar: str = 'qux'
90    lambda_ctx: typing.Optional[LambdaContext]
91
92ctx = Context()
93ctx.foo = 42
94"#,
95        );
96        Python::with_gil(|py| {
97            py_run!(
98                py,
99                ctx,
100                r#"
101assert ctx.foo == 42
102assert ctx.bar == 'qux'
103assert not hasattr(ctx, 'lambda_ctx')
104"#
105            );
106        });
107
108        ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
109        Python::with_gil(|py| {
110            py_run!(
111                py,
112                ctx,
113                r#"
114assert ctx.lambda_ctx.request_id == "my-req-id"
115assert ctx.lambda_ctx.deadline == 123
116# Make some modifications
117ctx.foo += 1
118ctx.bar = 'baz'
119"#
120            );
121        });
122
123        // Assume we are getting a new request but that one doesn't have a `LambdaContext`,
124        // in that case we should make fields `None` and shouldn't leak the previous `LambdaContext`.
125        ctx.populate_from_extensions(&empty_extensions());
126        Python::with_gil(|py| {
127            py_run!(
128                py,
129                ctx,
130                r#"
131assert ctx.lambda_ctx is None
132# Make sure we are preserving any modifications
133assert ctx.foo == 43
134assert ctx.bar == 'baz'
135"#
136            );
137        });
138
139        Ok(())
140    }
141
142    #[test]
143    fn works_with_none() -> PyResult<()> {
144        // Users can set context to `None` by explicity or implicitly by not providing a custom context class,
145        // it shouldn't be fail in that case.
146
147        pyo3::prepare_freethreaded_python();
148
149        let ctx = get_context("ctx = None");
150        ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
151        Python::with_gil(|py| {
152            py_run!(py, ctx, "assert ctx is None");
153        });
154
155        Ok(())
156    }
157
158    fn extensions_with_lambda_ctx(ctx: LambdaContext) -> Extensions {
159        let mut exts = empty_extensions();
160        exts.insert(ctx);
161        exts
162    }
163
164    fn empty_extensions() -> Extensions {
165        Extensions::new()
166    }
167}