1use once_cell::sync::Lazy;
2use pyo3::prelude::*;
3use pyo3::types::{PyModule, PyTuple};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8pub mod serialize;
9use serialize::{FromPython, ToPython};
10
11pub static BRIDGE: Lazy<Mutex<PythonBridge>> = Lazy::new(|| Mutex::new(PythonBridge::default()));
12
13pub struct PythonBridge {
14 modules: HashMap<String, Py<PyModule>>,
15}
16
17impl Default for PythonBridge {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl PythonBridge {
24 pub fn new() -> Self {
25 pyo3::prepare_freethreaded_python();
26 Self {
27 modules: HashMap::new(),
28 }
29 }
30
31 pub fn ensure_module(&mut self, name: &str) -> PyResult<()> {
32 if self.modules.contains_key(name) {
33 return Ok(());
34 }
35 Python::with_gil(|py| {
36 let module = PyModule::import_bound(py, name)?;
37 self.modules.insert(name.to_string(), module.unbind());
38 Ok(())
39 })
40 }
41
42 pub fn call(&self, module_name: &str, func_name: &str, arg: Value) -> PyResult<Value> {
43 Python::with_gil(|py| {
44 let module = self.modules.get(module_name).ok_or_else(|| {
45 PyErr::new::<pyo3::exceptions::PyImportError, _>(format!(
46 "Module {} not loaded",
47 module_name
48 ))
49 })?;
50 let func = module.bind(py).getattr(func_name)?;
51 let py_arg = arg.to_python(py)?;
52 let result = func.call1((py_arg,))?;
53 Value::from_python(&result)
54 })
55 }
56
57 pub fn call_with_multiple_args(
58 &self,
59 module_name: &str,
60 func_name: &str,
61 args: Vec<Value>,
62 ) -> PyResult<Value> {
63 Python::with_gil(|py| {
64 let module = self.modules.get(module_name).ok_or_else(|| {
65 PyErr::new::<pyo3::exceptions::PyImportError, _>(format!(
66 "Module {} not loaded",
67 module_name
68 ))
69 })?;
70 let func = module.bind(py).getattr(func_name)?;
71 let py_args: Vec<PyObject> = args.iter().map(|a| a.to_python(py).unwrap()).collect();
72 let tuple = PyTuple::new_bound(py, &py_args);
73 let result = func.call1(tuple)?;
74 Value::from_python(&result)
75 })
76 }
77
78 pub fn reset(&mut self) {
79 self.modules.clear();
80 }
81}
82
83pub fn call_python_function(module: &str, func: &str, arg: Value) -> Result<Value, String> {
84 let mut bridge = BRIDGE.lock().map_err(|e| e.to_string())?;
85 bridge.ensure_module(module).map_err(|e| e.to_string())?;
86 bridge.call(module, func, arg).map_err(|e| e.to_string())
87}
88
89pub fn call_python_function_with_args(
90 module: &str,
91 func: &str,
92 args: Vec<Value>,
93) -> Result<Value, String> {
94 let mut bridge = BRIDGE.lock().map_err(|e| e.to_string())?;
95 bridge.ensure_module(module).map_err(|e| e.to_string())?;
96 bridge
97 .call_with_multiple_args(module, func, args)
98 .map_err(|e| e.to_string())
99}
100
101pub fn reset_python_bridge() {
102 let mut bridge = BRIDGE.lock().unwrap();
103 bridge.reset();
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_math_sqrt() {
112 let mut bridge = BRIDGE.lock().unwrap();
113 bridge.ensure_module("math").unwrap();
114 let result = bridge
115 .call("math", "sqrt", serde_json::json!(16.0))
116 .unwrap();
117 assert_eq!(result, serde_json::json!(4.0));
118 }
119
120 #[test]
121 fn test_math_pow() {
122 let mut bridge = BRIDGE.lock().unwrap();
123 bridge.ensure_module("math").unwrap();
124 let result = bridge
125 .call_with_multiple_args(
126 "math",
127 "pow",
128 vec![serde_json::json!(2.0), serde_json::json!(3.0)],
129 )
130 .unwrap();
131 assert_eq!(result, serde_json::json!(8.0));
132 }
133
134 #[test]
135 fn test_module_not_loaded_error() {
136 reset_python_bridge();
137 let result = call_python_function("nonexistent", "func", serde_json::json!(0)).unwrap_err();
138 assert!(
139 result.contains("nonexistent") || result.contains("not found"),
140 "got: {}",
141 result
142 );
143 }
144
145 #[test]
146 fn test_reset_clears_modules() {
147 let mut bridge = BRIDGE.lock().unwrap();
150 bridge.ensure_module("math").unwrap();
151 assert!(bridge.modules.contains_key("math"));
152 bridge.reset();
153 assert!(!bridge.modules.contains_key("math"));
154 }
155}