1use std::{sync::Arc, time::Duration};
40
41use pyo3::prelude::*;
42use tokio::sync::{mpsc, oneshot};
43
44use crate::{error::Error, quota::QuotaState};
45
46pub(crate) mod bridge_state;
47pub(crate) mod command_loop;
48pub(crate) mod ffi_dispatch;
49mod handlers;
50pub(crate) mod py_scripts;
51pub(crate) mod streaming;
52pub(crate) mod venv;
53
54pub(crate) use bridge_state::{AgentBridgeState, AgentId, bridge_state};
56pub(crate) use ffi_dispatch::{
57 CREATE_AGENT_HOOK_GUARD, INITIALIZING_HOOK_RUNNER, PENDING_CONVERSATION_IDS,
58 dispatch_rust_hook, dispatch_rust_policy_confirm, dispatch_rust_tool,
59};
60
61#[must_use]
71pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
72 chat_timeout + Duration::from_mins(2)
73}
74pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
77
78pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
80
81const DEFAULT_CHANNEL_CAPACITY: usize = 64;
83
84const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
86
87#[must_use]
90pub fn default_chat_timeout() -> Duration {
91 let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
92 val.parse::<u64>().unwrap_or_else(|e| {
93 tracing::warn!(
94 value = %val,
95 error = %e,
96 "Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
97 );
98 DEFAULT_CHAT_TIMEOUT_SECS
99 })
100 });
101 Duration::from_secs(secs)
102}
103
104pub(crate) enum PyCommand {
109 CreateAgent {
111 config_json: String,
112 reply: oneshot::Sender<Result<AgentId, Error>>,
113 },
114 Chat {
116 agent_id: AgentId,
117 prompt: String,
118 reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
119 },
120 ShutdownAgent {
122 agent_id: AgentId,
123 reply: oneshot::Sender<Result<(), Error>>,
124 },
125 Cancel {
127 agent_id: AgentId,
128 reply: oneshot::Sender<Result<(), Error>>,
129 },
130 WaitForIdle {
132 agent_id: AgentId,
133 reply: oneshot::Sender<Result<(), Error>>,
134 },
135 Send {
137 agent_id: AgentId,
138 prompt: String,
139 reply: oneshot::Sender<Result<(), Error>>,
140 },
141 SignalIdle {
143 agent_id: AgentId,
144 reply: oneshot::Sender<Result<(), Error>>,
145 },
146 WaitForWakeup {
148 agent_id: AgentId,
149 timeout_secs: f64,
150 reply: oneshot::Sender<Result<bool, Error>>,
151 },
152 Shutdown,
154 GetHistory {
156 agent_id: AgentId,
157 reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
158 },
159 GetTurnCount {
161 agent_id: AgentId,
162 reply: oneshot::Sender<Result<u32, Error>>,
163 },
164 GetTotalUsage {
166 agent_id: AgentId,
167 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
168 },
169 GetLastTurnUsage {
171 agent_id: AgentId,
172 reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
173 },
174 ClearHistory {
176 agent_id: AgentId,
177 reply: oneshot::Sender<Result<(), Error>>,
178 },
179 GetCompactionIndices {
181 agent_id: AgentId,
182 reply: oneshot::Sender<Result<Vec<u32>, Error>>,
183 },
184 GetLastResponse {
186 agent_id: AgentId,
187 reply: oneshot::Sender<Result<Option<String>, Error>>,
188 },
189 Delete {
194 agent_id: AgentId,
195 reply: oneshot::Sender<Result<(), Error>>,
196 },
197 Disconnect {
201 agent_id: AgentId,
202 reply: oneshot::Sender<Result<(), Error>>,
203 },
204 IsIdle {
208 agent_id: AgentId,
209 reply: oneshot::Sender<Result<bool, Error>>,
210 },
211}
212
213#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
215#[serde(default)]
216pub struct RuntimeConfig {
217 pub channel_capacity: usize,
219 pub operation_timeout: Duration,
221 pub shutdown_timeout: Duration,
223 pub chat_timeout: Duration,
227 pub inter_agent_delay: Duration,
229}
230
231impl Default for RuntimeConfig {
232 fn default() -> Self {
233 let chat_timeout = default_chat_timeout();
234 Self {
235 channel_capacity: DEFAULT_CHANNEL_CAPACITY,
236 operation_timeout: default_operation_timeout(chat_timeout),
237 shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
238 chat_timeout,
239 inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
240 }
241 }
242}
243
244pub struct PythonRuntime {
249 cmd_tx: mpsc::Sender<PyCommand>,
250 thread: Option<std::thread::JoinHandle<()>>,
251 config: RuntimeConfig,
252 quota_registry: crate::quota::QuotaRegistry,
255 quota_state: Arc<QuotaState>,
257}
258
259impl std::fmt::Debug for PythonRuntime {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 f.debug_struct("PythonRuntime")
262 .field("config", &self.config)
263 .field(
264 "thread_running",
265 &self.thread.as_ref().is_some_and(|t| !t.is_finished()),
266 )
267 .finish_non_exhaustive()
268 }
269}
270
271impl PythonRuntime {
272 pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
282 let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
283
284 let thread_config = config.clone();
285 let thread = std::thread::Builder::new()
286 .name("agy-bridge-python-runtime".into())
287 .spawn(move || {
288 python_thread_main(cmd_rx, &thread_config);
289 })
290 .map_err(|e| Error::BackendError {
291 message: format!("Failed to spawn Python runtime thread: {e}"),
292 })?;
293
294 let quota_registry = crate::quota::QuotaRegistry::new();
295 let quota_state = quota_registry.state_for_key("");
296 Ok(Self {
297 cmd_tx,
298 thread: Some(thread),
299 config,
300 quota_registry,
301 quota_state,
302 })
303 }
304
305 async fn send_command<T>(
315 &self,
316 operation: &str,
317 is_llm_op: bool,
318 build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
319 ) -> Result<T, Error> {
320 let (reply_tx, reply_rx) = oneshot::channel();
321 let cmd = build_cmd(reply_tx);
322
323 self.cmd_tx
324 .send(cmd)
325 .await
326 .map_err(|e| Error::ChannelClosed {
327 message: format!("Python runtime thread has exited (sending {operation}): {e}"),
328 })?;
329
330 let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
331 reply_rx.await.map_err(|e| Error::ChannelClosed {
332 message: format!("Reply channel dropped for {operation}: {e}"),
333 })?
334 })
335 .await?;
336
337 if is_llm_op {
340 self.quota_state.record_success();
341 }
342
343 Ok(result)
344 }
345
346 pub async fn shutdown(mut self) -> Result<(), Error> {
354 if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
358 tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
359 }
360
361 let Some(thread) = self.thread.take() else {
364 tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
365 return Ok(());
366 };
367
368 let shutdown_timeout = self.config.shutdown_timeout;
369 let join_result = tokio::time::timeout(
370 shutdown_timeout,
371 tokio::task::spawn_blocking(move || thread.join()),
372 )
373 .await;
374
375 match join_result {
376 Ok(Ok(Ok(()))) => {
377 tracing::info!("Python runtime thread joined successfully");
378 Ok(())
379 }
380 Ok(Ok(Err(panic_payload))) => {
381 let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
382 || {
383 panic_payload
384 .downcast_ref::<String>()
385 .map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
386 },
387 |s| (*s).to_string(),
388 );
389 tracing::error!(
390 panic_message = %panic_msg,
391 "Python runtime thread panicked during shutdown"
392 );
393 Err(Error::BackendError {
394 message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
395 })
396 }
397 Ok(Err(join_err)) => {
398 tracing::error!("spawn_blocking join error: {join_err}");
399 Err(Error::BackendError {
400 message: format!("Failed to join Python thread: {join_err}"),
401 })
402 }
403 Err(_elapsed) => {
404 tracing::error!(
405 timeout_secs = shutdown_timeout.as_secs(),
406 "Python runtime thread did not exit within shutdown timeout"
407 );
408 Err(Error::Timeout {
409 duration: shutdown_timeout,
410 operation: "PythonRuntime::shutdown (thread join)".to_string(),
411 })
412 }
413 }
414 }
415
416 #[must_use]
418 pub const fn quota_state(&self) -> &Arc<QuotaState> {
419 &self.quota_state
420 }
421}
422
423impl Drop for PythonRuntime {
424 fn drop(&mut self) {
425 if self.thread.is_some() {
426 tracing::warn!(
427 "PythonRuntime dropped without calling shutdown() — \
428 Python thread may still be running"
429 );
430 }
431 }
432}
433
434fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
436 Python::initialize();
437
438 Python::attach(|py| {
443 if let Err(e) = venv::configure_python_sys_path(py) {
444 tracing::error!(
445 error = %e,
446 "Failed to configure Python sys.path in runtime thread — \
447 venv imports will likely fail"
448 );
449 }
450 });
451
452 if let Err(e) = run_live_thread(cmd_rx, config) {
453 tracing::error!(error = %e, "Python runtime thread failed");
454 }
455
456 tracing::info!("Python runtime thread exiting");
457}
458
459fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
462 Python::attach(|py| {
463 let asyncio = py.import("asyncio").map_err(|e| Error::BackendError {
464 message: format!("Failed to import asyncio: {e}"),
465 })?;
466 let event_loop =
467 asyncio
468 .call_method0("new_event_loop")
469 .map_err(|e| Error::BackendError {
470 message: format!("Failed to create new asyncio event loop: {e}"),
471 })?;
472 asyncio
473 .call_method1("set_event_loop", (&event_loop,))
474 .map_err(|e| Error::BackendError {
475 message: format!("Failed to set asyncio event loop: {e}"),
476 })?;
477
478 let sys = py.import("sys").map_err(|e| Error::BackendError {
480 message: format!("Failed to import sys: {e}"),
481 })?;
482 let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
483 message: format!("Failed to get sys.modules: {e}"),
484 })?;
485 let globals_mod = if sys_modules
486 .contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
487 .map_err(|e| Error::BackendError {
488 message: format!("Failed to check sys.modules: {e}"),
489 })? {
490 sys_modules
491 .get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
492 .map_err(|e| Error::BackendError {
493 message: format!("Failed to get _agy_bridge_globals: {e}"),
494 })?
495 } else {
496 let types = py.import("types").map_err(|e| Error::BackendError {
497 message: format!("Failed to import types: {e}"),
498 })?;
499 let module = types
500 .getattr("ModuleType")
501 .map_err(|e| Error::BackendError {
502 message: format!("Failed to get ModuleType: {e}"),
503 })?
504 .call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
505 .map_err(|e| Error::BackendError {
506 message: format!("Failed to create ModuleType: {e}"),
507 })?;
508 sys_modules
509 .set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
510 .map_err(|e| Error::BackendError {
511 message: format!("Failed to register _agy_bridge_globals: {e}"),
512 })?;
513 module
514 };
515 globals_mod
516 .setattr("EVENT_LOOP", &event_loop)
517 .map_err(|e| Error::BackendError {
518 message: format!("Failed to set EVENT_LOOP in globals: {e}"),
519 })?;
520
521 tracing::info!("Python asyncio event loop created on runtime thread");
522
523 let chat_timeout = config.chat_timeout;
524 let inter_agent_delay = config.inter_agent_delay;
525 let event_loop_obj = event_loop.clone().unbind();
526 let run_fut =
527 pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
528 command_loop::run_async_command_loop(
529 event_loop_obj,
530 cmd_rx,
531 chat_timeout,
532 inter_agent_delay,
533 )
534 .await
535 });
536
537 if let Err(e) = run_fut {
538 if let Err(close_err) = event_loop.call_method0("close") {
540 tracing::warn!("Failed to close asyncio event loop: {close_err}");
541 }
542 return Err(Error::BackendError {
543 message: format!("Python runtime command loop failed: {e}"),
544 });
545 }
546
547 if let Err(e) = event_loop.call_method0("close") {
548 tracing::warn!("Failed to close asyncio event loop: {e}");
549 }
550
551 Ok(())
552 })
553}
554
555impl crate::agent::Runtime for PythonRuntime {
556 async fn create_agent(
557 &self,
558 config: crate::config::AgentConfig,
559 ) -> Result<crate::agent::AgentId, Error> {
560 let mut all_tools = config.custom_tool_names();
562 if let Some(ref caps) = config.capabilities {
563 if let Some(ref builtins) = caps.enabled_tools {
564 all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
565 } else if caps.disabled_tools.is_none() {
566 all_tools.extend(
568 crate::config::capabilities::BuiltinTools::all_tools()
569 .iter()
570 .map(|b| b.as_sdk_name().to_string()),
571 );
572 }
573 } else {
574 all_tools.extend(
575 crate::config::capabilities::BuiltinTools::all_tools()
576 .iter()
577 .map(|b| b.as_sdk_name().to_string()),
578 );
579 }
580 tracing::info!(
581 "Agent starting with {} available tools: {:?}",
582 all_tools.len(),
583 all_tools
584 );
585
586 let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
587 message: format!("Failed to serialize AgentConfig: {e}"),
588 })?;
589
590 let raw_id = self
591 .send_command("create_agent", false, |reply| PyCommand::CreateAgent {
592 config_json,
593 reply,
594 })
595 .await?;
596
597 Ok(raw_id.0)
598 }
599
600 async fn chat(
601 &self,
602 agent_id: crate::agent::AgentId,
603 content: &crate::content::Content,
604 ) -> Result<crate::streaming::ChatResponseHandle, Error> {
605 let prompt = match content {
606 crate::content::Content::Text { text } => text.clone(),
607 other => crate::content::content_to_json(other)?,
608 };
609 self.send_command("chat", true, |reply| PyCommand::Chat {
610 agent_id: AgentId(agent_id),
611 prompt,
612 reply,
613 })
614 .await
615 }
616
617 async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
618 self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
619 agent_id: AgentId(agent_id),
620 reply,
621 })
622 .await
623 }
624
625 fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
626 let (reply, _) = oneshot::channel();
630 if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
631 agent_id: AgentId(agent_id),
632 reply,
633 }) {
634 tracing::debug!(
635 agent_id = agent_id,
636 error = %e,
637 "try_shutdown_agent: channel send failed (runtime may already be gone)"
638 );
639 }
640 }
641
642 async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
643 self.send_command("cancel", false, |reply| PyCommand::Cancel {
644 agent_id: AgentId(agent_id),
645 reply,
646 })
647 .await
648 }
649
650 async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
651 self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
652 agent_id: AgentId(agent_id),
653 reply,
654 })
655 .await
656 }
657
658 async fn send(
659 &self,
660 agent_id: crate::agent::AgentId,
661 content: &crate::content::Content,
662 ) -> Result<(), Error> {
663 let prompt = match content {
664 crate::content::Content::Text { text } => text.clone(),
665 other => crate::content::content_to_json(other)?,
666 };
667 self.send_command("send", false, |reply| PyCommand::Send {
668 agent_id: AgentId(agent_id),
669 prompt,
670 reply,
671 })
672 .await
673 }
674
675 async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
676 self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
677 agent_id: AgentId(agent_id),
678 reply,
679 })
680 .await
681 }
682
683 async fn wait_for_wakeup(
684 &self,
685 agent_id: crate::agent::AgentId,
686 timeout: std::time::Duration,
687 ) -> Result<bool, Error> {
688 self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
689 agent_id: AgentId(agent_id),
690 timeout_secs: timeout.as_secs_f64(),
691 reply,
692 })
693 .await
694 }
695
696 async fn wait_for_quota(&self) {
697 self.quota_state.wait_for_quota().await;
698 }
699
700 async fn record_quota_hit(&self, retry_after: std::time::Duration) {
701 self.quota_state.record_quota_hit(retry_after);
702 }
703
704 fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
705 &self.quota_registry
706 }
707
708 async fn history(
709 &self,
710 agent_id: crate::agent::AgentId,
711 ) -> Result<Vec<crate::types::ConversationMessage>, Error> {
712 self.send_command("get_history", false, |reply| PyCommand::GetHistory {
713 agent_id: AgentId(agent_id),
714 reply,
715 })
716 .await
717 }
718
719 async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
720 self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
721 agent_id: AgentId(agent_id),
722 reply,
723 })
724 .await
725 }
726
727 async fn total_usage(
728 &self,
729 agent_id: crate::agent::AgentId,
730 ) -> Result<crate::types::UsageMetadata, Error> {
731 self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
732 agent_id: AgentId(agent_id),
733 reply,
734 })
735 .await
736 }
737
738 async fn last_turn_usage(
739 &self,
740 agent_id: crate::agent::AgentId,
741 ) -> Result<crate::types::UsageMetadata, Error> {
742 self.send_command("get_last_turn_usage", false, |reply| {
743 PyCommand::GetLastTurnUsage {
744 agent_id: AgentId(agent_id),
745 reply,
746 }
747 })
748 .await
749 }
750
751 async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
752 self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
753 agent_id: AgentId(agent_id),
754 reply,
755 })
756 .await
757 }
758
759 async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
760 self.send_command("compaction_indices", false, |reply| {
761 PyCommand::GetCompactionIndices {
762 agent_id: AgentId(agent_id),
763 reply,
764 }
765 })
766 .await
767 }
768
769 async fn last_response(
770 &self,
771 agent_id: crate::agent::AgentId,
772 ) -> Result<Option<String>, Error> {
773 self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
774 agent_id: AgentId(agent_id),
775 reply,
776 })
777 .await
778 }
779
780 async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
781 self.send_command("delete", false, |reply| PyCommand::Delete {
782 agent_id: AgentId(agent_id),
783 reply,
784 })
785 .await
786 }
787
788 async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
789 self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
790 agent_id: AgentId(agent_id),
791 reply,
792 })
793 .await
794 }
795
796 async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
797 self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
798 agent_id: AgentId(agent_id),
799 reply,
800 })
801 .await
802 }
803}
804
805#[cfg(test)]
806mod tests {
807 use std::collections::HashMap;
808
809 use super::{ffi_dispatch::check_tool_execution_allowed, *};
810
811 fn test_config() -> RuntimeConfig {
812 RuntimeConfig {
813 channel_capacity: 16,
814 operation_timeout: Duration::from_secs(10),
815 shutdown_timeout: Duration::from_secs(5),
816 chat_timeout: Duration::from_mins(1),
817 inter_agent_delay: Duration::from_millis(100),
818 }
819 }
820
821 #[tokio::test]
822 async fn test_runtime_creation_and_shutdown() {
823 PythonRuntime::new(test_config())
825 .expect("Failed to create runtime")
826 .shutdown()
827 .await
828 .expect("Shutdown failed");
829 }
830
831 #[test]
832 fn runtime_config_serde_roundtrip() {
833 let config = test_config();
834 let json = serde_json::to_string(&config).unwrap();
835 let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
836 assert_eq!(parsed.channel_capacity, 16);
837 assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
838 assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
839 assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
840 assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
841 }
842
843 #[test]
844 fn default_operation_timeout_is_chat_plus_margin() {
845 let config = RuntimeConfig::default();
846 let expected = config.chat_timeout + Duration::from_mins(2);
847 assert_eq!(
848 config.operation_timeout, expected,
849 "operation_timeout should be chat_timeout + 2min safety margin"
850 );
851 }
852
853 #[test]
854 fn safety_error_structural() {
855 Python::initialize();
856 Python::attach(|py| {
857 let globals = pyo3::types::PyDict::new(py);
858 py.run(
859 c"
860class StopCandidateException(Exception):
861 pass
862err = StopCandidateException(\"dummy\")
863",
864 Some(&globals),
865 None,
866 )
867 .unwrap();
868
869 let err_obj = globals.get_item("err").unwrap().unwrap();
870 let err = PyErr::from_value(err_obj);
871
872 let mapped = crate::error::classify_py_error(py, &err);
873
874 assert!(
875 !matches!(mapped, crate::error::Error::Safety),
876 "Failed: matched Error::Safety based purely on the string name StopCandidateException!"
877 );
878 });
879 }
880
881 #[test]
882 fn maxtokens_error_structural() {
883 Python::initialize();
884 Python::attach(|py| {
885 let globals = pyo3::types::PyDict::new(py);
886 py.run(
887 c"
888class MaxTokensException(Exception):
889 pass
890err = MaxTokensException(\"dummy\")
891",
892 Some(&globals),
893 None,
894 )
895 .unwrap();
896
897 let err_obj = globals.get_item("err").unwrap().unwrap();
898 let err = PyErr::from_value(err_obj);
899
900 let mapped = crate::error::classify_py_error(py, &err);
901
902 assert!(
903 !matches!(mapped, crate::error::Error::MaxTokens),
904 "Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
905 );
906 });
907 }
908
909 struct MockAskUserHandler {
910 should_allow: std::sync::atomic::AtomicBool,
911 }
912
913 impl crate::policies::AskUserHandler for MockAskUserHandler {
914 fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
915 self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
916 }
917 }
918
919 #[test]
920 fn test_ask_user_policy_custom_tool_gating() {
921 let agent_id: u64 = 999;
922
923 let mut policies = crate::policies::PolicySet::new();
925 policies
926 .push(crate::policies::PolicyRule::AskUser {
927 tool: "dangerous_tool".to_owned(),
928 handler_id: "confirm_handler".to_owned(),
929 })
930 .unwrap();
931
932 let handler = Arc::new(MockAskUserHandler {
934 should_allow: std::sync::atomic::AtomicBool::new(true),
935 });
936
937 let mut registry = crate::tools::ToolRegistry::new();
939
940 #[crate::llm_tool]
942 fn dangerous_tool() -> Result<String, String> {
943 Ok("Executed dangerous action!".to_owned())
944 }
945 registry.register(DangerousTool);
946
947 bridge_state().write().unwrap().insert(
949 agent_id,
950 AgentBridgeState {
951 registry: Some(Arc::new(registry)),
952 hook_runner: None,
953 policies,
954 policy_handler: Some(
955 Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
956 ),
957 tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
958 conversation_id: Arc::new(std::sync::Mutex::new(None)),
959 },
960 );
961
962 handler
964 .should_allow
965 .store(true, std::sync::atomic::Ordering::SeqCst);
966 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
967 assert!(res.is_ok(), "Check should succeed");
968 assert!(
969 res.unwrap(),
970 "Should allow tool execution when handler returns true"
971 );
972
973 handler
975 .should_allow
976 .store(false, std::sync::atomic::Ordering::SeqCst);
977 let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
978 assert!(res.is_ok(), "Check should succeed");
979 assert!(
980 !res.unwrap(),
981 "Should block tool execution when handler returns false"
982 );
983
984 bridge_state().write().unwrap().remove(&agent_id);
986 }
987}