#[path = "agent/checkpointing.rs"]
mod checkpointing;
#[path = "agent/control.rs"]
mod control;
#[path = "agent/events.rs"]
mod events;
#[path = "agent/invoke.rs"]
mod invoke;
#[path = "agent/mutation.rs"]
mod mutation;
#[path = "agent/queueing.rs"]
mod queueing;
#[path = "agent/state_updates.rs"]
mod state_updates;
#[path = "agent/structured_output.rs"]
mod structured_output;
use std::collections::{HashSet, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::{Arc, Mutex};
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use crate::agent_id::AgentId;
use crate::agent_options::{
ApproveToolArc, AsyncTransformContextArc, CheckpointStoreArc, ConvertToLlmFn, GetApiKeyArc,
TransformContextArc,
};
use crate::agent_subscriptions::ListenerRegistry;
use crate::error::AgentError;
use crate::message_provider::MessageProvider;
use crate::retry::RetryStrategy;
use crate::stream::{StreamFn, StreamOptions};
use crate::tool::{AgentTool, ApprovalMode};
use crate::types::{AgentMessage, LlmMessage, ModelSpec};
pub use crate::agent_options::{AgentOptions, DEFAULT_PLAN_MODE_ADDENDUM};
pub use crate::agent_subscriptions::SubscriptionId;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SteeringMode {
All,
#[default]
OneAtATime,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FollowUpMode {
All,
#[default]
OneAtATime,
}
pub struct AgentState {
pub system_prompt: String,
pub model: ModelSpec,
pub tools: Vec<Arc<dyn AgentTool>>,
pub messages: Vec<AgentMessage>,
pub is_running: bool,
pub stream_message: Option<AgentMessage>,
pub pending_tool_calls: HashSet<String>,
pub error: Option<String>,
pub available_models: Vec<ModelSpec>,
}
impl std::fmt::Debug for AgentState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentState")
.field("system_prompt", &self.system_prompt)
.field("model", &self.model)
.field("tools", &format_args!("[{} tool(s)]", self.tools.len()))
.field("messages", &self.messages)
.field("is_running", &self.is_running)
.field("stream_message", &self.stream_message)
.field("pending_tool_calls", &self.pending_tool_calls)
.field("error", &self.error)
.field("available_models", &self.available_models)
.finish()
}
}
pub fn default_convert(msg: &AgentMessage) -> Option<LlmMessage> {
match msg {
AgentMessage::Llm(llm) => Some(llm.clone()),
AgentMessage::Custom(_) => None,
}
}
type ModelStreamRegistry = Vec<(ModelSpec, Arc<dyn StreamFn>)>;
fn available_models_and_stream_fns(
options: &AgentOptions,
) -> (Vec<ModelSpec>, ModelStreamRegistry) {
let primary_model = options.model.clone();
let primary_stream_fn = Arc::clone(&options.stream_fn);
let mut available_models = vec![options.model.clone()];
available_models.extend(
options
.available_models
.iter()
.map(|(model, _): &(ModelSpec, _)| model.clone()),
);
let mut model_stream_fns = vec![(primary_model, primary_stream_fn)];
model_stream_fns.extend(
options
.available_models
.iter()
.map(|(model, stream_fn): &(ModelSpec, _)| (model.clone(), Arc::clone(stream_fn))),
);
(available_models, model_stream_fns)
}
fn assert_unique_tool_names(tools: &[Arc<dyn AgentTool>]) {
let mut seen = HashSet::with_capacity(tools.len());
let mut duplicates = Vec::new();
for tool in tools {
let name = tool.name();
if !seen.insert(name.to_owned()) {
duplicates.push(name.to_owned());
}
}
if !duplicates.is_empty() {
duplicates.sort();
duplicates.dedup();
panic!(
"duplicate tool names are not allowed after composition: {}",
duplicates.join(", ")
);
}
}
#[cfg(feature = "plugins")]
fn dispatch_plugin_on_init(agent: &Agent) {
for plugin in &agent.plugins {
let name = plugin.name().to_owned();
let plugin_ref = Arc::clone(plugin);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
plugin_ref.on_init(agent);
}));
if let Err(cause) = result {
let msg = cause
.downcast_ref::<&str>()
.map(|s| (*s).to_owned())
.or_else(|| cause.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_owned());
tracing::warn!(plugin = %name, error = %msg, "plugin on_init panicked");
}
}
}
#[cfg(not(feature = "plugins"))]
const fn dispatch_plugin_on_init(_agent: &Agent) {}
pub struct Agent {
id: AgentId,
state: AgentState,
steering_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
follow_up_queue: Arc<Mutex<VecDeque<AgentMessage>>>,
listeners: ListenerRegistry,
abort_controller: Option<CancellationToken>,
steering_mode: SteeringMode,
follow_up_mode: FollowUpMode,
stream_fn: Arc<dyn StreamFn>,
convert_to_llm: ConvertToLlmFn,
transform_context: Option<TransformContextArc>,
get_api_key: Option<GetApiKeyArc>,
retry_strategy: Arc<dyn RetryStrategy>,
stream_options: StreamOptions,
structured_output_max_retries: usize,
idle_notify: Arc<Notify>,
in_flight_llm_messages: Option<Vec<AgentMessage>>,
in_flight_messages: Option<Vec<AgentMessage>>,
pending_message_snapshot: Arc<crate::pause_state::PendingMessageSnapshot>,
loop_context_snapshot: Arc<crate::pause_state::LoopContextSnapshot>,
approve_tool: Option<ApproveToolArc>,
approval_mode: ApprovalMode,
pre_turn_policies: Vec<Arc<dyn crate::policy::PreTurnPolicy>>,
pre_dispatch_policies: Vec<Arc<dyn crate::policy::PreDispatchPolicy>>,
post_turn_policies: Vec<Arc<dyn crate::policy::PostTurnPolicy>>,
post_loop_policies: Vec<Arc<dyn crate::policy::PostLoopPolicy>>,
model_stream_fns: Vec<(ModelSpec, Arc<dyn StreamFn>)>,
event_forwarders: Vec<crate::event_forwarder::EventForwarderFn>,
async_transform_context: Option<AsyncTransformContextArc>,
checkpoint_store: Option<CheckpointStoreArc>,
pub(crate) custom_message_registry: Option<Arc<crate::types::CustomMessageRegistry>>,
metrics_collector: Option<Arc<dyn crate::metrics::MetricsCollector>>,
fallback: Option<crate::fallback::ModelFallback>,
external_message_provider: Option<Arc<dyn MessageProvider>>,
tool_execution_policy: crate::tool_execution_policy::ToolExecutionPolicy,
plan_mode_addendum: Option<String>,
session_state: Arc<std::sync::RwLock<crate::SessionState>>,
credential_resolver: Option<Arc<dyn crate::credential::CredentialResolver>>,
cache_config: Option<crate::context_cache::CacheConfig>,
dynamic_system_prompt: Option<Arc<dyn Fn() -> String + Send + Sync>>,
loop_active: Arc<AtomicBool>,
loop_generation: Arc<AtomicU64>,
#[cfg(feature = "plugins")]
plugins: Vec<Arc<dyn crate::plugin::Plugin>>,
#[allow(clippy::struct_field_names)]
agent_name: Option<String>,
transfer_chain: Option<crate::transfer::TransferChain>,
}
impl Agent {
#[must_use]
pub fn new(options: AgentOptions) -> Self {
#[cfg(feature = "plugins")]
let options = merge_plugin_contributions(options);
assert_unique_tool_names(&options.tools);
let effective_prompt = options.effective_system_prompt().to_owned();
let (available_models, model_stream_fns) = available_models_and_stream_fns(&options);
let transform_context = match (options.token_counter, options.transform_context) {
(Some(counter), None) => Some(Arc::new(
crate::context_transformer::SlidingWindowTransformer::new(100_000, 50_000, 2)
.with_token_counter(counter),
) as TransformContextArc),
(_, tc) => tc,
};
let agent = Self {
id: AgentId::next(),
state: AgentState {
system_prompt: effective_prompt,
model: options.model,
tools: options.tools,
messages: Vec::new(),
is_running: false,
stream_message: None,
pending_tool_calls: HashSet::new(),
error: None,
available_models,
},
steering_queue: Arc::new(Mutex::new(VecDeque::new())),
follow_up_queue: Arc::new(Mutex::new(VecDeque::new())),
listeners: ListenerRegistry::new(),
abort_controller: None,
steering_mode: options.steering_mode,
follow_up_mode: options.follow_up_mode,
stream_fn: options.stream_fn,
convert_to_llm: options.convert_to_llm,
transform_context,
get_api_key: options.get_api_key,
retry_strategy: Arc::from(options.retry_strategy),
stream_options: options.stream_options,
structured_output_max_retries: options.structured_output_max_retries,
idle_notify: Arc::new(Notify::new()),
in_flight_llm_messages: None,
in_flight_messages: None,
pending_message_snapshot: Arc::new(
crate::pause_state::PendingMessageSnapshot::default(),
),
loop_context_snapshot: Arc::new(crate::pause_state::LoopContextSnapshot::default()),
approve_tool: options.approve_tool,
approval_mode: options.approval_mode,
pre_turn_policies: options.pre_turn_policies,
pre_dispatch_policies: options.pre_dispatch_policies,
post_turn_policies: options.post_turn_policies,
post_loop_policies: options.post_loop_policies,
model_stream_fns,
event_forwarders: options.event_forwarders,
async_transform_context: options.async_transform_context,
checkpoint_store: options.checkpoint_store,
custom_message_registry: options.custom_message_registry,
metrics_collector: options.metrics_collector,
fallback: options.fallback,
external_message_provider: options.external_message_provider,
tool_execution_policy: options.tool_execution_policy,
plan_mode_addendum: options.plan_mode_addendum,
session_state: Arc::new(std::sync::RwLock::new(
options.session_state.unwrap_or_default(),
)),
credential_resolver: options.credential_resolver,
cache_config: options.cache_config,
dynamic_system_prompt: options.dynamic_system_prompt.map(Arc::from),
loop_active: Arc::new(AtomicBool::new(false)),
loop_generation: Arc::new(AtomicU64::new(0)),
#[cfg(feature = "plugins")]
plugins: options.plugins,
agent_name: options.agent_name,
transfer_chain: options.transfer_chain,
};
dispatch_plugin_on_init(&agent);
agent
}
#[must_use]
pub const fn id(&self) -> AgentId {
self.id
}
#[must_use]
pub const fn state(&self) -> &AgentState {
&self.state
}
#[must_use]
pub fn is_running(&self) -> bool {
self.loop_active.load(std::sync::atomic::Ordering::Acquire)
}
#[must_use]
pub const fn session_state(&self) -> &Arc<std::sync::RwLock<crate::SessionState>> {
&self.session_state
}
#[must_use]
pub fn custom_message_registry(&self) -> Option<&crate::types::CustomMessageRegistry> {
self.custom_message_registry.as_deref()
}
#[cfg(feature = "plugins")]
#[must_use]
pub fn plugins(&self) -> &[Arc<dyn crate::plugin::Plugin>] {
&self.plugins
}
#[cfg(feature = "plugins")]
#[must_use]
pub fn plugin(&self, name: &str) -> Option<&Arc<dyn crate::plugin::Plugin>> {
self.plugins.iter().find(|p| p.name() == name)
}
}
impl std::fmt::Debug for Agent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Agent")
.field("state", &self.state)
.field("steering_mode", &self.steering_mode)
.field("follow_up_mode", &self.follow_up_mode)
.field(
"listeners",
&format_args!("{} listener(s)", self.listeners.len()),
)
.field("is_abort_active", &self.abort_controller.is_some())
.finish_non_exhaustive()
}
}
#[cfg(feature = "plugins")]
fn merge_plugin_contributions(mut options: AgentOptions) -> AgentOptions {
options
.plugins
.sort_by_key(|p| std::cmp::Reverse(p.priority()));
let mut plugin_pre_turn: Vec<Arc<dyn crate::policy::PreTurnPolicy>> = Vec::new();
let mut plugin_pre_dispatch: Vec<Arc<dyn crate::policy::PreDispatchPolicy>> = Vec::new();
let mut plugin_post_turn: Vec<Arc<dyn crate::policy::PostTurnPolicy>> = Vec::new();
let mut plugin_post_loop: Vec<Arc<dyn crate::policy::PostLoopPolicy>> = Vec::new();
let mut plugin_tools: Vec<Arc<dyn AgentTool>> = Vec::new();
let mut plugin_forwarders: Vec<crate::event_forwarder::EventForwarderFn> = Vec::new();
for plugin in &options.plugins {
plugin_pre_turn.extend(plugin.pre_turn_policies());
plugin_pre_dispatch.extend(plugin.pre_dispatch_policies());
plugin_post_turn.extend(plugin.post_turn_policies());
plugin_post_loop.extend(plugin.post_loop_policies());
let plugin_name = plugin.name().to_owned();
for tool in plugin.tools() {
plugin_tools.push(Arc::new(crate::plugin::NamespacedTool::new(
&plugin_name,
tool,
)));
}
let plugin_ref = Arc::clone(plugin);
plugin_forwarders.push(Arc::new(move |event: crate::loop_::AgentEvent| {
plugin_ref.on_event(&event);
}));
}
plugin_pre_turn.append(&mut options.pre_turn_policies);
options.pre_turn_policies = plugin_pre_turn;
plugin_pre_dispatch.append(&mut options.pre_dispatch_policies);
options.pre_dispatch_policies = plugin_pre_dispatch;
plugin_post_turn.append(&mut options.post_turn_policies);
options.post_turn_policies = plugin_post_turn;
plugin_post_loop.append(&mut options.post_loop_policies);
options.post_loop_policies = plugin_post_loop;
options.tools.extend(plugin_tools);
plugin_forwarders.append(&mut options.event_forwarders);
options.event_forwarders = plugin_forwarders;
options
}
struct SharedRetryStrategy(Arc<dyn RetryStrategy>);
impl RetryStrategy for SharedRetryStrategy {
fn should_retry(&self, error: &AgentError, attempt: u32) -> bool {
self.0.should_retry(error, attempt)
}
fn delay(&self, attempt: u32) -> std::time::Duration {
self.0.delay(attempt)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(all(test, feature = "plugins"))]
mod tests {
use std::sync::Arc;
use crate::testing::{MockPlugin, MockTool, SimpleMockStreamFn};
use super::*;
#[test]
#[should_panic(expected = "duplicate tool names are not allowed after composition")]
fn agent_new_rejects_duplicate_names_after_plugin_composition() {
let stream_fn = Arc::new(SimpleMockStreamFn::from_text("ok"));
let options = AgentOptions::new(
"test",
crate::testing::default_model(),
stream_fn,
crate::testing::default_convert,
)
.with_tools(vec![
Arc::new(MockTool::new("my_web_search")) as Arc<dyn AgentTool>
])
.with_plugin(Arc::new(MockPlugin::new("my-web").with_tools(&["search"])));
let _agent = Agent::new(options);
}
}