use crate::client::{Client, ClientError};
use crate::model::{FinishReason, Message, Part, Response, Usage};
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info, warn};
use crate::mcp::MCPServer;
pub struct Agent<C: Client> {
client: C,
max_iterations: usize,
server: Option<Box<dyn MCPServer>>,
}
impl<C: Client> Agent<C> {
pub fn new(client: C) -> Self {
Self {
client,
max_iterations: 10,
server: None,
}
}
pub fn with_server<S: MCPServer + 'static>(mut self, server: S) -> Self {
self.server = Some(Box::new(server));
self
}
pub fn with_max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub async fn chat(&self, mut messages: Vec<Message>) -> Result<Response, ClientError> {
debug!(
"Starting agent chat loop with {} initial messages",
messages.len()
);
let mut current_response = Response {
data: Vec::new(),
usage: Usage::default(),
finish: FinishReason::Unfinished,
};
let (tools, tool_map) = if let Some(server) = &self.server {
match server.list_tools().await {
Ok(tools) => {
let map: HashMap<String, Option<String>> = tools
.iter()
.map(|t| (t.value.name.to_string(), t.server_id.clone()))
.collect();
(tools.into_iter().map(|t| t.value).collect(), map)
}
Err(e) => {
return Err(ClientError::ProviderError(format!(
"Failed to list tools from MCP server: {}",
e
)));
}
}
} else {
(Vec::new(), HashMap::new())
};
for iteration in 0..self.max_iterations {
debug!("Agent iteration {}/{}", iteration + 1, self.max_iterations);
let response = self.client.request(messages.clone(), tools.clone()).await?;
current_response.usage += response.usage;
current_response.finish = response.finish.clone();
let mut tool_calls_executed = false;
for msg in response.data {
messages.push(msg.clone());
current_response.data.push(msg.clone());
for part in msg.parts() {
if let Part::FunctionCall {
id,
name,
arguments,
..
} = part
{
tool_calls_executed = true;
info!("Tool call requested: {}", name);
debug!("Tool arguments: {}", arguments);
let server = self.server.as_ref().ok_or_else(|| {
ClientError::Config("No MCP server configured".to_string())
})?;
let server_id = tool_map.get(name).cloned().flatten();
let result = server
.call_tool(name.clone(), arguments.clone(), server_id)
.await;
let response_part = match result {
Ok(mut part) => {
info!("Tool {} executed successfully", name);
debug!("Tool result: {:?}", part);
if let Part::FunctionResponse {
id: ref mut pid, ..
} = part
{
*pid = id.clone();
}
part
}
Err(e) => {
warn!("Tool {} execution failed: {}", name, e);
Part::FunctionResponse {
id: id.clone(),
name: name.clone(),
response: json!({ "error": format!("Error: {}", e) }),
parts: vec![],
finished: true,
}
}
};
let response_msg = Message::User(vec![response_part]);
messages.push(response_msg.clone());
current_response.data.push(response_msg);
}
}
}
if !tool_calls_executed {
debug!("No more function calls, agent loop complete");
return Ok(current_response);
}
}
warn!(
"Max iterations ({}) reached in agent loop",
self.max_iterations
);
Err(ClientError::Config(
"Max iterations reached in agent loop".to_string(),
))
}
pub fn chat_stream<'a>(
&'a self,
mut messages: Vec<Message>,
) -> std::pin::Pin<Box<dyn futures::Stream<Item = Result<Response, ClientError>> + Send + 'a>>
where
C: crate::client::StreamingClient,
{
Box::pin(async_stream::try_stream! {
debug!("Starting agent streaming chat loop");
use futures::StreamExt;
let mut current_response = Response {
data: Vec::new(),
usage: Usage::default(),
finish: FinishReason::Unfinished,
};
let (tools, tool_map) = if let Some(server) = &self.server {
match server.list_tools().await {
Ok(tools) => {
let map: HashMap<String, Option<String>> = tools
.iter()
.map(|t| (t.value.name.to_string(), t.server_id.clone()))
.collect();
(tools.into_iter().map(|t| t.value).collect(), map)
}
Err(e) => {
warn!("Failed to list tools from MCP server: {}", e);
(Vec::new(), HashMap::new())
}
}
} else {
(Vec::new(), HashMap::new())
};
for iteration in 0..self.max_iterations {
debug!(
"Agent streaming iteration {}/{}",
iteration + 1,
self.max_iterations
);
let mut stream = self.client.request_stream(messages.clone(), tools.clone()).await?;
let base_data_len = current_response.data.len();
let base_usage = current_response.usage.clone();
while let Some(response_result) = stream.next().await {
let response = response_result?;
current_response.data.truncate(base_data_len);
current_response.data.extend(response.data.clone());
current_response.usage = base_usage.clone();
current_response.usage += response.usage;
current_response.finish = response.finish;
yield current_response.clone();
}
if current_response.data.len() > base_data_len {
for i in base_data_len..current_response.data.len() {
messages.push(current_response.data[i].clone());
}
}
let mut tool_calls_executed = false;
let mut tool_responses = Vec::new();
if let Some(msg) = current_response.data.last() {
for part in msg.parts() {
if let Part::FunctionCall { id, name, arguments, finished, .. } = part {
if *finished {
tool_calls_executed = true;
info!("Executing tool: {}", name);
let server = self.server.as_ref().ok_or_else(|| ClientError::Config("No MCP server configured".to_string()))?;
let server_id = tool_map.get(name).cloned().flatten();
let result = server
.call_tool(name.clone(), arguments.clone(), server_id)
.await;
let response_part = match result {
Ok(mut part) => {
if let Part::FunctionResponse { id: ref mut pid, .. } = part {
*pid = id.clone();
}
part
}
Err(e) => {
Part::FunctionResponse {
id: id.clone(),
name: name.clone(),
response: json!({ "error": format!("Error: {}", e) }),
parts: vec![],
finished: true,
}
},
};
tool_responses.push(response_part);
}
}
}
}
if tool_calls_executed {
let tool_msg = Message::User(tool_responses);
messages.push(tool_msg.clone());
current_response.data.push(tool_msg);
yield current_response.clone();
} else {
return;
}
}
warn!(
"Max iterations ({}) reached in streaming agent loop",
self.max_iterations
);
Err(ClientError::Config(
"Max iterations reached in agent loop".to_string(),
))?;
})
}
}