1#![expect(clippy::useless_conversion)] use std::{collections::HashMap, sync::Arc, time::Duration};
41
42use pyo3::prelude::*;
43use tokio::sync::{mpsc, oneshot};
44
45use crate::{error::Error, quota::QuotaState};
46
47pub(crate) mod command_loop;
48mod handlers;
49pub(crate) mod py_scripts;
50pub(crate) mod streaming;
51pub(crate) mod venv;
52
53#[must_use]
63pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
64 chat_timeout + Duration::from_mins(2)
65}
66pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
69
70pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
72
73const DEFAULT_CHANNEL_CAPACITY: usize = 64;
75
76const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
78
79#[must_use]
82pub fn default_chat_timeout() -> Duration {
83 let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
84 val.parse::<u64>().unwrap_or_else(|e| {
85 tracing::warn!(
86 value = %val,
87 error = %e,
88 "Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
89 );
90 DEFAULT_CHAT_TIMEOUT_SECS
91 })
92 });
93 Duration::from_secs(secs)
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub(crate) struct AgentId(pub(crate) u64);
99
100impl std::fmt::Display for AgentId {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(f, "agent-{}", self.0)
103 }
104}
105
106pub(crate) struct AgentBridgeState {
112 pub(crate) registry: Option<Arc<crate::tools::ToolRegistry>>,
114 pub(crate) hook_runner: Option<Arc<crate::hooks::Hooks>>,
116 pub(crate) policies: crate::policies::PolicySet,
118 pub(crate) policy_handler: Option<Arc<dyn crate::policies::AskUserHandler>>,
120 pub(crate) tool_state: Arc<std::sync::RwLock<HashMap<String, serde_json::Value>>>,
122}
123
124static BRIDGE_STATE: std::sync::OnceLock<
141 std::sync::RwLock<std::collections::HashMap<u64, AgentBridgeState>>,
142> = std::sync::OnceLock::new();
143
144pub(crate) fn bridge_state()
146-> &'static std::sync::RwLock<std::collections::HashMap<u64, AgentBridgeState>> {
147 BRIDGE_STATE.get_or_init(|| std::sync::RwLock::new(std::collections::HashMap::new()))
148}
149
150pub(crate) static INITIALIZING_HOOK_RUNNER: std::sync::Mutex<Option<Arc<crate::hooks::Hooks>>> =
152 std::sync::Mutex::new(None);
153
154pub(crate) static CREATE_AGENT_HOOK_GUARD: tokio::sync::Mutex<()> =
158 tokio::sync::Mutex::const_new(());
159
160fn dispatch_hook_by_name(
164 hook_runner: &crate::hooks::Hooks,
165 hook_point: &str,
166 context_json: &str,
167) -> Result<String, crate::error::Error> {
168 let mut result_json = String::new();
169 match hook_point {
170 "pre_turn" => {
171 let ctx = serde_json::from_str::<crate::hooks::PreTurnContext>(context_json).map_err(
172 |e| crate::error::Error::BackendError {
173 message: format!("Failed to deserialize PreTurnContext: {e}"),
174 },
175 )?;
176 hook_runner.run_pre_turn(&ctx);
177 }
178 "post_turn" => {
179 let ctx = serde_json::from_str::<crate::hooks::PostTurnContext>(context_json).map_err(
180 |e| crate::error::Error::BackendError {
181 message: format!("Failed to deserialize PostTurnContext: {e}"),
182 },
183 )?;
184 hook_runner.run_post_turn(&ctx);
185 }
186 "pre_tool_call_decide" => {
187 let ctx = serde_json::from_str::<crate::hooks::PreToolCallDecideContext>(context_json)
188 .map_err(|e| crate::error::Error::BackendError {
189 message: format!("Failed to deserialize PreToolCallDecideContext: {e} | JSON was: {context_json}"),
190 })?;
191 let hook_result = hook_runner.run_pre_tool_call_decide(&ctx);
192 result_json = serde_json::to_string(&hook_result).map_err(|e| {
193 crate::error::Error::BackendError {
194 message: format!("Failed to serialize PreToolCallDecide result: {e}"),
195 }
196 })?;
197 }
198 "post_tool_call" => {
199 let ctx = serde_json::from_str::<crate::hooks::PostToolCallContext>(context_json)
200 .map_err(|e| crate::error::Error::BackendError {
201 message: format!(
202 "Failed to deserialize PostToolCallContext: {e} | JSON was: {context_json}"
203 ),
204 })?;
205 hook_runner.run_post_tool_call(&ctx);
206 }
207 "on_compaction" => {
208 let ctx = serde_json::from_str::<crate::hooks::OnCompactionContext>(context_json)
209 .map_err(|e| crate::error::Error::BackendError {
210 message: format!("Failed to deserialize OnCompactionContext: {e}"),
211 })?;
212 hook_runner.run_on_compaction(&ctx);
213 }
214 "on_session_start" => {
215 let ctx = serde_json::from_str::<crate::hooks::OnSessionStartContext>(context_json)
216 .map_err(|e| crate::error::Error::BackendError {
217 message: format!("Failed to deserialize OnSessionStartContext: {e}"),
218 })?;
219 hook_runner.run_on_session_start(&ctx);
220 }
221 "on_session_end" => {
222 let ctx = serde_json::from_str::<crate::hooks::OnSessionEndContext>(context_json)
223 .map_err(|e| crate::error::Error::BackendError {
224 message: format!("Failed to deserialize OnSessionEndContext: {e}"),
225 })?;
226 hook_runner.run_on_session_end(&ctx);
227 }
228 "on_tool_error" => {
229 let ctx = serde_json::from_str::<crate::hooks::OnToolErrorContext>(context_json)
230 .map_err(|e| crate::error::Error::BackendError {
231 message: format!("Failed to deserialize OnToolErrorContext: {e}"),
232 })?;
233 hook_runner.run_on_tool_error(&ctx);
234 }
235 "on_interaction" => {
236 let ctx = serde_json::from_str::<crate::hooks::OnInteractionContext>(context_json)
237 .map_err(|e| crate::error::Error::BackendError {
238 message: format!("Failed to deserialize OnInteractionContext: {e}"),
239 })?;
240 let hook_result = hook_runner.run_on_interaction(&ctx);
241 result_json = serde_json::to_string(&hook_result).map_err(|e| {
242 crate::error::Error::BackendError {
243 message: format!("Failed to serialize OnInteraction result: {e}"),
244 }
245 })?;
246 }
247 _ => {
248 tracing::warn!("Unknown hook point: {}", hook_point);
249 }
250 }
251 Ok(result_json)
252}
253
254#[pyfunction]
256pub(crate) fn dispatch_rust_hook(
257 py: Python<'_>,
258 agent_id: u64,
259 hook_point: String,
260 context_json: String,
261) -> PyResult<Bound<'_, PyAny>> {
262 tracing::debug!(agent_id, hook_point = %hook_point, "dispatch_rust_hook called from Python");
263 let hook_runner = {
264 let map = bridge_state().read().map_err(|e| {
265 pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
266 })?;
267 if let Some(entry) = map.get(&agent_id) {
268 let runner = entry.hook_runner.as_ref().ok_or_else(|| {
269 pyo3::exceptions::PyRuntimeError::new_err(format!(
270 "No active Hooks found for agent ID {agent_id}"
271 ))
272 })?;
273 Arc::clone(runner)
274 } else {
275 let opt = INITIALIZING_HOOK_RUNNER.lock().map_err(|e| {
276 pyo3::exceptions::PyRuntimeError::new_err(format!(
277 "Failed to lock INITIALIZING_HOOK_RUNNER: {e}"
278 ))
279 })?;
280 if let Some(ref runner) = *opt {
281 Arc::clone(runner)
282 } else {
283 return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
284 "No active bridge state or initializing hook runner found for agent ID {agent_id}"
285 )));
286 }
287 }
288 };
289
290 pyo3_async_runtimes::tokio::future_into_py(py, async move {
291 let result = tokio::task::spawn_blocking(move || {
296 dispatch_hook_by_name(&hook_runner, &hook_point, &context_json)
297 })
298 .await
299 .map_err(|e| {
300 pyo3::exceptions::PyRuntimeError::new_err(format!("Hook execution failed: {e}"))
301 })?
302 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
303
304 Ok(result)
305 })
306}
307
308#[pyfunction]
309pub(crate) fn dispatch_rust_policy_confirm(
310 py: Python<'_>,
311 agent_id: u64,
312 tool_name: String,
313 args_json: String,
314) -> PyResult<Bound<'_, PyAny>> {
315 tracing::info!(agent_id, tool = %tool_name, "dispatch_rust_policy_confirm called from Python");
316 let policy_handler = {
317 let map = bridge_state().read().map_err(|e| {
318 pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
319 })?;
320 let entry = map.get(&agent_id).ok_or_else(|| {
321 pyo3::exceptions::PyRuntimeError::new_err(format!(
322 "No active bridge state found for agent ID {agent_id}"
323 ))
324 })?;
325 let handler = entry.policy_handler.as_ref().ok_or_else(|| {
326 pyo3::exceptions::PyRuntimeError::new_err(format!(
327 "No active AskUserHandler found for agent ID {agent_id}"
328 ))
329 })?;
330 Arc::clone(handler)
331 };
332
333 pyo3_async_runtimes::tokio::future_into_py(py, async move {
334 let args_val: serde_json::Value = serde_json::from_str(&args_json).map_err(|e| {
337 pyo3::exceptions::PyValueError::new_err(format!(
338 "Failed to parse policy args JSON: {e}"
339 ))
340 })?;
341 let result =
342 tokio::task::spawn_blocking(move || policy_handler.confirm(&tool_name, &args_val))
343 .await
344 .map_err(|e| {
345 pyo3::exceptions::PyRuntimeError::new_err(format!(
346 "Policy confirmation panicked: {e}"
347 ))
348 })?;
349
350 Ok(result)
351 })
352}
353
354pub(crate) fn check_tool_execution_allowed(
356 agent_id: u64,
357 name: &str,
358 args_json: &str,
359) -> Result<bool, crate::error::Error> {
360 let map = bridge_state()
361 .read()
362 .map_err(|e| crate::error::Error::BackendError {
363 message: format!("Failed to read BRIDGE_STATE: {e}"),
364 })?;
365
366 let Some(state) = map.get(&agent_id) else {
367 return Ok(false);
368 };
369
370 let (is_allowed, needs_confirm) = match state.policies.evaluate(name) {
371 crate::policies::PolicyDecision::Allow => (true, false),
372 crate::policies::PolicyDecision::Deny => (false, false),
373 crate::policies::PolicyDecision::NeedsConfirmation { .. } => (false, true),
374 };
375
376 if is_allowed {
377 return Ok(true);
378 }
379
380 if needs_confirm && let Some(ref handler) = state.policy_handler {
381 let handler = Arc::clone(handler);
382 drop(map);
384 let args_val: serde_json::Value =
385 serde_json::from_str(args_json).map_err(|e| crate::error::Error::BackendError {
386 message: format!("Failed to parse policy args JSON: {e}"),
387 })?;
388 return Ok(handler.confirm(name, &args_val));
389 }
390
391 Ok(false)
392}
393
394#[pyfunction]
400fn dispatch_rust_tool<'py>(
401 py: Python<'py>,
402 agent_id: u64,
403 name: String,
404 args_json: &str,
405) -> PyResult<Bound<'py, PyAny>> {
406 tracing::info!(agent_id, tool = %name, "dispatch_rust_tool called from Python (async)");
407
408 let is_allowed = check_tool_execution_allowed(agent_id, &name, args_json)
410 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
411
412 if !is_allowed {
413 return Err(pyo3::exceptions::PyPermissionError::new_err(format!(
414 "Tool '{name}' execution blocked by agent policy rules"
415 )));
416 }
417
418 let (registry, tool_state) = {
419 let map = bridge_state().read().map_err(|e| {
420 pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
421 })?;
422 let entry = map.get(&agent_id).ok_or_else(|| {
423 pyo3::exceptions::PyRuntimeError::new_err(format!(
424 "No active bridge state found for agent ID {agent_id}"
425 ))
426 })?;
427 let registry = entry.registry.as_ref().ok_or_else(|| {
428 pyo3::exceptions::PyRuntimeError::new_err(format!(
429 "No active ToolRegistry found for agent ID {agent_id}"
430 ))
431 })?;
432 (Arc::clone(registry), Arc::clone(&entry.tool_state))
433 };
434
435 let args: serde_json::Value = serde_json::from_str(args_json).map_err(|e| {
436 pyo3::exceptions::PyValueError::new_err(format!("Failed to parse tool arguments JSON: {e}"))
437 })?;
438
439 pyo3_async_runtimes::tokio::future_into_py(py, async move {
440 let ctx = crate::tools::ToolContext::with_shared_state(None, tool_state);
441 let output = registry
442 .dispatch(&name, args, &ctx)
443 .await
444 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
445 Ok(output.into_content())
447 })
448}
449
450pub(crate) enum PyCommand {
455 CreateAgent {
457 config_json: String,
458 reply: oneshot::Sender<Result<AgentId, Error>>,
459 },
460 Chat {
462 agent_id: AgentId,
463 prompt: String,
464 reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
465 },
466 ShutdownAgent {
468 agent_id: AgentId,
469 reply: oneshot::Sender<Result<(), Error>>,
470 },
471 Cancel {
473 agent_id: AgentId,
474 reply: oneshot::Sender<Result<(), Error>>,
475 },
476 WaitForIdle {
478 agent_id: AgentId,
479 reply: oneshot::Sender<Result<(), Error>>,
480 },
481 Send {
483 agent_id: AgentId,
484 prompt: String,
485 reply: oneshot::Sender<Result<(), Error>>,
486 },
487 SignalIdle {
489 agent_id: AgentId,
490 reply: oneshot::Sender<Result<(), Error>>,
491 },
492 WaitForWakeup {
494 agent_id: AgentId,
495 timeout_secs: f64,
496 reply: oneshot::Sender<Result<bool, Error>>,
497 },
498 Shutdown,
500 GetHistory {
502 agent_id: AgentId,
503 reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
504 },
505 GetTurnCount {
507 agent_id: AgentId,
508 reply: oneshot::Sender<Result<u32, Error>>,
509 },
510 GetTotalUsage {
512 agent_id: AgentId,
513 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
514 },
515 GetLastTurnUsage {
517 agent_id: AgentId,
518 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
519 },
520 ClearHistory {
522 agent_id: AgentId,
523 reply: oneshot::Sender<Result<(), Error>>,
524 },
525 GetCompactionIndices {
527 agent_id: AgentId,
528 reply: oneshot::Sender<Result<Vec<u32>, Error>>,
529 },
530 GetLastResponse {
532 agent_id: AgentId,
533 reply: oneshot::Sender<Result<Option<String>, Error>>,
534 },
535 Delete {
540 agent_id: AgentId,
541 reply: oneshot::Sender<Result<(), Error>>,
542 },
543 Disconnect {
547 agent_id: AgentId,
548 reply: oneshot::Sender<Result<(), Error>>,
549 },
550 IsIdle {
554 agent_id: AgentId,
555 reply: oneshot::Sender<Result<bool, Error>>,
556 },
557}
558
559#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
561#[serde(default)]
562pub struct RuntimeConfig {
563 pub channel_capacity: usize,
565 pub operation_timeout: Duration,
567 pub shutdown_timeout: Duration,
569 pub chat_timeout: Duration,
573 pub inter_agent_delay: Duration,
575}
576
577impl Default for RuntimeConfig {
578 fn default() -> Self {
579 let chat_timeout = default_chat_timeout();
580 Self {
581 channel_capacity: DEFAULT_CHANNEL_CAPACITY,
582 operation_timeout: default_operation_timeout(chat_timeout),
583 shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
584 chat_timeout,
585 inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
586 }
587 }
588}
589
590pub struct PythonRuntime {
595 cmd_tx: mpsc::Sender<PyCommand>,
596 thread: Option<std::thread::JoinHandle<()>>,
597 config: RuntimeConfig,
598 quota_registry: crate::quota::QuotaRegistry,
601 quota_state: Arc<QuotaState>,
603}
604
605impl std::fmt::Debug for PythonRuntime {
606 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
607 f.debug_struct("PythonRuntime")
608 .field("config", &self.config)
609 .field(
610 "thread_running",
611 &self.thread.as_ref().is_some_and(|t| !t.is_finished()),
612 )
613 .finish_non_exhaustive()
614 }
615}
616
617impl PythonRuntime {
618 pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
628 let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
629
630 let thread_config = config.clone();
631 let thread = std::thread::Builder::new()
632 .name("agy-bridge-python-runtime".into())
633 .spawn(move || {
634 python_thread_main(cmd_rx, &thread_config);
635 })
636 .map_err(|e| Error::BackendError {
637 message: format!("Failed to spawn Python runtime thread: {e}"),
638 })?;
639
640 let quota_registry = crate::quota::QuotaRegistry::new();
641 let quota_state = quota_registry.state_for_key("");
642 Ok(Self {
643 cmd_tx,
644 thread: Some(thread),
645 config,
646 quota_registry,
647 quota_state,
648 })
649 }
650
651 async fn send_command<T>(
661 &self,
662 operation: &str,
663 is_llm_op: bool,
664 build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
665 ) -> Result<T, Error> {
666 self.quota_state.wait_for_quota().await;
667
668 let (reply_tx, reply_rx) = oneshot::channel();
669 let cmd = build_cmd(reply_tx);
670
671 self.cmd_tx
672 .send(cmd)
673 .await
674 .map_err(|e| Error::ChannelClosed {
675 message: format!("Python runtime thread has exited (sending {operation}): {e}"),
676 })?;
677
678 let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
679 reply_rx.await.map_err(|e| Error::ChannelClosed {
680 message: format!("Reply channel dropped for {operation}: {e}"),
681 })?
682 })
683 .await?;
684
685 if is_llm_op {
688 self.quota_state.record_success();
689 }
690
691 Ok(result)
692 }
693
694 pub async fn shutdown(mut self) -> Result<(), Error> {
702 if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
706 tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
707 }
708
709 let Some(thread) = self.thread.take() else {
712 tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
713 return Ok(());
714 };
715
716 let shutdown_timeout = self.config.shutdown_timeout;
717 let join_result = tokio::time::timeout(
718 shutdown_timeout,
719 tokio::task::spawn_blocking(move || thread.join()),
720 )
721 .await;
722
723 match join_result {
724 Ok(Ok(Ok(()))) => {
725 tracing::info!("Python runtime thread joined successfully");
726 Ok(())
727 }
728 Ok(Ok(Err(panic_payload))) => {
729 let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
730 || {
731 panic_payload
732 .downcast_ref::<String>()
733 .map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
734 },
735 |s| (*s).to_string(),
736 );
737 tracing::error!(
738 panic_message = %panic_msg,
739 "Python runtime thread panicked during shutdown"
740 );
741 Err(Error::BackendError {
742 message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
743 })
744 }
745 Ok(Err(join_err)) => {
746 tracing::error!("spawn_blocking join error: {join_err}");
747 Err(Error::BackendError {
748 message: format!("Failed to join Python thread: {join_err}"),
749 })
750 }
751 Err(_elapsed) => {
752 tracing::error!(
753 timeout_secs = shutdown_timeout.as_secs(),
754 "Python runtime thread did not exit within shutdown timeout"
755 );
756 Err(Error::Timeout {
757 duration: shutdown_timeout,
758 operation: "PythonRuntime::shutdown (thread join)".to_string(),
759 })
760 }
761 }
762 }
763
764 #[must_use]
766 pub const fn quota_state(&self) -> &Arc<QuotaState> {
767 &self.quota_state
768 }
769}
770
771impl Drop for PythonRuntime {
772 fn drop(&mut self) {
773 if self.thread.is_some() {
774 tracing::warn!(
775 "PythonRuntime dropped without calling shutdown() — \
776 Python thread may still be running"
777 );
778 }
779 }
780}
781
782fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
784 pyo3::prepare_freethreaded_python();
785
786 Python::with_gil(|py| {
791 if let Err(e) = venv::configure_python_sys_path(py) {
792 tracing::warn!("Failed to configure Python sys.path in runtime thread: {e}");
793 }
794 });
795
796 if let Err(e) = run_live_thread(cmd_rx, config) {
797 tracing::error!(error = %e, "Python runtime thread failed");
798 }
799
800 tracing::info!("Python runtime thread exiting");
801}
802
803fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
806 Python::with_gil(|py| {
807 let asyncio = py
808 .import_bound("asyncio")
809 .map_err(|e| Error::BackendError {
810 message: format!("Failed to import asyncio: {e}"),
811 })?;
812 let event_loop =
813 asyncio
814 .call_method0("new_event_loop")
815 .map_err(|e| Error::BackendError {
816 message: format!("Failed to create new asyncio event loop: {e}"),
817 })?;
818 asyncio
819 .call_method1("set_event_loop", (&event_loop,))
820 .map_err(|e| Error::BackendError {
821 message: format!("Failed to set asyncio event loop: {e}"),
822 })?;
823
824 let sys = py.import_bound("sys").map_err(|e| Error::BackendError {
826 message: format!("Failed to import sys: {e}"),
827 })?;
828 let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
829 message: format!("Failed to get sys.modules: {e}"),
830 })?;
831 let globals_mod = if sys_modules
832 .contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
833 .map_err(|e| Error::BackendError {
834 message: format!("Failed to check sys.modules: {e}"),
835 })? {
836 sys_modules
837 .get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
838 .map_err(|e| Error::BackendError {
839 message: format!("Failed to get _agy_bridge_globals: {e}"),
840 })?
841 } else {
842 let types = py.import_bound("types").map_err(|e| Error::BackendError {
843 message: format!("Failed to import types: {e}"),
844 })?;
845 let module = types
846 .getattr("ModuleType")
847 .map_err(|e| Error::BackendError {
848 message: format!("Failed to get ModuleType: {e}"),
849 })?
850 .call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
851 .map_err(|e| Error::BackendError {
852 message: format!("Failed to create ModuleType: {e}"),
853 })?;
854 sys_modules
855 .set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
856 .map_err(|e| Error::BackendError {
857 message: format!("Failed to register _agy_bridge_globals: {e}"),
858 })?;
859 module
860 };
861 globals_mod
862 .setattr("EVENT_LOOP", &event_loop)
863 .map_err(|e| Error::BackendError {
864 message: format!("Failed to set EVENT_LOOP in globals: {e}"),
865 })?;
866
867 tracing::info!("Python asyncio event loop created on runtime thread");
868
869 let chat_timeout = config.chat_timeout;
870 let inter_agent_delay = config.inter_agent_delay;
871 let run_fut =
872 pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
873 command_loop::run_async_command_loop(cmd_rx, chat_timeout, inter_agent_delay).await
874 });
875
876 if let Err(e) = run_fut {
877 if let Err(close_err) = event_loop.call_method0("close") {
879 tracing::warn!("Failed to close asyncio event loop: {close_err}");
880 }
881 return Err(Error::BackendError {
882 message: format!("Python runtime command loop failed: {e}"),
883 });
884 }
885
886 if let Err(e) = event_loop.call_method0("close") {
887 tracing::warn!("Failed to close asyncio event loop: {e}");
888 }
889
890 Ok(())
891 })
892}
893
894impl crate::agent::Runtime for PythonRuntime {
895 async fn create_agent(
896 &self,
897 config: crate::config::AgentConfig,
898 ) -> Result<crate::agent::AgentId, Error> {
899 let mut all_tools = config.custom_tool_names();
901 if let Some(ref caps) = config.capabilities {
902 if let Some(ref builtins) = caps.enabled_tools {
903 all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
904 } else if caps.disabled_tools.is_none() {
905 all_tools.extend(
907 crate::config::capabilities::BuiltinTools::all_tools()
908 .iter()
909 .map(|b| b.as_sdk_name().to_string()),
910 );
911 }
912 } else {
913 all_tools.extend(
914 crate::config::capabilities::BuiltinTools::all_tools()
915 .iter()
916 .map(|b| b.as_sdk_name().to_string()),
917 );
918 }
919 tracing::info!(
920 "Agent starting with {} available tools: {:?}",
921 all_tools.len(),
922 all_tools
923 );
924
925 let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
926 message: format!("Failed to serialize AgentConfig: {e}"),
927 })?;
928
929 let raw_id = self
930 .send_command("create_agent", false, |reply| PyCommand::CreateAgent {
931 config_json,
932 reply,
933 })
934 .await?;
935
936 Ok(raw_id.0)
937 }
938
939 async fn chat(
940 &self,
941 agent_id: crate::agent::AgentId,
942 content: &crate::content::Content,
943 ) -> Result<crate::streaming::ChatResponseHandle, Error> {
944 let prompt = match content {
945 crate::content::Content::Text { text } => text.clone(),
946 other => crate::content::content_to_json(other)?,
947 };
948 self.send_command("chat", true, |reply| PyCommand::Chat {
949 agent_id: AgentId(agent_id),
950 prompt,
951 reply,
952 })
953 .await
954 }
955
956 async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
957 self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
958 agent_id: AgentId(agent_id),
959 reply,
960 })
961 .await
962 }
963
964 fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
965 let (reply, _) = oneshot::channel();
969 if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
970 agent_id: AgentId(agent_id),
971 reply,
972 }) {
973 tracing::debug!(
974 agent_id = agent_id,
975 error = %e,
976 "try_shutdown_agent: channel send failed (runtime may already be gone)"
977 );
978 }
979 }
980
981 async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
982 self.send_command("cancel", false, |reply| PyCommand::Cancel {
983 agent_id: AgentId(agent_id),
984 reply,
985 })
986 .await
987 }
988
989 async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
990 self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
991 agent_id: AgentId(agent_id),
992 reply,
993 })
994 .await
995 }
996
997 async fn send(
998 &self,
999 agent_id: crate::agent::AgentId,
1000 content: &crate::content::Content,
1001 ) -> Result<(), Error> {
1002 let prompt = match content {
1003 crate::content::Content::Text { text } => text.clone(),
1004 other => crate::content::content_to_json(other)?,
1005 };
1006 self.send_command("send", false, |reply| PyCommand::Send {
1007 agent_id: AgentId(agent_id),
1008 prompt,
1009 reply,
1010 })
1011 .await
1012 }
1013
1014 async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1015 self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
1016 agent_id: AgentId(agent_id),
1017 reply,
1018 })
1019 .await
1020 }
1021
1022 async fn wait_for_wakeup(
1023 &self,
1024 agent_id: crate::agent::AgentId,
1025 timeout: std::time::Duration,
1026 ) -> Result<bool, Error> {
1027 self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
1028 agent_id: AgentId(agent_id),
1029 timeout_secs: timeout.as_secs_f64(),
1030 reply,
1031 })
1032 .await
1033 }
1034
1035 async fn wait_for_quota(&self) {
1036 self.quota_state.wait_for_quota().await;
1037 }
1038
1039 async fn record_quota_hit(&self, retry_after: std::time::Duration) {
1040 self.quota_state.record_quota_hit(retry_after);
1041 }
1042
1043 fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
1044 &self.quota_registry
1045 }
1046
1047 async fn history(
1048 &self,
1049 agent_id: crate::agent::AgentId,
1050 ) -> Result<Vec<crate::types::ConversationMessage>, Error> {
1051 self.send_command("get_history", false, |reply| PyCommand::GetHistory {
1052 agent_id: AgentId(agent_id),
1053 reply,
1054 })
1055 .await
1056 }
1057
1058 async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
1059 self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
1060 agent_id: AgentId(agent_id),
1061 reply,
1062 })
1063 .await
1064 }
1065
1066 async fn total_usage(
1067 &self,
1068 agent_id: crate::agent::AgentId,
1069 ) -> Result<crate::types::UsageMetadata, Error> {
1070 self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
1071 agent_id: AgentId(agent_id),
1072 reply,
1073 })
1074 .await
1075 }
1076
1077 async fn last_turn_usage(
1078 &self,
1079 agent_id: crate::agent::AgentId,
1080 ) -> Result<crate::types::UsageMetadata, Error> {
1081 self.send_command("get_last_turn_usage", false, |reply| {
1082 PyCommand::GetLastTurnUsage {
1083 agent_id: AgentId(agent_id),
1084 reply,
1085 }
1086 })
1087 .await
1088 }
1089
1090 async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1091 self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
1092 agent_id: AgentId(agent_id),
1093 reply,
1094 })
1095 .await
1096 }
1097
1098 async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
1099 self.send_command("compaction_indices", false, |reply| {
1100 PyCommand::GetCompactionIndices {
1101 agent_id: AgentId(agent_id),
1102 reply,
1103 }
1104 })
1105 .await
1106 }
1107
1108 async fn last_response(
1109 &self,
1110 agent_id: crate::agent::AgentId,
1111 ) -> Result<Option<String>, Error> {
1112 self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
1113 agent_id: AgentId(agent_id),
1114 reply,
1115 })
1116 .await
1117 }
1118
1119 async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1120 self.send_command("delete", false, |reply| PyCommand::Delete {
1121 agent_id: AgentId(agent_id),
1122 reply,
1123 })
1124 .await
1125 }
1126
1127 async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1128 self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
1129 agent_id: AgentId(agent_id),
1130 reply,
1131 })
1132 .await
1133 }
1134
1135 async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
1136 self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
1137 agent_id: AgentId(agent_id),
1138 reply,
1139 })
1140 .await
1141 }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146 use super::*;
1147
1148 fn test_config() -> RuntimeConfig {
1149 RuntimeConfig {
1150 channel_capacity: 16,
1151 operation_timeout: Duration::from_secs(10),
1152 shutdown_timeout: Duration::from_secs(5),
1153 chat_timeout: Duration::from_mins(1),
1154 inter_agent_delay: Duration::from_millis(100),
1155 }
1156 }
1157
1158 #[tokio::test]
1159 async fn test_runtime_creation_and_shutdown() {
1160 PythonRuntime::new(test_config())
1162 .expect("Failed to create runtime")
1163 .shutdown()
1164 .await
1165 .expect("Shutdown failed");
1166 }
1167
1168 #[test]
1169 fn runtime_config_serde_roundtrip() {
1170 let config = test_config();
1171 let json = serde_json::to_string(&config).unwrap();
1172 let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
1173 assert_eq!(parsed.channel_capacity, 16);
1174 assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
1175 assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
1176 assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
1177 assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
1178 }
1179
1180 #[test]
1181 fn default_operation_timeout_is_chat_plus_margin() {
1182 let config = RuntimeConfig::default();
1183 let expected = config.chat_timeout + Duration::from_mins(2);
1184 assert_eq!(
1185 config.operation_timeout, expected,
1186 "operation_timeout should be chat_timeout + 2min safety margin"
1187 );
1188 }
1189
1190 #[test]
1191 fn safety_error_structural() {
1192 pyo3::prepare_freethreaded_python();
1193 Python::with_gil(|py| {
1194 let globals = pyo3::types::PyDict::new_bound(py);
1195 py.run_bound(
1196 r#"
1197class StopCandidateException(Exception):
1198 pass
1199err = StopCandidateException("dummy")
1200"#,
1201 Some(&globals),
1202 None,
1203 )
1204 .unwrap();
1205
1206 let err_obj = globals.get_item("err").unwrap().unwrap();
1207 let err = PyErr::from_value_bound(err_obj);
1208
1209 let mapped = crate::error::classify_py_error(py, &err);
1210
1211 assert!(
1212 !matches!(mapped, crate::error::Error::Safety),
1213 "Failed: matched Error::Safety based purely on the string name StopCandidateException!"
1214 );
1215 });
1216 }
1217
1218 #[test]
1219 fn maxtokens_error_structural() {
1220 pyo3::prepare_freethreaded_python();
1221 Python::with_gil(|py| {
1222 let globals = pyo3::types::PyDict::new_bound(py);
1223 py.run_bound(
1224 r#"
1225class MaxTokensException(Exception):
1226 pass
1227err = MaxTokensException("dummy")
1228"#,
1229 Some(&globals),
1230 None,
1231 )
1232 .unwrap();
1233
1234 let err_obj = globals.get_item("err").unwrap().unwrap();
1235 let err = PyErr::from_value_bound(err_obj);
1236
1237 let mapped = crate::error::classify_py_error(py, &err);
1238
1239 assert!(
1240 !matches!(mapped, crate::error::Error::MaxTokens),
1241 "Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
1242 );
1243 });
1244 }
1245
1246 struct MockAskUserHandler {
1247 should_allow: std::sync::atomic::AtomicBool,
1248 }
1249
1250 impl crate::policies::AskUserHandler for MockAskUserHandler {
1251 fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
1252 self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
1253 }
1254 }
1255
1256 #[test]
1257 fn test_ask_user_policy_custom_tool_gating() {
1258 let agent_id: u64 = 999;
1259
1260 let mut policies = crate::policies::PolicySet::new();
1262 policies
1263 .push(crate::policies::PolicyRule::AskUser {
1264 tool: "dangerous_tool".to_owned(),
1265 handler_id: "confirm_handler".to_owned(),
1266 })
1267 .unwrap();
1268
1269 let handler = Arc::new(MockAskUserHandler {
1271 should_allow: std::sync::atomic::AtomicBool::new(true),
1272 });
1273
1274 let mut registry = crate::tools::ToolRegistry::new();
1276
1277 #[crate::llm_tool]
1279 fn dangerous_tool() -> Result<String, String> {
1280 Ok("Executed dangerous action!".to_owned())
1281 }
1282 registry.register(DangerousTool);
1283
1284 bridge_state().write().unwrap().insert(
1286 agent_id,
1287 AgentBridgeState {
1288 registry: Some(Arc::new(registry)),
1289 hook_runner: None,
1290 policies,
1291 policy_handler: Some(
1292 Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
1293 ),
1294 tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
1295 },
1296 );
1297
1298 handler
1300 .should_allow
1301 .store(true, std::sync::atomic::Ordering::SeqCst);
1302 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
1303 assert!(res.is_ok(), "Check should succeed");
1304 assert!(
1305 res.unwrap(),
1306 "Should allow tool execution when handler returns true"
1307 );
1308
1309 handler
1311 .should_allow
1312 .store(false, std::sync::atomic::Ordering::SeqCst);
1313 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
1314 assert!(res.is_ok(), "Check should succeed");
1315 assert!(
1316 !res.unwrap(),
1317 "Should block tool execution when handler returns false"
1318 );
1319
1320 bridge_state().write().unwrap().remove(&agent_id);
1322 }
1323}