py_runner/
lib.rs

1use crossbeam::channel::{self, Sender};
2use nanoid::nanoid;
3use pyo3::Python;
4use pyo3::prelude::*;
5use pyo3::types::PyDict;
6use std::env;
7use std::ffi::{CStr, CString};
8use std::path::Path;
9use std::path::PathBuf;
10use std::thread;
11
12/// sets env variable PYTHONPATH
13/// `set_venv("./venv", "python3.11")`
14pub fn set_venv(venv: &str, python_version: &str) {
15    unsafe {
16        env::set_var(
17            "PYTHONPATH",
18            format!("{venv}/lib/{python_version}/site-packages",),
19        );
20    }
21}
22
23pub struct PythonModule {
24    task_sender: Sender<Option<Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send>>>,
25    thread_handle: thread::JoinHandle<PyResult<()>>,
26}
27
28impl Drop for PythonModule {
29    fn drop(&mut self) {
30        self.task_sender.send(None).unwrap();
31    }
32}
33
34impl PythonModule {
35    /// Runs action on the imported module
36    ///```rs
37    /// module
38    ///    .action(|py, module| module.call_method1("add", (1, 2))?.extract::<i64>())
39    ///    .unwrap();
40    /// ```
41    pub fn action<T: Send + 'static>(
42        &self,
43        call: fn(&Python<'_>, &Bound<'_, PyAny>) -> PyResult<T>,
44    ) -> PyResult<T> {
45        if self.thread_handle.is_finished() {
46            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
47                "Python thread has exited",
48            ));
49        }
50
51        let (sender, receiver) = std::sync::mpsc::sync_channel(1);
52
53        let task: Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send> =
54            Box::new(move |py: &Python, module: &Bound<'_, PyAny>| {
55                let result = call(py, module);
56                let _ = sender.send(result);
57            });
58
59        self.task_sender
60            .send(Some(task))
61            .map_err(|_| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Task send failed"))?;
62
63        receiver.recv().unwrap()
64    }
65
66    /// Loads a Python module from a directory
67    /// `let module = PythonModule::new_module(Path::new("./my-module")).unwrap();`
68    pub fn new_module(path: &Path) -> PyResult<PythonModule> {
69        let init_file = path.join("__init__.py");
70        Self::new_project(init_file)
71    }
72
73    /// Loads a Python project from root file
74    /// `let project = PythonModule::new_project(Path::new("./my-project/main.py").into()).unwrap()`
75    pub fn new_project(init_file: PathBuf) -> PyResult<PythonModule> {
76        if !init_file.is_file() {
77            return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
78                format!("No {} found", init_file.display()),
79            ));
80        }
81        let module_name = nanoid!(16);
82        let (task_sender, task_receiver) =
83            channel::unbounded::<Option<Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send>>>();
84        let (init_sender, init_receiver) = std::sync::mpsc::sync_channel::<PyResult<()>>(0);
85
86        let thread_handle = thread::spawn(move || {
87            let v: PyResult<()> = Python::with_gil(|py| {
88                let init = || {
89                    let importlib_util = PyModule::import(py, "importlib.util")?;
90
91                    let spec = importlib_util
92                        .getattr("spec_from_file_location")?
93                        .call1((&module_name, init_file))?;
94
95                    let module = importlib_util
96                        .getattr("module_from_spec")?
97                        .call1((spec.clone(),))?;
98                    let sys = py.import("sys")?;
99                    let modules = sys.getattr("modules")?;
100                    modules.set_item(module_name, &module)?;
101                    let loader = spec.getattr("loader")?;
102                    loader.call_method1("exec_module", (module.clone(),))?;
103                    Ok(module)
104                };
105                match init() {
106                    Ok(module) => {
107                        let _ = init_sender.send(Ok(()));
108                        while let Ok(Some(task)) = py.allow_threads(|| task_receiver.recv()) {
109                            task(&py, &module);
110                        }
111                    }
112                    Err(e) => {
113                        let _ = init_sender.send(Err(e));
114                    }
115                }
116
117                Ok(())
118            });
119            v
120        });
121        if let Ok(v) = init_receiver.recv() {
122            v?;
123        }
124
125        Ok(PythonModule {
126            task_sender,
127            thread_handle,
128        })
129    }
130}
131
132pub fn execute_code_(s: &str) -> PyResult<()> {
133    execute_code::<()>(s, |_, _| Ok(()))
134}
135
136/// Runs Python code
137pub fn execute_code<T>(
138    s: &str,
139    f: fn(Python<'_>, Bound<'_, PyDict>) -> PyResult<T>,
140) -> PyResult<T> {
141    Python::with_gil(|py| {
142        let c_string = CString::new(s).expect("CString::new failed");
143
144        let c_str: &CStr = c_string.as_c_str();
145        let globals = PyDict::new(py);
146
147        py.run(c_str, Some(&globals), None).unwrap();
148        f(py, globals)
149    })
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_execute_code() {
158        let x = execute_code("x = '10'", |_, globals| {
159            globals.get_item("x")?.unwrap().extract::<String>()
160        })
161        .unwrap();
162
163        assert_eq!(x, "10");
164    }
165
166    #[test]
167    fn test_load_project() {
168        let project1 = PythonModule::new_project(Path::new("./my-project/main.py").into()).unwrap();
169        let sum = project1
170            .action(|_, module| module.call_method1("add", (1, 2))?.extract::<i64>())
171            .unwrap();
172        assert_eq!(sum, 3)
173    }
174
175    #[test]
176    fn test_load_module() {
177        let module1 = PythonModule::new_module(Path::new("./my-module")).unwrap();
178        let sum = module1
179            .action(|_, module| module.call_method1("add", (1, 2))?.extract::<i64>())
180            .unwrap();
181        assert_eq!(sum, 3)
182    }
183}