use std::ffi::CString;
use std::path::{Path, PathBuf};
use std::sync::Once;
use std::time::Duration;
use anyhow::{anyhow, Context, Result};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use serde_json::Value;
use tracing::{debug, error, warn};
use super::Tool;
use crate::llm::ToolSpec;
static PYTHON_INIT: Once = Once::new();
fn ensure_python_initialized() {
PYTHON_INIT.call_once(|| {
Python::initialize();
debug!("Python interpreter initialized");
});
}
pub struct PythonTool {
name: String,
script_path: PathBuf,
spec: ToolSpec,
timeout: Option<Duration>,
}
impl PythonTool {
pub fn new(name: String, script_path: impl AsRef<Path>, spec: ToolSpec) -> Result<Self> {
ensure_python_initialized();
let script_path = script_path.as_ref().to_path_buf();
if !script_path.exists() {
return Err(anyhow!(
"Python script not found: {}",
script_path.display()
));
}
Ok(Self {
name,
script_path,
spec,
timeout: Some(Duration::from_secs(60)),
})
}
pub fn with_timeout(mut self, timeout: Option<Duration>) -> Self {
self.timeout = timeout;
self
}
fn execute_python(&self, args: Value) -> Result<Value> {
Python::attach(|py| {
let script_source = std::fs::read_to_string(&self.script_path).with_context(|| {
format!(
"Failed to read Python script: {}",
self.script_path.display()
)
})?;
let code = CString::new(script_source)
.context("Python script source contains an embedded NUL byte")?;
let file_name = CString::new(self.script_path.to_string_lossy().as_ref())
.context("Python script path contains an embedded NUL byte")?;
let module_name = CString::new(self.name.as_str())
.context("Python module name contains an embedded NUL byte")?;
let module = PyModule::from_code(
py,
code.as_c_str(),
file_name.as_c_str(),
module_name.as_c_str(),
)
.with_context(|| "Failed to load Python module")?;
let execute_fn = module
.getattr("execute")
.with_context(|| "Python script must define an 'execute' function")?;
let py_args = Self::json_to_python(py, &args)?;
let py_result = execute_fn
.call1((py_args,))
.with_context(|| "Python execute function failed")?;
let result = Self::python_to_json(&py_result)?;
Ok(result)
})
}
fn json_to_python<'py>(py: Python<'py>, value: &Value) -> Result<Bound<'py, PyAny>> {
use pyo3::IntoPyObjectExt;
match value {
Value::Null => Ok(py.None().into_bound(py)),
Value::Bool(b) => Ok((*b).into_bound_py_any(py)?),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(i.into_bound_py_any(py)?)
} else if let Some(u) = n.as_u64() {
Ok(u.into_bound_py_any(py)?)
} else if let Some(f) = n.as_f64() {
Ok(f.into_bound_py_any(py)?)
} else {
Err(anyhow!("Unsupported number type"))
}
}
Value::String(s) => Ok(s.as_str().into_bound_py_any(py)?),
Value::Array(arr) => {
let py_list = pyo3::types::PyList::empty(py);
for item in arr {
let py_item = Self::json_to_python(py, item)?;
py_list.append(py_item)?;
}
Ok(py_list.into_any())
}
Value::Object(obj) => {
let py_dict = PyDict::new(py);
for (key, value) in obj {
let py_value = Self::json_to_python(py, value)?;
py_dict.set_item(key, py_value)?;
}
Ok(py_dict.into_any())
}
}
}
fn python_to_json(obj: &Bound<'_, PyAny>) -> Result<Value> {
if obj.is_none() {
Ok(Value::Null)
} else if let Ok(b) = obj.extract::<bool>() {
Ok(Value::Bool(b))
} else if let Ok(i) = obj.extract::<i64>() {
Ok(serde_json::to_value(i)?)
} else if let Ok(f) = obj.extract::<f64>() {
Ok(serde_json::to_value(f)?)
} else if let Ok(s) = obj.extract::<String>() {
Ok(Value::String(s))
} else if let Ok(list) = obj.cast::<pyo3::types::PyList>() {
let mut arr = Vec::new();
for item in list.iter() {
arr.push(Self::python_to_json(&item)?);
}
Ok(Value::Array(arr))
} else if let Ok(dict) = obj.cast::<pyo3::types::PyDict>() {
let mut map = serde_json::Map::new();
for (key, value) in dict.iter() {
let key_str = key.extract::<String>()?;
let value_json = Self::python_to_json(&value)?;
map.insert(key_str, value_json);
}
Ok(Value::Object(map))
} else {
warn!("Unsupported Python type, converting to string");
let s = obj.str()?.extract::<String>()?;
Ok(Value::String(s))
}
}
}
impl Tool for PythonTool {
fn name(&self) -> &str {
&self.name
}
fn spec(&self) -> Result<ToolSpec> {
Ok(self.spec.clone())
}
fn execute(&self, args: Value) -> Result<Value> {
debug!(
tool = %self.name,
script = %self.script_path.display(),
"Executing Python tool"
);
let result = self.execute_python(args);
match result {
Ok(value) => {
debug!(tool = %self.name, "Python tool succeeded");
Ok(value)
}
Err(e) => {
error!(
tool = %self.name,
error = %e,
"Python tool execution failed"
);
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_script(code: &str) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
file.write_all(code.as_bytes()).unwrap();
file.flush().unwrap();
file
}
#[test]
fn test_python_tool_basic() {
let script = create_test_script(
r#"
def execute(args):
message = args.get("message", "")
return {"output": f"Echo: {message}"}
"#,
);
let spec = serde_json::from_value(json!({
"type": "function",
"name": "echo",
"description": "Echo tool",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string"}
}
}
}))
.unwrap();
let tool = PythonTool::new("echo".to_string(), script.path(), spec).unwrap();
let result = tool.execute(json!({"message": "hello"})).unwrap();
assert_eq!(result["output"], "Echo: hello");
}
#[test]
fn test_python_tool_error_handling() {
let script = create_test_script(
r#"
def execute(args):
raise ValueError("Test error")
"#,
);
let spec = serde_json::from_value(json!({
"type": "function",
"name": "error",
"description": "Error tool",
"parameters": {"type": "object", "properties": {}}
}))
.unwrap();
let tool = PythonTool::new("error".to_string(), script.path(), spec).unwrap();
let result = tool.execute(json!({}));
assert!(result.is_err());
}
#[test]
fn test_json_python_conversion() {
let script = create_test_script(
r#"
def execute(args):
return args
"#,
);
let spec = serde_json::from_value(json!({
"type": "function",
"name": "passthrough",
"description": "Passthrough tool",
"parameters": {"type": "object"}
}))
.unwrap();
let tool = PythonTool::new("passthrough".to_string(), script.path(), spec).unwrap();
let input = json!({
"string": "hello",
"number": 42,
"bool": true,
"null": null,
"array": [1, 2, 3],
"object": {"nested": "value"}
});
let result = tool.execute(input.clone()).unwrap();
assert_eq!(result, input);
}
}