1#[cfg(feature = "pyo3")]
9mod pyo3_runner;
10#[cfg(feature = "rustpython")]
11mod rustpython_runner;
12
13use once_cell::sync::Lazy;
14use serde_json::Value;
15use std::path::{Path, PathBuf};
16use std::sync::mpsc as std_mpsc;
17use std::thread;
18use thiserror::Error;
19use tokio::runtime::Runtime;
20use tokio::sync::{mpsc, oneshot};
21
22#[derive(Debug)]
23pub(crate) enum CmdType {
24 RunFile(PathBuf),
25 RunCode(String),
26 EvalCode(String),
27 ReadVariable(String),
28 CallFunction { name: String, args: Vec<Value> },
29 CallAsyncFunction { name: String, args: Vec<Value> },
30 Stop,
31}
32pub(crate) struct PyCommand {
36 cmd_type: CmdType,
37 responder: oneshot::Sender<Result<Value, String>>,
38}
39
40type Task = Box<dyn FnOnce(&Runtime) -> Result<Value, PyRunnerError> + Send>;
42
43static SYNC_WORKER: Lazy<std_mpsc::Sender<Task>> = Lazy::new(|| {
47 let (tx, rx) = std_mpsc::channel::<Task>();
48
49 thread::spawn(move || {
50 let rt = Runtime::new().expect("Failed to create Tokio runtime for sync worker");
51 while let Ok(task) = rx.recv() {
53 let _ = task(&rt); }
55 });
56 tx
57});
58#[derive(Error, Debug, Clone)]
60pub enum PyRunnerError {
61 #[error("Failed to send command to Python thread. The thread may have panicked.")]
62 SendCommandFailed,
63
64 #[error("Failed to receive result from Python thread. The thread may have panicked.")]
65 ReceiveResultFailed,
66
67 #[error("Python execution error: {0:?}")]
68 PyError(String),
69}
70
71fn cleanup_path_for_python(path: &PathBuf) -> String {
72 dunce::canonicalize(path)
73 .unwrap()
74 .to_string_lossy()
75 .replace("\\", "/")
76}
77
78pub fn print_path_for_python(path: &PathBuf) -> String {
79 #[cfg(not(target_os = "windows"))]
80 {
81 format!("\"{}\"", cleanup_path_for_python(path))
82 }
83 #[cfg(target_os = "windows")]
84 {
85 format!("r\"{}\"", cleanup_path_for_python(path))
86 }
87}
88
89#[derive(Clone)]
91pub struct PyRunner {
92 sender: mpsc::Sender<PyCommand>,
93}
94
95
96impl Default for PyRunner {
97 fn default() -> Self {
98 PyRunner::new()
99 }
100}
101
102impl PyRunner {
103 pub fn new() -> Self {
112 let (sender, receiver) = mpsc::channel::<PyCommand>(32);
114
115 thread::spawn(move || {
118 #[cfg(all(feature = "pyo3", not(feature = "rustpython")))]
119 {
120 use tokio::runtime::Builder;
121 let rt = Builder::new_multi_thread().enable_all().build().unwrap();
122 rt.block_on(pyo3_runner::python_thread_main(receiver));
123 }
124
125 #[cfg(feature = "rustpython")]
126 {
127 rustpython_runner::python_thread_main(receiver);
128 }
129 });
130
131 Self { sender }
132 }
133
134 async fn send_command(&self, cmd_type: CmdType) -> Result<Value, PyRunnerError> {
137 let (responder, receiver) = oneshot::channel();
139 let cmd = PyCommand {
140 cmd_type,
141 responder,
142 };
143
144 self.sender
146 .send(cmd)
147 .await
148 .map_err(|_| PyRunnerError::SendCommandFailed)?;
149
150 receiver
152 .await
153 .map_err(|_| PyRunnerError::ReceiveResultFailed)?
154 .map_err(PyRunnerError::PyError)
155 }
156
157 fn send_command_sync(&self, cmd_type: CmdType) -> Result<Value, PyRunnerError> {
160 let (tx, rx) = std_mpsc::channel();
161 let sender = self.sender.clone();
162
163 let task = Box::new(move |rt: &Runtime| {
164 let result = rt.block_on(async {
165 let (responder, receiver) = oneshot::channel();
168 let cmd = PyCommand {
169 cmd_type,
170 responder,
171 };
172 sender
173 .send(cmd)
174 .await
175 .map_err(|_| PyRunnerError::SendCommandFailed)?;
176 receiver
177 .await
178 .map_err(|_| PyRunnerError::ReceiveResultFailed.clone())?
179 .map_err(PyRunnerError::PyError)
180 });
181 if tx.send(result.clone()).is_err() {
182 return Err(PyRunnerError::SendCommandFailed);
183 }
184 result
185 });
186
187 SYNC_WORKER
188 .send(task)
189 .map_err(|_| PyRunnerError::SendCommandFailed)?;
190 rx.recv().map_err(|_| PyRunnerError::ReceiveResultFailed)?
191 }
192 pub async fn run(&self, code: &str) -> Result<(), PyRunnerError> {
198 self.send_command(CmdType::RunCode(code.into()))
199 .await
200 .map(|_| ())
201 }
202
203 pub fn run_sync(&self, code: &str) -> Result<(), PyRunnerError> {
212 self.send_command_sync(CmdType::RunCode(code.into()))
213 .map(|_| ())
214 }
215
216 pub async fn run_file(&self, file: &Path) -> Result<(), PyRunnerError> {
221 self.send_command(CmdType::RunFile(file.to_path_buf()))
222 .await
223 .map(|_| ())
224 }
225
226 pub fn run_file_sync(&self, file: &Path) -> Result<(), PyRunnerError> {
235 self.send_command_sync(CmdType::RunFile(file.to_path_buf()))
236 .map(|_| ())
237 }
238
239 pub async fn eval(&self, code: &str) -> Result<Value, PyRunnerError> {
246 self.send_command(CmdType::EvalCode(code.into())).await
247 }
248
249 pub fn eval_sync(&self, code: &str) -> Result<Value, PyRunnerError> {
258 self.send_command_sync(CmdType::EvalCode(code.into()))
259 }
260
261 pub async fn read_variable(&self, var_name: &str) -> Result<Value, PyRunnerError> {
268 self.send_command(CmdType::ReadVariable(var_name.into()))
269 .await
270 }
271
272 pub fn read_variable_sync(&self, var_name: &str) -> Result<Value, PyRunnerError> {
281 self.send_command_sync(CmdType::ReadVariable(var_name.into()))
282 }
283
284 pub async fn call_function(
293 &self,
294 name: &str,
295 args: Vec<Value>,
296 ) -> Result<Value, PyRunnerError> {
297 self.send_command(CmdType::CallFunction {
298 name: name.into(),
299 args,
300 })
301 .await
302 }
303
304 pub fn call_function_sync(&self, name: &str, args: Vec<Value>) -> Result<Value, PyRunnerError> {
315 self.send_command_sync(CmdType::CallFunction {
316 name: name.into(),
317 args,
318 })
319 }
320
321 pub async fn call_async_function(
330 &self,
331 name: &str,
332 args: Vec<Value>,
333 ) -> Result<Value, PyRunnerError> {
334 self.send_command(CmdType::CallAsyncFunction {
335 name: name.into(),
336 args,
337 })
338 .await
339 }
340
341 #[cfg(feature = "pyo3")]
351 pub fn call_async_function_sync(
352 &self,
353 name: &str,
354 args: Vec<Value>,
355 ) -> Result<Value, PyRunnerError> {
356 self.send_command_sync(CmdType::CallAsyncFunction {
357 name: name.into(),
358 args,
359 })
360 }
361
362 pub async fn stop(&self) -> Result<(), PyRunnerError> {
364 self.send_command(CmdType::Stop).await?;
366 Ok(())
367 }
368
369 pub fn stop_sync(&self) -> Result<(), PyRunnerError> {
376 self.send_command_sync(CmdType::Stop).map(|_| ())
377 }
378
379 pub async fn set_venv(&self, venv_path: &Path) -> Result<(), PyRunnerError> {
381 if !venv_path.is_dir() {
382 return Err(PyRunnerError::PyError(format!(
383 "Could not find venv directory {}",
384 venv_path.display()
385 )));
386 }
387 let set_venv_code = include_str!("set_venv.py");
388 self.run(set_venv_code).await?;
389
390 let site_packages = if cfg!(target_os = "windows") {
391 venv_path.join("Lib").join("site-packages")
392 } else {
393 let version_code = "f\"python{sys.version_info.major}.{sys.version_info.minor}\"";
394 let py_version = self.eval(version_code).await?;
395 venv_path
396 .join("lib")
397 .join(py_version.as_str().unwrap())
398 .join("site-packages")
399 };
400 #[cfg(all(feature = "pyo3", not(feature = "rustpython")))]
401 let with_pth = "True";
402 #[cfg(feature = "rustpython")]
403 let with_pth = "False";
404
405 self.run(&format!(
406 "add_venv_libs_to_syspath({}, {})",
407 print_path_for_python(&site_packages),
408 with_pth
409 ))
410 .await
411 }
412
413 pub fn set_venv_sync(&self, venv_path: &Path) -> Result<(), PyRunnerError> {
422 if !venv_path.is_dir() {
423 return Err(PyRunnerError::PyError(format!(
424 "Could not find venv directory {}",
425 venv_path.display()
426 )));
427 }
428 let set_venv_code = include_str!("set_venv.py");
429 self.run_sync(set_venv_code)?;
430
431 let site_packages = if cfg!(target_os = "windows") {
432 venv_path.join("Lib").join("site-packages")
433 } else {
434 let version_code = "f\"python{sys.version_info.major}.{sys.version_info.minor}\"";
435 let py_version = self.eval_sync(version_code)?;
436 venv_path
437 .join("lib")
438 .join(py_version.as_str().unwrap())
439 .join("site-packages")
440 };
441 #[cfg(all(feature = "pyo3", not(feature = "rustpython")))]
442 let with_pth = "True";
443 #[cfg(feature = "rustpython")]
444 let with_pth = "False";
445
446 self.run_sync(&format!(
447 "add_venv_libs_to_syspath({}, {})",
448 print_path_for_python(&site_packages),
449 with_pth
450 ))
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457 use std::fs::{self, File};
458 use std::io::Write;
459
460 #[tokio::test]
461 async fn test_eval_simple_code() {
462 let executor = PyRunner::new();
463
464 let code = "10 + 20";
466
467 let result = executor.eval(code).await.unwrap();
468
469 assert_eq!(result, Value::Number(30.into()));
470 }
471
472 #[tokio::test]
473 async fn test_run_simple_code() {
474 let executor = PyRunner::new();
475 let code = r#"
476x = 10
477y = 20
478z = x + y"#;
479
480 let result_module = executor.run(code).await;
481
482 assert!(result_module.is_ok());
483
484 let z_val = executor.read_variable("z").await.unwrap();
485
486 assert_eq!(z_val, Value::Number(30.into()));
487 }
488
489 #[tokio::test]
490 async fn test_run_sync_from_async() {
491 let executor = PyRunner::new();
492 let code = r#"
493x = 10
494y = 20
495z = x + y"#;
496
497 let result_module = executor.run(code).await;
498
499 assert!(result_module.is_ok());
500
501 let z_val = executor.read_variable_sync("z").unwrap();
502
503 assert_eq!(z_val, Value::Number(30.into()));
504 }
505
506 #[tokio::test]
507 async fn test_run_with_function() {
508 let executor = PyRunner::new();
510 let code = r#"
511def add(a, b):
512 return a + b
513"#;
514
515 executor.run(code).await.unwrap();
516 let start_time = std::time::Instant::now();
517 let result = executor
518 .call_function("add", vec![5.into(), 9.into()])
519 .await
520 .unwrap();
521 assert_eq!(result, Value::Number(14.into()));
522 let duration = start_time.elapsed();
523 println!(
524 "test_run_with_function took: {} microseconds",
525 duration.as_micros()
526 );
527 }
528
529 #[test]
530 fn test_sync_run_with_function() {
531 let executor = PyRunner::new();
533 let code = r#"
534def add(a, b):
535 return a + b
536"#;
537
538 executor.run_sync(code).unwrap();
539 let start_time = std::time::Instant::now();
540 let result = executor
541 .call_function_sync("add", vec![5.into(), 9.into()])
542 .unwrap();
543 assert_eq!(result, Value::Number(14.into()));
544 let duration = start_time.elapsed();
545 println!(
546 "test_run_with_function_sync took: {} microseconds",
547 duration.as_micros()
548 );
549 }
550
551 #[cfg(feature = "pyo3")]
552 #[tokio::test]
553 async fn test_run_with_async_function() {
554 let executor = PyRunner::new();
555 let code = r#"
556import asyncio
557counter = 0
558
559async def add_and_sleep(a, b, sleep_time):
560 global counter
561 await asyncio.sleep(sleep_time)
562 counter += 1
563 return a + b + counter
564"#;
565
566 executor.run(code).await.unwrap();
567 let result1 =
568 executor.call_async_function("add_and_sleep", vec![5.into(), 10.into(), 1.into()]);
569 let result2 =
570 executor.call_async_function("add_and_sleep", vec![5.into(), 10.into(), 0.1.into()]);
571 let (result1, result2) = tokio::join!(result1, result2);
572 assert_eq!(result1.unwrap(), Value::Number(17.into()));
573 assert_eq!(result2.unwrap(), Value::Number(16.into()));
574 }
575
576 #[cfg(feature = "pyo3")]
577 #[test]
578 fn test_run_with_async_function_sync() {
579 let executor = PyRunner::new();
580 let code = r#"
581import asyncio
582
583async def add(a, b):
584 await asyncio.sleep(0.1)
585 return a + b
586"#;
587
588 executor.run_sync(code).unwrap();
589 let result = executor
590 .call_async_function_sync("add", vec![5.into(), 9.into()])
591 .unwrap();
592 assert_eq!(result, Value::Number(14.into()));
593 }
594
595 #[tokio::test]
596 async fn test_concurrent_calls() {
597 let executor = PyRunner::new();
598
599 let task1 = executor.run("import time; time.sleep(0.3); result='task1'");
600 let result1 = executor.read_variable("result");
601 let task2 = executor.run("result='task2'");
602 let result2 = executor.read_variable("result");
603 let task3 = executor.eval("'task3'");
604
605 let (res1, result1, res2, result2, res3) =
606 tokio::join!(task1, result1, task2, result2, task3);
607 assert!(res1.is_ok());
608 assert!(result1.is_ok());
609 assert!(res2.is_ok());
610 assert!(result2.is_ok());
611 assert!(res3.is_ok());
612
613 assert!(res1.is_ok());
614 assert_eq!(result1.unwrap(), Value::String("task1".to_string()));
615 assert_eq!(result2.unwrap(), Value::String("task2".to_string()));
616 assert_eq!(res3.unwrap(), Value::String("task3".to_string()));
617 }
618
619 #[tokio::test]
620 async fn test_python_error() {
621 let executor = PyRunner::new();
622 let code = "1 / 0";
623 let result = executor.eval(code).await;
625 assert!(result.is_err());
626 }
627 #[tokio::test]
628 async fn test_sample_readme() {
629 let runner = PyRunner::new();
630
631 let code = r#"
632counter = 0
633def greet(name):
634 global counter
635 counter = counter + 1
636 s = "" if counter < 2 else "s"
637 return f"Hello {name}! Called {counter} time{s} from Python."
638"#;
639 runner.run(code).await.unwrap();
641
642 let result1 = runner
644 .call_function("greet", vec!["World".into()])
645 .await
646 .unwrap();
647 println!("{}", result1.as_str().unwrap()); assert_eq!(
649 result1.as_str().unwrap(),
650 "Hello World! Called 1 time from Python."
651 );
652
653 let result2 = runner
654 .call_function("greet", vec!["World".into()])
655 .await
656 .unwrap();
657 assert_eq!(
658 result2.as_str().unwrap(),
659 "Hello World! Called 2 times from Python."
660 );
661 }
662
663 #[tokio::test]
664 async fn test_run_file() {
665 let runner = PyRunner::new();
666 let dir = tempfile::tempdir().unwrap();
667 let dir_path = dir.path();
668
669 let mut module_file = File::create(dir_path.join("mymodule.py")).unwrap();
671 writeln!(
672 module_file,
673 r#"
674def my_func():
675 return 42
676"#
677 )
678 .unwrap();
679
680 let script_path = dir_path.join("main.py");
682 let mut script_file = File::create(&script_path).unwrap();
683 writeln!(
684 script_file,
685 r#"
686import mymodule
687result = mymodule.my_func()
688"#
689 )
690 .unwrap();
691
692 runner.run_file(&script_path).await.unwrap();
693
694 let result = runner.read_variable("result").await.unwrap();
695 assert_eq!(result, Value::Number(42.into()));
696 }
697
698 #[tokio::test]
699 async fn test_set_venv() {
700 let runner = PyRunner::new();
701 let venv_dir = tempfile::tempdir().unwrap();
702 let venv_path = venv_dir.path();
703
704 let version_str = runner
706 .eval("f'{__import__(\"sys\").version_info.major}.{__import__(\"sys\").version_info.minor}'")
707 .await
708 .unwrap();
709 let py_version = version_str.as_str().unwrap();
710
711 let site_packages = if cfg!(target_os = "windows") {
713 venv_path.join("Lib").join("site-packages")
714 } else {
715 venv_path
716 .join("lib")
717 .join(format!("python{}", py_version))
718 .join("site-packages")
719 };
720 fs::create_dir_all(&site_packages).unwrap();
721
722 let package_dir = site_packages.join("dummy_package");
724 fs::create_dir(&package_dir).unwrap();
725 let mut init_file = File::create(package_dir.join("__init__.py")).unwrap();
726 writeln!(init_file, "def dummy_func(): return 'hello from venv'").unwrap();
727
728 runner.set_venv(venv_path).await.unwrap();
730
731 runner.run("import dummy_package").await.unwrap();
733 let result = runner.eval("dummy_package.dummy_func()").await.unwrap();
734
735 assert_eq!(result, Value::String("hello from venv".to_string()));
736 }
737}