aws_smithy_http_server_python/
util.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6pub mod collection;
7pub mod error;
8
9use pyo3::{PyAny, PyObject, PyResult, PyTypeInfo, Python};
10
11// Captures some information about a Python function.
12#[derive(Debug, PartialEq)]
13pub struct FuncMetadata {
14    pub name: String,
15    pub is_coroutine: bool,
16    pub num_args: usize,
17}
18
19// Returns `FuncMetadata` for given `func`.
20pub fn func_metadata(py: Python, func: &PyObject) -> PyResult<FuncMetadata> {
21    let name = func.getattr(py, "__name__")?.extract::<String>(py)?;
22    let is_coroutine = is_coroutine(py, func)?;
23    let inspect = py.import("inspect")?;
24    let args = inspect
25        .call_method1("getargs", (func.getattr(py, "__code__")?,))?
26        .getattr("args")?
27        .extract::<Vec<String>>()?;
28    Ok(FuncMetadata {
29        name,
30        is_coroutine,
31        num_args: args.len(),
32    })
33}
34
35// Check if a Python function is a coroutine. Since the function has not run yet,
36// we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`.
37fn is_coroutine(py: Python, func: &PyObject) -> PyResult<bool> {
38    let inspect = py.import("inspect")?;
39    // NOTE: that `asyncio.iscoroutine()` doesn't work here.
40    inspect
41        .call_method1("iscoroutinefunction", (func,))?
42        .extract::<bool>()
43}
44
45// Checks whether given Python type is `Optional[T]`.
46pub fn is_optional_of<T: PyTypeInfo>(py: Python, ty: &PyAny) -> PyResult<bool> {
47    // for reference: https://stackoverflow.com/a/56833826
48
49    // in Python `Optional[T]` is an alias for `Union[T, None]`
50    // so we should check if the type origin is `Union`
51    let union_ty = py.import("typing")?.getattr("Union")?;
52    match ty.getattr("__origin__").map(|origin| origin.is(union_ty)) {
53        Ok(true) => {}
54        // Here we can ignore errors because `__origin__` is not present on all types
55        // and it is not really an error, it is just a type we don't expect
56        _ => return Ok(false),
57    };
58
59    let none = py.None();
60    // in typing, `None` is a special case and it is converted to `type(None)`,
61    // so we are getting type of `None` here to match
62    let none_ty = none.as_ref(py).get_type();
63    let target_ty = py.get_type::<T>();
64
65    // `Union` should be tuple of `(T, NoneType)` or `(NoneType, T)`
66    match ty
67        .getattr("__args__")
68        .and_then(|args| args.extract::<(&PyAny, &PyAny)>())
69    {
70        Ok((first_ty, second_ty)) => Ok(
71            // (T, NoneType)
72            (first_ty.is(target_ty) && second_ty.is(none_ty)) ||
73                // (NoneType, T)
74                (first_ty.is(none_ty) && second_ty.is(target_ty)),
75        ),
76        // Here we can ignore errors because `__args__` is not present on all types
77        // and it is not really an error, it is just a type we don't expect
78        _ => Ok(false),
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use pyo3::{
85        types::{PyBool, PyDict, PyModule, PyString},
86        IntoPy,
87    };
88
89    use super::*;
90
91    #[test]
92    fn function_metadata() -> PyResult<()> {
93        pyo3::prepare_freethreaded_python();
94
95        Python::with_gil(|py| {
96            let module = PyModule::from_code(
97                py,
98                r#"
99def regular_func(first_arg, second_arg):
100    pass
101
102async def async_func():
103    pass
104"#,
105                "",
106                "",
107            )?;
108
109            let regular_func = module.getattr("regular_func")?.into_py(py);
110            assert_eq!(
111                FuncMetadata {
112                    name: "regular_func".to_string(),
113                    is_coroutine: false,
114                    num_args: 2,
115                },
116                func_metadata(py, &regular_func)?
117            );
118
119            let async_func = module.getattr("async_func")?.into_py(py);
120            assert_eq!(
121                FuncMetadata {
122                    name: "async_func".to_string(),
123                    is_coroutine: true,
124                    num_args: 0,
125                },
126                func_metadata(py, &async_func)?
127            );
128
129            Ok(())
130        })
131    }
132
133    #[allow(clippy::bool_assert_comparison)]
134    #[test]
135    fn check_if_is_optional_of() -> PyResult<()> {
136        pyo3::prepare_freethreaded_python();
137
138        Python::with_gil(|py| {
139            let typing = py.import("typing")?;
140            let module = PyModule::from_code(
141                py,
142                r#"
143import typing
144
145class Types:
146    opt_of_str: typing.Optional[str] = "hello"
147    opt_of_bool: typing.Optional[bool] = None
148    regular_str: str = "world"
149"#,
150                "",
151                "",
152            )?;
153
154            let types = module.getattr("Types")?.into_py(py);
155            let type_hints = typing
156                .call_method1("get_type_hints", (types,))
157                .and_then(|res| res.extract::<&PyDict>())?;
158
159            assert_eq!(
160                true,
161                is_optional_of::<PyString>(py, type_hints.get_item("opt_of_str").unwrap())?
162            );
163            assert_eq!(
164                false,
165                is_optional_of::<PyString>(py, type_hints.get_item("regular_str").unwrap())?
166            );
167            assert_eq!(
168                true,
169                is_optional_of::<PyBool>(py, type_hints.get_item("opt_of_bool").unwrap())?
170            );
171            assert_eq!(
172                false,
173                is_optional_of::<PyString>(py, type_hints.get_item("opt_of_bool").unwrap())?
174            );
175
176            Ok(())
177        })
178    }
179}