Skip to main content

agy_bridge/runtime/
mod.rs

1//! Python runtime manager: owns a dedicated Python thread with an asyncio event loop.
2//!
3//! The `PythonRuntime` struct bridges Rust's tokio async world with Python's asyncio
4//! by running a command dispatch loop in a dedicated thread. Rust sends `PyCommand`
5//! messages via an `mpsc` channel, and receives results via per-command `oneshot` channels.
6//!
7//! # Threading architecture
8//!
9//! - **One Python thread**: All GIL acquisition is confined to a single dedicated thread
10//!   (`agy-bridge-python-runtime`). This thread runs an asyncio event loop via
11//!   `pyo3_async_runtimes::tokio::run_until_complete`.
12//!
13//! - **Concurrent command processing**: Commands received from the `mpsc` channel are
14//!   **not** serialized. Each command spawns a future into a `FuturesUnordered` task set,
15//!   and `tokio::select!` drives both incoming commands and in-flight task completions.
16//!   Multiple chats/operations run concurrently through the Python asyncio event loop.
17//!
18//! - **Rust tool dispatch**: When the Python SDK invokes a Rust tool, `dispatch_rust_tool`
19//!   reads tool state from `BRIDGE_STATE`, then uses `future_into_py` to run the async
20//!   tool on the tokio runtime — keeping the Python thread unblocked for other coroutines.
21//!
22//! - **Hook/policy dispatch**: Similarly, `dispatch_rust_hook` and `dispatch_rust_policy_confirm`
23//!   use `spawn_blocking` to run synchronous hook callbacks without holding the GIL.
24//!
25//! # Why global state (`BRIDGE_STATE`)?
26//!
27//! The Python SDK's tool/hook/policy callbacks are dispatched via PyO3 `#[pyfunction]`
28//! entries (e.g. `dispatch_rust_tool`, `dispatch_rust_hook`). These functions are
29//! registered as plain Python callables and receive **only** the arguments the SDK
30//! passes (agent ID + serialized context). There is no way to thread a Rust reference
31//! or `Arc` through the Python call boundary.
32//!
33//! Therefore per-agent state (tool registries, hook runners, policy sets) is stored in
34//! a global `RwLock<HashMap<AgentId, AgentBridgeState>>`. The agent ID is used as a
35//! lookup key, and the lock is held only for brief `HashMap` operations (never across
36//! `.await` points). This is the standard pattern for PyO3 FFI bridges that need to
37//! associate Rust state with Python-side identifiers.
38
39use std::{sync::Arc, time::Duration};
40
41use pyo3::prelude::*;
42use tokio::sync::{mpsc, oneshot};
43
44use crate::{error::Error, quota::QuotaState};
45
46pub(crate) mod bridge_state;
47pub(crate) mod command_loop;
48pub(crate) mod ffi_dispatch;
49mod handlers;
50pub(crate) mod py_scripts;
51pub(crate) mod streaming;
52pub(crate) mod venv;
53
54// Re-export items used by sibling modules and external crate consumers.
55pub(crate) use bridge_state::{AgentBridgeState, AgentId, bridge_state};
56pub(crate) use ffi_dispatch::{
57    CREATE_AGENT_HOOK_GUARD, INITIALIZING_HOOK_RUNNER, PENDING_CONVERSATION_IDS,
58    dispatch_rust_hook, dispatch_rust_policy_confirm, dispatch_rust_tool,
59};
60
61/// Safety-net timeout for a single `send_command` round-trip.
62///
63/// This is the *outer* Rust-side timeout that wraps all commands sent to the
64/// Python thread (chat, `create_agent`, cancel, `get_history`, …).  The Python
65/// side applies its own, tighter timeouts (`chat_timeout`, `HANDLER_TIMEOUT`),
66/// so this value should only fire if the Python thread is completely stuck.
67///
68/// Defaults to `chat_timeout + 2 minutes` to give inner timeouts room to
69/// fire first.
70#[must_use]
71pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
72    chat_timeout + Duration::from_mins(2)
73}
74/// Default timeout (seconds) for a single `agent.chat()` round-trip.
75/// 120s (2 min) is generous for a normal turn while detecting stalls quickly.
76pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
77
78/// Default delay between successive chat commands to prevent burst requests.
79pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
80
81/// Default command channel buffer size.
82const DEFAULT_CHANNEL_CAPACITY: usize = 64;
83
84/// Default timeout for joining the Python thread on shutdown.
85const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
86
87/// Returns the default chat round-trip timeout, configurable via
88/// `AGI_CHAT_TIMEOUT_SECS` (defaults to 120 s).
89#[must_use]
90pub fn default_chat_timeout() -> Duration {
91    let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
92        val.parse::<u64>().unwrap_or_else(|e| {
93            tracing::warn!(
94                value = %val,
95                error = %e,
96                "Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
97            );
98            DEFAULT_CHAT_TIMEOUT_SECS
99        })
100    });
101    Duration::from_secs(secs)
102}
103
104/// Commands sent from Rust to the Python thread.
105///
106/// Each variant is constructed in `impl Runtime for PythonRuntime` and
107/// dispatched in `command_loop::run_async_command_loop`.
108pub(crate) enum PyCommand {
109    /// Create a new agent with the given configuration dict as JSON.
110    CreateAgent {
111        config_json: String,
112        reply: oneshot::Sender<Result<AgentId, Error>>,
113    },
114    /// Send a chat message to an agent.
115    Chat {
116        agent_id: AgentId,
117        prompt: String,
118        reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
119    },
120    /// Shut down a specific agent.
121    ShutdownAgent {
122        agent_id: AgentId,
123        reply: oneshot::Sender<Result<(), Error>>,
124    },
125    /// Cancel active execution on the agent.
126    Cancel {
127        agent_id: AgentId,
128        reply: oneshot::Sender<Result<(), Error>>,
129    },
130    /// Wait for the agent to stabilize/become idle.
131    WaitForIdle {
132        agent_id: AgentId,
133        reply: oneshot::Sender<Result<(), Error>>,
134    },
135    /// Send a message without waiting for completion (fire-and-forget).
136    Send {
137        agent_id: AgentId,
138        prompt: String,
139        reply: oneshot::Sender<Result<(), Error>>,
140    },
141    /// Signal that the agent is idle.
142    SignalIdle {
143        agent_id: AgentId,
144        reply: oneshot::Sender<Result<(), Error>>,
145    },
146    /// Wait for the agent to wake up; returns true if woken, false on timeout.
147    WaitForWakeup {
148        agent_id: AgentId,
149        timeout_secs: f64,
150        reply: oneshot::Sender<Result<bool, Error>>,
151    },
152    /// Shut down the entire Python runtime.
153    Shutdown,
154    /// Retrieve the conversation's message history.
155    GetHistory {
156        agent_id: AgentId,
157        reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
158    },
159    /// Return the number of completed turns.
160    GetTurnCount {
161        agent_id: AgentId,
162        reply: oneshot::Sender<Result<u32, Error>>,
163    },
164    /// Return cumulative token usage across all turns.
165    GetTotalUsage {
166        agent_id: AgentId,
167        reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
168    },
169    /// Return token usage from the most recent turn.
170    GetLastTurnUsage {
171        agent_id: AgentId,
172        reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
173    },
174    /// Clear the conversation history.
175    ClearHistory {
176        agent_id: AgentId,
177        reply: oneshot::Sender<Result<(), Error>>,
178    },
179    /// Return step indices where compaction occurred.
180    GetCompactionIndices {
181        agent_id: AgentId,
182        reply: oneshot::Sender<Result<Vec<u32>, Error>>,
183    },
184    /// Return the text of the last model response.
185    GetLastResponse {
186        agent_id: AgentId,
187        reply: oneshot::Sender<Result<Option<String>, Error>>,
188    },
189    /// Delete the conversation and all associated state.
190    ///
191    /// Constructed by `impl Runtime for PythonRuntime::delete()` — only
192    /// reachable when an external consumer calls `AgentHandle::delete()`.
193    Delete {
194        agent_id: AgentId,
195        reply: oneshot::Sender<Result<(), Error>>,
196    },
197    /// Disconnect from the agent without deleting state.
198    ///
199    /// Constructed by `impl Runtime for PythonRuntime::disconnect()`.
200    Disconnect {
201        agent_id: AgentId,
202        reply: oneshot::Sender<Result<(), Error>>,
203    },
204    /// Check whether the agent is currently idle.
205    ///
206    /// Constructed by `impl Runtime for PythonRuntime::is_idle()`.
207    IsIdle {
208        agent_id: AgentId,
209        reply: oneshot::Sender<Result<bool, Error>>,
210    },
211}
212
213/// Configuration for the bridge runtime.
214#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
215#[serde(default)]
216pub struct RuntimeConfig {
217    /// Channel buffer size for the command channel.
218    pub channel_capacity: usize,
219    /// Timeout for individual runtime operations.
220    pub operation_timeout: Duration,
221    /// Timeout for joining the Python thread on shutdown.
222    pub shutdown_timeout: Duration,
223    /// Timeout for a single `agent.chat()` round-trip.
224    ///
225    /// Defaults to the value of `AGI_CHAT_TIMEOUT_SECS` (env var), or 120 s.
226    pub chat_timeout: Duration,
227    /// Delay injected between successive chat commands to prevent burst requests.
228    pub inter_agent_delay: Duration,
229}
230
231impl Default for RuntimeConfig {
232    fn default() -> Self {
233        let chat_timeout = default_chat_timeout();
234        Self {
235            channel_capacity: DEFAULT_CHANNEL_CAPACITY,
236            operation_timeout: default_operation_timeout(chat_timeout),
237            shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
238            chat_timeout,
239            inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
240        }
241    }
242}
243
244/// Manages a dedicated Python thread with an asyncio event loop.
245///
246/// All Python/SDK interactions go through the command channel. This isolates
247/// GIL acquisition to the Python thread and keeps the tokio runtime responsive.
248pub struct PythonRuntime {
249    cmd_tx: mpsc::Sender<PyCommand>,
250    thread: Option<std::thread::JoinHandle<()>>,
251    config: RuntimeConfig,
252    /// Per-runtime quota registry. Each API key gets its own [`QuotaState`],
253    /// and different `PythonRuntime` instances are fully independent.
254    quota_registry: crate::quota::QuotaRegistry,
255    /// Default quota state used by `send_command` for runtime-level backoff.
256    quota_state: Arc<QuotaState>,
257}
258
259impl std::fmt::Debug for PythonRuntime {
260    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261        f.debug_struct("PythonRuntime")
262            .field("config", &self.config)
263            .field(
264                "thread_running",
265                &self.thread.as_ref().is_some_and(|t| !t.is_finished()),
266            )
267            .finish_non_exhaustive()
268    }
269}
270
271impl PythonRuntime {
272    /// Spawn a new Python runtime on a dedicated thread.
273    ///
274    /// Creates an asyncio event loop in the thread and starts the command
275    /// dispatch loop.
276    ///
277    /// # Errors
278    ///
279    /// Returns `Error::BackendError` if the thread fails to spawn or
280    /// Python initialization fails.
281    pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
282        let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
283
284        let thread_config = config.clone();
285        let thread = std::thread::Builder::new()
286            .name("agy-bridge-python-runtime".into())
287            .spawn(move || {
288                python_thread_main(cmd_rx, &thread_config);
289            })
290            .map_err(|e| Error::BackendError {
291                message: format!("Failed to spawn Python runtime thread: {e}"),
292            })?;
293
294        let quota_registry = crate::quota::QuotaRegistry::new();
295        let quota_state = quota_registry.state_for_key("");
296        Ok(Self {
297            cmd_tx,
298            thread: Some(thread),
299            config,
300            quota_registry,
301            quota_state,
302        })
303    }
304
305    /// Send a command to the Python thread and await the result.
306    ///
307    /// This is the primary interface for all Python interactions. It checks
308    /// quota state before sending and applies a configurable timeout.
309    ///
310    /// # Errors
311    ///
312    /// Returns `Error::ChannelClosed` if the Python thread has exited,
313    /// `Error::Timeout` if the operation exceeds the configured timeout.
314    async fn send_command<T>(
315        &self,
316        operation: &str,
317        is_llm_op: bool,
318        build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
319    ) -> Result<T, Error> {
320        let (reply_tx, reply_rx) = oneshot::channel();
321        let cmd = build_cmd(reply_tx);
322
323        self.cmd_tx
324            .send(cmd)
325            .await
326            .map_err(|e| Error::ChannelClosed {
327                message: format!("Python runtime thread has exited (sending {operation}): {e}"),
328            })?;
329
330        let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
331            reply_rx.await.map_err(|e| Error::ChannelClosed {
332                message: format!("Reply channel dropped for {operation}: {e}"),
333            })?
334        })
335        .await?;
336
337        // Only reset quota backoff for LLM operations (e.g. chat); non-LLM
338        // ops succeeding should not clear a 429 backoff.
339        if is_llm_op {
340            self.quota_state.record_success();
341        }
342
343        Ok(result)
344    }
345
346    /// Graceful shutdown: send `Shutdown` command, then join the thread.
347    ///
348    /// # Errors
349    ///
350    /// Returns `Error::Timeout` if the thread doesn't join within the
351    /// configured shutdown timeout, or `Error::BackendError` if the
352    /// thread panicked.
353    pub async fn shutdown(mut self) -> Result<(), Error> {
354        // Signal the command loop to exit.
355        // Ignoring send error: if the receiver is already gone the thread
356        // is already exiting, which is the outcome we want.
357        if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
358            tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
359        }
360
361        // Take the JoinHandle so Drop doesn't fire the "dropped without
362        // shutdown" warning.
363        let Some(thread) = self.thread.take() else {
364            tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
365            return Ok(());
366        };
367
368        let shutdown_timeout = self.config.shutdown_timeout;
369        let join_result = tokio::time::timeout(
370            shutdown_timeout,
371            tokio::task::spawn_blocking(move || thread.join()),
372        )
373        .await;
374
375        match join_result {
376            Ok(Ok(Ok(()))) => {
377                tracing::info!("Python runtime thread joined successfully");
378                Ok(())
379            }
380            Ok(Ok(Err(panic_payload))) => {
381                let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
382                    || {
383                        panic_payload
384                            .downcast_ref::<String>()
385                            .map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
386                    },
387                    |s| (*s).to_string(),
388                );
389                tracing::error!(
390                    panic_message = %panic_msg,
391                    "Python runtime thread panicked during shutdown"
392                );
393                Err(Error::BackendError {
394                    message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
395                })
396            }
397            Ok(Err(join_err)) => {
398                tracing::error!("spawn_blocking join error: {join_err}");
399                Err(Error::BackendError {
400                    message: format!("Failed to join Python thread: {join_err}"),
401                })
402            }
403            Err(_elapsed) => {
404                tracing::error!(
405                    timeout_secs = shutdown_timeout.as_secs(),
406                    "Python runtime thread did not exit within shutdown timeout"
407                );
408                Err(Error::Timeout {
409                    duration: shutdown_timeout,
410                    operation: "PythonRuntime::shutdown (thread join)".to_string(),
411                })
412            }
413        }
414    }
415
416    /// Access the shared quota state.
417    #[must_use]
418    pub const fn quota_state(&self) -> &Arc<QuotaState> {
419        &self.quota_state
420    }
421}
422
423impl Drop for PythonRuntime {
424    fn drop(&mut self) {
425        if self.thread.is_some() {
426            tracing::warn!(
427                "PythonRuntime dropped without calling shutdown() — \
428                 Python thread may still be running"
429            );
430        }
431    }
432}
433
434/// Entry point for the dedicated Python thread.
435fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
436    Python::initialize();
437
438    // Environment variables are already loaded by load_dotenv() at bridge
439    // construction time, before any threads are spawned.
440
441    // Configure sys.path so the venv's site-packages are importable.
442    Python::attach(|py| {
443        if let Err(e) = venv::configure_python_sys_path(py) {
444            tracing::error!(
445                error = %e,
446                "Failed to configure Python sys.path in runtime thread — \
447                 venv imports will likely fail"
448            );
449        }
450    });
451
452    if let Err(e) = run_live_thread(cmd_rx, config) {
453        tracing::error!(error = %e, "Python runtime thread failed");
454    }
455
456    tracing::info!("Python runtime thread exiting");
457}
458
459/// Live SDK thread: creates an asyncio event loop and dispatches commands
460/// to the real Antigravity SDK via `pyo3_async_runtimes`.
461fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
462    Python::attach(|py| {
463        let asyncio = py.import("asyncio").map_err(|e| Error::BackendError {
464            message: format!("Failed to import asyncio: {e}"),
465        })?;
466        let event_loop =
467            asyncio
468                .call_method0("new_event_loop")
469                .map_err(|e| Error::BackendError {
470                    message: format!("Failed to create new asyncio event loop: {e}"),
471                })?;
472        asyncio
473            .call_method1("set_event_loop", (&event_loop,))
474            .map_err(|e| Error::BackendError {
475                message: format!("Failed to set asyncio event loop: {e}"),
476            })?;
477
478        // Register event_loop in globals for access from any thread
479        let sys = py.import("sys").map_err(|e| Error::BackendError {
480            message: format!("Failed to import sys: {e}"),
481        })?;
482        let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
483            message: format!("Failed to get sys.modules: {e}"),
484        })?;
485        let globals_mod = if sys_modules
486            .contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
487            .map_err(|e| Error::BackendError {
488                message: format!("Failed to check sys.modules: {e}"),
489            })? {
490            sys_modules
491                .get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
492                .map_err(|e| Error::BackendError {
493                    message: format!("Failed to get _agy_bridge_globals: {e}"),
494                })?
495        } else {
496            let types = py.import("types").map_err(|e| Error::BackendError {
497                message: format!("Failed to import types: {e}"),
498            })?;
499            let module = types
500                .getattr("ModuleType")
501                .map_err(|e| Error::BackendError {
502                    message: format!("Failed to get ModuleType: {e}"),
503                })?
504                .call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
505                .map_err(|e| Error::BackendError {
506                    message: format!("Failed to create ModuleType: {e}"),
507                })?;
508            sys_modules
509                .set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
510                .map_err(|e| Error::BackendError {
511                    message: format!("Failed to register _agy_bridge_globals: {e}"),
512                })?;
513            module
514        };
515        globals_mod
516            .setattr("EVENT_LOOP", &event_loop)
517            .map_err(|e| Error::BackendError {
518                message: format!("Failed to set EVENT_LOOP in globals: {e}"),
519            })?;
520
521        tracing::info!("Python asyncio event loop created on runtime thread");
522
523        let chat_timeout = config.chat_timeout;
524        let inter_agent_delay = config.inter_agent_delay;
525        let event_loop_obj = event_loop.clone().unbind();
526        let run_fut =
527            pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
528                command_loop::run_async_command_loop(
529                    event_loop_obj,
530                    cmd_rx,
531                    chat_timeout,
532                    inter_agent_delay,
533                )
534                .await
535            });
536
537        if let Err(e) = run_fut {
538            // Close the event loop best-effort before propagating.
539            if let Err(close_err) = event_loop.call_method0("close") {
540                tracing::warn!("Failed to close asyncio event loop: {close_err}");
541            }
542            return Err(Error::BackendError {
543                message: format!("Python runtime command loop failed: {e}"),
544            });
545        }
546
547        if let Err(e) = event_loop.call_method0("close") {
548            tracing::warn!("Failed to close asyncio event loop: {e}");
549        }
550
551        Ok(())
552    })
553}
554
555impl crate::agent::Runtime for PythonRuntime {
556    async fn create_agent(
557        &self,
558        config: crate::config::AgentConfig,
559    ) -> Result<crate::agent::AgentId, Error> {
560        // Report all available tools as requested by the user.
561        let mut all_tools = config.custom_tool_names();
562        if let Some(ref caps) = config.capabilities {
563            if let Some(ref builtins) = caps.enabled_tools {
564                all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
565            } else if caps.disabled_tools.is_none() {
566                // Default is all tools
567                all_tools.extend(
568                    crate::config::capabilities::BuiltinTools::all_tools()
569                        .iter()
570                        .map(|b| b.as_sdk_name().to_string()),
571                );
572            }
573        } else {
574            all_tools.extend(
575                crate::config::capabilities::BuiltinTools::all_tools()
576                    .iter()
577                    .map(|b| b.as_sdk_name().to_string()),
578            );
579        }
580        tracing::info!(
581            "Agent starting with {} available tools: {:?}",
582            all_tools.len(),
583            all_tools
584        );
585
586        let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
587            message: format!("Failed to serialize AgentConfig: {e}"),
588        })?;
589
590        let raw_id = self
591            .send_command("create_agent", false, |reply| PyCommand::CreateAgent {
592                config_json,
593                reply,
594            })
595            .await?;
596
597        Ok(raw_id.0)
598    }
599
600    async fn chat(
601        &self,
602        agent_id: crate::agent::AgentId,
603        content: &crate::content::Content,
604    ) -> Result<crate::streaming::ChatResponseHandle, Error> {
605        let prompt = match content {
606            crate::content::Content::Text { text } => text.clone(),
607            other => crate::content::content_to_json(other)?,
608        };
609        self.send_command("chat", true, |reply| PyCommand::Chat {
610            agent_id: AgentId(agent_id),
611            prompt,
612            reply,
613        })
614        .await
615    }
616
617    async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
618        self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
619            agent_id: AgentId(agent_id),
620            reply,
621        })
622        .await
623    }
624
625    fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
626        // Fire-and-forget: create a oneshot whose receiver we drop immediately.
627        // The Python thread will still process the shutdown; we just don't wait
628        // for the result.
629        let (reply, _) = oneshot::channel();
630        if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
631            agent_id: AgentId(agent_id),
632            reply,
633        }) {
634            tracing::debug!(
635                agent_id = agent_id,
636                error = %e,
637                "try_shutdown_agent: channel send failed (runtime may already be gone)"
638            );
639        }
640    }
641
642    async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
643        self.send_command("cancel", false, |reply| PyCommand::Cancel {
644            agent_id: AgentId(agent_id),
645            reply,
646        })
647        .await
648    }
649
650    async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
651        self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
652            agent_id: AgentId(agent_id),
653            reply,
654        })
655        .await
656    }
657
658    async fn send(
659        &self,
660        agent_id: crate::agent::AgentId,
661        content: &crate::content::Content,
662    ) -> Result<(), Error> {
663        let prompt = match content {
664            crate::content::Content::Text { text } => text.clone(),
665            other => crate::content::content_to_json(other)?,
666        };
667        self.send_command("send", false, |reply| PyCommand::Send {
668            agent_id: AgentId(agent_id),
669            prompt,
670            reply,
671        })
672        .await
673    }
674
675    async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
676        self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
677            agent_id: AgentId(agent_id),
678            reply,
679        })
680        .await
681    }
682
683    async fn wait_for_wakeup(
684        &self,
685        agent_id: crate::agent::AgentId,
686        timeout: std::time::Duration,
687    ) -> Result<bool, Error> {
688        self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
689            agent_id: AgentId(agent_id),
690            timeout_secs: timeout.as_secs_f64(),
691            reply,
692        })
693        .await
694    }
695
696    async fn wait_for_quota(&self) {
697        self.quota_state.wait_for_quota().await;
698    }
699
700    async fn record_quota_hit(&self, retry_after: std::time::Duration) {
701        self.quota_state.record_quota_hit(retry_after);
702    }
703
704    fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
705        &self.quota_registry
706    }
707
708    async fn history(
709        &self,
710        agent_id: crate::agent::AgentId,
711    ) -> Result<Vec<crate::types::ConversationMessage>, Error> {
712        self.send_command("get_history", false, |reply| PyCommand::GetHistory {
713            agent_id: AgentId(agent_id),
714            reply,
715        })
716        .await
717    }
718
719    async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
720        self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
721            agent_id: AgentId(agent_id),
722            reply,
723        })
724        .await
725    }
726
727    async fn total_usage(
728        &self,
729        agent_id: crate::agent::AgentId,
730    ) -> Result<crate::types::UsageMetadata, Error> {
731        self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
732            agent_id: AgentId(agent_id),
733            reply,
734        })
735        .await
736    }
737
738    async fn last_turn_usage(
739        &self,
740        agent_id: crate::agent::AgentId,
741    ) -> Result<crate::types::UsageMetadata, Error> {
742        self.send_command("get_last_turn_usage", false, |reply| {
743            PyCommand::GetLastTurnUsage {
744                agent_id: AgentId(agent_id),
745                reply,
746            }
747        })
748        .await
749    }
750
751    async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
752        self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
753            agent_id: AgentId(agent_id),
754            reply,
755        })
756        .await
757    }
758
759    async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
760        self.send_command("compaction_indices", false, |reply| {
761            PyCommand::GetCompactionIndices {
762                agent_id: AgentId(agent_id),
763                reply,
764            }
765        })
766        .await
767    }
768
769    async fn last_response(
770        &self,
771        agent_id: crate::agent::AgentId,
772    ) -> Result<Option<String>, Error> {
773        self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
774            agent_id: AgentId(agent_id),
775            reply,
776        })
777        .await
778    }
779
780    async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
781        self.send_command("delete", false, |reply| PyCommand::Delete {
782            agent_id: AgentId(agent_id),
783            reply,
784        })
785        .await
786    }
787
788    async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
789        self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
790            agent_id: AgentId(agent_id),
791            reply,
792        })
793        .await
794    }
795
796    async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
797        self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
798            agent_id: AgentId(agent_id),
799            reply,
800        })
801        .await
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use std::collections::HashMap;
808
809    use super::{ffi_dispatch::check_tool_execution_allowed, *};
810
811    fn test_config() -> RuntimeConfig {
812        RuntimeConfig {
813            channel_capacity: 16,
814            operation_timeout: Duration::from_secs(10),
815            shutdown_timeout: Duration::from_secs(5),
816            chat_timeout: Duration::from_mins(1),
817            inter_agent_delay: Duration::from_millis(100),
818        }
819    }
820
821    #[tokio::test]
822    async fn test_runtime_creation_and_shutdown() {
823        // Shutdown should complete cleanly.
824        PythonRuntime::new(test_config())
825            .expect("Failed to create runtime")
826            .shutdown()
827            .await
828            .expect("Shutdown failed");
829    }
830
831    #[test]
832    fn runtime_config_serde_roundtrip() {
833        let config = test_config();
834        let json = serde_json::to_string(&config).unwrap();
835        let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
836        assert_eq!(parsed.channel_capacity, 16);
837        assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
838        assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
839        assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
840        assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
841    }
842
843    #[test]
844    fn default_operation_timeout_is_chat_plus_margin() {
845        let config = RuntimeConfig::default();
846        let expected = config.chat_timeout + Duration::from_mins(2);
847        assert_eq!(
848            config.operation_timeout, expected,
849            "operation_timeout should be chat_timeout + 2min safety margin"
850        );
851    }
852
853    #[test]
854    fn safety_error_structural() {
855        Python::initialize();
856        Python::attach(|py| {
857            let globals = pyo3::types::PyDict::new(py);
858            py.run(
859                c"
860class StopCandidateException(Exception):
861    pass
862err = StopCandidateException(\"dummy\")
863",
864                Some(&globals),
865                None,
866            )
867            .unwrap();
868
869            let err_obj = globals.get_item("err").unwrap().unwrap();
870            let err = PyErr::from_value(err_obj);
871
872            let mapped = crate::error::classify_py_error(py, &err);
873
874            assert!(
875                !matches!(mapped, crate::error::Error::Safety),
876                "Failed: matched Error::Safety based purely on the string name StopCandidateException!"
877            );
878        });
879    }
880
881    #[test]
882    fn maxtokens_error_structural() {
883        Python::initialize();
884        Python::attach(|py| {
885            let globals = pyo3::types::PyDict::new(py);
886            py.run(
887                c"
888class MaxTokensException(Exception):
889    pass
890err = MaxTokensException(\"dummy\")
891",
892                Some(&globals),
893                None,
894            )
895            .unwrap();
896
897            let err_obj = globals.get_item("err").unwrap().unwrap();
898            let err = PyErr::from_value(err_obj);
899
900            let mapped = crate::error::classify_py_error(py, &err);
901
902            assert!(
903                !matches!(mapped, crate::error::Error::MaxTokens),
904                "Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
905            );
906        });
907    }
908
909    struct MockAskUserHandler {
910        should_allow: std::sync::atomic::AtomicBool,
911    }
912
913    impl crate::policies::AskUserHandler for MockAskUserHandler {
914        fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
915            self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
916        }
917    }
918
919    #[test]
920    fn test_ask_user_policy_custom_tool_gating() {
921        let agent_id: u64 = 999;
922
923        // 1. Setup the PolicySet with an AskUser rule for "dangerous_tool"
924        let mut policies = crate::policies::PolicySet::new();
925        policies
926            .push(crate::policies::PolicyRule::AskUser {
927                tool: "dangerous_tool".to_owned(),
928                handler_id: "confirm_handler".to_owned(),
929            })
930            .unwrap();
931
932        // 2. Setup mock handler
933        let handler = Arc::new(MockAskUserHandler {
934            should_allow: std::sync::atomic::AtomicBool::new(true),
935        });
936
937        // 3. Mock the tool registry
938        let mut registry = crate::tools::ToolRegistry::new();
939
940        /// A dangerous tool.
941        #[crate::llm_tool]
942        fn dangerous_tool() -> Result<String, String> {
943            Ok("Executed dangerous action!".to_owned())
944        }
945        registry.register(DangerousTool);
946
947        // 4. Register all state in a single bridge_state() insertion
948        bridge_state().write().unwrap().insert(
949            agent_id,
950            AgentBridgeState {
951                registry: Some(Arc::new(registry)),
952                hook_runner: None,
953                policies,
954                policy_handler: Some(
955                    Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
956                ),
957                tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
958                conversation_id: Arc::new(std::sync::Mutex::new(None)),
959            },
960        );
961
962        // 5. Simulate check_tool_execution_allowed when the AskUserHandler allows it (returns true)
963        handler
964            .should_allow
965            .store(true, std::sync::atomic::Ordering::SeqCst);
966        let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
967        assert!(res.is_ok(), "Check should succeed");
968        assert!(
969            res.unwrap(),
970            "Should allow tool execution when handler returns true"
971        );
972
973        // 6. Simulate check_tool_execution_allowed when the AskUserHandler denies it (returns false)
974        handler
975            .should_allow
976            .store(false, std::sync::atomic::Ordering::SeqCst);
977        let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
978        assert!(res.is_ok(), "Check should succeed");
979        assert!(
980            !res.unwrap(),
981            "Should block tool execution when handler returns false"
982        );
983
984        // Clean up
985        bridge_state().write().unwrap().remove(&agent_id);
986    }
987}