aws_smithy_http_server_python/
util.rs1pub mod collection;
7pub mod error;
8
9use pyo3::{PyAny, PyObject, PyResult, PyTypeInfo, Python};
10
11#[derive(Debug, PartialEq)]
13pub struct FuncMetadata {
14 pub name: String,
15 pub is_coroutine: bool,
16 pub num_args: usize,
17}
18
19pub fn func_metadata(py: Python, func: &PyObject) -> PyResult<FuncMetadata> {
21 let name = func.getattr(py, "__name__")?.extract::<String>(py)?;
22 let is_coroutine = is_coroutine(py, func)?;
23 let inspect = py.import("inspect")?;
24 let args = inspect
25 .call_method1("getargs", (func.getattr(py, "__code__")?,))?
26 .getattr("args")?
27 .extract::<Vec<String>>()?;
28 Ok(FuncMetadata {
29 name,
30 is_coroutine,
31 num_args: args.len(),
32 })
33}
34
35fn is_coroutine(py: Python, func: &PyObject) -> PyResult<bool> {
38 let inspect = py.import("inspect")?;
39 inspect
41 .call_method1("iscoroutinefunction", (func,))?
42 .extract::<bool>()
43}
44
45pub fn is_optional_of<T: PyTypeInfo>(py: Python, ty: &PyAny) -> PyResult<bool> {
47 let union_ty = py.import("typing")?.getattr("Union")?;
52 match ty.getattr("__origin__").map(|origin| origin.is(union_ty)) {
53 Ok(true) => {}
54 _ => return Ok(false),
57 };
58
59 let none = py.None();
60 let none_ty = none.as_ref(py).get_type();
63 let target_ty = py.get_type::<T>();
64
65 match ty
67 .getattr("__args__")
68 .and_then(|args| args.extract::<(&PyAny, &PyAny)>())
69 {
70 Ok((first_ty, second_ty)) => Ok(
71 (first_ty.is(target_ty) && second_ty.is(none_ty)) ||
73 (first_ty.is(none_ty) && second_ty.is(target_ty)),
75 ),
76 _ => Ok(false),
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use pyo3::{
85 types::{PyBool, PyDict, PyModule, PyString},
86 IntoPy,
87 };
88
89 use super::*;
90
91 #[test]
92 fn function_metadata() -> PyResult<()> {
93 pyo3::prepare_freethreaded_python();
94
95 Python::with_gil(|py| {
96 let module = PyModule::from_code(
97 py,
98 r#"
99def regular_func(first_arg, second_arg):
100 pass
101
102async def async_func():
103 pass
104"#,
105 "",
106 "",
107 )?;
108
109 let regular_func = module.getattr("regular_func")?.into_py(py);
110 assert_eq!(
111 FuncMetadata {
112 name: "regular_func".to_string(),
113 is_coroutine: false,
114 num_args: 2,
115 },
116 func_metadata(py, ®ular_func)?
117 );
118
119 let async_func = module.getattr("async_func")?.into_py(py);
120 assert_eq!(
121 FuncMetadata {
122 name: "async_func".to_string(),
123 is_coroutine: true,
124 num_args: 0,
125 },
126 func_metadata(py, &async_func)?
127 );
128
129 Ok(())
130 })
131 }
132
133 #[allow(clippy::bool_assert_comparison)]
134 #[test]
135 fn check_if_is_optional_of() -> PyResult<()> {
136 pyo3::prepare_freethreaded_python();
137
138 Python::with_gil(|py| {
139 let typing = py.import("typing")?;
140 let module = PyModule::from_code(
141 py,
142 r#"
143import typing
144
145class Types:
146 opt_of_str: typing.Optional[str] = "hello"
147 opt_of_bool: typing.Optional[bool] = None
148 regular_str: str = "world"
149"#,
150 "",
151 "",
152 )?;
153
154 let types = module.getattr("Types")?.into_py(py);
155 let type_hints = typing
156 .call_method1("get_type_hints", (types,))
157 .and_then(|res| res.extract::<&PyDict>())?;
158
159 assert_eq!(
160 true,
161 is_optional_of::<PyString>(py, type_hints.get_item("opt_of_str").unwrap())?
162 );
163 assert_eq!(
164 false,
165 is_optional_of::<PyString>(py, type_hints.get_item("regular_str").unwrap())?
166 );
167 assert_eq!(
168 true,
169 is_optional_of::<PyBool>(py, type_hints.get_item("opt_of_bool").unwrap())?
170 );
171 assert_eq!(
172 false,
173 is_optional_of::<PyString>(py, type_hints.get_item("opt_of_bool").unwrap())?
174 );
175
176 Ok(())
177 })
178 }
179}