use crate::agent_loop::{
agent_loop, agent_loop_continue, AfterTurnFn, AgentLoopConfig, BeforeTurnFn, OnErrorFn,
};
use crate::context::{CompactionStrategy, ContextConfig, ExecutionLimits};
use crate::mcp::{McpClient, McpError, McpToolAdapter};
use crate::provider::{ModelConfig, StreamProvider};
use crate::types::*;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueueMode {
OneAtATime,
All,
}
pub struct Agent {
pub system_prompt: String,
pub model: String,
pub api_key: String,
pub thinking_level: ThinkingLevel,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
model_config: Option<ModelConfig>,
messages: Vec<AgentMessage>,
tools: Vec<Box<dyn AgentTool>>,
provider: Arc<dyn StreamProvider>,
steering_queue: Arc<Mutex<Vec<AgentMessage>>>,
follow_up_queue: Arc<Mutex<Vec<AgentMessage>>>,
steering_mode: QueueMode,
follow_up_mode: QueueMode,
pub context_config: Option<ContextConfig>,
context_management_disabled: bool,
pub execution_limits: Option<ExecutionLimits>,
pub cache_config: CacheConfig,
pub tool_execution: ToolExecutionStrategy,
pub retry_config: crate::retry::RetryConfig,
before_turn: Option<BeforeTurnFn>,
after_turn: Option<AfterTurnFn>,
on_error: Option<OnErrorFn>,
input_filters: Vec<Arc<dyn InputFilter>>,
compaction_strategy: Option<Arc<dyn CompactionStrategy>>,
cancel: Option<CancellationToken>,
is_streaming: bool,
#[allow(clippy::type_complexity)]
pending_completion: Option<JoinHandle<(Vec<Box<dyn AgentTool>>, Vec<AgentMessage>)>>,
}
impl Agent {
pub fn new(provider: impl StreamProvider + 'static) -> Self {
Self {
system_prompt: String::new(),
model: String::new(),
api_key: String::new(),
thinking_level: ThinkingLevel::Off,
max_tokens: None,
temperature: None,
model_config: None,
messages: Vec::new(),
tools: Vec::new(),
provider: Arc::new(provider),
steering_queue: Arc::new(Mutex::new(Vec::new())),
follow_up_queue: Arc::new(Mutex::new(Vec::new())),
steering_mode: QueueMode::OneAtATime,
follow_up_mode: QueueMode::OneAtATime,
context_config: None,
context_management_disabled: false,
execution_limits: Some(ExecutionLimits::default()),
cache_config: CacheConfig::default(),
tool_execution: ToolExecutionStrategy::default(),
retry_config: crate::retry::RetryConfig::default(),
before_turn: None,
after_turn: None,
on_error: None,
input_filters: Vec::new(),
compaction_strategy: None,
cancel: None,
is_streaming: false,
pending_completion: None,
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = key.into();
self
}
pub fn with_thinking(mut self, level: ThinkingLevel) -> Self {
self.thinking_level = level;
self
}
pub fn with_tools(mut self, tools: Vec<Box<dyn AgentTool>>) -> Self {
self.tools = tools;
self
}
pub fn with_model_config(mut self, config: ModelConfig) -> Self {
self.model_config = Some(config);
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
pub fn with_context_config(mut self, config: ContextConfig) -> Self {
self.context_config = Some(config);
self
}
pub fn with_cache_config(mut self, config: CacheConfig) -> Self {
self.cache_config = config;
self
}
pub fn with_tool_execution(mut self, strategy: ToolExecutionStrategy) -> Self {
self.tool_execution = strategy;
self
}
pub fn with_retry_config(mut self, config: crate::retry::RetryConfig) -> Self {
self.retry_config = config;
self
}
pub fn with_skills(mut self, skills: crate::skills::SkillSet) -> Self {
let prompt_fragment = skills.format_for_prompt();
if !prompt_fragment.is_empty() {
if self.system_prompt.is_empty() {
self.system_prompt = prompt_fragment;
} else {
self.system_prompt = format!("{}\n\n{}", self.system_prompt, prompt_fragment);
}
}
self
}
pub fn with_execution_limits(mut self, limits: ExecutionLimits) -> Self {
self.execution_limits = Some(limits);
self
}
pub fn with_messages(mut self, msgs: Vec<AgentMessage>) -> Self {
self.messages = msgs;
self
}
pub fn on_before_turn(
mut self,
f: impl Fn(&[AgentMessage], usize) -> bool + Send + Sync + 'static,
) -> Self {
self.before_turn = Some(Arc::new(f));
self
}
pub fn on_after_turn(
mut self,
f: impl Fn(&[AgentMessage], &Usage) + Send + Sync + 'static,
) -> Self {
self.after_turn = Some(Arc::new(f));
self
}
pub fn on_error(mut self, f: impl Fn(&str) + Send + Sync + 'static) -> Self {
self.on_error = Some(Arc::new(f));
self
}
pub fn with_input_filter(mut self, filter: impl InputFilter + 'static) -> Self {
self.input_filters.push(Arc::new(filter));
self
}
pub fn with_compaction_strategy(mut self, strategy: impl CompactionStrategy + 'static) -> Self {
self.compaction_strategy = Some(Arc::new(strategy));
self
}
pub fn with_sub_agent(mut self, sub: crate::sub_agent::SubAgentTool) -> Self {
self.tools.push(Box::new(sub));
self
}
pub fn without_context_management(mut self) -> Self {
self.context_config = None;
self.context_management_disabled = true;
self.execution_limits = None;
self
}
#[cfg(feature = "openapi")]
pub async fn with_openapi_file(
mut self,
path: impl AsRef<std::path::Path>,
config: crate::openapi::OpenApiConfig,
filter: &crate::openapi::OperationFilter,
) -> Result<Self, crate::openapi::OpenApiError> {
let adapters = crate::openapi::OpenApiToolAdapter::from_file(path, config, filter).await?;
for adapter in adapters {
self.tools.push(Box::new(adapter));
}
Ok(self)
}
#[cfg(feature = "openapi")]
pub async fn with_openapi_url(
mut self,
url: &str,
config: crate::openapi::OpenApiConfig,
filter: &crate::openapi::OperationFilter,
) -> Result<Self, crate::openapi::OpenApiError> {
let adapters = crate::openapi::OpenApiToolAdapter::from_url(url, config, filter).await?;
for adapter in adapters {
self.tools.push(Box::new(adapter));
}
Ok(self)
}
#[cfg(feature = "openapi")]
pub fn with_openapi_spec(
mut self,
spec_str: &str,
config: crate::openapi::OpenApiConfig,
filter: &crate::openapi::OperationFilter,
) -> Result<Self, crate::openapi::OpenApiError> {
let adapters = crate::openapi::OpenApiToolAdapter::from_str(spec_str, config, filter)?;
for adapter in adapters {
self.tools.push(Box::new(adapter));
}
Ok(self)
}
pub async fn with_mcp_server_stdio(
mut self,
command: &str,
args: &[&str],
env: Option<HashMap<String, String>>,
) -> Result<Self, McpError> {
let client = McpClient::connect_stdio(command, args, env).await?;
let client = Arc::new(tokio::sync::Mutex::new(client));
let adapters = McpToolAdapter::from_client(client).await?;
for adapter in adapters {
self.tools.push(Box::new(adapter));
}
Ok(self)
}
pub async fn with_mcp_server_http(mut self, url: &str) -> Result<Self, McpError> {
let client = McpClient::connect_http(url).await?;
let client = Arc::new(tokio::sync::Mutex::new(client));
let adapters = McpToolAdapter::from_client(client).await?;
for adapter in adapters {
self.tools.push(Box::new(adapter));
}
Ok(self)
}
pub fn messages(&self) -> &[AgentMessage] {
&self.messages
}
pub fn is_streaming(&self) -> bool {
self.is_streaming
}
pub fn set_tools(&mut self, tools: Vec<Box<dyn AgentTool>>) {
self.tools = tools;
}
pub fn clear_messages(&mut self) {
self.messages.clear();
}
pub fn append_message(&mut self, msg: AgentMessage) {
self.messages.push(msg);
}
pub fn replace_messages(&mut self, msgs: Vec<AgentMessage>) {
self.messages = msgs;
}
pub fn save_messages(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(&self.messages)
}
pub fn restore_messages(&mut self, json: &str) -> Result<(), serde_json::Error> {
let msgs: Vec<AgentMessage> = serde_json::from_str(json)?;
self.messages = msgs;
Ok(())
}
pub fn steer(&self, msg: AgentMessage) {
self.steering_queue.lock().unwrap().push(msg);
}
pub fn follow_up(&self, msg: AgentMessage) {
self.follow_up_queue.lock().unwrap().push(msg);
}
pub fn clear_steering_queue(&self) {
self.steering_queue.lock().unwrap().clear();
}
pub fn clear_follow_up_queue(&self) {
self.follow_up_queue.lock().unwrap().clear();
}
pub fn clear_all_queues(&self) {
self.clear_steering_queue();
self.clear_follow_up_queue();
}
pub fn set_steering_mode(&mut self, mode: QueueMode) {
self.steering_mode = mode;
}
pub fn set_follow_up_mode(&mut self, mode: QueueMode) {
self.follow_up_mode = mode;
}
pub fn abort(&self) {
if let Some(ref cancel) = self.cancel {
cancel.cancel();
}
}
pub async fn reset(&mut self) {
if let Some(ref cancel) = self.cancel {
cancel.cancel();
}
if let Some(handle) = self.pending_completion.take() {
if let Ok((tools, _messages)) = handle.await {
self.tools = tools;
}
}
self.messages.clear();
self.clear_all_queues();
self.is_streaming = false;
self.cancel = None;
}
pub async fn prompt(&mut self, text: impl Into<String>) -> mpsc::UnboundedReceiver<AgentEvent> {
let msg = AgentMessage::Llm(Message::user(text));
self.prompt_messages(vec![msg]).await
}
pub async fn prompt_messages(
&mut self,
messages: Vec<AgentMessage>,
) -> mpsc::UnboundedReceiver<AgentEvent> {
self.finish().await;
assert!(
!self.is_streaming,
"Agent is already streaming. Use steer() or follow_up()."
);
let cancel = CancellationToken::new();
self.cancel = Some(cancel.clone());
self.is_streaming = true;
let (tx, rx) = mpsc::unbounded_channel();
let mut context = AgentContext {
system_prompt: self.system_prompt.clone(),
messages: self.messages.clone(),
tools: std::mem::take(&mut self.tools),
};
let config = self.build_config();
let handle = tokio::spawn(async move {
let _new_messages = agent_loop(messages, &mut context, &config, tx, cancel).await;
(context.tools, context.messages)
});
self.pending_completion = Some(handle);
rx
}
pub async fn prompt_with_sender(
&mut self,
text: impl Into<String>,
tx: mpsc::UnboundedSender<AgentEvent>,
) {
let msg = AgentMessage::Llm(Message::user(text));
self.prompt_messages_with_sender(vec![msg], tx).await;
}
pub async fn prompt_messages_with_sender(
&mut self,
messages: Vec<AgentMessage>,
tx: mpsc::UnboundedSender<AgentEvent>,
) {
self.finish().await;
assert!(
!self.is_streaming,
"Agent is already streaming. Use steer() or follow_up()."
);
let cancel = CancellationToken::new();
self.cancel = Some(cancel.clone());
self.is_streaming = true;
let mut context = AgentContext {
system_prompt: self.system_prompt.clone(),
messages: self.messages.clone(),
tools: std::mem::take(&mut self.tools),
};
let config = self.build_config();
let _new_messages = agent_loop(messages, &mut context, &config, tx, cancel).await;
self.tools = context.tools;
self.messages = context.messages;
self.is_streaming = false;
self.cancel = None;
}
pub async fn continue_loop(&mut self) -> mpsc::UnboundedReceiver<AgentEvent> {
self.finish().await;
assert!(!self.is_streaming, "Agent is already streaming.");
assert!(!self.messages.is_empty(), "No messages to continue from.");
let cancel = CancellationToken::new();
self.cancel = Some(cancel.clone());
self.is_streaming = true;
let (tx, rx) = mpsc::unbounded_channel();
let mut context = AgentContext {
system_prompt: self.system_prompt.clone(),
messages: self.messages.clone(),
tools: std::mem::take(&mut self.tools),
};
let config = self.build_config();
let handle = tokio::spawn(async move {
let _new_messages = agent_loop_continue(&mut context, &config, tx, cancel).await;
(context.tools, context.messages)
});
self.pending_completion = Some(handle);
rx
}
pub async fn continue_loop_with_sender(&mut self, tx: mpsc::UnboundedSender<AgentEvent>) {
self.finish().await;
assert!(!self.is_streaming, "Agent is already streaming.");
assert!(!self.messages.is_empty(), "No messages to continue from.");
let cancel = CancellationToken::new();
self.cancel = Some(cancel.clone());
self.is_streaming = true;
let mut context = AgentContext {
system_prompt: self.system_prompt.clone(),
messages: self.messages.clone(),
tools: std::mem::take(&mut self.tools),
};
let config = self.build_config();
let _new_messages = agent_loop_continue(&mut context, &config, tx, cancel).await;
self.tools = context.tools;
self.messages = context.messages;
self.is_streaming = false;
self.cancel = None;
}
pub async fn finish(&mut self) {
if let Some(handle) = self.pending_completion.take() {
match handle.await {
Ok((tools, messages)) => {
self.tools = tools;
self.messages = messages;
}
Err(e) => {
tracing::error!("Agent loop task failed: {}", e);
}
}
self.is_streaming = false;
self.cancel = None;
}
}
fn build_config(&self) -> AgentLoopConfig {
let steering_queue = self.steering_queue.clone();
let steering_mode = self.steering_mode;
let follow_up_queue = self.follow_up_queue.clone();
let follow_up_mode = self.follow_up_mode;
AgentLoopConfig {
provider: self.provider.clone(),
model: self.model.clone(),
api_key: self.api_key.clone(),
thinking_level: self.thinking_level,
max_tokens: self.max_tokens,
temperature: self.temperature,
model_config: self.model_config.clone(),
convert_to_llm: None,
transform_context: None,
get_steering_messages: Some(Box::new(move || {
let mut queue = steering_queue.lock().unwrap();
match steering_mode {
QueueMode::OneAtATime => {
if queue.is_empty() {
vec![]
} else {
vec![queue.remove(0)]
}
}
QueueMode::All => queue.drain(..).collect(),
}
})),
context_config: if self.context_management_disabled {
None
} else {
Some(self.context_config.clone().unwrap_or_else(|| {
self.model_config
.as_ref()
.map(|m| ContextConfig::from_context_window(m.context_window))
.unwrap_or_default()
}))
},
compaction_strategy: self.compaction_strategy.clone(),
execution_limits: self.execution_limits.clone(),
cache_config: self.cache_config.clone(),
tool_execution: self.tool_execution.clone(),
retry_config: self.retry_config.clone(),
get_follow_up_messages: Some(Box::new(move || {
let mut queue = follow_up_queue.lock().unwrap();
match follow_up_mode {
QueueMode::OneAtATime => {
if queue.is_empty() {
vec![]
} else {
vec![queue.remove(0)]
}
}
QueueMode::All => queue.drain(..).collect(),
}
})),
before_turn: self.before_turn.clone(),
after_turn: self.after_turn.clone(),
on_error: self.on_error.clone(),
input_filters: self.input_filters.clone(),
turn_delay: None,
}
}
}