#![expect(clippy::useless_conversion)]
use std::sync::Arc;
use pyo3::prelude::*;
use super::bridge_state::bridge_state;
pub(crate) static INITIALIZING_HOOK_RUNNER: std::sync::Mutex<Option<Arc<crate::hooks::Hooks>>> =
std::sync::Mutex::new(None);
pub(crate) static CREATE_AGENT_HOOK_GUARD: tokio::sync::Mutex<()> =
tokio::sync::Mutex::const_new(());
fn dispatch_hook_by_name(
hook_runner: &crate::hooks::Hooks,
hook_point: &str,
context_json: &str,
) -> Result<String, crate::error::Error> {
let mut result_json = String::new();
match hook_point {
"pre_turn" => {
let ctx = serde_json::from_str::<crate::hooks::PreTurnContext>(context_json).map_err(
|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize PreTurnContext: {e}"),
},
)?;
hook_runner.run_pre_turn(&ctx);
}
"post_turn" => {
let ctx = serde_json::from_str::<crate::hooks::PostTurnContext>(context_json).map_err(
|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize PostTurnContext: {e}"),
},
)?;
hook_runner.run_post_turn(&ctx);
}
"pre_tool_call_decide" => {
let ctx = serde_json::from_str::<crate::hooks::PreToolCallDecideContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize PreToolCallDecideContext: {e} | JSON was: {context_json}"),
})?;
let hook_result = hook_runner.run_pre_tool_call_decide(&ctx);
result_json = serde_json::to_string(&hook_result).map_err(|e| {
crate::error::Error::BackendError {
message: format!("Failed to serialize PreToolCallDecide result: {e}"),
}
})?;
}
"post_tool_call" => {
let ctx = serde_json::from_str::<crate::hooks::PostToolCallContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!(
"Failed to deserialize PostToolCallContext: {e} | JSON was: {context_json}"
),
})?;
hook_runner.run_post_tool_call(&ctx);
}
"on_compaction" => {
let ctx = serde_json::from_str::<crate::hooks::OnCompactionContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize OnCompactionContext: {e}"),
})?;
hook_runner.run_on_compaction(&ctx);
}
"on_session_start" => {
let ctx = serde_json::from_str::<crate::hooks::OnSessionStartContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize OnSessionStartContext: {e}"),
})?;
hook_runner.run_on_session_start(&ctx);
}
"on_session_end" => {
let ctx = serde_json::from_str::<crate::hooks::OnSessionEndContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize OnSessionEndContext: {e}"),
})?;
hook_runner.run_on_session_end(&ctx);
}
"on_tool_error" => {
let ctx = serde_json::from_str::<crate::hooks::OnToolErrorContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize OnToolErrorContext: {e}"),
})?;
hook_runner.run_on_tool_error(&ctx);
}
"on_interaction" => {
let ctx = serde_json::from_str::<crate::hooks::OnInteractionContext>(context_json)
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to deserialize OnInteractionContext: {e}"),
})?;
let hook_result = hook_runner.run_on_interaction(&ctx);
result_json = serde_json::to_string(&hook_result).map_err(|e| {
crate::error::Error::BackendError {
message: format!("Failed to serialize OnInteraction result: {e}"),
}
})?;
}
_ => {
tracing::warn!("Unknown hook point: {}", hook_point);
}
}
Ok(result_json)
}
#[pyfunction]
pub(crate) fn dispatch_rust_hook(
py: Python<'_>,
agent_id: u64,
hook_point: String,
context_json: String,
) -> PyResult<Bound<'_, PyAny>> {
tracing::debug!(agent_id, hook_point = %hook_point, "dispatch_rust_hook called from Python");
let hook_runner = {
let map = bridge_state().read().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
})?;
if let Some(entry) = map.get(&agent_id) {
let runner = entry.hook_runner.as_ref().ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active Hooks found for agent ID {agent_id}"
))
})?;
Arc::clone(runner)
} else {
let opt = INITIALIZING_HOOK_RUNNER.lock().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to lock INITIALIZING_HOOK_RUNNER: {e}"
))
})?;
if let Some(ref runner) = *opt {
Arc::clone(runner)
} else {
return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active bridge state or initializing hook runner found for agent ID {agent_id}"
)));
}
}
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let result = tokio::task::spawn_blocking(move || {
dispatch_hook_by_name(&hook_runner, &hook_point, &context_json)
})
.await
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Hook execution failed: {e}"))
})?
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(result)
})
}
#[pyfunction]
pub(crate) fn dispatch_rust_policy_confirm(
py: Python<'_>,
agent_id: u64,
tool_name: String,
args_json: String,
) -> PyResult<Bound<'_, PyAny>> {
tracing::info!(agent_id, tool = %tool_name, "dispatch_rust_policy_confirm called from Python");
let policy_handler = {
let map = bridge_state().read().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
})?;
let entry = map.get(&agent_id).ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active bridge state found for agent ID {agent_id}"
))
})?;
let handler = entry.policy_handler.as_ref().ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active AskUserHandler found for agent ID {agent_id}"
))
})?;
Arc::clone(handler)
};
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let args_val: serde_json::Value = serde_json::from_str(&args_json).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!(
"Failed to parse policy args JSON: {e}"
))
})?;
let result =
tokio::task::spawn_blocking(move || policy_handler.confirm(&tool_name, &args_val))
.await
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Policy confirmation panicked: {e}"
))
})?;
Ok(result)
})
}
pub(crate) fn check_tool_execution_allowed(
agent_id: u64,
name: &str,
args_json: &str,
) -> Result<bool, crate::error::Error> {
let map = bridge_state()
.read()
.map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to read BRIDGE_STATE: {e}"),
})?;
let Some(state) = map.get(&agent_id) else {
return Ok(false);
};
let (is_allowed, needs_confirm) = match state.policies.evaluate(name) {
crate::policies::PolicyDecision::Allow => (true, false),
crate::policies::PolicyDecision::Deny => (false, false),
crate::policies::PolicyDecision::NeedsConfirmation { .. } => (false, true),
};
if is_allowed {
return Ok(true);
}
if needs_confirm && let Some(ref handler) = state.policy_handler {
let handler = Arc::clone(handler);
drop(map);
let args_val: serde_json::Value =
serde_json::from_str(args_json).map_err(|e| crate::error::Error::BackendError {
message: format!("Failed to parse policy args JSON: {e}"),
})?;
return Ok(handler.confirm(name, &args_val));
}
Ok(false)
}
#[pyfunction]
pub(crate) fn dispatch_rust_tool<'py>(
py: Python<'py>,
agent_id: u64,
name: String,
args_json: &str,
) -> PyResult<Bound<'py, PyAny>> {
tracing::info!(agent_id, tool = %name, "dispatch_rust_tool called from Python (async)");
let is_allowed = check_tool_execution_allowed(agent_id, &name, args_json)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
if !is_allowed {
return Err(pyo3::exceptions::PyPermissionError::new_err(format!(
"Tool '{name}' execution blocked by agent policy rules"
)));
}
let (registry, tool_state) = {
let map = bridge_state().read().map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
})?;
let entry = map.get(&agent_id).ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active bridge state found for agent ID {agent_id}"
))
})?;
let registry = entry.registry.as_ref().ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"No active ToolRegistry found for agent ID {agent_id}"
))
})?;
(Arc::clone(registry), Arc::clone(&entry.tool_state))
};
let args: serde_json::Value = serde_json::from_str(args_json).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Failed to parse tool arguments JSON: {e}"))
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let ctx = crate::tools::ToolContext::with_shared_state(None, tool_state);
let output = registry
.dispatch(&name, args, &ctx)
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(output.into_content())
})
}