use crate::PythonBlock;
use crate::run::run_python_code;
use pyo3::{
FromPyObject, IntoPyObject, Py, PyResult, Python,
prelude::*,
types::{PyCFunction, PyDict},
};
pub struct Context {
pub(crate) globals: Py<PyDict>,
}
impl Context {
#[allow(clippy::new_without_default)]
#[track_caller]
pub fn new() -> Self {
Python::with_gil(Self::new_with_gil)
}
#[track_caller]
pub(crate) fn new_with_gil(py: Python) -> Self {
match Self::try_new(py) {
Ok(x) => x,
Err(err) => panic!("{}", panic_string(py, &err)),
}
}
fn try_new(py: Python) -> PyResult<Self> {
Ok(Self {
globals: py.import("__main__")?.dict().copy()?.into(),
})
}
pub fn globals(&self) -> &Py<PyDict> {
&self.globals
}
pub fn get<T: for<'p> FromPyObject<'p>>(&self, name: &str) -> T {
Python::with_gil(|py| match self.globals.bind(py).get_item(name) {
Err(_) | Ok(None) => {
panic!("Python context does not contain a variable named `{name}`",)
}
Ok(Some(value)) => match FromPyObject::extract_bound(&value) {
Ok(value) => value,
Err(e) => panic!(
"Unable to convert `{name}` to `{ty}`: {e}",
ty = std::any::type_name::<T>(),
),
},
})
}
pub fn set<T: for<'p> IntoPyObject<'p>>(&self, name: &str, value: T) {
Python::with_gil(|py| {
if let Err(e) = self.globals().bind(py).set_item(name, value) {
panic!(
"Unable to set `{name}` from a `{ty}`: {e}",
ty = std::any::type_name::<T>(),
);
}
})
}
pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyResult<Bound<'_, PyCFunction>>) {
Python::with_gil(|py| {
let obj = wrapper(py).unwrap();
let name = obj
.getattr("__name__")
.expect("wrapped item should have a __name__");
if let Err(err) = self.globals().bind(py).set_item(name, obj) {
panic!("{}", panic_string(py, &err));
}
})
}
pub fn run(
&self,
#[cfg(not(doc))] code: PythonBlock<impl FnOnce(&Bound<PyDict>)>,
#[cfg(doc)] code: PythonBlock, ) {
Python::with_gil(|py| self.run_with_gil(py, code));
}
#[cfg(not(doc))]
pub(crate) fn run_with_gil<F: FnOnce(&Bound<PyDict>)>(
&self,
py: Python<'_>,
block: PythonBlock<F>,
) {
(block.set_vars)(self.globals().bind(py));
if let Err(err) = run_python_code(py, self, block.bytecode) {
(block.panic)(panic_string(py, &err));
}
}
}
fn panic_string(py: Python, err: &PyErr) -> String {
match py_err_to_string(py, &err) {
Ok(msg) => msg,
Err(_) => err.to_string(),
}
}
fn py_err_to_string(py: Python, err: &PyErr) -> Result<String, PyErr> {
let sys = py.import("sys")?;
let stderr = py.import("io")?.getattr("StringIO")?.call0()?;
let original_stderr = sys.dict().get_item("stderr")?;
sys.dict().set_item("stderr", &stderr)?;
err.print(py);
sys.dict().set_item("stderr", original_stderr)?;
stderr.call_method0("getvalue")?.extract()
}