Skip to main content

tupa_pyffi/
lib.rs

1use once_cell::sync::Lazy;
2use pyo3::prelude::*;
3use pyo3::types::{PyModule, PyTuple};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8pub mod serialize;
9use serialize::{FromPython, ToPython};
10
11pub static BRIDGE: Lazy<Mutex<PythonBridge>> = Lazy::new(|| Mutex::new(PythonBridge::default()));
12
13pub struct PythonBridge {
14    modules: HashMap<String, Py<PyModule>>,
15}
16
17impl Default for PythonBridge {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl PythonBridge {
24    pub fn new() -> Self {
25        pyo3::prepare_freethreaded_python();
26        Self {
27            modules: HashMap::new(),
28        }
29    }
30
31    pub fn ensure_module(&mut self, name: &str) -> PyResult<()> {
32        if self.modules.contains_key(name) {
33            return Ok(());
34        }
35        Python::with_gil(|py| {
36            let module = PyModule::import_bound(py, name)?;
37            self.modules.insert(name.to_string(), module.unbind());
38            Ok(())
39        })
40    }
41
42    pub fn call(&self, module_name: &str, func_name: &str, arg: Value) -> PyResult<Value> {
43        Python::with_gil(|py| {
44            let module = self.modules.get(module_name).ok_or_else(|| {
45                PyErr::new::<pyo3::exceptions::PyImportError, _>(format!(
46                    "Module {} not loaded",
47                    module_name
48                ))
49            })?;
50            let func = module.bind(py).getattr(func_name)?;
51            let py_arg = arg.to_python(py)?;
52            let result = func.call1((py_arg,))?;
53            Value::from_python(&result)
54        })
55    }
56
57    pub fn call_with_multiple_args(
58        &self,
59        module_name: &str,
60        func_name: &str,
61        args: Vec<Value>,
62    ) -> PyResult<Value> {
63        Python::with_gil(|py| {
64            let module = self.modules.get(module_name).ok_or_else(|| {
65                PyErr::new::<pyo3::exceptions::PyImportError, _>(format!(
66                    "Module {} not loaded",
67                    module_name
68                ))
69            })?;
70            let func = module.bind(py).getattr(func_name)?;
71            let py_args: Vec<PyObject> = args.iter().map(|a| a.to_python(py).unwrap()).collect();
72            let tuple = PyTuple::new_bound(py, &py_args);
73            let result = func.call1(tuple)?;
74            Value::from_python(&result)
75        })
76    }
77
78    pub fn reset(&mut self) {
79        self.modules.clear();
80    }
81}
82
83pub fn call_python_function(module: &str, func: &str, arg: Value) -> Result<Value, String> {
84    let mut bridge = BRIDGE.lock().map_err(|e| e.to_string())?;
85    bridge.ensure_module(module).map_err(|e| e.to_string())?;
86    bridge.call(module, func, arg).map_err(|e| e.to_string())
87}
88
89pub fn call_python_function_with_args(
90    module: &str,
91    func: &str,
92    args: Vec<Value>,
93) -> Result<Value, String> {
94    let mut bridge = BRIDGE.lock().map_err(|e| e.to_string())?;
95    bridge.ensure_module(module).map_err(|e| e.to_string())?;
96    bridge
97        .call_with_multiple_args(module, func, args)
98        .map_err(|e| e.to_string())
99}
100
101pub fn reset_python_bridge() {
102    let mut bridge = BRIDGE.lock().unwrap();
103    bridge.reset();
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_math_sqrt() {
112        let mut bridge = BRIDGE.lock().unwrap();
113        bridge.ensure_module("math").unwrap();
114        let result = bridge
115            .call("math", "sqrt", serde_json::json!(16.0))
116            .unwrap();
117        assert_eq!(result, serde_json::json!(4.0));
118    }
119
120    #[test]
121    fn test_math_pow() {
122        let mut bridge = BRIDGE.lock().unwrap();
123        bridge.ensure_module("math").unwrap();
124        let result = bridge
125            .call_with_multiple_args(
126                "math",
127                "pow",
128                vec![serde_json::json!(2.0), serde_json::json!(3.0)],
129            )
130            .unwrap();
131        assert_eq!(result, serde_json::json!(8.0));
132    }
133
134    #[test]
135    fn test_module_not_loaded_error() {
136        reset_python_bridge();
137        let result = call_python_function("nonexistent", "func", serde_json::json!(0)).unwrap_err();
138        assert!(
139            result.contains("nonexistent") || result.contains("not found"),
140            "got: {}",
141            result
142        );
143    }
144
145    #[test]
146    fn test_reset_clears_modules() {
147        // Hold the lock for the entire test so no concurrent test can
148        // re-insert "math" between reset() and the final assertion.
149        let mut bridge = BRIDGE.lock().unwrap();
150        bridge.ensure_module("math").unwrap();
151        assert!(bridge.modules.contains_key("math"));
152        bridge.reset();
153        assert!(!bridge.modules.contains_key("math"));
154    }
155}