aws_smithy_http_server_python/context/
lambda.rs1use std::collections::HashSet;
9
10use http::Extensions;
11use lambda_http::Context as LambdaContext;
12use pyo3::{types::PyDict, PyObject, PyResult, Python};
13
14use crate::{lambda::PyLambdaContext, rich_py_err, util::is_optional_of};
15
16#[derive(Clone)]
17pub struct PyContextLambda {
18 fields: HashSet<String>,
19}
20
21impl PyContextLambda {
22 pub fn new(ctx: PyObject) -> PyResult<Self> {
23 let fields = Python::with_gil(|py| get_lambda_ctx_fields(py, &ctx))?;
24 Ok(Self { fields })
25 }
26
27 pub fn populate_from_extensions(&self, ctx: PyObject, ext: &Extensions) {
28 if self.fields.is_empty() {
29 return;
31 }
32
33 let lambda_ctx = ext
34 .get::<LambdaContext>()
35 .cloned()
36 .map(PyLambdaContext::new);
37
38 Python::with_gil(|py| {
39 for field in self.fields.iter() {
40 if let Err(err) = ctx.setattr(py, field.as_str(), lambda_ctx.clone()) {
41 tracing::warn!(field = ?field, error = ?rich_py_err(err), "could not inject `LambdaContext` to context")
42 }
43 }
44 });
45 }
46}
47
48fn get_lambda_ctx_fields(py: Python, ctx: &PyObject) -> PyResult<HashSet<String>> {
50 let typing = py.import("typing")?;
51 let hints = match typing
52 .call_method1("get_type_hints", (ctx,))
53 .and_then(|res| res.extract::<&PyDict>())
54 {
55 Ok(hints) => hints,
56 Err(_) => {
57 return Ok(HashSet::new());
61 }
62 };
63
64 let mut fields = HashSet::new();
65 for (key, value) in hints {
66 if is_optional_of::<PyLambdaContext>(py, value)? {
67 fields.insert(key.to_string());
68 }
69 }
70 Ok(fields)
71}
72
73#[cfg(test)]
74mod tests {
75 use http::Extensions;
76 use lambda_http::Context as LambdaContext;
77 use pyo3::{prelude::*, py_run};
78
79 use crate::context::testing::{get_context, lambda_ctx};
80
81 #[test]
82 fn py_context_with_lambda_context() -> PyResult<()> {
83 pyo3::prepare_freethreaded_python();
84
85 let ctx = get_context(
86 r#"
87class Context:
88 foo: int = 0
89 bar: str = 'qux'
90 lambda_ctx: typing.Optional[LambdaContext]
91
92ctx = Context()
93ctx.foo = 42
94"#,
95 );
96 Python::with_gil(|py| {
97 py_run!(
98 py,
99 ctx,
100 r#"
101assert ctx.foo == 42
102assert ctx.bar == 'qux'
103assert not hasattr(ctx, 'lambda_ctx')
104"#
105 );
106 });
107
108 ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
109 Python::with_gil(|py| {
110 py_run!(
111 py,
112 ctx,
113 r#"
114assert ctx.lambda_ctx.request_id == "my-req-id"
115assert ctx.lambda_ctx.deadline == 123
116# Make some modifications
117ctx.foo += 1
118ctx.bar = 'baz'
119"#
120 );
121 });
122
123 ctx.populate_from_extensions(&empty_extensions());
126 Python::with_gil(|py| {
127 py_run!(
128 py,
129 ctx,
130 r#"
131assert ctx.lambda_ctx is None
132# Make sure we are preserving any modifications
133assert ctx.foo == 43
134assert ctx.bar == 'baz'
135"#
136 );
137 });
138
139 Ok(())
140 }
141
142 #[test]
143 fn works_with_none() -> PyResult<()> {
144 pyo3::prepare_freethreaded_python();
148
149 let ctx = get_context("ctx = None");
150 ctx.populate_from_extensions(&extensions_with_lambda_ctx(lambda_ctx("my-req-id", "123")));
151 Python::with_gil(|py| {
152 py_run!(py, ctx, "assert ctx is None");
153 });
154
155 Ok(())
156 }
157
158 fn extensions_with_lambda_ctx(ctx: LambdaContext) -> Extensions {
159 let mut exts = empty_extensions();
160 exts.insert(ctx);
161 exts
162 }
163
164 fn empty_extensions() -> Extensions {
165 Extensions::new()
166 }
167}