Skip to main content

agy_bridge/runtime/
mod.rs

1//! Python runtime manager: owns a dedicated Python thread with an asyncio event loop.
2//!
3//! The `PythonRuntime` struct bridges Rust's tokio async world with Python's asyncio
4//! by running a command dispatch loop in a dedicated thread. Rust sends `PyCommand`
5//! messages via an `mpsc` channel, and receives results via per-command `oneshot` channels.
6//!
7//! # Threading architecture
8//!
9//! - **One Python thread**: All GIL acquisition is confined to a single dedicated thread
10//!   (`agy-bridge-python-runtime`). This thread runs an asyncio event loop via
11//!   `pyo3_async_runtimes::tokio::run_until_complete`.
12//!
13//! - **Concurrent command processing**: Commands received from the `mpsc` channel are
14//!   **not** serialized. Each command spawns a future into a `FuturesUnordered` task set,
15//!   and `tokio::select!` drives both incoming commands and in-flight task completions.
16//!   Multiple chats/operations run concurrently through the Python asyncio event loop.
17//!
18//! - **Rust tool dispatch**: When the Python SDK invokes a Rust tool, `dispatch_rust_tool`
19//!   reads tool state from `BRIDGE_STATE`, then uses `future_into_py` to run the async
20//!   tool on the tokio runtime — keeping the Python thread unblocked for other coroutines.
21//!
22//! - **Hook/policy dispatch**: Similarly, `dispatch_rust_hook` and `dispatch_rust_policy_confirm`
23//!   use `spawn_blocking` to run synchronous hook callbacks without holding the GIL.
24//!
25//! # Why global state (`BRIDGE_STATE`)?
26//!
27//! The Python SDK's tool/hook/policy callbacks are dispatched via PyO3 `#[pyfunction]`
28//! entries (e.g. `dispatch_rust_tool`, `dispatch_rust_hook`). These functions are
29//! registered as plain Python callables and receive **only** the arguments the SDK
30//! passes (agent ID + serialized context). There is no way to thread a Rust reference
31//! or `Arc` through the Python call boundary.
32//!
33//! Therefore per-agent state (tool registries, hook runners, policy sets) is stored in
34//! a global `RwLock<HashMap<AgentId, AgentBridgeState>>`. The agent ID is used as a
35//! lookup key, and the lock is held only for brief `HashMap` operations (never across
36//! `.await` points). This is the standard pattern for PyO3 FFI bridges that need to
37//! associate Rust state with Python-side identifiers.
38
39#![expect(clippy::useless_conversion)] // PyO3 #[pyfunction] wrapper generates .into() on PyErr
40use std::{collections::HashMap, sync::Arc, time::Duration};
41
42use pyo3::prelude::*;
43use tokio::sync::{mpsc, oneshot};
44
45use crate::{error::Error, quota::QuotaState};
46
47pub(crate) mod command_loop;
48mod handlers;
49pub(crate) mod py_scripts;
50pub(crate) mod streaming;
51pub(crate) mod venv;
52
53/// Safety-net timeout for a single `send_command` round-trip.
54///
55/// This is the *outer* Rust-side timeout that wraps all commands sent to the
56/// Python thread (chat, `create_agent`, cancel, `get_history`, …).  The Python
57/// side applies its own, tighter timeouts (`chat_timeout`, `HANDLER_TIMEOUT`),
58/// so this value should only fire if the Python thread is completely stuck.
59///
60/// Defaults to `chat_timeout + 2 minutes` to give inner timeouts room to
61/// fire first.
62#[must_use]
63pub fn default_operation_timeout(chat_timeout: Duration) -> Duration {
64    chat_timeout + Duration::from_mins(2)
65}
66/// Default timeout (seconds) for a single `agent.chat()` round-trip.
67/// 120s (2 min) is generous for a normal turn while detecting stalls quickly.
68pub const DEFAULT_CHAT_TIMEOUT_SECS: u64 = 120;
69
70/// Default delay between successive chat commands to prevent burst requests.
71pub const DEFAULT_INTER_AGENT_DELAY: Duration = Duration::from_millis(500);
72
73/// Default command channel buffer size.
74const DEFAULT_CHANNEL_CAPACITY: usize = 64;
75
76/// Default timeout for joining the Python thread on shutdown.
77const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
78
79/// Returns the default chat round-trip timeout, configurable via
80/// `AGI_CHAT_TIMEOUT_SECS` (defaults to 120 s).
81#[must_use]
82pub fn default_chat_timeout() -> Duration {
83    let secs = std::env::var("AGI_CHAT_TIMEOUT_SECS").map_or(DEFAULT_CHAT_TIMEOUT_SECS, |val| {
84        val.parse::<u64>().unwrap_or_else(|e| {
85            tracing::warn!(
86                value = %val,
87                error = %e,
88                "Invalid AGI_CHAT_TIMEOUT_SECS, using default {DEFAULT_CHAT_TIMEOUT_SECS}s"
89            );
90            DEFAULT_CHAT_TIMEOUT_SECS
91        })
92    });
93    Duration::from_secs(secs)
94}
95
96/// Opaque agent identifier returned by the runtime.
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
98pub(crate) struct AgentId(pub(crate) u64);
99
100impl std::fmt::Display for AgentId {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(f, "agent-{}", self.0)
103    }
104}
105
106/// Per-agent state stored in the global [`BRIDGE_STATE`] registry.
107///
108/// Bundles all sidecar data that FFI callbacks need to look up by agent ID.
109/// Consolidating into one struct means a single lock acquisition covers all
110/// lookups/insertions/removals, preventing inconsistent partial state.
111pub(crate) struct AgentBridgeState {
112    /// Custom Rust tools registered for this agent.
113    pub(crate) registry: Option<Arc<crate::tools::ToolRegistry>>,
114    /// Lifecycle hooks for pre/post turn, tool-call gating, etc.
115    pub(crate) hook_runner: Option<Arc<crate::hooks::Hooks>>,
116    /// Policy rules governing tool-call permissions.
117    pub(crate) policies: crate::policies::PolicySet,
118    /// Interactive confirmation handler for `NeedsConfirmation` policies.
119    pub(crate) policy_handler: Option<Arc<dyn crate::policies::AskUserHandler>>,
120    /// Shared key-value state persisted across tool calls for this agent.
121    pub(crate) tool_state: Arc<std::sync::RwLock<HashMap<String, serde_json::Value>>>,
122}
123
124/// Single global registry of per-agent bridge state, keyed by agent ID.
125///
126/// # Lock choice
127///
128/// Uses `std::sync::RwLock` (not `tokio::sync::RwLock`) because the lock is
129/// held only for brief `HashMap` insert/remove/lookup operations and is never
130/// held across an `.await` point. This avoids the overhead of an async lock
131/// and is safe from deadlocks.
132///
133/// # Scalability
134///
135/// For typical agent counts (< ~100), `RwLock<HashMap>` provides sufficient
136/// throughput.  Read-side contention is bounded by the microsecond-scale lock
137/// duration.  If the bridge ever needs to support thousands of concurrent
138/// agents, replacing this with a `DashMap` would eliminate read-lock overhead
139/// entirely — but is unnecessary for current workloads.
140static BRIDGE_STATE: std::sync::OnceLock<
141    std::sync::RwLock<std::collections::HashMap<u64, AgentBridgeState>>,
142> = std::sync::OnceLock::new();
143
144/// Access the global per-agent bridge state registry.
145pub(crate) fn bridge_state()
146-> &'static std::sync::RwLock<std::collections::HashMap<u64, AgentBridgeState>> {
147    BRIDGE_STATE.get_or_init(|| std::sync::RwLock::new(std::collections::HashMap::new()))
148}
149
150/// Fallback `Hooks` registry used during `create_agent` when the permanent entry is not yet registered.
151pub(crate) static INITIALIZING_HOOK_RUNNER: std::sync::Mutex<Option<Arc<crate::hooks::Hooks>>> =
152    std::sync::Mutex::new(None);
153
154/// Serializes `create_agent` calls that install a temporary hook runner in
155/// [`INITIALIZING_HOOK_RUNNER`], preventing concurrent creates from
156/// overwriting each other's fallback runner.
157pub(crate) static CREATE_AGENT_HOOK_GUARD: tokio::sync::Mutex<()> =
158    tokio::sync::Mutex::const_new(());
159
160/// Execute a hook by name, deserializing the context JSON and calling the
161/// appropriate method on the runner. Returns the serialized result (empty
162/// string for void hooks).
163fn dispatch_hook_by_name(
164    hook_runner: &crate::hooks::Hooks,
165    hook_point: &str,
166    context_json: &str,
167) -> Result<String, crate::error::Error> {
168    let mut result_json = String::new();
169    match hook_point {
170        "pre_turn" => {
171            let ctx = serde_json::from_str::<crate::hooks::PreTurnContext>(context_json).map_err(
172                |e| crate::error::Error::BackendError {
173                    message: format!("Failed to deserialize PreTurnContext: {e}"),
174                },
175            )?;
176            hook_runner.run_pre_turn(&ctx);
177        }
178        "post_turn" => {
179            let ctx = serde_json::from_str::<crate::hooks::PostTurnContext>(context_json).map_err(
180                |e| crate::error::Error::BackendError {
181                    message: format!("Failed to deserialize PostTurnContext: {e}"),
182                },
183            )?;
184            hook_runner.run_post_turn(&ctx);
185        }
186        "pre_tool_call_decide" => {
187            let ctx = serde_json::from_str::<crate::hooks::PreToolCallDecideContext>(context_json)
188                .map_err(|e| crate::error::Error::BackendError {
189                    message: format!("Failed to deserialize PreToolCallDecideContext: {e} | JSON was: {context_json}"),
190                })?;
191            let hook_result = hook_runner.run_pre_tool_call_decide(&ctx);
192            result_json = serde_json::to_string(&hook_result).map_err(|e| {
193                crate::error::Error::BackendError {
194                    message: format!("Failed to serialize PreToolCallDecide result: {e}"),
195                }
196            })?;
197        }
198        "post_tool_call" => {
199            let ctx = serde_json::from_str::<crate::hooks::PostToolCallContext>(context_json)
200                .map_err(|e| crate::error::Error::BackendError {
201                    message: format!(
202                        "Failed to deserialize PostToolCallContext: {e} | JSON was: {context_json}"
203                    ),
204                })?;
205            hook_runner.run_post_tool_call(&ctx);
206        }
207        "on_compaction" => {
208            let ctx = serde_json::from_str::<crate::hooks::OnCompactionContext>(context_json)
209                .map_err(|e| crate::error::Error::BackendError {
210                    message: format!("Failed to deserialize OnCompactionContext: {e}"),
211                })?;
212            hook_runner.run_on_compaction(&ctx);
213        }
214        "on_session_start" => {
215            let ctx = serde_json::from_str::<crate::hooks::OnSessionStartContext>(context_json)
216                .map_err(|e| crate::error::Error::BackendError {
217                    message: format!("Failed to deserialize OnSessionStartContext: {e}"),
218                })?;
219            hook_runner.run_on_session_start(&ctx);
220        }
221        "on_session_end" => {
222            let ctx = serde_json::from_str::<crate::hooks::OnSessionEndContext>(context_json)
223                .map_err(|e| crate::error::Error::BackendError {
224                    message: format!("Failed to deserialize OnSessionEndContext: {e}"),
225                })?;
226            hook_runner.run_on_session_end(&ctx);
227        }
228        "on_tool_error" => {
229            let ctx = serde_json::from_str::<crate::hooks::OnToolErrorContext>(context_json)
230                .map_err(|e| crate::error::Error::BackendError {
231                    message: format!("Failed to deserialize OnToolErrorContext: {e}"),
232                })?;
233            hook_runner.run_on_tool_error(&ctx);
234        }
235        "on_interaction" => {
236            let ctx = serde_json::from_str::<crate::hooks::OnInteractionContext>(context_json)
237                .map_err(|e| crate::error::Error::BackendError {
238                    message: format!("Failed to deserialize OnInteractionContext: {e}"),
239                })?;
240            let hook_result = hook_runner.run_on_interaction(&ctx);
241            result_json = serde_json::to_string(&hook_result).map_err(|e| {
242                crate::error::Error::BackendError {
243                    message: format!("Failed to serialize OnInteraction result: {e}"),
244                }
245            })?;
246        }
247        _ => {
248            tracing::warn!("Unknown hook point: {}", hook_point);
249        }
250    }
251    Ok(result_json)
252}
253
254/// Dispatches a Rust hook call from the Python thread.
255#[pyfunction]
256pub(crate) fn dispatch_rust_hook(
257    py: Python<'_>,
258    agent_id: u64,
259    hook_point: String,
260    context_json: String,
261) -> PyResult<Bound<'_, PyAny>> {
262    tracing::debug!(agent_id, hook_point = %hook_point, "dispatch_rust_hook called from Python");
263    let hook_runner = {
264        let map = bridge_state().read().map_err(|e| {
265            pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
266        })?;
267        if let Some(entry) = map.get(&agent_id) {
268            let runner = entry.hook_runner.as_ref().ok_or_else(|| {
269                pyo3::exceptions::PyRuntimeError::new_err(format!(
270                    "No active Hooks found for agent ID {agent_id}"
271                ))
272            })?;
273            Arc::clone(runner)
274        } else {
275            let opt = INITIALIZING_HOOK_RUNNER.lock().map_err(|e| {
276                pyo3::exceptions::PyRuntimeError::new_err(format!(
277                    "Failed to lock INITIALIZING_HOOK_RUNNER: {e}"
278                ))
279            })?;
280            if let Some(ref runner) = *opt {
281                Arc::clone(runner)
282            } else {
283                return Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
284                    "No active bridge state or initializing hook runner found for agent ID {agent_id}"
285                )));
286            }
287        }
288    };
289
290    pyo3_async_runtimes::tokio::future_into_py(py, async move {
291        // SAFETY CONSTRAINT: Hooks dispatched here MUST NOT acquire the Python
292        // GIL. The Python thread (which holds the GIL) is blocked waiting for
293        // this future to complete via `future_into_py`. Acquiring the GIL from
294        // a blocking thread would deadlock.
295        let result = tokio::task::spawn_blocking(move || {
296            dispatch_hook_by_name(&hook_runner, &hook_point, &context_json)
297        })
298        .await
299        .map_err(|e| {
300            pyo3::exceptions::PyRuntimeError::new_err(format!("Hook execution failed: {e}"))
301        })?
302        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
303
304        Ok(result)
305    })
306}
307
308#[pyfunction]
309pub(crate) fn dispatch_rust_policy_confirm(
310    py: Python<'_>,
311    agent_id: u64,
312    tool_name: String,
313    args_json: String,
314) -> PyResult<Bound<'_, PyAny>> {
315    tracing::info!(agent_id, tool = %tool_name, "dispatch_rust_policy_confirm called from Python");
316    let policy_handler = {
317        let map = bridge_state().read().map_err(|e| {
318            pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
319        })?;
320        let entry = map.get(&agent_id).ok_or_else(|| {
321            pyo3::exceptions::PyRuntimeError::new_err(format!(
322                "No active bridge state found for agent ID {agent_id}"
323            ))
324        })?;
325        let handler = entry.policy_handler.as_ref().ok_or_else(|| {
326            pyo3::exceptions::PyRuntimeError::new_err(format!(
327                "No active AskUserHandler found for agent ID {agent_id}"
328            ))
329        })?;
330        Arc::clone(handler)
331    };
332
333    pyo3_async_runtimes::tokio::future_into_py(py, async move {
334        // SAFETY CONSTRAINT: Handlers dispatched here MUST NOT acquire the Python
335        // GIL. The Python thread is blocked waiting for this future.
336        let args_val: serde_json::Value = serde_json::from_str(&args_json).map_err(|e| {
337            pyo3::exceptions::PyValueError::new_err(format!(
338                "Failed to parse policy args JSON: {e}"
339            ))
340        })?;
341        let result =
342            tokio::task::spawn_blocking(move || policy_handler.confirm(&tool_name, &args_val))
343                .await
344                .map_err(|e| {
345                    pyo3::exceptions::PyRuntimeError::new_err(format!(
346                        "Policy confirmation panicked: {e}"
347                    ))
348                })?;
349
350        Ok(result)
351    })
352}
353
354/// Evaluates policies and registered handlers to check if a tool execution is allowed.
355pub(crate) fn check_tool_execution_allowed(
356    agent_id: u64,
357    name: &str,
358    args_json: &str,
359) -> Result<bool, crate::error::Error> {
360    let map = bridge_state()
361        .read()
362        .map_err(|e| crate::error::Error::BackendError {
363            message: format!("Failed to read BRIDGE_STATE: {e}"),
364        })?;
365
366    let Some(state) = map.get(&agent_id) else {
367        return Ok(false);
368    };
369
370    let (is_allowed, needs_confirm) = match state.policies.evaluate(name) {
371        crate::policies::PolicyDecision::Allow => (true, false),
372        crate::policies::PolicyDecision::Deny => (false, false),
373        crate::policies::PolicyDecision::NeedsConfirmation { .. } => (false, true),
374    };
375
376    if is_allowed {
377        return Ok(true);
378    }
379
380    if needs_confirm && let Some(ref handler) = state.policy_handler {
381        let handler = Arc::clone(handler);
382        // Drop the lock before calling the handler (it may block).
383        drop(map);
384        let args_val: serde_json::Value =
385            serde_json::from_str(args_json).map_err(|e| crate::error::Error::BackendError {
386                message: format!("Failed to parse policy args JSON: {e}"),
387            })?;
388        return Ok(handler.confirm(name, &args_val));
389    }
390
391    Ok(false)
392}
393
394/// Dispatches a Rust tool call from the Python thread.
395///
396/// Called by `AsyncRustProxy.__call__` in the Python SDK. Uses the stored
397/// tokio `Handle` to `block_on` the async `ToolRegistry::dispatch`, which
398/// is safe because this function runs on the Python thread (not a tokio worker).
399#[pyfunction]
400fn dispatch_rust_tool<'py>(
401    py: Python<'py>,
402    agent_id: u64,
403    name: String,
404    args_json: &str,
405) -> PyResult<Bound<'py, PyAny>> {
406    tracing::info!(agent_id, tool = %name, "dispatch_rust_tool called from Python (async)");
407
408    // Evaluate policies before tool dispatch
409    let is_allowed = check_tool_execution_allowed(agent_id, &name, args_json)
410        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
411
412    if !is_allowed {
413        return Err(pyo3::exceptions::PyPermissionError::new_err(format!(
414            "Tool '{name}' execution blocked by agent policy rules"
415        )));
416    }
417
418    let (registry, tool_state) = {
419        let map = bridge_state().read().map_err(|e| {
420            pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to read BRIDGE_STATE: {e}"))
421        })?;
422        let entry = map.get(&agent_id).ok_or_else(|| {
423            pyo3::exceptions::PyRuntimeError::new_err(format!(
424                "No active bridge state found for agent ID {agent_id}"
425            ))
426        })?;
427        let registry = entry.registry.as_ref().ok_or_else(|| {
428            pyo3::exceptions::PyRuntimeError::new_err(format!(
429                "No active ToolRegistry found for agent ID {agent_id}"
430            ))
431        })?;
432        (Arc::clone(registry), Arc::clone(&entry.tool_state))
433    };
434
435    let args: serde_json::Value = serde_json::from_str(args_json).map_err(|e| {
436        pyo3::exceptions::PyValueError::new_err(format!("Failed to parse tool arguments JSON: {e}"))
437    })?;
438
439    pyo3_async_runtimes::tokio::future_into_py(py, async move {
440        let ctx = crate::tools::ToolContext::with_shared_state(None, tool_state);
441        let output = registry
442            .dispatch(&name, args, &ctx)
443            .await
444            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
445        // Extract the text content for the Python SDK — metadata stays Rust-side.
446        Ok(output.into_content())
447    })
448}
449
450/// Commands sent from Rust to the Python thread.
451///
452/// Each variant is constructed in `impl Runtime for PythonRuntime` and
453/// dispatched in `command_loop::run_async_command_loop`.
454pub(crate) enum PyCommand {
455    /// Create a new agent with the given configuration dict as JSON.
456    CreateAgent {
457        config_json: String,
458        reply: oneshot::Sender<Result<AgentId, Error>>,
459    },
460    /// Send a chat message to an agent.
461    Chat {
462        agent_id: AgentId,
463        prompt: String,
464        reply: oneshot::Sender<Result<crate::streaming::ChatResponseHandle, Error>>,
465    },
466    /// Shut down a specific agent.
467    ShutdownAgent {
468        agent_id: AgentId,
469        reply: oneshot::Sender<Result<(), Error>>,
470    },
471    /// Cancel active execution on the agent.
472    Cancel {
473        agent_id: AgentId,
474        reply: oneshot::Sender<Result<(), Error>>,
475    },
476    /// Wait for the agent to stabilize/become idle.
477    WaitForIdle {
478        agent_id: AgentId,
479        reply: oneshot::Sender<Result<(), Error>>,
480    },
481    /// Send a message without waiting for completion (fire-and-forget).
482    Send {
483        agent_id: AgentId,
484        prompt: String,
485        reply: oneshot::Sender<Result<(), Error>>,
486    },
487    /// Signal that the agent is idle.
488    SignalIdle {
489        agent_id: AgentId,
490        reply: oneshot::Sender<Result<(), Error>>,
491    },
492    /// Wait for the agent to wake up; returns true if woken, false on timeout.
493    WaitForWakeup {
494        agent_id: AgentId,
495        timeout_secs: f64,
496        reply: oneshot::Sender<Result<bool, Error>>,
497    },
498    /// Shut down the entire Python runtime.
499    Shutdown,
500    /// Retrieve the conversation's message history.
501    GetHistory {
502        agent_id: AgentId,
503        reply: oneshot::Sender<Result<Vec<crate::types::ConversationMessage>, Error>>,
504    },
505    /// Return the number of completed turns.
506    GetTurnCount {
507        agent_id: AgentId,
508        reply: oneshot::Sender<Result<u32, Error>>,
509    },
510    /// Return cumulative token usage across all turns.
511    GetTotalUsage {
512        agent_id: AgentId,
513        reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
514    },
515    /// Return token usage from the most recent turn.
516    GetLastTurnUsage {
517        agent_id: AgentId,
518        reply: oneshot::Sender<Result<crate::types::UsageMetadata, Error>>,
519    },
520    /// Clear the conversation history.
521    ClearHistory {
522        agent_id: AgentId,
523        reply: oneshot::Sender<Result<(), Error>>,
524    },
525    /// Return step indices where compaction occurred.
526    GetCompactionIndices {
527        agent_id: AgentId,
528        reply: oneshot::Sender<Result<Vec<u32>, Error>>,
529    },
530    /// Return the text of the last model response.
531    GetLastResponse {
532        agent_id: AgentId,
533        reply: oneshot::Sender<Result<Option<String>, Error>>,
534    },
535    /// Delete the conversation and all associated state.
536    ///
537    /// Constructed by `impl Runtime for PythonRuntime::delete()` — only
538    /// reachable when an external consumer calls `AgentHandle::delete()`.
539    Delete {
540        agent_id: AgentId,
541        reply: oneshot::Sender<Result<(), Error>>,
542    },
543    /// Disconnect from the agent without deleting state.
544    ///
545    /// Constructed by `impl Runtime for PythonRuntime::disconnect()`.
546    Disconnect {
547        agent_id: AgentId,
548        reply: oneshot::Sender<Result<(), Error>>,
549    },
550    /// Check whether the agent is currently idle.
551    ///
552    /// Constructed by `impl Runtime for PythonRuntime::is_idle()`.
553    IsIdle {
554        agent_id: AgentId,
555        reply: oneshot::Sender<Result<bool, Error>>,
556    },
557}
558
559/// Configuration for the bridge runtime.
560#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
561#[serde(default)]
562pub struct RuntimeConfig {
563    /// Channel buffer size for the command channel.
564    pub channel_capacity: usize,
565    /// Timeout for individual runtime operations.
566    pub operation_timeout: Duration,
567    /// Timeout for joining the Python thread on shutdown.
568    pub shutdown_timeout: Duration,
569    /// Timeout for a single `agent.chat()` round-trip.
570    ///
571    /// Defaults to the value of `AGI_CHAT_TIMEOUT_SECS` (env var), or 600 s.
572    pub chat_timeout: Duration,
573    /// Delay injected between successive chat commands to prevent burst requests.
574    pub inter_agent_delay: Duration,
575}
576
577impl Default for RuntimeConfig {
578    fn default() -> Self {
579        let chat_timeout = default_chat_timeout();
580        Self {
581            channel_capacity: DEFAULT_CHANNEL_CAPACITY,
582            operation_timeout: default_operation_timeout(chat_timeout),
583            shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT,
584            chat_timeout,
585            inter_agent_delay: DEFAULT_INTER_AGENT_DELAY,
586        }
587    }
588}
589
590/// Manages a dedicated Python thread with an asyncio event loop.
591///
592/// All Python/SDK interactions go through the command channel. This isolates
593/// GIL acquisition to the Python thread and keeps the tokio runtime responsive.
594pub struct PythonRuntime {
595    cmd_tx: mpsc::Sender<PyCommand>,
596    thread: Option<std::thread::JoinHandle<()>>,
597    config: RuntimeConfig,
598    /// Per-runtime quota registry. Each API key gets its own [`QuotaState`],
599    /// and different `PythonRuntime` instances are fully independent.
600    quota_registry: crate::quota::QuotaRegistry,
601    /// Default quota state used by `send_command` for runtime-level backoff.
602    quota_state: Arc<QuotaState>,
603}
604
605impl std::fmt::Debug for PythonRuntime {
606    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
607        f.debug_struct("PythonRuntime")
608            .field("config", &self.config)
609            .field(
610                "thread_running",
611                &self.thread.as_ref().is_some_and(|t| !t.is_finished()),
612            )
613            .finish_non_exhaustive()
614    }
615}
616
617impl PythonRuntime {
618    /// Spawn a new Python runtime on a dedicated thread.
619    ///
620    /// Creates an asyncio event loop in the thread and starts the command
621    /// dispatch loop.
622    ///
623    /// # Errors
624    ///
625    /// Returns `Error::BackendError` if the thread fails to spawn or
626    /// Python initialization fails.
627    pub fn new(config: RuntimeConfig) -> Result<Self, Error> {
628        let (cmd_tx, cmd_rx) = mpsc::channel(config.channel_capacity);
629
630        let thread_config = config.clone();
631        let thread = std::thread::Builder::new()
632            .name("agy-bridge-python-runtime".into())
633            .spawn(move || {
634                python_thread_main(cmd_rx, &thread_config);
635            })
636            .map_err(|e| Error::BackendError {
637                message: format!("Failed to spawn Python runtime thread: {e}"),
638            })?;
639
640        let quota_registry = crate::quota::QuotaRegistry::new();
641        let quota_state = quota_registry.state_for_key("");
642        Ok(Self {
643            cmd_tx,
644            thread: Some(thread),
645            config,
646            quota_registry,
647            quota_state,
648        })
649    }
650
651    /// Send a command to the Python thread and await the result.
652    ///
653    /// This is the primary interface for all Python interactions. It checks
654    /// quota state before sending and applies a configurable timeout.
655    ///
656    /// # Errors
657    ///
658    /// Returns `Error::ChannelClosed` if the Python thread has exited,
659    /// `Error::Timeout` if the operation exceeds the configured timeout.
660    async fn send_command<T>(
661        &self,
662        operation: &str,
663        is_llm_op: bool,
664        build_cmd: impl FnOnce(oneshot::Sender<Result<T, Error>>) -> PyCommand,
665    ) -> Result<T, Error> {
666        self.quota_state.wait_for_quota().await;
667
668        let (reply_tx, reply_rx) = oneshot::channel();
669        let cmd = build_cmd(reply_tx);
670
671        self.cmd_tx
672            .send(cmd)
673            .await
674            .map_err(|e| Error::ChannelClosed {
675                message: format!("Python runtime thread has exited (sending {operation}): {e}"),
676            })?;
677
678        let result = crate::error::with_timeout(self.config.operation_timeout, operation, async {
679            reply_rx.await.map_err(|e| Error::ChannelClosed {
680                message: format!("Reply channel dropped for {operation}: {e}"),
681            })?
682        })
683        .await?;
684
685        // Only reset quota backoff for LLM operations (e.g. chat); non-LLM
686        // ops succeeding should not clear a 429 backoff.
687        if is_llm_op {
688            self.quota_state.record_success();
689        }
690
691        Ok(result)
692    }
693
694    /// Graceful shutdown: send `Shutdown` command, then join the thread.
695    ///
696    /// # Errors
697    ///
698    /// Returns `Error::Timeout` if the thread doesn't join within the
699    /// configured shutdown timeout, or `Error::BackendError` if the
700    /// thread panicked.
701    pub async fn shutdown(mut self) -> Result<(), Error> {
702        // Signal the command loop to exit.
703        // Ignoring send error: if the receiver is already gone the thread
704        // is already exiting, which is the outcome we want.
705        if let Err(e) = self.cmd_tx.send(PyCommand::Shutdown).await {
706            tracing::warn!("Shutdown command send failed (thread may already be exiting): {e}");
707        }
708
709        // Take the JoinHandle so Drop doesn't fire the "dropped without
710        // shutdown" warning.
711        let Some(thread) = self.thread.take() else {
712            tracing::warn!("PythonRuntime::shutdown() called but thread handle already taken");
713            return Ok(());
714        };
715
716        let shutdown_timeout = self.config.shutdown_timeout;
717        let join_result = tokio::time::timeout(
718            shutdown_timeout,
719            tokio::task::spawn_blocking(move || thread.join()),
720        )
721        .await;
722
723        match join_result {
724            Ok(Ok(Ok(()))) => {
725                tracing::info!("Python runtime thread joined successfully");
726                Ok(())
727            }
728            Ok(Ok(Err(panic_payload))) => {
729                let panic_msg = panic_payload.downcast_ref::<&str>().map_or_else(
730                    || {
731                        panic_payload
732                            .downcast_ref::<String>()
733                            .map_or_else(|| format!("{panic_payload:?}"), Clone::clone)
734                    },
735                    |s| (*s).to_string(),
736                );
737                tracing::error!(
738                    panic_message = %panic_msg,
739                    "Python runtime thread panicked during shutdown"
740                );
741                Err(Error::BackendError {
742                    message: format!("Python runtime thread panicked during shutdown: {panic_msg}"),
743                })
744            }
745            Ok(Err(join_err)) => {
746                tracing::error!("spawn_blocking join error: {join_err}");
747                Err(Error::BackendError {
748                    message: format!("Failed to join Python thread: {join_err}"),
749                })
750            }
751            Err(_elapsed) => {
752                tracing::error!(
753                    timeout_secs = shutdown_timeout.as_secs(),
754                    "Python runtime thread did not exit within shutdown timeout"
755                );
756                Err(Error::Timeout {
757                    duration: shutdown_timeout,
758                    operation: "PythonRuntime::shutdown (thread join)".to_string(),
759                })
760            }
761        }
762    }
763
764    /// Access the shared quota state.
765    #[must_use]
766    pub const fn quota_state(&self) -> &Arc<QuotaState> {
767        &self.quota_state
768    }
769}
770
771impl Drop for PythonRuntime {
772    fn drop(&mut self) {
773        if self.thread.is_some() {
774            tracing::warn!(
775                "PythonRuntime dropped without calling shutdown() — \
776                 Python thread may still be running"
777            );
778        }
779    }
780}
781
782/// Entry point for the dedicated Python thread.
783fn python_thread_main(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) {
784    pyo3::prepare_freethreaded_python();
785
786    // Environment variables are already loaded by load_dotenv() at bridge
787    // construction time, before any threads are spawned.
788
789    // Configure sys.path so the venv's site-packages are importable.
790    Python::with_gil(|py| {
791        if let Err(e) = venv::configure_python_sys_path(py) {
792            tracing::warn!("Failed to configure Python sys.path in runtime thread: {e}");
793        }
794    });
795
796    if let Err(e) = run_live_thread(cmd_rx, config) {
797        tracing::error!(error = %e, "Python runtime thread failed");
798    }
799
800    tracing::info!("Python runtime thread exiting");
801}
802
803/// Live SDK thread: creates an asyncio event loop and dispatches commands
804/// to the real Antigravity SDK via `pyo3_async_runtimes`.
805fn run_live_thread(cmd_rx: mpsc::Receiver<PyCommand>, config: &RuntimeConfig) -> Result<(), Error> {
806    Python::with_gil(|py| {
807        let asyncio = py
808            .import_bound("asyncio")
809            .map_err(|e| Error::BackendError {
810                message: format!("Failed to import asyncio: {e}"),
811            })?;
812        let event_loop =
813            asyncio
814                .call_method0("new_event_loop")
815                .map_err(|e| Error::BackendError {
816                    message: format!("Failed to create new asyncio event loop: {e}"),
817                })?;
818        asyncio
819            .call_method1("set_event_loop", (&event_loop,))
820            .map_err(|e| Error::BackendError {
821                message: format!("Failed to set asyncio event loop: {e}"),
822            })?;
823
824        // Register event_loop in globals for access from any thread
825        let sys = py.import_bound("sys").map_err(|e| Error::BackendError {
826            message: format!("Failed to import sys: {e}"),
827        })?;
828        let sys_modules = sys.getattr("modules").map_err(|e| Error::BackendError {
829            message: format!("Failed to get sys.modules: {e}"),
830        })?;
831        let globals_mod = if sys_modules
832            .contains(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
833            .map_err(|e| Error::BackendError {
834                message: format!("Failed to check sys.modules: {e}"),
835            })? {
836            sys_modules
837                .get_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE)
838                .map_err(|e| Error::BackendError {
839                    message: format!("Failed to get _agy_bridge_globals: {e}"),
840                })?
841        } else {
842            let types = py.import_bound("types").map_err(|e| Error::BackendError {
843                message: format!("Failed to import types: {e}"),
844            })?;
845            let module = types
846                .getattr("ModuleType")
847                .map_err(|e| Error::BackendError {
848                    message: format!("Failed to get ModuleType: {e}"),
849                })?
850                .call1((command_loop::AGY_BRIDGE_GLOBALS_MODULE,))
851                .map_err(|e| Error::BackendError {
852                    message: format!("Failed to create ModuleType: {e}"),
853                })?;
854            sys_modules
855                .set_item(command_loop::AGY_BRIDGE_GLOBALS_MODULE, &module)
856                .map_err(|e| Error::BackendError {
857                    message: format!("Failed to register _agy_bridge_globals: {e}"),
858                })?;
859            module
860        };
861        globals_mod
862            .setattr("EVENT_LOOP", &event_loop)
863            .map_err(|e| Error::BackendError {
864                message: format!("Failed to set EVENT_LOOP in globals: {e}"),
865            })?;
866
867        tracing::info!("Python asyncio event loop created on runtime thread");
868
869        let chat_timeout = config.chat_timeout;
870        let inter_agent_delay = config.inter_agent_delay;
871        let run_fut =
872            pyo3_async_runtimes::tokio::run_until_complete(event_loop.clone(), async move {
873                command_loop::run_async_command_loop(cmd_rx, chat_timeout, inter_agent_delay).await
874            });
875
876        if let Err(e) = run_fut {
877            // Close the event loop best-effort before propagating.
878            if let Err(close_err) = event_loop.call_method0("close") {
879                tracing::warn!("Failed to close asyncio event loop: {close_err}");
880            }
881            return Err(Error::BackendError {
882                message: format!("Python runtime command loop failed: {e}"),
883            });
884        }
885
886        if let Err(e) = event_loop.call_method0("close") {
887            tracing::warn!("Failed to close asyncio event loop: {e}");
888        }
889
890        Ok(())
891    })
892}
893
894impl crate::agent::Runtime for PythonRuntime {
895    async fn create_agent(
896        &self,
897        config: crate::config::AgentConfig,
898    ) -> Result<crate::agent::AgentId, Error> {
899        // Report all available tools as requested by the user.
900        let mut all_tools = config.custom_tool_names();
901        if let Some(ref caps) = config.capabilities {
902            if let Some(ref builtins) = caps.enabled_tools {
903                all_tools.extend(builtins.iter().map(|b| b.as_sdk_name().to_string()));
904            } else if caps.disabled_tools.is_none() {
905                // Default is all tools
906                all_tools.extend(
907                    crate::config::capabilities::BuiltinTools::all_tools()
908                        .iter()
909                        .map(|b| b.as_sdk_name().to_string()),
910                );
911            }
912        } else {
913            all_tools.extend(
914                crate::config::capabilities::BuiltinTools::all_tools()
915                    .iter()
916                    .map(|b| b.as_sdk_name().to_string()),
917            );
918        }
919        tracing::info!(
920            "Agent starting with {} available tools: {:?}",
921            all_tools.len(),
922            all_tools
923        );
924
925        let config_json = serde_json::to_string(&config).map_err(|e| Error::BackendError {
926            message: format!("Failed to serialize AgentConfig: {e}"),
927        })?;
928
929        let raw_id = self
930            .send_command("create_agent", false, |reply| PyCommand::CreateAgent {
931                config_json,
932                reply,
933            })
934            .await?;
935
936        Ok(raw_id.0)
937    }
938
939    async fn chat(
940        &self,
941        agent_id: crate::agent::AgentId,
942        content: &crate::content::Content,
943    ) -> Result<crate::streaming::ChatResponseHandle, Error> {
944        let prompt = match content {
945            crate::content::Content::Text { text } => text.clone(),
946            other => crate::content::content_to_json(other)?,
947        };
948        self.send_command("chat", true, |reply| PyCommand::Chat {
949            agent_id: AgentId(agent_id),
950            prompt,
951            reply,
952        })
953        .await
954    }
955
956    async fn shutdown_agent(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
957        self.send_command("shutdown_agent", false, |reply| PyCommand::ShutdownAgent {
958            agent_id: AgentId(agent_id),
959            reply,
960        })
961        .await
962    }
963
964    fn try_shutdown_agent(&self, agent_id: crate::agent::AgentId) {
965        // Fire-and-forget: create a oneshot whose receiver we drop immediately.
966        // The Python thread will still process the shutdown; we just don't wait
967        // for the result.
968        let (reply, _) = oneshot::channel();
969        if let Err(e) = self.cmd_tx.try_send(PyCommand::ShutdownAgent {
970            agent_id: AgentId(agent_id),
971            reply,
972        }) {
973            tracing::debug!(
974                agent_id = agent_id,
975                error = %e,
976                "try_shutdown_agent: channel send failed (runtime may already be gone)"
977            );
978        }
979    }
980
981    async fn cancel(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
982        self.send_command("cancel", false, |reply| PyCommand::Cancel {
983            agent_id: AgentId(agent_id),
984            reply,
985        })
986        .await
987    }
988
989    async fn wait_for_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
990        self.send_command("wait_for_idle", false, |reply| PyCommand::WaitForIdle {
991            agent_id: AgentId(agent_id),
992            reply,
993        })
994        .await
995    }
996
997    async fn send(
998        &self,
999        agent_id: crate::agent::AgentId,
1000        content: &crate::content::Content,
1001    ) -> Result<(), Error> {
1002        let prompt = match content {
1003            crate::content::Content::Text { text } => text.clone(),
1004            other => crate::content::content_to_json(other)?,
1005        };
1006        self.send_command("send", false, |reply| PyCommand::Send {
1007            agent_id: AgentId(agent_id),
1008            prompt,
1009            reply,
1010        })
1011        .await
1012    }
1013
1014    async fn signal_idle(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1015        self.send_command("signal_idle", false, |reply| PyCommand::SignalIdle {
1016            agent_id: AgentId(agent_id),
1017            reply,
1018        })
1019        .await
1020    }
1021
1022    async fn wait_for_wakeup(
1023        &self,
1024        agent_id: crate::agent::AgentId,
1025        timeout: std::time::Duration,
1026    ) -> Result<bool, Error> {
1027        self.send_command("wait_for_wakeup", false, |reply| PyCommand::WaitForWakeup {
1028            agent_id: AgentId(agent_id),
1029            timeout_secs: timeout.as_secs_f64(),
1030            reply,
1031        })
1032        .await
1033    }
1034
1035    async fn wait_for_quota(&self) {
1036        self.quota_state.wait_for_quota().await;
1037    }
1038
1039    async fn record_quota_hit(&self, retry_after: std::time::Duration) {
1040        self.quota_state.record_quota_hit(retry_after);
1041    }
1042
1043    fn quota_registry(&self) -> &crate::quota::QuotaRegistry {
1044        &self.quota_registry
1045    }
1046
1047    async fn history(
1048        &self,
1049        agent_id: crate::agent::AgentId,
1050    ) -> Result<Vec<crate::types::ConversationMessage>, Error> {
1051        self.send_command("get_history", false, |reply| PyCommand::GetHistory {
1052            agent_id: AgentId(agent_id),
1053            reply,
1054        })
1055        .await
1056    }
1057
1058    async fn turn_count(&self, agent_id: crate::agent::AgentId) -> Result<u32, Error> {
1059        self.send_command("get_turn_count", false, |reply| PyCommand::GetTurnCount {
1060            agent_id: AgentId(agent_id),
1061            reply,
1062        })
1063        .await
1064    }
1065
1066    async fn total_usage(
1067        &self,
1068        agent_id: crate::agent::AgentId,
1069    ) -> Result<crate::types::UsageMetadata, Error> {
1070        self.send_command("get_total_usage", false, |reply| PyCommand::GetTotalUsage {
1071            agent_id: AgentId(agent_id),
1072            reply,
1073        })
1074        .await
1075    }
1076
1077    async fn last_turn_usage(
1078        &self,
1079        agent_id: crate::agent::AgentId,
1080    ) -> Result<crate::types::UsageMetadata, Error> {
1081        self.send_command("get_last_turn_usage", false, |reply| {
1082            PyCommand::GetLastTurnUsage {
1083                agent_id: AgentId(agent_id),
1084                reply,
1085            }
1086        })
1087        .await
1088    }
1089
1090    async fn clear_history(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1091        self.send_command("clear_history", false, |reply| PyCommand::ClearHistory {
1092            agent_id: AgentId(agent_id),
1093            reply,
1094        })
1095        .await
1096    }
1097
1098    async fn compaction_indices(&self, agent_id: crate::agent::AgentId) -> Result<Vec<u32>, Error> {
1099        self.send_command("compaction_indices", false, |reply| {
1100            PyCommand::GetCompactionIndices {
1101                agent_id: AgentId(agent_id),
1102                reply,
1103            }
1104        })
1105        .await
1106    }
1107
1108    async fn last_response(
1109        &self,
1110        agent_id: crate::agent::AgentId,
1111    ) -> Result<Option<String>, Error> {
1112        self.send_command("last_response", false, |reply| PyCommand::GetLastResponse {
1113            agent_id: AgentId(agent_id),
1114            reply,
1115        })
1116        .await
1117    }
1118
1119    async fn delete(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1120        self.send_command("delete", false, |reply| PyCommand::Delete {
1121            agent_id: AgentId(agent_id),
1122            reply,
1123        })
1124        .await
1125    }
1126
1127    async fn disconnect(&self, agent_id: crate::agent::AgentId) -> Result<(), Error> {
1128        self.send_command("disconnect", false, |reply| PyCommand::Disconnect {
1129            agent_id: AgentId(agent_id),
1130            reply,
1131        })
1132        .await
1133    }
1134
1135    async fn is_idle(&self, agent_id: crate::agent::AgentId) -> Result<bool, Error> {
1136        self.send_command("is_idle", false, |reply| PyCommand::IsIdle {
1137            agent_id: AgentId(agent_id),
1138            reply,
1139        })
1140        .await
1141    }
1142}
1143
1144#[cfg(test)]
1145mod tests {
1146    use super::*;
1147
1148    fn test_config() -> RuntimeConfig {
1149        RuntimeConfig {
1150            channel_capacity: 16,
1151            operation_timeout: Duration::from_secs(10),
1152            shutdown_timeout: Duration::from_secs(5),
1153            chat_timeout: Duration::from_mins(1),
1154            inter_agent_delay: Duration::from_millis(100),
1155        }
1156    }
1157
1158    #[tokio::test]
1159    async fn test_runtime_creation_and_shutdown() {
1160        // Shutdown should complete cleanly.
1161        PythonRuntime::new(test_config())
1162            .expect("Failed to create runtime")
1163            .shutdown()
1164            .await
1165            .expect("Shutdown failed");
1166    }
1167
1168    #[test]
1169    fn runtime_config_serde_roundtrip() {
1170        let config = test_config();
1171        let json = serde_json::to_string(&config).unwrap();
1172        let parsed: RuntimeConfig = serde_json::from_str(&json).unwrap();
1173        assert_eq!(parsed.channel_capacity, 16);
1174        assert_eq!(parsed.operation_timeout, Duration::from_secs(10));
1175        assert_eq!(parsed.shutdown_timeout, Duration::from_secs(5));
1176        assert_eq!(parsed.chat_timeout, Duration::from_mins(1));
1177        assert_eq!(parsed.inter_agent_delay, Duration::from_millis(100));
1178    }
1179
1180    #[test]
1181    fn default_operation_timeout_is_chat_plus_margin() {
1182        let config = RuntimeConfig::default();
1183        let expected = config.chat_timeout + Duration::from_mins(2);
1184        assert_eq!(
1185            config.operation_timeout, expected,
1186            "operation_timeout should be chat_timeout + 2min safety margin"
1187        );
1188    }
1189
1190    #[test]
1191    fn safety_error_structural() {
1192        pyo3::prepare_freethreaded_python();
1193        Python::with_gil(|py| {
1194            let globals = pyo3::types::PyDict::new_bound(py);
1195            py.run_bound(
1196                r#"
1197class StopCandidateException(Exception):
1198    pass
1199err = StopCandidateException("dummy")
1200"#,
1201                Some(&globals),
1202                None,
1203            )
1204            .unwrap();
1205
1206            let err_obj = globals.get_item("err").unwrap().unwrap();
1207            let err = PyErr::from_value_bound(err_obj);
1208
1209            let mapped = crate::error::classify_py_error(py, &err);
1210
1211            assert!(
1212                !matches!(mapped, crate::error::Error::Safety),
1213                "Failed: matched Error::Safety based purely on the string name StopCandidateException!"
1214            );
1215        });
1216    }
1217
1218    #[test]
1219    fn maxtokens_error_structural() {
1220        pyo3::prepare_freethreaded_python();
1221        Python::with_gil(|py| {
1222            let globals = pyo3::types::PyDict::new_bound(py);
1223            py.run_bound(
1224                r#"
1225class MaxTokensException(Exception):
1226    pass
1227err = MaxTokensException("dummy")
1228"#,
1229                Some(&globals),
1230                None,
1231            )
1232            .unwrap();
1233
1234            let err_obj = globals.get_item("err").unwrap().unwrap();
1235            let err = PyErr::from_value_bound(err_obj);
1236
1237            let mapped = crate::error::classify_py_error(py, &err);
1238
1239            assert!(
1240                !matches!(mapped, crate::error::Error::MaxTokens),
1241                "Failed: matched Error::MaxTokens based purely on the string name MaxTokensException!"
1242            );
1243        });
1244    }
1245
1246    struct MockAskUserHandler {
1247        should_allow: std::sync::atomic::AtomicBool,
1248    }
1249
1250    impl crate::policies::AskUserHandler for MockAskUserHandler {
1251        fn confirm(&self, _tool_name: &str, _tool_args: &serde_json::Value) -> bool {
1252            self.should_allow.load(std::sync::atomic::Ordering::SeqCst)
1253        }
1254    }
1255
1256    #[test]
1257    fn test_ask_user_policy_custom_tool_gating() {
1258        let agent_id: u64 = 999;
1259
1260        // 1. Setup the PolicySet with an AskUser rule for "dangerous_tool"
1261        let mut policies = crate::policies::PolicySet::new();
1262        policies
1263            .push(crate::policies::PolicyRule::AskUser {
1264                tool: "dangerous_tool".to_owned(),
1265                handler_id: "confirm_handler".to_owned(),
1266            })
1267            .unwrap();
1268
1269        // 2. Setup mock handler
1270        let handler = Arc::new(MockAskUserHandler {
1271            should_allow: std::sync::atomic::AtomicBool::new(true),
1272        });
1273
1274        // 3. Mock the tool registry
1275        let mut registry = crate::tools::ToolRegistry::new();
1276
1277        /// A dangerous tool.
1278        #[crate::llm_tool]
1279        fn dangerous_tool() -> Result<String, String> {
1280            Ok("Executed dangerous action!".to_owned())
1281        }
1282        registry.register(DangerousTool);
1283
1284        // 4. Register all state in a single bridge_state() insertion
1285        bridge_state().write().unwrap().insert(
1286            agent_id,
1287            AgentBridgeState {
1288                registry: Some(Arc::new(registry)),
1289                hook_runner: None,
1290                policies,
1291                policy_handler: Some(
1292                    Arc::clone(&handler) as Arc<dyn crate::policies::AskUserHandler>
1293                ),
1294                tool_state: Arc::new(std::sync::RwLock::new(HashMap::new())),
1295            },
1296        );
1297
1298        // 5. Simulate check_tool_execution_allowed when the AskUserHandler allows it (returns true)
1299        handler
1300            .should_allow
1301            .store(true, std::sync::atomic::Ordering::SeqCst);
1302        let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
1303        assert!(res.is_ok(), "Check should succeed");
1304        assert!(
1305            res.unwrap(),
1306            "Should allow tool execution when handler returns true"
1307        );
1308
1309        // 6. Simulate check_tool_execution_allowed when the AskUserHandler denies it (returns false)
1310        handler
1311            .should_allow
1312            .store(false, std::sync::atomic::Ordering::SeqCst);
1313        let res = check_tool_execution_allowed(agent_id, "dangerous_tool", "{}");
1314        assert!(res.is_ok(), "Check should succeed");
1315        assert!(
1316            !res.unwrap(),
1317            "Should block tool execution when handler returns false"
1318        );
1319
1320        // Clean up
1321        bridge_state().write().unwrap().remove(&agent_id);
1322    }
1323}