use std::{sync::Arc, time::Duration};
use pyo3::prelude::*;
use tokio::sync::{mpsc, oneshot};
use crate::{error::Error, quota::QuotaState};
pub(crate) mod bridge_state;
pub(crate) mod command_loop;
pub(crate) mod ffi_dispatch;
mod handlers;
pub(crate) mod py_scripts;
pub(crate) mod streaming;
pub(crate) mod venv;
pub(crate) use bridge_state::{AgentBridgeState, AgentId, bridge_state};
pub(crate) use ffi_dispatch::{
CREATE_AGENT_HOOK_GUARD, INITIALIZING_HOOK_RUNNER, dispatch_rust_hook,
dispatch_rust_policy_confirm, dispatch_rust_tool,
};
#[must_use]
pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
chat_timeout + Duration::from_mins(2)
}
pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
const DEFAULT_CHANNEL_CAPACITY: usize = 64;
const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
#[must_use]
pub fn default_chat_timeout() -> Duration {
let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
val.parse::<u64>().unwrap_or_else(|e| {
tracing::warn!(
value = %val,
error = %e,
"Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
);
DEFAULT_CHAT_TIMEOUT_SECS
})
});
Duration::from_secs(secs)
}
pub(crate) enum PyCommand {
CreateAgent {
config_json: String,
reply: oneshot::Sender<Result<AgentId, Error>>,
},
Chat {
agent_id: AgentId,
prompt: String,
reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
},
ShutdownAgent {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
Cancel {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
WaitForIdle {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
Send {
agent_id: AgentId,
prompt: String,
reply: oneshot::Sender<Result<(), Error>>,
},
SignalIdle {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
WaitForWakeup {
agent_id: AgentId,
timeout_secs: f64,
reply: oneshot::Sender<Result<bool, Error>>,
},
Shutdown,
GetHistory {
agent_id: AgentId,
reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
},
GetTurnCount {
agent_id: AgentId,
reply: oneshot::Sender<Result<u32, Error>>,
},
GetTotalUsage {
agent_id: AgentId,
reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
},
GetLastTurnUsage {
agent_id: AgentId,
reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
},
ClearHistory {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
GetCompactionIndices {
agent_id: AgentId,
reply: oneshot::Sender<Result<Vec<u32>, Error>>,
},
GetLastResponse {
agent_id: AgentId,
reply: oneshot::Sender<Result<Option<String>, Error>>,
},
Delete {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
Disconnect {
agent_id: AgentId,
reply: oneshot::Sender<Result<(), Error>>,
},
IsIdle {
agent_id: AgentId,
reply: oneshot::Sender<Result<bool, Error>>,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct RuntimeConfig {
pub channel_capacity: usize,
pub operation_timeout: Duration,
pub shutdown_timeout: Duration,
pub chat_timeout: Duration,
pub inter_agent_delay: Duration,
}
impl Default for RuntimeConfig {
fn default() -> Self {
let chat_timeout = default_chat_timeout();
Self {
channel_capacity: DEFAULT_CHANNEL_CAPACITY,
operation_timeout: default_operation_timeout(chat_timeout),
shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
chat_timeout,
inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
}
}
}
pub struct PythonRuntime {
cmd_tx: mpsc::Sender<PyCommand>,
thread: Option<std::thread::JoinHandle<()>>,
config: RuntimeConfig,
quota_registry: crate::quota::QuotaRegistry,
quota_state: Arc<QuotaState>,
}
impl std::fmt::Debug for PythonRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PythonRuntime")
.field("config", &self.config)
.field(
"thread_running",
&self.thread.as_ref().is_some_and(|t| !t.is_finished()),
)
.finish_non_exhaustive()
}
}
impl PythonRuntime {
pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
let thread_config = config.clone();
let thread = std::thread::Builder::new()
.name("agy-bridge-python-runtime".into())
.spawn(move || {
python_thread_main(cmd_rx, &thread_config);
})
.map_err(|e| Error::BackendError {
message: format!("Failed to spawn Python runtime thread: {e}"),
})?;
let quota_registry = crate::quota::QuotaRegistry::new();
let quota_state = quota_registry.state_for_key("");
Ok(Self {
cmd_tx,
thread: Some(thread),
config,
quota_registry,
quota_state,
})
}
async fn send_command<T>(
&self,
operation: &str,
is_llm_op: bool,
build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
) -> Result<T, Error> {
let (reply_tx, reply_rx) = oneshot::channel();
let cmd = build_cmd(reply_tx);
self.cmd_tx
.send(cmd)
.await
.map_err(|e| Error::ChannelClosed {
message: format!("Python runtime thread has exited (sending {operation}): {e}"),
})?;
let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
reply_rx.await.map_err(|e| Error::ChannelClosed {
message: format!("Reply channel dropped for {operation}: {e}"),
})?
})
.await?;
if is_llm_op {
self.quota_state.record_success();
}
Ok(result)
}
pub async fn shutdown(mut self) -> Result<(), Error> {
if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
}
let Some(thread) = self.thread.take() else {
tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
return Ok(());
};
let shutdown_timeout = self.config.shutdown_timeout;
let join_result = tokio::time::timeout(
shutdown_timeout,
tokio::task::spawn_blocking(move || thread.join()),
)
.await;
match join_result {
Ok(Ok(Ok(()))) => {
tracing::info!("Python runtime thread joined successfully");
Ok(())
}
Ok(Ok(Err(panic_payload))) => {
let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
|| {
panic_payload
.downcast_ref::<String>()
.map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
},
|s| (*s).to_string(),
);
tracing::error!(
panic_message = %panic_msg,
"Python runtime thread panicked during shutdown"
);
Err(Error::BackendError {
message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
})
}
Ok(Err(join_err)) => {
tracing::error!("spawn_blocking join error: {join_err}");
Err(Error::BackendError {
message: format!("Failed to join Python thread: {join_err}"),
})
}
Err(_elapsed) => {
tracing::error!(
timeout_secs = shutdown_timeout.as_secs(),
"Python runtime thread did not exit within shutdown timeout"
);
Err(Error::Timeout {
duration: shutdown_timeout,
operation: "PythonRuntime::shutdown (thread join)".to_string(),
})
}
}
}
#[must_use]
pub const fn quota_state(&self) -> &Arc<QuotaState> {
&self.quota_state
}
}
impl Drop for PythonRuntime {
fn drop(&mut self) {
if self.thread.is_some() {
tracing::warn!(
"PythonRuntime dropped without calling shutdown() — \
Python thread may still be running"
);
}
}
}
fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
if let Err(e) = venv::configure_python_sys_path(py) {
tracing::warn!("Failed to configure Python sys.path in runtime thread: {e}");
}
});
if let Err(e) = run_live_thread(cmd_rx, config) {
tracing::error!(error = %e, "Python runtime thread failed");
}
tracing::info!("Python runtime thread exiting");
}
fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
Python::with_gil(|py| {
let asyncio = py
.import_bound("asyncio")
.map_err(|e| Error::BackendError {
message: format!("Failed to import asyncio: {e}"),
})?;
let event_loop =
asyncio
.call_method0("new_event_loop")
.map_err(|e| Error::BackendError {
message: format!("Failed to create new asyncio event loop: {e}"),
})?;
asyncio
.call_method1("set_event_loop", (&event_loop,))
.map_err(|e| Error::BackendError {
message: format!("Failed to set asyncio event loop: {e}"),
})?;
let sys = py.import_bound("sys").map_err(|e| Error::BackendError {
message: format!("Failed to import sys: {e}"),
})?;
let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
message: format!("Failed to get sys.modules: {e}"),
})?;
let globals_mod = if sys_modules
.contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
.map_err(|e| Error::BackendError {
message: format!("Failed to check sys.modules: {e}"),
})? {
sys_modules
.get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
.map_err(|e| Error::BackendError {
message: format!("Failed to get _agy_bridge_globals: {e}"),
})?
} else {
let types = py.import_bound("types").map_err(|e| Error::BackendError {
message: format!("Failed to import types: {e}"),
})?;
let module = types
.getattr("ModuleType")
.map_err(|e| Error::BackendError {
message: format!("Failed to get ModuleType: {e}"),
})?
.call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
.map_err(|e| Error::BackendError {
message: format!("Failed to create ModuleType: {e}"),
})?;
sys_modules
.set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
.map_err(|e| Error::BackendError {
message: format!("Failed to register _agy_bridge_globals: {e}"),
})?;
module
};
globals_mod
.setattr("EVENT_LOOP", &event_loop)
.map_err(|e| Error::BackendError {
message: format!("Failed to set EVENT_LOOP in globals: {e}"),
})?;
tracing::info!("Python asyncio event loop created on runtime thread");
let chat_timeout = config.chat_timeout;
let inter_agent_delay = config.inter_agent_delay;
let run_fut =
pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
command_loop::run_async_command_loop(cmd_rx, chat_timeout, inter_agent_delay).await
});
if let Err(e) = run_fut {
if let Err(close_err) = event_loop.call_method0("close") {
tracing::warn!("Failed to close asyncio event loop: {close_err}");
}
return Err(Error::BackendError {
message: format!("Python runtime command loop failed: {e}"),
});
}
if let Err(e) = event_loop.call_method0("close") {
tracing::warn!("Failed to close asyncio event loop: {e}");
}
Ok(())
})
}
impl crate::agent::Runtime for PythonRuntime {
async fn create_agent(
&self,
config: crate::config::AgentConfig,
) -> Result<crate::agent::AgentId, Error> {
let mut all_tools = config.custom_tool_names();
if let Some(ref caps) = config.capabilities {
if let Some(ref builtins) = caps.enabled_tools {
all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
} else if caps.disabled_tools.is_none() {
all_tools.extend(
crate::config::capabilities::BuiltinTools::all_tools()
.iter()
.map(|b| b.as_sdk_name().to_string()),
);
}
} else {
all_tools.extend(
crate::config::capabilities::BuiltinTools::all_tools()
.iter()
.map(|b| b.as_sdk_name().to_string()),
);
}
tracing::info!(
"Agent starting with {} available tools: {:?}",
all_tools.len(),
all_tools
);
let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
message: format!("Failed to serialize AgentConfig: {e}"),
})?;
let raw_id = self
.send_command("create_agent", false, |reply| PyCommand::CreateAgent {
config_json,
reply,
})
.await?;
Ok(raw_id.0)
}
async fn chat(
&self,
agent_id: crate::agent::AgentId,
content: &crate::content::Content,
) -> Result<crate::streaming::ChatResponseHandle, Error> {
let prompt = match content {
crate::content::Content::Text { text } => text.clone(),
other => crate::content::content_to_json(other)?,
};
self.send_command("chat", true, |reply| PyCommand::Chat {
agent_id: AgentId(agent_id),
prompt,
reply,
})
.await
}
async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
agent_id: AgentId(agent_id),
reply,
})
.await
}
fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
let (reply, _) = oneshot::channel();
if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
agent_id: AgentId(agent_id),
reply,
}) {
tracing::debug!(
agent_id = agent_id,
error = %e,
"try_shutdown_agent: channel send failed (runtime may already be gone)"
);
}
}
async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("cancel", false, |reply| PyCommand::Cancel {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn send(
&self,
agent_id: crate::agent::AgentId,
content: &crate::content::Content,
) -> Result<(), Error> {
let prompt = match content {
crate::content::Content::Text { text } => text.clone(),
other => crate::content::content_to_json(other)?,
};
self.send_command("send", false, |reply| PyCommand::Send {
agent_id: AgentId(agent_id),
prompt,
reply,
})
.await
}
async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn wait_for_wakeup(
&self,
agent_id: crate::agent::AgentId,
timeout: std::time::Duration,
) -> Result<bool, Error> {
self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
agent_id: AgentId(agent_id),
timeout_secs: timeout.as_secs_f64(),
reply,
})
.await
}
async fn wait_for_quota(&self) {
self.quota_state.wait_for_quota().await;
}
async fn record_quota_hit(&self, retry_after: std::time::Duration) {
self.quota_state.record_quota_hit(retry_after);
}
fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
&self.quota_registry
}
async fn history(
&self,
agent_id: crate::agent::AgentId,
) -> Result<Vec<crate::types::ConversationMessage>, Error> {
self.send_command("get_history", false, |reply| PyCommand::GetHistory {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn total_usage(
&self,
agent_id: crate::agent::AgentId,
) -> Result<crate::types::UsageMetadata, Error> {
self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn last_turn_usage(
&self,
agent_id: crate::agent::AgentId,
) -> Result<crate::types::UsageMetadata, Error> {
self.send_command("get_last_turn_usage", false, |reply| {
PyCommand::GetLastTurnUsage {
agent_id: AgentId(agent_id),
reply,
}
})
.await
}
async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
self.send_command("compaction_indices", false, |reply| {
PyCommand::GetCompactionIndices {
agent_id: AgentId(agent_id),
reply,
}
})
.await
}
async fn last_response(
&self,
agent_id: crate::agent::AgentId,
) -> Result<Option<String>, Error> {
self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("delete", false, |reply| PyCommand::Delete {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
agent_id: AgentId(agent_id),
reply,
})
.await
}
async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
agent_id: AgentId(agent_id),
reply,
})
.await
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::{ffi_dispatch::check_tool_execution_allowed, *};
fn test_config() -> RuntimeConfig {
RuntimeConfig {
channel_capacity: 16,
operation_timeout: Duration::from_secs(10),
shutdown_timeout: Duration::from_secs(5),
chat_timeout: Duration::from_mins(1),
inter_agent_delay: Duration::from_millis(100),
}
}
#[tokio::test]
async fn test_runtime_creation_and_shutdown() {
PythonRuntime::new(test_config())
.expect("Failed to create runtime")
.shutdown()
.await
.expect("Shutdown failed");
}
#[test]
fn runtime_config_serde_roundtrip() {
let config = test_config();
let json = serde_json::to_string(&config).unwrap();
let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.channel_capacity, 16);
assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
}
#[test]
fn default_operation_timeout_is_chat_plus_margin() {
let config = RuntimeConfig::default();
let expected = config.chat_timeout + Duration::from_mins(2);
assert_eq!(
config.operation_timeout, expected,
"operation_timeout should be chat_timeout + 2min safety margin"
);
}
#[test]
fn safety_error_structural() {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let globals = pyo3::types::PyDict::new_bound(py);
py.run_bound(
r#"
class StopCandidateException(Exception):
pass
err = StopCandidateException("dummy")
"#,
Some(&globals),
None,
)
.unwrap();
let err_obj = globals.get_item("err").unwrap().unwrap();
let err = PyErr::from_value_bound(err_obj);
let mapped = crate::error::classify_py_error(py, &err);
assert!(
!matches!(mapped, crate::error::Error::Safety),
"Failed: matched Error::Safety based purely on the string name StopCandidateException!"
);
});
}
#[test]
fn maxtokens_error_structural() {
pyo3::prepare_freethreaded_python();
Python::with_gil(|py| {
let globals = pyo3::types::PyDict::new_bound(py);
py.run_bound(
r#"
class MaxTokensException(Exception):
pass
err = MaxTokensException("dummy")
"#,
Some(&globals),
None,
)
.unwrap();
let err_obj = globals.get_item("err").unwrap().unwrap();
let err = PyErr::from_value_bound(err_obj);
let mapped = crate::error::classify_py_error(py, &err);
assert!(
!matches!(mapped, crate::error::Error::MaxTokens),
"Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
);
});
}
struct MockAskUserHandler {
should_allow: std::sync::atomic::AtomicBool,
}
impl crate::policies::AskUserHandler for MockAskUserHandler {
fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
}
}
#[test]
fn test_ask_user_policy_custom_tool_gating() {
let agent_id: u64 = 999;
let mut policies = crate::policies::PolicySet::new();
policies
.push(crate::policies::PolicyRule::AskUser {
tool: "dangerous_tool".to_owned(),
handler_id: "confirm_handler".to_owned(),
})
.unwrap();
let handler = Arc::new(MockAskUserHandler {
should_allow: std::sync::atomic::AtomicBool::new(true),
});
let mut registry = crate::tools::ToolRegistry::new();
#[crate::llm_tool]
fn dangerous_tool() -> Result<String, String> {
Ok("Executed dangerous action!".to_owned())
}
registry.register(DangerousTool);
bridge_state().write().unwrap().insert(
agent_id,
AgentBridgeState {
registry: Some(Arc::new(registry)),
hook_runner: None,
policies,
policy_handler: Some(
Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
),
tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
},
);
handler
.should_allow
.store(true, std::sync::atomic::Ordering::SeqCst);
let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
assert!(res.is_ok(), "Check should succeed");
assert!(
res.unwrap(),
"Should allow tool execution when handler returns true"
);
handler
.should_allow
.store(false, std::sync::atomic::Ordering::SeqCst);
let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
assert!(res.is_ok(), "Check should succeed");
assert!(
!res.unwrap(),
"Should block tool execution when handler returns false"
);
bridge_state().write().unwrap().remove(&agent_id);
}
}