use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, Mutex};
use crate::{
agent::models::{configs::{ModelConfig, PromptConfig}, error::AgentBuildError}, notifications::Notification, services::{llm::{ClientConfig, Provider}, mcp::mcp_tool_builder::McpServerType}, templates::Template, Agent, Flow, Tool
};
#[derive(Debug, Default)]
pub struct AgentBuilder {
name: Option<String>,
model: Option<String>,
provider: Option<Provider>,
base_url: Option<String>,
api_key: Option<String>,
organization: Option<String>,
extra_headers: Option<HashMap<String, String>>,
template: Option<Arc<Mutex<Template>>>,
system_prompt: Option<String>,
tools: Option<Vec<Tool>>,
response_format: Option<String>,
mcp_servers: Option<Vec<McpServerType>>,
stop_prompt: Option<String>,
stopword: Option<String>,
strip_thinking: Option<bool>,
max_iterations: Option<usize>,
clear_histroy_on_invoke: Option<bool>,
temperature: Option<f32>,
top_p: Option<f32>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
num_ctx: Option<u32>,
repeat_last_n: Option<i32>,
repeat_penalty: Option<f32>,
seed: Option<i32>,
stop: Option<String>,
num_predict: Option<i32>,
top_k: Option<u32>,
min_p: Option<f32>,
stream: Option<bool>,
notification_channel: Option<mpsc::Sender<Notification>>,
flow: Option<Flow>,
}
impl AgentBuilder {
pub fn import_client_config(mut self, conf: ClientConfig) -> Self {
self = self.set_provider(conf.provider);
if let Some(base_url) = conf.base_url {
self = self.set_base_url(base_url);
}
if let Some(api_key) = conf.api_key {
self = self.set_api_key(api_key);
}
if let Some(organization) = conf.organization {
self = self.set_organization(organization);
}
if let Some(extra_headers) = conf.extra_headers {
self = self.set_extra_headers(extra_headers);
}
self
}
pub fn import_prompt_config(mut self, conf: PromptConfig) -> Self {
if let Some(template) = conf.template {
self = self.set_template(template);
}
if let Some(system_prompt) = conf.system_prompt {
self = self.set_system_prompt(system_prompt);
}
if let Some(tools) = conf.tools {
for tool in tools {
self = self.add_tool(tool);
}
}
if let Some(response_format) = conf.response_format {
self = self.set_response_format(response_format);
}
if let Some(mcp_servers) = conf.mcp_servers {
for mcp in mcp_servers {
self = self.add_mcp_server(mcp);
}
}
if let Some(stop_prompt) = conf.stop_prompt {
self = self.set_stop_prompt(stop_prompt);
}
if let Some(stopword) = conf.stopword {
self = self.set_stopword(stopword);
}
if let Some(strip_thinking) = conf.strip_thinking {
self = self.strip_thinking(strip_thinking);
}
if let Some(max_iterations) = conf.max_iterations {
self = self.set_max_iterations(max_iterations);
}
if let Some(clear_histroy_on_invoke) = conf.clear_histroy_on_invoke {
self = self.set_clear_history_on_invocation(clear_histroy_on_invoke);
}
self = self.set_stream(conf.stream);
self
}
pub fn import_model_config(mut self, conf: ModelConfig) -> Self {
if let Some(model) = conf.model {
self = self.set_model(model)
}
if let Some(temperature) = conf.temperature {
self = self.set_temperature(temperature)
}
if let Some(top_p) = conf.top_p {
self = self.set_top_p(top_p)
}
if let Some(presence_penalty) = conf.presence_penalty {
self = self.set_presence_penalty(presence_penalty)
}
if let Some(frequency_penalty) = conf.frequency_penalty {
self = self.set_frequency_penalty(frequency_penalty)
}
if let Some(num_ctx) = conf.num_ctx {
self = self.set_num_ctx(num_ctx)
}
if let Some(repeat_last_n) = conf.repeat_last_n {
self = self.set_repeat_last_n(repeat_last_n)
}
if let Some(repeat_penalty) = conf.repeat_penalty {
self = self.set_repeat_penalty(repeat_penalty)
}
if let Some(seed) = conf.seed {
self = self.set_seed(seed)
}
if let Some(stop) = conf.stop {
self = self.set_stop(stop)
}
if let Some(num_predict) = conf.num_predict {
self = self.set_num_predict(num_predict)
}
if let Some(top_k) = conf.top_k {
self = self.set_top_k(top_k)
}
if let Some(min_p) = conf.min_p {
self = self.set_min_p(min_p)
}
self
}
pub fn set_name<T>(mut self, name: T) -> Self where T: Into<String> {
self.name = Some(name.into());
self
}
pub fn set_provider(mut self, provider: Provider) -> Self {
self.provider = Some(provider);
self
}
pub fn set_base_url<T>(mut self, base_url: T) -> Self where T: Into<String> {
self.base_url = Some(base_url.into());
self
}
pub fn set_api_key<T>(mut self, api_key: T) -> Self where T: Into<String> {
self.api_key = Some(api_key.into());
self
}
pub fn set_organization<T>(mut self, organization: T) -> Self where T: Into<String> {
self.organization = Some(organization.into());
self
}
pub fn set_extra_headers(mut self, extra_headers:HashMap<String, String>) -> Self {
self.extra_headers = Some(extra_headers);
self
}
pub fn set_stream(mut self, set: bool) -> Self {
self.stream = Some(set);
self
}
pub fn set_temperature(mut self, v: f32) -> Self {
self.temperature = Some(v);
self
}
pub fn set_top_p(mut self, v: f32) -> Self {
self.top_p = Some(v);
self
}
pub fn set_presence_penalty(mut self, v: f32) -> Self {
self.presence_penalty = Some(v);
self
}
pub fn set_frequency_penalty(mut self, v: f32) -> Self {
self.frequency_penalty = Some(v);
self
}
pub fn set_num_ctx(mut self, v: u32) -> Self {
self.num_ctx = Some(v);
self
}
pub fn set_repeat_last_n(mut self, v: i32) -> Self {
self.repeat_last_n = Some(v);
self
}
pub fn set_repeat_penalty(mut self, v: f32) -> Self {
self.repeat_penalty = Some(v);
self
}
pub fn set_seed(mut self, v: i32) -> Self {
self.seed = Some(v);
self
}
pub fn set_stop<T: Into<String>>(mut self, v: T) -> Self {
self.stop = Some(v.into());
self
}
pub fn set_num_predict(mut self, v: i32) -> Self {
self.num_predict = Some(v);
self
}
pub fn set_top_k(mut self, v: u32) -> Self {
self.top_k = Some(v);
self
}
pub fn set_min_p(mut self, v: f32) -> Self {
self.min_p = Some(v);
self
}
pub fn set_model<T: Into<String>>(mut self, model: T) -> Self {
self.model = Some(model.into());
self
}
pub fn set_system_prompt<T: Into<String>>(mut self, prompt: T) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn set_response_format<T: Into<String>>(mut self, format: T) -> Self {
self.response_format = Some(format.into());
self
}
pub fn set_stop_prompt<T: Into<String>>(mut self, stop_prompt: T) -> Self {
self.stop_prompt = Some(stop_prompt.into());
self
}
pub fn set_stopword<T: Into<String>>(mut self, stopword: T) -> Self {
self.stopword = Some(stopword.into());
self
}
pub fn strip_thinking(mut self, strip: bool) -> Self {
self.strip_thinking = Some(strip);
self
}
pub fn set_flow(mut self, flow: Flow) -> Self {
self.flow = Some(flow);
self
}
pub fn add_tool(mut self, tool: Tool) -> Self {
if let Some(ref mut vec) = self.tools {
vec.push(tool);
} else {
self.tools = Some(vec![tool]);
}
self
}
pub fn add_mcp_server(mut self, server: McpServerType) -> Self {
if let Some(ref mut svs) = self.mcp_servers {
svs.push(server);
} else {
self.mcp_servers = Some(vec![server]);
}
self
}
pub fn set_template(mut self, template: Template) -> Self {
self.template = Some(Arc::new(Mutex::new(template)));
self
}
pub fn set_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = Some(max_iterations);
self
}
pub fn set_clear_history_on_invocation(mut self, clear: bool) -> Self {
self.clear_histroy_on_invoke = Some(clear);
self
}
pub async fn build_with_notification(
mut self
) -> Result<(Agent, mpsc::Receiver<Notification>), AgentBuildError> {
let (sender, receiver) = mpsc::channel(100);
self.notification_channel = Some(sender);
let agent = self.build().await?;
Ok((agent, receiver))
}
pub async fn build(self) -> Result<Agent, AgentBuildError> {
let model = self.model.ok_or(AgentBuildError::ModelNotSet)?;
let system_prompt = self.system_prompt.unwrap_or_else(|| "You are a helpful agent.".into());
let strip_thinking = self.strip_thinking.unwrap_or(true);
let clear_histroy_on_invoke = self.clear_histroy_on_invoke.unwrap_or(false);
let response_format = if let Some(schema) = self.response_format {
let trimmed = schema.trim();
match serde_json::from_str(trimmed) {
Ok(v) => Some(v),
Err(e) => {
return Err(AgentBuildError::InvalidJsonSchema(format!(
"Failed to parse JSON schema `{trimmed}`: {e}"
)))
}
}
} else {
None
};
let flow = self.flow.unwrap_or(Flow::Default);
let name = match self.name {
Some(n) => n,
None => format!("Agent-{model}"),
};
let stream = self.stream.unwrap_or(false);
let mut client_config = ClientConfig::default();
if let Some(provider) = self.provider {
client_config.provider = provider
}
if let Some(base_url) = self.base_url {
client_config.base_url = Some(base_url)
}
if let Some(api_key) = self.api_key {
client_config.api_key = Some(api_key)
}
if let Some(organization) = self.organization {
client_config.organization = Some(organization)
}
if let Some(extra_headers) = self.extra_headers {
client_config.extra_headers = Some(extra_headers)
}
Agent::try_new(
name,
&model,
client_config,
&system_prompt,
self.tools.clone(),
response_format,
self.stop_prompt,
self.stopword,
strip_thinking,
self.temperature,
self.top_p,
self.presence_penalty,
self.frequency_penalty,
self.num_ctx,
self.repeat_last_n,
self.repeat_penalty,
self.seed,
self.stop,
self.num_predict,
stream,
self.top_k,
self.min_p,
self.notification_channel,
self.mcp_servers,
flow.into(),
self.template,
self.max_iterations,
clear_histroy_on_invoke,
).await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use serde_json::Value;
use super::*;
use crate::{FlowFuture, notifications::NotificationContent, Agent, AsyncToolFn, Message, ToolBuilder};
#[tokio::test]
async fn defaults_fail_without_model() {
let err = AgentBuilder::default().build().await.unwrap_err();
assert!(matches!(err, AgentBuildError::ModelNotSet));
}
#[tokio::test]
async fn build_minimal_succeeds() {
let agent = AgentBuilder::default()
.set_model("test-model")
.build()
.await
.expect("build should succeed");
assert_eq!(agent.model, "test-model");
assert_eq!(
agent.history.len(),
1,
"history should contain exactly the system prompt"
);
}
#[tokio::test]
async fn custom_system_prompt_and_response_format() {
let json = r#"{"type":"object"}"#;
let agent = AgentBuilder::default()
.set_model("m")
.set_system_prompt("Hello world")
.set_response_format(json)
.build()
.await
.unwrap();
assert_eq!(agent.history[0].content.as_ref().unwrap(), "Hello world");
assert!(agent.response_format.is_some());
assert_eq!(
agent
.response_format
.as_ref()
.unwrap()
.get("type")
.unwrap()
.as_str()
.unwrap(),
"object"
);
}
#[tokio::test]
async fn invalid_json_schema_errors() {
let bad = "not json";
let err = AgentBuilder::default()
.set_model("m")
.set_response_format(bad)
.build()
.await
.unwrap_err();
assert!(matches!(err, AgentBuildError::InvalidJsonSchema(_)));
}
#[tokio::test]
async fn add_tools() {
let weather_exec: AsyncToolFn = {
Arc::new(move |_model_args_json: Value| {
Box::pin(async move {
Ok(r#"
{
"type":"object",
"properties":{
"windy":{"type":"boolean"},
"temperature":{"type":"integer"},
"description":{"type":"string"}
},
"required":["windy","temperature","description"]
}
"#.into())
})
})
};
let weather_tool = ToolBuilder::new()
.function_name("get_current_weather")
.function_description("Returns a weather forecast for a given location")
.add_required_property("location", "string", "City name")
.executor(weather_exec)
.build()
.unwrap();
let agent = AgentBuilder::default()
.set_model("x")
.add_tool(weather_tool.clone())
.build()
.await
.unwrap();
assert_eq!(agent.local_tools.unwrap()[0].name(), weather_tool.name());
}
#[tokio::test]
async fn build_with_notification_channel() {
let (agent, mut rx) = AgentBuilder::default()
.set_model("foo")
.build_with_notification()
.await
.unwrap();
agent
.notification_channel
.as_ref()
.unwrap()
.send(Notification::new(
"test".to_string() ,
NotificationContent::Done(false, None),
))
.await
.unwrap();
let notified = rx.recv().await.unwrap();
assert!(matches!(notified.content, NotificationContent::Done(false, None)));
}
#[tokio::test]
async fn custom_flow_invocation() {
fn echo_flow<'a>(_agent: &'a mut Agent, prompt: String) -> FlowFuture<'a> {
Box::pin(async move {
Ok(Message::system(format!("ECHO: {prompt}")))
})
}
let agent = AgentBuilder::default()
.set_model("m")
.set_flow(Flow::Custom(echo_flow))
.build()
.await
.unwrap();
let mut a = agent.clone();
let resp = a.invoke_flow("abc").await.unwrap();
assert_eq!(resp.content.unwrap(), "ECHO: abc");
}
}