use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
};
use crate::{
config::AgentConfig,
content::Content,
error::Error,
streaming::{ChatResponseHandle, ChatResponseSharedState},
types::{ConversationMessage, UsageMetadata},
};
const DEFAULT_QUOTA_BACKOFF: std::time::Duration = std::time::Duration::from_secs(2);
const QUOTA_EXHAUSTED_RETRY_AFTER: std::time::Duration = std::time::Duration::from_mins(2);
#[cfg(test)]
pub(crate) mod mock;
pub type AgentId = u64;
#[expect(
async_fn_in_trait,
reason = "Runtime is not object-safe by design; callers always know the concrete type"
)]
pub trait Runtime: Send + Sync {
async fn create_agent(&self, config: AgentConfig) -> Result<AgentId, Error>;
async fn chat(&self, agent_id: AgentId, content: &Content)
-> Result<ChatResponseHandle, Error>;
async fn shutdown_agent(&self, agent_id: AgentId) -> Result<(), Error>;
async fn cancel(&self, agent_id: AgentId) -> Result<(), Error>;
async fn wait_for_idle(&self, agent_id: AgentId) -> Result<(), Error>;
async fn send(&self, agent_id: AgentId, content: &Content) -> Result<(), Error>;
async fn signal_idle(&self, agent_id: AgentId) -> Result<(), Error>;
async fn wait_for_wakeup(
&self,
agent_id: AgentId,
timeout: std::time::Duration,
) -> Result<bool, Error>;
async fn wait_for_quota(&self);
async fn record_quota_hit(&self, retry_after: std::time::Duration);
fn quota_registry(&self) -> &crate::quota::QuotaRegistry;
async fn history(&self, agent_id: AgentId) -> Result<Vec<ConversationMessage>, Error>;
async fn turn_count(&self, agent_id: AgentId) -> Result<u32, Error>;
async fn total_usage(&self, agent_id: AgentId) -> Result<UsageMetadata, Error>;
async fn last_turn_usage(&self, agent_id: AgentId) -> Result<UsageMetadata, Error>;
async fn clear_history(&self, agent_id: AgentId) -> Result<(), Error>;
async fn last_response(&self, _agent_id: AgentId) -> Result<Option<String>, Error> {
Ok(None)
}
async fn compaction_indices(&self, _agent_id: AgentId) -> Result<Vec<u32>, Error> {
Ok(Vec::new())
}
async fn delete(&self, _agent_id: AgentId) -> Result<(), Error> {
Ok(())
}
async fn disconnect(&self, _agent_id: AgentId) -> Result<(), Error> {
Ok(())
}
async fn is_idle(&self, _agent_id: AgentId) -> Result<bool, Error> {
Ok(true)
}
fn try_shutdown_agent(&self, _agent_id: AgentId) {}
}
pub struct AgentHandle<R: Runtime + 'static> {
id: AgentId,
runtime: Arc<R>,
config: AgentConfig,
quota_state: Arc<crate::quota::QuotaState>,
_registry: Option<Arc<crate::tools::ToolRegistry>>,
policy_handler: Option<Arc<dyn crate::policies::AskUserHandler>>,
conversation_id: Mutex<Option<String>>,
is_started: AtomicBool,
is_shutdown: AtomicBool,
last_shared_state: Mutex<Option<Arc<Mutex<ChatResponseSharedState>>>>,
}
impl<R: Runtime> AgentHandle<R> {
pub async fn new(
runtime: Arc<R>,
config: AgentConfig,
registry: Option<Arc<crate::tools::ToolRegistry>>,
hook_runner: Option<Arc<crate::hooks::Hooks>>,
policy_handler: Option<Arc<dyn crate::policies::AskUserHandler>>,
) -> Result<Self, Error> {
let quota_key = config.effective_api_key().unwrap_or_default();
let quota_state = runtime.quota_registry().state_for_key("a_key);
let id = if hook_runner.is_some() {
let _guard = crate::runtime::CREATE_AGENT_HOOK_GUARD.lock().await;
if let Ok(mut opt) = crate::runtime::INITIALIZING_HOOK_RUNNER.lock() {
*opt = hook_runner.as_ref().map(Arc::clone);
} else {
tracing::error!("INITIALIZING_HOOK_RUNNER mutex poisoned — hook may not fire");
}
let result = runtime.create_agent(config.clone()).await;
if let Ok(mut opt) = crate::runtime::INITIALIZING_HOOK_RUNNER.lock() {
*opt = None;
} else {
tracing::error!("INITIALIZING_HOOK_RUNNER mutex poisoned — stale hook may persist");
}
result?
} else {
runtime.create_agent(config.clone()).await?
};
tracing::info!(agent_id = id, "Agent created successfully");
let policies_set = crate::policies::PolicySet::validated_from(config.policies.clone())?;
let bridge_entry = crate::runtime::AgentBridgeState {
registry: registry.as_ref().map(Arc::clone),
hook_runner: hook_runner.as_ref().map(Arc::clone),
policies: policies_set,
policy_handler: policy_handler.as_ref().map(Arc::clone),
tool_state: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
};
if let Ok(mut map) = crate::runtime::bridge_state().write() {
map.insert(id, bridge_entry);
} else {
tracing::error!(
agent_id = id,
"Failed to acquire write lock on BRIDGE_STATE"
);
}
let conversation_id = Mutex::new(config.conversation_id.clone());
Ok(Self {
id,
runtime,
config,
quota_state,
_registry: registry,
policy_handler,
conversation_id,
is_started: AtomicBool::new(true),
is_shutdown: AtomicBool::new(false),
last_shared_state: Mutex::new(None),
})
}
pub async fn chat(&self, content: impl Into<Content>) -> Result<ChatResponseHandle, Error> {
if !self.is_started() {
return Err(Error::AgentNotStarted);
}
let content = content.into();
let max_retries = self.config.max_quota_retries.unwrap_or(0);
let handle = 'retry: {
for attempt in 0..=max_retries {
if attempt > 0 {
self.quota_state.wait_for_quota().await;
}
match self.runtime.chat(self.id, &content).await {
Ok(h) => break 'retry h,
Err(Error::QuotaExceeded { retry_after }) => {
self.handle_quota_error("chat", attempt, max_retries, retry_after)?;
}
Err(ref e) if e.is_quota_error() => {
self.handle_quota_error(
"chat",
attempt,
max_retries,
DEFAULT_QUOTA_BACKOFF,
)?;
}
Err(e) => return Err(e),
}
}
return Err(Error::QuotaExceeded {
retry_after: QUOTA_EXHAUSTED_RETRY_AFTER,
});
};
if let Ok(mut guard) = self.last_shared_state.lock() {
*guard = Some(Arc::clone(&handle.shared_state));
} else {
tracing::error!("last_shared_state mutex poisoned — streaming metadata may be stale");
}
Ok(handle)
}
pub async fn chat_text(&self, message: impl Into<Content>) -> Result<String, Error> {
let response = self.chat(message.into()).await?;
let text = response.text().await.map_err(|e| {
let converted = Error::from(e);
if matches!(converted, Error::Safety) {
converted
} else {
Error::BackendError {
message: format!("Failed to read response text: {converted}"),
}
}
})?;
Ok(text.into_string())
}
#[must_use]
pub fn conversation_id(&self) -> Option<String> {
self.conversation_id
.lock()
.inspect_err(|e| {
tracing::error!(
agent_id = self.id,
error = %e,
"conversation_id mutex poisoned"
);
})
.ok()
.and_then(|guard| guard.clone())
}
pub fn set_conversation_id(&self, id: String) {
if let Ok(mut guard) = self.conversation_id.lock() {
*guard = Some(id);
} else {
tracing::error!("Failed to acquire lock on conversation_id");
}
}
#[must_use]
pub fn is_started(&self) -> bool {
self.is_started.load(Ordering::SeqCst) && !self.is_shutdown.load(Ordering::SeqCst)
}
#[must_use]
pub const fn id(&self) -> AgentId {
self.id
}
#[must_use]
pub const fn config(&self) -> &AgentConfig {
&self.config
}
pub async fn cancel(&self) -> Result<(), Error> {
self.runtime.cancel(self.id).await
}
pub async fn wait_for_idle(&self) -> Result<(), Error> {
self.runtime.wait_for_idle(self.id).await
}
pub async fn history(&self) -> Result<Vec<ConversationMessage>, Error> {
self.runtime.history(self.id).await
}
pub async fn turn_count(&self) -> Result<u32, Error> {
self.runtime.turn_count(self.id).await
}
pub async fn total_usage(&self) -> Result<UsageMetadata, Error> {
self.runtime.total_usage(self.id).await
}
pub async fn last_turn_usage(&self) -> Result<UsageMetadata, Error> {
self.runtime.last_turn_usage(self.id).await
}
pub async fn clear_history(&self) -> Result<(), Error> {
self.runtime.clear_history(self.id).await
}
pub async fn last_response(&self) -> Result<Option<String>, Error> {
self.runtime.last_response(self.id).await
}
pub async fn compaction_indices(&self) -> Result<Vec<u32>, Error> {
self.runtime.compaction_indices(self.id).await
}
pub async fn delete(&self) -> Result<(), Error> {
let result = self.runtime.delete(self.id).await;
self.is_shutdown.store(true, Ordering::SeqCst);
result
}
pub async fn disconnect(&self) -> Result<(), Error> {
let result = self.runtime.disconnect(self.id).await;
self.is_shutdown.store(true, Ordering::SeqCst);
result
}
pub async fn is_idle(&self) -> Result<bool, Error> {
self.runtime.is_idle(self.id).await
}
#[must_use]
pub fn get_last_structured_output(&self) -> Option<serde_json::Value> {
let guard = self
.last_shared_state
.lock()
.inspect_err(|e| {
tracing::error!(
agent_id = self.id,
error = %e,
"last_shared_state mutex poisoned in get_last_structured_output"
);
})
.ok()?;
let state = guard
.as_ref()?
.lock()
.inspect_err(|e| {
tracing::error!(
agent_id = self.id,
error = %e,
"ChatResponseSharedState mutex poisoned in get_last_structured_output"
);
})
.ok()?;
state.structured_output.clone()
}
#[must_use]
pub fn get_last_usage(&self) -> Option<UsageMetadata> {
let guard = self
.last_shared_state
.lock()
.inspect_err(|e| {
tracing::error!(
agent_id = self.id,
error = %e,
"last_shared_state mutex poisoned in get_last_usage"
);
})
.ok()?;
let state = guard
.as_ref()?
.lock()
.inspect_err(|e| {
tracing::error!(
agent_id = self.id,
error = %e,
"ChatResponseSharedState mutex poisoned in get_last_usage"
);
})
.ok()?;
state.usage.clone()
}
pub async fn send(&self, content: impl Into<Content>) -> Result<(), Error> {
if !self.is_started() {
return Err(Error::AgentNotStarted);
}
let content = content.into();
let max_retries = self.config.max_quota_retries.unwrap_or(0);
for attempt in 0..=max_retries {
if attempt > 0 {
self.quota_state.wait_for_quota().await;
}
match self.runtime.send(self.id, &content).await {
Ok(()) => return Ok(()),
Err(Error::QuotaExceeded { retry_after }) => {
self.handle_quota_error("send", attempt, max_retries, retry_after)?;
}
Err(ref e) if e.is_quota_error() => {
self.handle_quota_error("send", attempt, max_retries, DEFAULT_QUOTA_BACKOFF)?;
}
Err(e) => return Err(e),
}
}
Err(Error::QuotaExceeded {
retry_after: QUOTA_EXHAUSTED_RETRY_AFTER,
})
}
pub async fn signal_idle(&self) -> Result<(), Error> {
self.runtime.signal_idle(self.id).await
}
pub async fn wait_for_wakeup(&self, timeout: std::time::Duration) -> Result<bool, Error> {
self.runtime.wait_for_wakeup(self.id, timeout).await
}
fn handle_quota_error(
&self,
operation: &str,
attempt: u32,
max_retries: u32,
retry_after: std::time::Duration,
) -> Result<(), Error> {
if attempt >= max_retries {
return Err(Error::QuotaExceeded { retry_after });
}
tracing::warn!(
agent_id = self.id,
attempt = attempt + 1,
max = max_retries,
retry_after_ms = u64::try_from(retry_after.as_millis()).unwrap_or_else(|e| {
tracing::warn!("Int conversion failed: {e}");
u64::MAX
}),
"Quota exceeded on {operation} — recording hit and retrying"
);
self.quota_state.record_quota_hit(retry_after);
Ok(())
}
pub async fn shutdown(&self) -> Result<(), Error> {
if self.is_shutdown.load(Ordering::SeqCst) {
tracing::debug!(agent_id = self.id, "Agent already shut down");
return Ok(());
}
tracing::info!(agent_id = self.id, "Shutting down agent");
let result = self.runtime.shutdown_agent(self.id).await;
self.is_shutdown.store(true, Ordering::SeqCst);
match result {
Ok(()) => {
tracing::info!(agent_id = self.id, "Agent shut down successfully");
}
Err(ref e) => {
tracing::error!(agent_id = self.id, error = ?e, "Agent shutdown failed");
}
}
result
}
pub async fn spawn_subagent(
&self,
mut config: AgentConfig,
registry: impl Into<Option<crate::tools::ToolRegistry>>,
) -> Result<Self, Error> {
let opt_registry = registry.into();
if let Some(disp) = &opt_registry
&& config.tools.is_empty()
{
config.tools = disp.definitions();
}
let arc_registry = opt_registry.map(Arc::new);
Self::new(
Arc::clone(&self.runtime),
config,
arc_registry,
None,
self.policy_handler.clone(),
)
.await
}
}
impl<R: Runtime> Drop for AgentHandle<R> {
fn drop(&mut self) {
if self.is_started.load(Ordering::SeqCst) && !self.is_shutdown.load(Ordering::SeqCst) {
tracing::debug!(
agent_id = self.id,
"AgentHandle dropped without explicit shutdown() — \
sending best-effort shutdown signal"
);
self.runtime.try_shutdown_agent(self.id);
}
if let Ok(mut map) = crate::runtime::bridge_state().write() {
map.remove(&self.id);
} else {
tracing::error!(
agent_id = self.id,
"BRIDGE_STATE RwLock poisoned during Drop — \
bridge state entry for this agent may leak"
);
}
}
}