use crate::core::types::{
Content, GenerateOptions, GenerateResult, Message, Prompt, Role, StreamPart, ToolDefinition,
};
use crate::core::{LanguageModel, Result};
use crate::core::error::ProviderError;
use futures::stream::BoxStream;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AgentStreamPart {
TextDelta { delta: String },
ToolCall {
id: String,
name: String,
arguments: serde_json::Value,
},
ToolResult {
name: String,
result: serde_json::Value,
},
Error { message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStep {
pub step: usize,
pub text: String,
pub tool_calls: Vec<AgentToolCall>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentToolCall {
pub name: String,
pub arguments: serde_json::Value,
pub result: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResult {
pub text: String,
pub steps: Vec<AgentStep>,
pub total_steps: usize,
pub finish_reason: String,
}
pub type ToolHandlerFn = Arc<
dyn Fn(
String,
serde_json::Value,
) -> Pin<Box<dyn Future<Output = anyhow::Result<serde_json::Value>> + Send>>
+ Send
+ Sync,
>;
pub struct Agent {
model: Box<dyn LanguageModel>,
tools: Vec<ToolDefinition>,
tool_handler: ToolHandlerFn,
max_steps: usize,
system: Option<String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
pub struct AgentBuilder {
model: Option<Box<dyn LanguageModel>>,
tools: Vec<ToolDefinition>,
tool_handler: Option<ToolHandlerFn>,
max_steps: usize,
system: Option<String>,
model_id: String,
temperature: Option<f32>,
max_tokens: Option<u32>,
}
impl AgentBuilder {
#[must_use]
pub fn model(mut self, model: Box<dyn LanguageModel>) -> Self {
self.model = Some(model);
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.tools = tools;
self
}
#[must_use]
pub fn tool_handler<F, Fut>(mut self, handler: F) -> Self
where
F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = anyhow::Result<serde_json::Value>> + Send + 'static,
{
self.tool_handler = Some(Arc::new(move |name, args| {
Box::pin(handler(name, args))
}));
self
}
#[must_use]
pub fn max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
#[must_use]
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
self.model_id = model_id.into();
self
}
#[must_use]
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
#[must_use]
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn build(self) -> std::result::Result<Agent, String> {
Ok(Agent {
model: self.model.ok_or("model is required")?,
tools: self.tools,
tool_handler: self.tool_handler.ok_or("tool_handler is required")?,
max_steps: self.max_steps,
system: self.system,
model_id: self.model_id,
temperature: self.temperature,
max_tokens: self.max_tokens,
})
}
}
impl Agent {
#[must_use]
pub fn builder() -> AgentBuilder {
AgentBuilder {
model: None,
tools: Vec::new(),
tool_handler: None,
max_steps: 10,
system: None,
model_id: String::new(),
temperature: None,
max_tokens: None,
}
}
pub async fn run(&self, prompt_text: &str) -> Result<AgentResult> {
let mut messages = Vec::new();
if let Some(ref sys) = self.system {
messages.push(Message {
role: Role::System,
content: vec![Content::Text {
text: sys.clone(),
}],
});
}
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: prompt_text.to_string(),
}],
});
let mut steps = Vec::new();
let mut last_result: Option<GenerateResult> = None;
for step_idx in 0..self.max_steps {
let prompt = Prompt {
messages: messages.clone(),
};
let options = GenerateOptions {
model_id: self.model_id.clone(),
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: None,
stop_sequences: None,
tools: if self.tools.is_empty() {
None
} else {
Some(self.tools.clone())
},
response_format: None,
};
let result = self.model.generate(prompt, options).await?;
let mut step = AgentStep {
step: step_idx,
text: result.text.clone(),
tool_calls: Vec::new(),
};
if result.tool_calls.is_empty() {
steps.push(step);
last_result = Some(result);
break;
}
let mut assistant_content = Vec::new();
if !result.text.is_empty() {
assistant_content.push(Content::Text {
text: result.text.clone(),
});
}
for tc in &result.tool_calls {
assistant_content.push(Content::ToolCall {
id: tc.name.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
});
}
messages.push(Message {
role: Role::Assistant,
content: assistant_content,
});
for tc in &result.tool_calls {
let tool_result = (self.tool_handler)(
tc.name.clone(),
tc.arguments.clone(),
)
.await;
let result_value = match tool_result {
Ok(v) => v,
Err(e) => serde_json::json!({ "error": e.to_string() }),
};
step.tool_calls.push(AgentToolCall {
name: tc.name.clone(),
arguments: tc.arguments.clone(),
result: Some(result_value.clone()),
});
messages.push(Message {
role: Role::Tool,
content: vec![Content::ToolResult {
id: tc.name.clone(),
result: result_value,
}],
});
}
steps.push(step);
last_result = Some(result);
}
let final_result = last_result.ok_or_else(|| {
ProviderError::InvalidResponse("Agent produced no results".to_string())
})?;
let total_steps = steps.len();
Ok(AgentResult {
text: final_result.text,
steps,
total_steps,
finish_reason: final_result.finish_reason,
})
}
pub async fn run_stream<'a>(
&'a mut self,
prompt_text: &str,
) -> Result<BoxStream<'a, AgentStreamPart>> {
let mut messages = vec![];
if let Some(sys) = &self.system {
messages.push(Message {
role: Role::System,
content: vec![Content::Text { text: sys.clone() }],
});
}
messages.push(Message {
role: Role::User,
content: vec![Content::Text {
text: prompt_text.to_string(),
}],
});
let stream = async_stream::stream! {
for _step in 0..self.max_steps {
let prompt = Prompt {
messages: messages.clone(),
};
let options = GenerateOptions {
model_id: self.model_id.clone(),
max_tokens: self.max_tokens,
temperature: self.temperature,
top_p: None,
stop_sequences: None,
tools: if self.tools.is_empty() {
None
} else {
Some(self.tools.clone())
},
response_format: None,
};
let mut inner_stream = match self.model.generate_stream(prompt, options).await {
Ok(s) => s,
Err(e) => {
yield AgentStreamPart::Error { message: e.to_string() };
break;
}
};
let mut tc_names = std::collections::HashMap::new();
let mut tc_args = std::collections::HashMap::new();
while let Some(part) = inner_stream.next().await {
match part {
StreamPart::TextDelta { delta } => {
yield AgentStreamPart::TextDelta { delta };
}
StreamPart::ToolCallDelta { index, name, arguments_delta, .. } => {
if let Some(n) = name {
tc_names.insert(index, n);
}
if let Some(d) = arguments_delta {
tc_args.entry(index).or_insert_with(String::new).push_str(&d);
}
}
StreamPart::Error { message } => {
yield AgentStreamPart::Error { message };
}
_ => {}
}
}
if tc_names.is_empty() && tc_args.is_empty() {
break;
}
let mut contents = vec![];
let mut tool_results_to_yield = vec![];
for (idx, name) in &tc_names {
let args_str = tc_args.get(idx).map(|s| s.as_str()).unwrap_or("{}");
let arguments: serde_json::Value = serde_json::from_str(args_str).unwrap_or(serde_json::Value::Null);
contents.push(Content::ToolCall {
id: name.clone(), name: name.clone(),
arguments: arguments.clone(),
});
yield AgentStreamPart::ToolCall {
id: name.clone(),
name: name.clone(),
arguments: arguments.clone(),
};
let handler = &self.tool_handler;
let result_val = match handler(name.clone(), arguments).await {
Ok(res) => res,
Err(e) => serde_json::json!({ "error": e.to_string() }),
};
yield AgentStreamPart::ToolResult {
name: name.clone(),
result: result_val.clone(),
};
tool_results_to_yield.push((name.clone(), result_val));
}
messages.push(Message {
role: Role::Assistant,
content: contents,
});
for (name, result_val) in tool_results_to_yield {
messages.push(Message {
role: Role::Tool,
content: vec![Content::ToolResult {
id: name,
result: result_val,
}],
});
}
}
};
Ok(Box::pin(stream))
}
}