use pyo3::prelude::*;
use pyo3::types::IntoPyDict;
use rust_mcp_sdk::{
macros::{mcp_tool, JsonSchema},
schema::{schema_utils::CallToolError, CallToolResult, TextContent},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, error};
#[mcp_tool(
name = "angreal_command",
description = "Execute an angreal command with arguments",
idempotent_hint = false,
destructive_hint = false,
open_world_hint = true,
read_only_hint = false
)]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct AngrealCommandTool {
pub command_path: String,
pub args: Option<HashMap<String, serde_json::Value>>,
}
impl AngrealCommandTool {
pub async fn call_tool(&self) -> Result<CallToolResult, CallToolError> {
debug!("Executing angreal command: {}", self.command_path);
angreal::initialize_python_tasks().map_err(|e| {
CallToolError::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to initialize angreal tasks: {}", e),
))
})?;
let tasks = angreal::task::ANGREAL_TASKS.lock().map_err(|e| {
CallToolError::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to lock ANGREAL_TASKS: {}", e),
))
})?;
let command = tasks.get(&self.command_path).ok_or_else(|| {
CallToolError::new(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Command '{}' not found", self.command_path),
))
})?;
debug!(
"Found command '{}', executing with args: {:?}",
command.name, self.args
);
let args_registry = angreal::task::ANGREAL_ARGS.lock().map_err(|e| {
CallToolError::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to lock ANGREAL_ARGS: {}", e),
))
})?;
let command_args = args_registry
.get(&self.command_path)
.cloned()
.unwrap_or_default();
let result = Python::with_gil(|py| -> PyResult<(String, String, String)> {
debug!("Starting Python execution for command: {}", command.name);
let sys = py.import("sys")?;
let io = py.import("io")?;
let stdout_capture = io.call_method0("StringIO")?;
let stderr_capture = io.call_method0("StringIO")?;
let original_stdout = sys.getattr("stdout")?;
let original_stderr = sys.getattr("stderr")?;
sys.setattr("stdout", &stdout_capture)?;
sys.setattr("stderr", &stderr_capture)?;
let mut kwargs: Vec<(&str, PyObject)> = Vec::new();
if let Some(provided_args) = &self.args {
for arg in command_args.iter() {
let arg_name = &arg.name;
if let Some(value) = provided_args.get(arg_name) {
let python_type = arg.python_type.as_deref().unwrap_or("str");
let py_value = match python_type {
"str" => {
if let Some(s) = value.as_str() {
s.to_object(py)
} else {
value.to_string().to_object(py)
}
}
"int" => {
if let Some(i) = value.as_i64() {
i.to_object(py)
} else if let Some(s) = value.as_str() {
s.parse::<i64>().unwrap_or(0).to_object(py)
} else {
0.to_object(py)
}
}
"float" => {
if let Some(f) = value.as_f64() {
f.to_object(py)
} else if let Some(s) = value.as_str() {
s.parse::<f64>().unwrap_or(0.0).to_object(py)
} else {
0.0.to_object(py)
}
}
"bool" => {
if let Some(b) = value.as_bool() {
b.to_object(py)
} else if let Some(s) = value.as_str() {
s.parse::<bool>().unwrap_or(false).to_object(py)
} else {
false.to_object(py)
}
}
_ => value.to_string().to_object(py),
};
kwargs.push((Box::leak(Box::new(arg_name.clone())).as_str(), py_value));
} else if arg.is_flag.unwrap_or(false) {
kwargs.push((
Box::leak(Box::new(arg_name.clone())).as_str(),
false.to_object(py),
));
}
}
}
debug!("Calling Python function with {} arguments", kwargs.len());
let result = command.func.call(py, (), Some(kwargs.into_py_dict(py)));
sys.setattr("stdout", original_stdout)?;
sys.setattr("stderr", original_stderr)?;
let stdout_output = stdout_capture.call_method0("getvalue")?.to_string();
let stderr_output = stderr_capture.call_method0("getvalue")?.to_string();
let result_str = match result {
Ok(result_obj) => {
if result_obj.is_none(py) {
format!("Command '{}' executed successfully", command.name)
} else {
result_obj.to_string()
}
}
Err(e) => {
return Err(e);
}
};
Ok((result_str, stdout_output, stderr_output))
});
match result {
Ok((return_value, stdout, stderr)) => {
debug!("Successfully executed command: {}", self.command_path);
let response = serde_json::json!({
"command": self.command_path,
"result": "success",
"return_value": return_value,
"stdout": stdout,
"stderr": stderr,
"timestamp": chrono::Utc::now().to_rfc3339()
});
Ok(CallToolResult::text_content(vec![TextContent::from(
serde_json::to_string_pretty(&response).map_err(CallToolError::new)?,
)]))
}
Err(err) => {
error!("Failed to execute command '{}': {}", self.command_path, err);
let error_response = serde_json::json!({
"command": self.command_path,
"result": "error",
"error": err.to_string(),
"timestamp": chrono::Utc::now().to_rfc3339()
});
Ok(CallToolResult::text_content(vec![TextContent::from(
serde_json::to_string_pretty(&error_response).map_err(CallToolError::new)?,
)]))
}
}
}
}