1use crossbeam::channel::{self, Sender};
2use nanoid::nanoid;
3use pyo3::Python;
4use pyo3::prelude::*;
5use pyo3::types::PyDict;
6use std::env;
7use std::ffi::{CStr, CString};
8use std::path::Path;
9use std::path::PathBuf;
10use std::thread;
11
12pub fn set_venv(venv: &str, python_version: &str) {
15 unsafe {
16 env::set_var(
17 "PYTHONPATH",
18 format!("{venv}/lib/{python_version}/site-packages",),
19 );
20 }
21}
22
23pub struct PythonModule {
24 task_sender: Sender<Option<Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send>>>,
25 thread_handle: thread::JoinHandle<PyResult<()>>,
26}
27
28impl Drop for PythonModule {
29 fn drop(&mut self) {
30 self.task_sender.send(None).unwrap();
31 }
32}
33
34impl PythonModule {
35 pub fn action<T: Send + 'static>(
42 &self,
43 call: fn(&Python<'_>, &Bound<'_, PyAny>) -> PyResult<T>,
44 ) -> PyResult<T> {
45 if self.thread_handle.is_finished() {
46 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
47 "Python thread has exited",
48 ));
49 }
50
51 let (sender, receiver) = std::sync::mpsc::sync_channel(1);
52
53 let task: Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send> =
54 Box::new(move |py: &Python, module: &Bound<'_, PyAny>| {
55 let result = call(py, module);
56 let _ = sender.send(result);
57 });
58
59 self.task_sender
60 .send(Some(task))
61 .map_err(|_| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Task send failed"))?;
62
63 receiver.recv().unwrap()
64 }
65
66 pub fn new_module(path: &Path) -> PyResult<PythonModule> {
69 let init_file = path.join("__init__.py");
70 Self::new_project(init_file)
71 }
72
73 pub fn new_project(init_file: PathBuf) -> PyResult<PythonModule> {
76 if !init_file.is_file() {
77 return Err(PyErr::new::<pyo3::exceptions::PyFileNotFoundError, _>(
78 format!("No {} found", init_file.display()),
79 ));
80 }
81 let module_name = nanoid!(16);
82 let (task_sender, task_receiver) =
83 channel::unbounded::<Option<Box<dyn FnOnce(&Python, &Bound<'_, PyAny>) + Send>>>();
84 let (init_sender, init_receiver) = std::sync::mpsc::sync_channel::<PyResult<()>>(0);
85
86 let thread_handle = thread::spawn(move || {
87 let v: PyResult<()> = Python::with_gil(|py| {
88 let init = || {
89 let importlib_util = PyModule::import(py, "importlib.util")?;
90
91 let spec = importlib_util
92 .getattr("spec_from_file_location")?
93 .call1((&module_name, init_file))?;
94
95 let module = importlib_util
96 .getattr("module_from_spec")?
97 .call1((spec.clone(),))?;
98 let sys = py.import("sys")?;
99 let modules = sys.getattr("modules")?;
100 modules.set_item(module_name, &module)?;
101 let loader = spec.getattr("loader")?;
102 loader.call_method1("exec_module", (module.clone(),))?;
103 Ok(module)
104 };
105 match init() {
106 Ok(module) => {
107 let _ = init_sender.send(Ok(()));
108 while let Ok(Some(task)) = py.allow_threads(|| task_receiver.recv()) {
109 task(&py, &module);
110 }
111 }
112 Err(e) => {
113 let _ = init_sender.send(Err(e));
114 }
115 }
116
117 Ok(())
118 });
119 v
120 });
121 if let Ok(v) = init_receiver.recv() {
122 v?;
123 }
124
125 Ok(PythonModule {
126 task_sender,
127 thread_handle,
128 })
129 }
130}
131
132pub fn execute_code_(s: &str) -> PyResult<()> {
133 execute_code::<()>(s, |_, _| Ok(()))
134}
135
136pub fn execute_code<T>(
138 s: &str,
139 f: fn(Python<'_>, Bound<'_, PyDict>) -> PyResult<T>,
140) -> PyResult<T> {
141 Python::with_gil(|py| {
142 let c_string = CString::new(s).expect("CString::new failed");
143
144 let c_str: &CStr = c_string.as_c_str();
145 let globals = PyDict::new(py);
146
147 py.run(c_str, Some(&globals), None).unwrap();
148 f(py, globals)
149 })
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_execute_code() {
158 let x = execute_code("x = '10'", |_, globals| {
159 globals.get_item("x")?.unwrap().extract::<String>()
160 })
161 .unwrap();
162
163 assert_eq!(x, "10");
164 }
165
166 #[test]
167 fn test_load_project() {
168 let project1 = PythonModule::new_project(Path::new("./my-project/main.py").into()).unwrap();
169 let sum = project1
170 .action(|_, module| module.call_method1("add", (1, 2))?.extract::<i64>())
171 .unwrap();
172 assert_eq!(sum, 3)
173 }
174
175 #[test]
176 fn test_load_module() {
177 let module1 = PythonModule::new_module(Path::new("./my-module")).unwrap();
178 let sum = module1
179 .action(|_, module| module.call_method1("add", (1, 2))?.extract::<i64>())
180 .unwrap();
181 assert_eq!(sum, 3)
182 }
183}