use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::{
CalledFunction, ChatCompletionChunkResponse, ChatCompletionResponse, ChunkChoice, Delta, Model,
RequestBuilder, Response, TextMessageRole, Tool, ToolCallResponse, ToolCallback, ToolChoice,
};
pub type AsyncToolCallback = dyn Fn(CalledFunction) -> Pin<Box<dyn Future<Output = anyhow::Result<String>> + Send>>
+ Send
+ Sync;
#[derive(Clone)]
pub enum ToolCallbackType {
Sync(Arc<ToolCallback>),
Async(Arc<AsyncToolCallback>),
}
#[derive(Clone, Debug)]
pub struct AgentConfig {
pub max_iterations: usize,
pub tool_choice: ToolChoice,
pub system_prompt: Option<String>,
pub parallel_tool_execution: bool,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
max_iterations: 10,
tool_choice: ToolChoice::Auto,
system_prompt: None,
parallel_tool_execution: true,
}
}
}
#[derive(Debug, Clone)]
pub struct AgentStep {
pub response: ChatCompletionResponse,
pub tool_calls: Vec<ToolCallResponse>,
pub tool_results: Vec<ToolResult>,
}
#[derive(Debug, Clone)]
pub struct ToolResult {
pub tool_call_id: String,
pub tool_name: String,
pub result: Result<String, String>,
}
#[derive(Debug, Clone)]
pub struct AgentResponse {
pub steps: Vec<AgentStep>,
pub final_response: Option<String>,
pub iterations: usize,
pub stop_reason: AgentStopReason,
}
#[derive(Debug, Clone, PartialEq)]
pub enum AgentStopReason {
TextResponse,
MaxIterations,
NoAction,
Error(String),
}
#[derive(Debug, Clone)]
pub enum AgentEvent {
TextDelta(String),
ToolCallsStart(Vec<ToolCallResponse>),
ToolResult(ToolResult),
ToolCallsComplete,
Complete(AgentResponse),
}
enum AgentStreamState {
Streaming {
messages: RequestBuilder,
iteration: usize,
accumulated_content: String,
accumulated_tool_calls: Vec<ToolCallResponse>,
steps: Vec<AgentStep>,
},
ExecutingTools {
messages: RequestBuilder,
iteration: usize,
response: ChatCompletionResponse,
tool_calls: Vec<ToolCallResponse>,
tool_results: Vec<ToolResult>,
pending_indices: Vec<usize>,
steps: Vec<AgentStep>,
},
Done,
}
pub struct AgentStream<'a> {
agent: &'a Agent,
state: AgentStreamState,
model_stream: Option<crate::model::Stream<'a>>,
}
impl<'a> AgentStream<'a> {
pub async fn next(&mut self) -> Option<AgentEvent> {
loop {
match &mut self.state {
AgentStreamState::Done => return None,
AgentStreamState::Streaming {
messages,
iteration,
accumulated_content,
accumulated_tool_calls,
steps,
} => {
if let Some(ref mut stream) = self.model_stream {
if let Some(response) = stream.next().await {
match response {
Response::Chunk(ChatCompletionChunkResponse {
choices, ..
}) => {
if let Some(ChunkChoice {
delta:
Delta {
content,
tool_calls,
..
},
finish_reason,
..
}) = choices.first()
{
if let Some(text) = content {
accumulated_content.push_str(text);
return Some(AgentEvent::TextDelta(text.clone()));
}
if let Some(calls) = tool_calls {
accumulated_tool_calls.extend(calls.clone());
}
if finish_reason.is_some() {
self.model_stream = None;
if accumulated_tool_calls.is_empty() {
let final_response =
if accumulated_content.is_empty() {
None
} else {
Some(accumulated_content.clone())
};
let stop_reason = if final_response.is_some() {
AgentStopReason::TextResponse
} else {
AgentStopReason::NoAction
};
let response = AgentResponse {
steps: steps.clone(),
final_response,
iterations: *iteration + 1,
stop_reason,
};
self.state = AgentStreamState::Done;
return Some(AgentEvent::Complete(response));
} else {
let tool_calls = accumulated_tool_calls.clone();
let event =
AgentEvent::ToolCallsStart(tool_calls.clone());
let placeholder_response = ChatCompletionResponse {
id: String::new(),
choices: vec![],
created: 0,
model: String::new(),
system_fingerprint: String::new(),
object: String::new(),
usage: crate::Usage {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
avg_tok_per_sec: 0.0,
avg_prompt_tok_per_sec: 0.0,
avg_compl_tok_per_sec: 0.0,
total_time_sec: 0.0,
total_prompt_time_sec: 0.0,
total_completion_time_sec: 0.0,
},
};
self.state = AgentStreamState::ExecutingTools {
messages: messages.clone(),
iteration: *iteration,
response: placeholder_response,
tool_calls: tool_calls.clone(),
tool_results: Vec::new(),
pending_indices: (0..tool_calls.len())
.collect(),
steps: steps.clone(),
};
return Some(event);
}
}
}
}
Response::Done(response) => {
self.model_stream = None;
let tool_calls = response
.choices
.first()
.and_then(|c| c.message.tool_calls.clone())
.unwrap_or_default();
if tool_calls.is_empty() {
let final_response = response
.choices
.first()
.and_then(|c| c.message.content.clone());
let stop_reason = if final_response.is_some() {
AgentStopReason::TextResponse
} else {
AgentStopReason::NoAction
};
let agent_response = AgentResponse {
steps: steps.clone(),
final_response,
iterations: *iteration + 1,
stop_reason,
};
self.state = AgentStreamState::Done;
return Some(AgentEvent::Complete(agent_response));
} else {
let event = AgentEvent::ToolCallsStart(tool_calls.clone());
self.state = AgentStreamState::ExecutingTools {
messages: messages.clone(),
iteration: *iteration,
response: response.clone(),
tool_calls: tool_calls.clone(),
tool_results: Vec::new(),
pending_indices: (0..tool_calls.len()).collect(),
steps: steps.clone(),
};
return Some(event);
}
}
_ => continue,
}
}
}
self.state = AgentStreamState::Done;
return None;
}
AgentStreamState::ExecutingTools {
messages,
iteration,
response,
tool_calls,
tool_results,
pending_indices,
steps,
} => {
if pending_indices.is_empty() {
let mut new_messages = messages.clone();
new_messages = new_messages.add_message_with_tool_call(
TextMessageRole::Assistant,
response
.choices
.first()
.and_then(|c| c.message.content.clone())
.unwrap_or_default(),
tool_calls.clone(),
);
for result in tool_results.iter() {
let result_str = match &result.result {
Ok(s) => s.clone(),
Err(e) => format!("Error: {}", e),
};
new_messages =
new_messages.add_tool_message(&result_str, &result.tool_call_id);
}
let step = AgentStep {
response: response.clone(),
tool_calls: tool_calls.clone(),
tool_results: tool_results.clone(),
};
let mut new_steps = steps.clone();
new_steps.push(step);
let new_iteration = *iteration + 1;
if new_iteration >= self.agent.config.max_iterations {
let agent_response = AgentResponse {
steps: new_steps,
final_response: None,
iterations: new_iteration,
stop_reason: AgentStopReason::MaxIterations,
};
self.state = AgentStreamState::Done;
return Some(AgentEvent::Complete(agent_response));
}
let request = new_messages
.clone()
.set_tools(self.agent.tools.clone())
.set_tool_choice(self.agent.config.tool_choice.clone());
match self.agent.model.stream_chat_request(request).await {
Ok(stream) => {
self.model_stream = Some(stream);
self.state = AgentStreamState::Streaming {
messages: new_messages,
iteration: new_iteration,
accumulated_content: String::new(),
accumulated_tool_calls: Vec::new(),
steps: new_steps,
};
return Some(AgentEvent::ToolCallsComplete);
}
Err(e) => {
let agent_response = AgentResponse {
steps: new_steps,
final_response: None,
iterations: new_iteration,
stop_reason: AgentStopReason::Error(e.to_string()),
};
self.state = AgentStreamState::Done;
return Some(AgentEvent::Complete(agent_response));
}
}
}
let idx = pending_indices.remove(0);
let tool_call = &tool_calls[idx];
let result = self.agent.execute_tool_async(tool_call).await;
let event = AgentEvent::ToolResult(result.clone());
tool_results.push(result);
return Some(event);
}
}
}
}
}
pub struct Agent {
model: Model,
tools: Vec<Tool>,
callbacks: HashMap<String, ToolCallbackType>,
config: AgentConfig,
}
impl Agent {
pub fn new(model: Model, config: AgentConfig) -> Self {
Self {
model,
tools: Vec::new(),
callbacks: HashMap::new(),
config,
}
}
pub fn with_tool(mut self, tool: Tool, callback: ToolCallbackType) -> Self {
let name = tool.function.name.clone();
self.tools.push(tool);
self.callbacks.insert(name, callback);
self
}
pub async fn run(&self, user_message: impl ToString) -> anyhow::Result<AgentResponse> {
let mut steps = Vec::new();
let mut messages = RequestBuilder::new();
if let Some(ref system) = self.config.system_prompt {
messages = messages.add_message(TextMessageRole::System, system);
}
messages = messages.add_message(TextMessageRole::User, user_message.to_string());
for iteration in 0..self.config.max_iterations {
let request = messages
.clone()
.set_tools(self.tools.clone())
.set_tool_choice(self.config.tool_choice.clone());
let response = self.model.send_chat_request(request).await?;
let choice = response
.choices
.first()
.ok_or_else(|| anyhow::anyhow!("No choices in response"))?;
let tool_calls = choice.message.tool_calls.clone().unwrap_or_default();
if tool_calls.is_empty() {
let final_text = choice.message.content.clone();
steps.push(AgentStep {
response: response.clone(),
tool_calls: vec![],
tool_results: vec![],
});
let stop_reason = if final_text.is_some() {
AgentStopReason::TextResponse
} else {
AgentStopReason::NoAction
};
return Ok(AgentResponse {
steps,
final_response: final_text,
iterations: iteration + 1,
stop_reason,
});
}
let tool_results = if self.config.parallel_tool_execution {
self.execute_tools_parallel(&tool_calls).await
} else {
let mut results = Vec::new();
for tool_call in &tool_calls {
results.push(self.execute_tool_async(tool_call).await);
}
results
};
messages = messages.add_message_with_tool_call(
TextMessageRole::Assistant,
choice.message.content.clone().unwrap_or_default(),
tool_calls.clone(),
);
for result in &tool_results {
let result_str = match &result.result {
Ok(s) => s.clone(),
Err(e) => format!("Error: {}", e),
};
messages = messages.add_tool_message(&result_str, &result.tool_call_id);
}
steps.push(AgentStep {
response: response.clone(),
tool_calls: tool_calls.clone(),
tool_results,
});
}
Ok(AgentResponse {
steps,
final_response: None,
iterations: self.config.max_iterations,
stop_reason: AgentStopReason::MaxIterations,
})
}
pub async fn run_stream(&self, user_message: impl ToString) -> anyhow::Result<AgentStream<'_>> {
let mut messages = RequestBuilder::new();
if let Some(ref system) = self.config.system_prompt {
messages = messages.add_message(TextMessageRole::System, system);
}
messages = messages.add_message(TextMessageRole::User, user_message.to_string());
let request = messages
.clone()
.set_tools(self.tools.clone())
.set_tool_choice(self.config.tool_choice.clone());
let stream = self.model.stream_chat_request(request).await?;
Ok(AgentStream {
agent: self,
state: AgentStreamState::Streaming {
messages,
iteration: 0,
accumulated_content: String::new(),
accumulated_tool_calls: Vec::new(),
steps: Vec::new(),
},
model_stream: Some(stream),
})
}
async fn execute_tools_parallel(&self, tool_calls: &[ToolCallResponse]) -> Vec<ToolResult> {
let futures: Vec<_> = tool_calls
.iter()
.map(|tc| self.execute_tool_async(tc))
.collect();
futures::future::join_all(futures).await
}
async fn execute_tool_async(&self, tool_call: &ToolCallResponse) -> ToolResult {
let tool_name = &tool_call.function.name;
let result = match self.callbacks.get(tool_name) {
Some(ToolCallbackType::Sync(callback)) => {
let callback = Arc::clone(callback);
let function = tool_call.function.clone();
tokio::task::spawn_blocking(move || callback(&function))
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))
.and_then(|r| r)
.map_err(|e| e.to_string())
}
Some(ToolCallbackType::Async(callback)) => {
let function = tool_call.function.clone();
callback(function).await.map_err(|e| e.to_string())
}
None => Err(format!("Unknown tool: {}", tool_name)),
};
ToolResult {
tool_call_id: tool_call.id.clone(),
tool_name: tool_name.clone(),
result,
}
}
pub fn model(&self) -> &Model {
&self.model
}
pub fn tools(&self) -> &[Tool] {
&self.tools
}
pub fn config(&self) -> &AgentConfig {
&self.config
}
}
pub struct AgentBuilder {
model: Model,
tools: Vec<Tool>,
callbacks: HashMap<String, ToolCallbackType>,
config: AgentConfig,
}
impl AgentBuilder {
pub fn new(model: Model) -> Self {
Self {
model,
tools: Vec::new(),
callbacks: HashMap::new(),
config: AgentConfig::default(),
}
}
pub fn with_max_iterations(mut self, max: usize) -> Self {
self.config.max_iterations = max;
self
}
pub fn with_system_prompt(mut self, prompt: impl ToString) -> Self {
self.config.system_prompt = Some(prompt.to_string());
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.config.tool_choice = choice;
self
}
pub fn with_parallel_tool_execution(mut self, enabled: bool) -> Self {
self.config.parallel_tool_execution = enabled;
self
}
pub fn with_sync_tool(mut self, tool: Tool, callback: Arc<ToolCallback>) -> Self {
let name = tool.function.name.clone();
self.tools.push(tool);
self.callbacks
.insert(name, ToolCallbackType::Sync(callback));
self
}
pub fn with_async_tool(mut self, tool: Tool, callback: Arc<AsyncToolCallback>) -> Self {
let name = tool.function.name.clone();
self.tools.push(tool);
self.callbacks
.insert(name, ToolCallbackType::Async(callback));
self
}
pub fn register_tool(mut self, (tool, callback): (Tool, ToolCallbackType)) -> Self {
let name = tool.function.name.clone();
self.tools.push(tool);
self.callbacks.insert(name, callback);
self
}
pub fn build(self) -> Agent {
Agent {
model: self.model,
tools: self.tools,
callbacks: self.callbacks,
config: self.config,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_config_default() {
let config = AgentConfig::default();
assert_eq!(config.max_iterations, 10);
assert!(config.system_prompt.is_none());
assert!(config.parallel_tool_execution);
}
#[test]
fn test_agent_stop_reason_equality() {
assert_eq!(AgentStopReason::TextResponse, AgentStopReason::TextResponse);
assert_eq!(
AgentStopReason::MaxIterations,
AgentStopReason::MaxIterations
);
assert_ne!(
AgentStopReason::TextResponse,
AgentStopReason::MaxIterations
);
}
#[test]
fn test_tool_callback_type_clone() {
let sync_cb: Arc<ToolCallback> = Arc::new(|_| Ok("test".to_string()));
let cb_type = ToolCallbackType::Sync(sync_cb);
let _ = cb_type.clone();
}
}