1use std::{sync::Arc, time::Duration};
40
41use pyo3::prelude::*;
42use tokio::sync::{mpsc, oneshot};
43
44use crate::{error::Error, quota::QuotaState};
45
46pub(crate) mod bridge_state;
47pub(crate) mod command_loop;
48pub(crate) mod ffi_dispatch;
49mod handlers;
50pub(crate) mod py_scripts;
51pub(crate) mod streaming;
52pub(crate) mod venv;
53
54pub(crate) use bridge_state::{AgentBridgeState, AgentId, bridge_state};
56pub(crate) use ffi_dispatch::{
57 CREATE_AGENT_HOOK_GUARD, INITIALIZING_HOOK_RUNNER, dispatch_rust_hook,
58 dispatch_rust_policy_confirm, dispatch_rust_tool,
59};
60
61#[must_use]
71pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
72 chat_timeout + Duration::from_mins(2)
73}
74pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
77
78pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
80
81const DEFAULT_CHANNEL_CAPACITY: usize = 64;
83
84const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
86
87#[must_use]
90pub fn default_chat_timeout() -> Duration {
91 let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
92 val.parse::<u64>().unwrap_or_else(|e| {
93 tracing::warn!(
94 value = %val,
95 error = %e,
96 "Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
97 );
98 DEFAULT_CHAT_TIMEOUT_SECS
99 })
100 });
101 Duration::from_secs(secs)
102}
103
104pub(crate) enum PyCommand {
109 CreateAgent {
111 config_json: String,
112 reply: oneshot::Sender<Result<AgentId, Error>>,
113 },
114 Chat {
116 agent_id: AgentId,
117 prompt: String,
118 reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
119 },
120 ShutdownAgent {
122 agent_id: AgentId,
123 reply: oneshot::Sender<Result<(), Error>>,
124 },
125 Cancel {
127 agent_id: AgentId,
128 reply: oneshot::Sender<Result<(), Error>>,
129 },
130 WaitForIdle {
132 agent_id: AgentId,
133 reply: oneshot::Sender<Result<(), Error>>,
134 },
135 Send {
137 agent_id: AgentId,
138 prompt: String,
139 reply: oneshot::Sender<Result<(), Error>>,
140 },
141 SignalIdle {
143 agent_id: AgentId,
144 reply: oneshot::Sender<Result<(), Error>>,
145 },
146 WaitForWakeup {
148 agent_id: AgentId,
149 timeout_secs: f64,
150 reply: oneshot::Sender<Result<bool, Error>>,
151 },
152 Shutdown,
154 GetHistory {
156 agent_id: AgentId,
157 reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
158 },
159 GetTurnCount {
161 agent_id: AgentId,
162 reply: oneshot::Sender<Result<u32, Error>>,
163 },
164 GetTotalUsage {
166 agent_id: AgentId,
167 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
168 },
169 GetLastTurnUsage {
171 agent_id: AgentId,
172 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
173 },
174 ClearHistory {
176 agent_id: AgentId,
177 reply: oneshot::Sender<Result<(), Error>>,
178 },
179 GetCompactionIndices {
181 agent_id: AgentId,
182 reply: oneshot::Sender<Result<Vec<u32>, Error>>,
183 },
184 GetLastResponse {
186 agent_id: AgentId,
187 reply: oneshot::Sender<Result<Option<String>, Error>>,
188 },
189 Delete {
194 agent_id: AgentId,
195 reply: oneshot::Sender<Result<(), Error>>,
196 },
197 Disconnect {
201 agent_id: AgentId,
202 reply: oneshot::Sender<Result<(), Error>>,
203 },
204 IsIdle {
208 agent_id: AgentId,
209 reply: oneshot::Sender<Result<bool, Error>>,
210 },
211}
212
213#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
215#[serde(default)]
216pub struct RuntimeConfig {
217 pub channel_capacity: usize,
219 pub operation_timeout: Duration,
221 pub shutdown_timeout: Duration,
223 pub chat_timeout: Duration,
227 pub inter_agent_delay: Duration,
229}
230
231impl Default for RuntimeConfig {
232 fn default() -> Self {
233 let chat_timeout = default_chat_timeout();
234 Self {
235 channel_capacity: DEFAULT_CHANNEL_CAPACITY,
236 operation_timeout: default_operation_timeout(chat_timeout),
237 shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
238 chat_timeout,
239 inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
240 }
241 }
242}
243
244pub struct PythonRuntime {
249 cmd_tx: mpsc::Sender<PyCommand>,
250 thread: Option<std::thread::JoinHandle<()>>,
251 config: RuntimeConfig,
252 quota_registry: crate::quota::QuotaRegistry,
255 quota_state: Arc<QuotaState>,
257}
258
259impl std::fmt::Debug for PythonRuntime {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("PythonRuntime")
262 .field("config", &self.config)
263 .field(
264 "thread_running",
265 &self.thread.as_ref().is_some_and(|t| !t.is_finished()),
266 )
267 .finish_non_exhaustive()
268 }
269}
270
271impl PythonRuntime {
272 pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
282 let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
283
284 let thread_config = config.clone();
285 let thread = std::thread::Builder::new()
286 .name("agy-bridge-python-runtime".into())
287 .spawn(move || {
288 python_thread_main(cmd_rx, &thread_config);
289 })
290 .map_err(|e| Error::BackendError {
291 message: format!("Failed to spawn Python runtime thread: {e}"),
292 })?;
293
294 let quota_registry = crate::quota::QuotaRegistry::new();
295 let quota_state = quota_registry.state_for_key("");
296 Ok(Self {
297 cmd_tx,
298 thread: Some(thread),
299 config,
300 quota_registry,
301 quota_state,
302 })
303 }
304
305 async fn send_command<T>(
315 &self,
316 operation: &str,
317 is_llm_op: bool,
318 build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
319 ) -> Result<T, Error> {
320 let (reply_tx, reply_rx) = oneshot::channel();
321 let cmd = build_cmd(reply_tx);
322
323 self.cmd_tx
324 .send(cmd)
325 .await
326 .map_err(|e| Error::ChannelClosed {
327 message: format!("Python runtime thread has exited (sending {operation}): {e}"),
328 })?;
329
330 let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
331 reply_rx.await.map_err(|e| Error::ChannelClosed {
332 message: format!("Reply channel dropped for {operation}: {e}"),
333 })?
334 })
335 .await?;
336
337 if is_llm_op {
340 self.quota_state.record_success();
341 }
342
343 Ok(result)
344 }
345
346 pub async fn shutdown(mut self) -> Result<(), Error> {
354 if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
358 tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
359 }
360
361 let Some(thread) = self.thread.take() else {
364 tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
365 return Ok(());
366 };
367
368 let shutdown_timeout = self.config.shutdown_timeout;
369 let join_result = tokio::time::timeout(
370 shutdown_timeout,
371 tokio::task::spawn_blocking(move || thread.join()),
372 )
373 .await;
374
375 match join_result {
376 Ok(Ok(Ok(()))) => {
377 tracing::info!("Python runtime thread joined successfully");
378 Ok(())
379 }
380 Ok(Ok(Err(panic_payload))) => {
381 let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
382 || {
383 panic_payload
384 .downcast_ref::<String>()
385 .map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
386 },
387 |s| (*s).to_string(),
388 );
389 tracing::error!(
390 panic_message = %panic_msg,
391 "Python runtime thread panicked during shutdown"
392 );
393 Err(Error::BackendError {
394 message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
395 })
396 }
397 Ok(Err(join_err)) => {
398 tracing::error!("spawn_blocking join error: {join_err}");
399 Err(Error::BackendError {
400 message: format!("Failed to join Python thread: {join_err}"),
401 })
402 }
403 Err(_elapsed) => {
404 tracing::error!(
405 timeout_secs = shutdown_timeout.as_secs(),
406 "Python runtime thread did not exit within shutdown timeout"
407 );
408 Err(Error::Timeout {
409 duration: shutdown_timeout,
410 operation: "PythonRuntime::shutdown (thread join)".to_string(),
411 })
412 }
413 }
414 }
415
416 #[must_use]
418 pub const fn quota_state(&self) -> &Arc<QuotaState> {
419 &self.quota_state
420 }
421}
422
423impl Drop for PythonRuntime {
424 fn drop(&mut self) {
425 if self.thread.is_some() {
426 tracing::warn!(
427 "PythonRuntime dropped without calling shutdown() — \
428 Python thread may still be running"
429 );
430 }
431 }
432}
433
434fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
436 pyo3::prepare_freethreaded_python();
437
438 Python::with_gil(|py| {
443 if let Err(e) = venv::configure_python_sys_path(py) {
444 tracing::warn!("Failed to configure Python sys.path in runtime thread: {e}");
445 }
446 });
447
448 if let Err(e) = run_live_thread(cmd_rx, config) {
449 tracing::error!(error = %e, "Python runtime thread failed");
450 }
451
452 tracing::info!("Python runtime thread exiting");
453}
454
455fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
458 Python::with_gil(|py| {
459 let asyncio = py
460 .import_bound("asyncio")
461 .map_err(|e| Error::BackendError {
462 message: format!("Failed to import asyncio: {e}"),
463 })?;
464 let event_loop =
465 asyncio
466 .call_method0("new_event_loop")
467 .map_err(|e| Error::BackendError {
468 message: format!("Failed to create new asyncio event loop: {e}"),
469 })?;
470 asyncio
471 .call_method1("set_event_loop", (&event_loop,))
472 .map_err(|e| Error::BackendError {
473 message: format!("Failed to set asyncio event loop: {e}"),
474 })?;
475
476 let sys = py.import_bound("sys").map_err(|e| Error::BackendError {
478 message: format!("Failed to import sys: {e}"),
479 })?;
480 let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
481 message: format!("Failed to get sys.modules: {e}"),
482 })?;
483 let globals_mod = if sys_modules
484 .contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
485 .map_err(|e| Error::BackendError {
486 message: format!("Failed to check sys.modules: {e}"),
487 })? {
488 sys_modules
489 .get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
490 .map_err(|e| Error::BackendError {
491 message: format!("Failed to get _agy_bridge_globals: {e}"),
492 })?
493 } else {
494 let types = py.import_bound("types").map_err(|e| Error::BackendError {
495 message: format!("Failed to import types: {e}"),
496 })?;
497 let module = types
498 .getattr("ModuleType")
499 .map_err(|e| Error::BackendError {
500 message: format!("Failed to get ModuleType: {e}"),
501 })?
502 .call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
503 .map_err(|e| Error::BackendError {
504 message: format!("Failed to create ModuleType: {e}"),
505 })?;
506 sys_modules
507 .set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
508 .map_err(|e| Error::BackendError {
509 message: format!("Failed to register _agy_bridge_globals: {e}"),
510 })?;
511 module
512 };
513 globals_mod
514 .setattr("EVENT_LOOP", &event_loop)
515 .map_err(|e| Error::BackendError {
516 message: format!("Failed to set EVENT_LOOP in globals: {e}"),
517 })?;
518
519 tracing::info!("Python asyncio event loop created on runtime thread");
520
521 let chat_timeout = config.chat_timeout;
522 let inter_agent_delay = config.inter_agent_delay;
523 let run_fut =
524 pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
525 command_loop::run_async_command_loop(cmd_rx, chat_timeout, inter_agent_delay).await
526 });
527
528 if let Err(e) = run_fut {
529 if let Err(close_err) = event_loop.call_method0("close") {
531 tracing::warn!("Failed to close asyncio event loop: {close_err}");
532 }
533 return Err(Error::BackendError {
534 message: format!("Python runtime command loop failed: {e}"),
535 });
536 }
537
538 if let Err(e) = event_loop.call_method0("close") {
539 tracing::warn!("Failed to close asyncio event loop: {e}");
540 }
541
542 Ok(())
543 })
544}
545
546impl crate::agent::Runtime for PythonRuntime {
547 async fn create_agent(
548 &self,
549 config: crate::config::AgentConfig,
550 ) -> Result<crate::agent::AgentId, Error> {
551 let mut all_tools = config.custom_tool_names();
553 if let Some(ref caps) = config.capabilities {
554 if let Some(ref builtins) = caps.enabled_tools {
555 all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
556 } else if caps.disabled_tools.is_none() {
557 all_tools.extend(
559 crate::config::capabilities::BuiltinTools::all_tools()
560 .iter()
561 .map(|b| b.as_sdk_name().to_string()),
562 );
563 }
564 } else {
565 all_tools.extend(
566 crate::config::capabilities::BuiltinTools::all_tools()
567 .iter()
568 .map(|b| b.as_sdk_name().to_string()),
569 );
570 }
571 tracing::info!(
572 "Agent starting with {} available tools: {:?}",
573 all_tools.len(),
574 all_tools
575 );
576
577 let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
578 message: format!("Failed to serialize AgentConfig: {e}"),
579 })?;
580
581 let raw_id = self
582 .send_command("create_agent", false, |reply| PyCommand::CreateAgent {
583 config_json,
584 reply,
585 })
586 .await?;
587
588 Ok(raw_id.0)
589 }
590
591 async fn chat(
592 &self,
593 agent_id: crate::agent::AgentId,
594 content: &crate::content::Content,
595 ) -> Result<crate::streaming::ChatResponseHandle, Error> {
596 let prompt = match content {
597 crate::content::Content::Text { text } => text.clone(),
598 other => crate::content::content_to_json(other)?,
599 };
600 self.send_command("chat", true, |reply| PyCommand::Chat {
601 agent_id: AgentId(agent_id),
602 prompt,
603 reply,
604 })
605 .await
606 }
607
608 async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
609 self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
610 agent_id: AgentId(agent_id),
611 reply,
612 })
613 .await
614 }
615
616 fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
617 let (reply, _) = oneshot::channel();
621 if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
622 agent_id: AgentId(agent_id),
623 reply,
624 }) {
625 tracing::debug!(
626 agent_id = agent_id,
627 error = %e,
628 "try_shutdown_agent: channel send failed (runtime may already be gone)"
629 );
630 }
631 }
632
633 async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
634 self.send_command("cancel", false, |reply| PyCommand::Cancel {
635 agent_id: AgentId(agent_id),
636 reply,
637 })
638 .await
639 }
640
641 async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
642 self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
643 agent_id: AgentId(agent_id),
644 reply,
645 })
646 .await
647 }
648
649 async fn send(
650 &self,
651 agent_id: crate::agent::AgentId,
652 content: &crate::content::Content,
653 ) -> Result<(), Error> {
654 let prompt = match content {
655 crate::content::Content::Text { text } => text.clone(),
656 other => crate::content::content_to_json(other)?,
657 };
658 self.send_command("send", false, |reply| PyCommand::Send {
659 agent_id: AgentId(agent_id),
660 prompt,
661 reply,
662 })
663 .await
664 }
665
666 async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
667 self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
668 agent_id: AgentId(agent_id),
669 reply,
670 })
671 .await
672 }
673
674 async fn wait_for_wakeup(
675 &self,
676 agent_id: crate::agent::AgentId,
677 timeout: std::time::Duration,
678 ) -> Result<bool, Error> {
679 self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
680 agent_id: AgentId(agent_id),
681 timeout_secs: timeout.as_secs_f64(),
682 reply,
683 })
684 .await
685 }
686
687 async fn wait_for_quota(&self) {
688 self.quota_state.wait_for_quota().await;
689 }
690
691 async fn record_quota_hit(&self, retry_after: std::time::Duration) {
692 self.quota_state.record_quota_hit(retry_after);
693 }
694
695 fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
696 &self.quota_registry
697 }
698
699 async fn history(
700 &self,
701 agent_id: crate::agent::AgentId,
702 ) -> Result<Vec<crate::types::ConversationMessage>, Error> {
703 self.send_command("get_history", false, |reply| PyCommand::GetHistory {
704 agent_id: AgentId(agent_id),
705 reply,
706 })
707 .await
708 }
709
710 async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
711 self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
712 agent_id: AgentId(agent_id),
713 reply,
714 })
715 .await
716 }
717
718 async fn total_usage(
719 &self,
720 agent_id: crate::agent::AgentId,
721 ) -> Result<crate::types::UsageMetadata, Error> {
722 self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
723 agent_id: AgentId(agent_id),
724 reply,
725 })
726 .await
727 }
728
729 async fn last_turn_usage(
730 &self,
731 agent_id: crate::agent::AgentId,
732 ) -> Result<crate::types::UsageMetadata, Error> {
733 self.send_command("get_last_turn_usage", false, |reply| {
734 PyCommand::GetLastTurnUsage {
735 agent_id: AgentId(agent_id),
736 reply,
737 }
738 })
739 .await
740 }
741
742 async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
743 self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
744 agent_id: AgentId(agent_id),
745 reply,
746 })
747 .await
748 }
749
750 async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
751 self.send_command("compaction_indices", false, |reply| {
752 PyCommand::GetCompactionIndices {
753 agent_id: AgentId(agent_id),
754 reply,
755 }
756 })
757 .await
758 }
759
760 async fn last_response(
761 &self,
762 agent_id: crate::agent::AgentId,
763 ) -> Result<Option<String>, Error> {
764 self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
765 agent_id: AgentId(agent_id),
766 reply,
767 })
768 .await
769 }
770
771 async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
772 self.send_command("delete", false, |reply| PyCommand::Delete {
773 agent_id: AgentId(agent_id),
774 reply,
775 })
776 .await
777 }
778
779 async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
780 self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
781 agent_id: AgentId(agent_id),
782 reply,
783 })
784 .await
785 }
786
787 async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
788 self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
789 agent_id: AgentId(agent_id),
790 reply,
791 })
792 .await
793 }
794}
795
796#[cfg(test)]
797mod tests {
798 use std::collections::HashMap;
799
800 use super::{ffi_dispatch::check_tool_execution_allowed, *};
801
802 fn test_config() -> RuntimeConfig {
803 RuntimeConfig {
804 channel_capacity: 16,
805 operation_timeout: Duration::from_secs(10),
806 shutdown_timeout: Duration::from_secs(5),
807 chat_timeout: Duration::from_mins(1),
808 inter_agent_delay: Duration::from_millis(100),
809 }
810 }
811
812 #[tokio::test]
813 async fn test_runtime_creation_and_shutdown() {
814 PythonRuntime::new(test_config())
816 .expect("Failed to create runtime")
817 .shutdown()
818 .await
819 .expect("Shutdown failed");
820 }
821
822 #[test]
823 fn runtime_config_serde_roundtrip() {
824 let config = test_config();
825 let json = serde_json::to_string(&config).unwrap();
826 let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
827 assert_eq!(parsed.channel_capacity, 16);
828 assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
829 assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
830 assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
831 assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
832 }
833
834 #[test]
835 fn default_operation_timeout_is_chat_plus_margin() {
836 let config = RuntimeConfig::default();
837 let expected = config.chat_timeout + Duration::from_mins(2);
838 assert_eq!(
839 config.operation_timeout, expected,
840 "operation_timeout should be chat_timeout + 2min safety margin"
841 );
842 }
843
844 #[test]
845 fn safety_error_structural() {
846 pyo3::prepare_freethreaded_python();
847 Python::with_gil(|py| {
848 let globals = pyo3::types::PyDict::new_bound(py);
849 py.run_bound(
850 r#"
851class StopCandidateException(Exception):
852 pass
853err = StopCandidateException("dummy")
854"#,
855 Some(&globals),
856 None,
857 )
858 .unwrap();
859
860 let err_obj = globals.get_item("err").unwrap().unwrap();
861 let err = PyErr::from_value_bound(err_obj);
862
863 let mapped = crate::error::classify_py_error(py, &err);
864
865 assert!(
866 !matches!(mapped, crate::error::Error::Safety),
867 "Failed: matched Error::Safety based purely on the string name StopCandidateException!"
868 );
869 });
870 }
871
872 #[test]
873 fn maxtokens_error_structural() {
874 pyo3::prepare_freethreaded_python();
875 Python::with_gil(|py| {
876 let globals = pyo3::types::PyDict::new_bound(py);
877 py.run_bound(
878 r#"
879class MaxTokensException(Exception):
880 pass
881err = MaxTokensException("dummy")
882"#,
883 Some(&globals),
884 None,
885 )
886 .unwrap();
887
888 let err_obj = globals.get_item("err").unwrap().unwrap();
889 let err = PyErr::from_value_bound(err_obj);
890
891 let mapped = crate::error::classify_py_error(py, &err);
892
893 assert!(
894 !matches!(mapped, crate::error::Error::MaxTokens),
895 "Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
896 );
897 });
898 }
899
900 struct MockAskUserHandler {
901 should_allow: std::sync::atomic::AtomicBool,
902 }
903
904 impl crate::policies::AskUserHandler for MockAskUserHandler {
905 fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
906 self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
907 }
908 }
909
910 #[test]
911 fn test_ask_user_policy_custom_tool_gating() {
912 let agent_id: u64 = 999;
913
914 let mut policies = crate::policies::PolicySet::new();
916 policies
917 .push(crate::policies::PolicyRule::AskUser {
918 tool: "dangerous_tool".to_owned(),
919 handler_id: "confirm_handler".to_owned(),
920 })
921 .unwrap();
922
923 let handler = Arc::new(MockAskUserHandler {
925 should_allow: std::sync::atomic::AtomicBool::new(true),
926 });
927
928 let mut registry = crate::tools::ToolRegistry::new();
930
931 #[crate::llm_tool]
933 fn dangerous_tool() -> Result<String, String> {
934 Ok("Executed dangerous action!".to_owned())
935 }
936 registry.register(DangerousTool);
937
938 bridge_state().write().unwrap().insert(
940 agent_id,
941 AgentBridgeState {
942 registry: Some(Arc::new(registry)),
943 hook_runner: None,
944 policies,
945 policy_handler: Some(
946 Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
947 ),
948 tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
949 },
950 );
951
952 handler
954 .should_allow
955 .store(true, std::sync::atomic::Ordering::SeqCst);
956 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
957 assert!(res.is_ok(), "Check should succeed");
958 assert!(
959 res.unwrap(),
960 "Should allow tool execution when handler returns true"
961 );
962
963 handler
965 .should_allow
966 .store(false, std::sync::atomic::Ordering::SeqCst);
967 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
968 assert!(res.is_ok(), "Check should succeed");
969 assert!(
970 !res.unwrap(),
971 "Should block tool execution when handler returns false"
972 );
973
974 bridge_state().write().unwrap().remove(&agent_id);
976 }
977}