use crate::{
agent::models::{
configs::{ModelConfig, PromptConfig},
error::AgentBuildError,
},
notifications::Notification,
services::{
llm::{ClientBuilder, ClientConfig, Provider, ResponseFormatConfig, SchemaSpec},
mcp::mcp_tool_builder::McpServerType,
},
skills::{build_read_skill_tool, load_skill_sources},
templates::Template,
Agent, Flow, FlowFuture, Skill, Tool, ToolBuilderError, SKILL_SYSTEM_PROMPT_TEMPLATE,
};
use futures::future::join_all;
use rmcp::schemars::JsonSchema;
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use tokio::sync::{mpsc, Mutex};
#[derive(Debug, Default)]
pub struct AgentBuilder {
name: Option<String>,
client_config: ClientConfig,
model_config: ModelConfig,
template: Option<Arc<Mutex<Template>>>,
system_prompt: Option<String>,
tools: Option<Vec<Tool>>,
response_format: ResponseFormatConfig,
mcp_servers: Option<Vec<McpServerType>>,
skill_paths: Vec<PathBuf>,
skill_collection_paths: Vec<PathBuf>,
builtin_skills: Vec<Skill>,
stop_prompt: Option<String>,
stopword: Option<String>,
strip_thinking: Option<bool>,
max_iterations: Option<usize>,
clear_histroy_on_invoke: Option<bool>,
stream: Option<bool>,
keep_alive: Option<String>,
notification_channel: Option<mpsc::Sender<Notification>>,
flow: Option<Flow>,
}
impl AgentBuilder {
pub fn import_client_config(mut self, conf: ClientConfig) -> Self {
if let Some(provider) = conf.provider {
self = self.set_provider(provider);
}
if let Some(base_url) = conf.base_url {
self = self.set_base_url(base_url);
}
if let Some(api_key) = conf.api_key {
self = self.set_api_key(api_key);
}
if let Some(organization) = conf.organization {
self = self.set_organization(organization);
}
if let Some(extra_headers) = conf.extra_headers {
self = self.set_extra_headers(extra_headers);
}
self
}
pub fn import_prompt_config(mut self, conf: PromptConfig) -> Self {
if let Some(template) = conf.template {
self = self.set_template(template);
}
if let Some(system_prompt) = conf.system_prompt {
self = self.set_system_prompt(system_prompt);
}
if let Some(tools) = conf.tools {
for tool in tools {
self = self.add_tool(tool);
}
}
if let Some(response_format) = conf.response_format {
self = self.set_response_format_spec(response_format);
}
if let Some(mcp_servers) = conf.mcp_servers {
for mcp in mcp_servers {
self = self.add_mcp_server(mcp);
}
}
if let Some(stop_prompt) = conf.stop_prompt {
self = self.set_stop_prompt(stop_prompt);
}
if let Some(stopword) = conf.stopword {
self = self.set_stopword(stopword);
}
if let Some(strip_thinking) = conf.strip_thinking {
self = self.strip_thinking(strip_thinking);
}
if let Some(max_iterations) = conf.max_iterations {
self = self.set_max_iterations(max_iterations);
}
if let Some(clear_histroy_on_invoke) = conf.clear_histroy_on_invoke {
self = self.set_clear_history_on_invocation(clear_histroy_on_invoke);
}
self = self.set_stream(conf.stream);
self
}
pub fn import_model_config(mut self, conf: ModelConfig) -> Self {
if let Some(model) = conf.model {
self = self.set_model(model)
}
if let Some(temperature) = conf.temperature {
self = self.set_temperature(temperature)
}
if let Some(top_p) = conf.top_p {
self = self.set_top_p(top_p)
}
if let Some(presence_penalty) = conf.presence_penalty {
self = self.set_presence_penalty(presence_penalty)
}
if let Some(frequency_penalty) = conf.frequency_penalty {
self = self.set_frequency_penalty(frequency_penalty)
}
if let Some(num_ctx) = conf.num_ctx {
self = self.set_num_ctx(num_ctx)
}
if let Some(repeat_last_n) = conf.repeat_last_n {
self = self.set_repeat_last_n(repeat_last_n)
}
if let Some(repeat_penalty) = conf.repeat_penalty {
self = self.set_repeat_penalty(repeat_penalty)
}
if let Some(seed) = conf.seed {
self = self.set_seed(seed)
}
if let Some(stop) = conf.stop {
self = self.set_stop(stop)
}
if let Some(num_predict) = conf.num_predict {
self = self.set_num_predict(num_predict)
}
if let Some(top_k) = conf.top_k {
self = self.set_top_k(top_k)
}
if let Some(min_p) = conf.min_p {
self = self.set_min_p(min_p)
}
self
}
pub fn set_name<T>(mut self, name: T) -> Self
where
T: Into<String>,
{
self.name = Some(name.into());
self
}
pub fn set_provider(mut self, provider: Provider) -> Self {
self.client_config = self.client_config.provider(Some(provider));
self
}
pub fn set_base_url<T>(mut self, base_url: T) -> Self
where
T: Into<String>,
{
self.client_config = self.client_config.base_url(Some(base_url));
self
}
pub fn set_api_key<T>(mut self, api_key: T) -> Self
where
T: Into<String>,
{
self.client_config = self.client_config.api_key(Some(api_key));
self
}
pub fn set_organization<T>(mut self, organization: T) -> Self
where
T: Into<String>,
{
self.client_config = self.client_config.organization(Some(organization));
self
}
pub fn set_extra_headers(mut self, extra_headers: HashMap<String, String>) -> Self {
self.client_config = self.client_config.extra_headers(Some(extra_headers));
self
}
pub fn set_stream(mut self, set: bool) -> Self {
self.stream = Some(set);
self
}
pub fn set_temperature(mut self, v: f32) -> Self {
self.model_config.temperature = Some(v);
self
}
pub fn set_top_p(mut self, v: f32) -> Self {
self.model_config.top_p = Some(v);
self
}
pub fn set_presence_penalty(mut self, v: f32) -> Self {
self.model_config.presence_penalty = Some(v);
self
}
pub fn set_frequency_penalty(mut self, v: f32) -> Self {
self.model_config.frequency_penalty = Some(v);
self
}
pub fn set_num_ctx(mut self, v: u32) -> Self {
self.model_config.num_ctx = Some(v);
self
}
pub fn set_repeat_last_n(mut self, v: i32) -> Self {
self.model_config.repeat_last_n = Some(v);
self
}
pub fn set_keep_alive(mut self, v: String) -> Self {
self.keep_alive = Some(v);
self
}
pub fn set_repeat_penalty(mut self, v: f32) -> Self {
self.model_config.repeat_penalty = Some(v);
self
}
pub fn set_seed(mut self, v: i32) -> Self {
self.model_config.seed = Some(v);
self
}
pub fn set_stop<T: Into<String>>(mut self, v: T) -> Self {
self.model_config.stop = Some(v.into());
self
}
pub fn set_num_predict(mut self, v: i32) -> Self {
self.model_config.num_predict = Some(v);
self
}
pub fn set_top_k(mut self, v: u32) -> Self {
self.model_config.top_k = Some(v);
self
}
pub fn set_min_p(mut self, v: f32) -> Self {
self.model_config.min_p = Some(v);
self
}
pub fn set_model<T: Into<String>>(mut self, model: T) -> Self {
self.model_config.model = Some(model.into());
self
}
pub fn set_system_prompt<T: Into<String>>(mut self, prompt: T) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn set_stop_prompt<T: Into<String>>(mut self, stop_prompt: T) -> Self {
self.stop_prompt = Some(stop_prompt.into());
self
}
pub fn set_stopword<T: Into<String>>(mut self, stopword: T) -> Self {
self.stopword = Some(stopword.into());
self
}
pub fn strip_thinking(mut self, strip: bool) -> Self {
self.strip_thinking = Some(strip);
self
}
pub fn set_flow_fn(mut self, flow: Flow) -> Self {
self.flow = Some(flow);
self
}
pub fn set_flow<F>(self, f: F) -> Self
where
F: for<'a> Fn(&'a mut Agent, String) -> FlowFuture<'a> + Send + Sync + 'static,
{
self.set_flow_fn(Flow::from_fn(f))
}
pub fn add_tool(mut self, tool: Tool) -> Self {
if let Some(ref mut vec) = self.tools {
vec.push(tool);
} else {
self.tools = Some(vec![tool]);
}
self
}
pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn remove_tools(mut self) -> Self {
self.tools = None;
self
}
pub fn add_mcp_server(mut self, server: McpServerType) -> Self {
if let Some(ref mut svs) = self.mcp_servers {
svs.push(server);
} else {
self.mcp_servers = Some(vec![server]);
}
self
}
pub fn add_skill(mut self, path: impl Into<PathBuf>) -> Self {
self.skill_paths.push(path.into());
self
}
pub fn add_skill_collection(mut self, path: impl Into<PathBuf>) -> Self {
self.skill_collection_paths.push(path.into());
self
}
pub fn add_bash(mut self) -> Result<Self, ToolBuilderError> {
let bash_tool = crate::tools::prebuilt::bash::build_bash_tool(Default::default())?;
self = self.add_tool(bash_tool);
self = self.add_bash_skill();
Ok(self)
}
fn add_bash_skill(mut self) -> Self {
self.builtin_skills.push(crate::skills::bash_skill());
self
}
pub fn set_template(mut self, template: Template) -> Self {
self.template = Some(Arc::new(Mutex::new(template)));
self
}
pub fn set_max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = Some(max_iterations);
self
}
pub fn set_clear_history_on_invocation(mut self, clear: bool) -> Self {
self.clear_histroy_on_invoke = Some(clear);
self
}
pub fn set_response_format_str(mut self, schema_json: &str) -> Self {
self.response_format.set_raw(schema_json);
self
}
pub fn set_response_format_value(mut self, schema: serde_json::Value) -> Self {
self.response_format.set_value(schema);
self
}
pub fn set_response_format_from<T: JsonSchema>(mut self) -> Self {
self.response_format.set_type::<T>();
self
}
pub fn set_response_format_spec(mut self, schema: SchemaSpec) -> Self {
self.response_format.set_spec(schema);
self
}
pub fn set_schema_name(mut self, name: impl Into<String>) -> Self {
self.response_format.set_name(name);
self
}
pub fn set_schema_strict(mut self, strict: bool) -> Self {
self.response_format.set_strict(strict);
self
}
pub async fn build_with_notification(
mut self,
) -> Result<(Agent, mpsc::Receiver<Notification>), AgentBuildError> {
let (sender, receiver) = mpsc::channel(100);
self.notification_channel = Some(sender);
let agent = self.build().await?;
Ok((agent, receiver))
}
pub async fn build(self) -> Result<Agent, AgentBuildError> {
let model_config = self.model_config;
let model = model_config
.model
.clone()
.ok_or(AgentBuildError::ModelNotSet)?;
let skill_template = Template::simple(SKILL_SYSTEM_PROMPT_TEMPLATE);
let mut system_prompt = self
.system_prompt
.unwrap_or_else(|| "You are a helpful agent.".into());
let mut skills = load_skill_sources(&self.skill_paths, &self.skill_collection_paths)?;
skills.extend(self.builtin_skills);
let mut tools = self.tools.clone();
if !skills.is_empty() {
let skill_descriptions = join_all(
skills
.iter()
.map(|s| async move { s.discovery_description().await }),
)
.await;
let skills_section = skill_descriptions.join("\n\n---\n\n");
let data = HashMap::from([
("system_prompt", system_prompt),
("skills_discovery", skills_section),
]);
system_prompt = skill_template.compile(&data).await;
if tools
.as_ref()
.is_some_and(|tools| tools.iter().any(|tool| tool.name() == "read_skill"))
{
return Err(AgentBuildError::ReservedToolName("read_skill".into()));
}
let read_skill_tool = build_read_skill_tool(&skills)?;
match tools.as_mut() {
Some(tools) => tools.push(read_skill_tool),
None => tools = Some(vec![read_skill_tool]),
}
}
let strip_thinking = self.strip_thinking.unwrap_or(true);
let clear_histroy_on_invoke = self.clear_histroy_on_invoke.unwrap_or(false);
let flow = self.flow.unwrap_or(Flow::Default);
let name = match self.name {
Some(n) => n,
None => format!("Agent-{model}"),
};
let stream = self.stream.unwrap_or(false);
let inference_client = self.client_config.build()?;
let response_format = self
.response_format
.resolve()
.map_err(AgentBuildError::InvalidJsonSchema)?;
let response_format = match response_format {
Some(f) => Some(inference_client.structured_output_format(&f)?),
None => None,
};
Agent::try_new(
name,
&model,
inference_client,
&system_prompt,
tools,
response_format,
self.stop_prompt,
self.stopword,
strip_thinking,
model_config.temperature,
model_config.top_p,
model_config.presence_penalty,
model_config.frequency_penalty,
model_config.num_ctx,
model_config.repeat_last_n,
model_config.repeat_penalty,
model_config.seed,
model_config.stop,
model_config.num_predict,
stream,
model_config.top_k,
model_config.min_p,
self.keep_alive,
self.notification_channel,
self.mcp_servers,
flow,
self.template,
skills,
self.max_iterations,
clear_histroy_on_invoke,
)
.await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use serde_json::Value;
use super::*;
use crate::{
notifications::NotificationContent, Agent, AsyncToolFn, FlowFuture, Message, ToolBuilder,
};
#[tokio::test]
async fn defaults_fail_without_model() {
let err = AgentBuilder::default().build().await.unwrap_err();
assert!(matches!(err, AgentBuildError::ModelNotSet));
}
#[tokio::test]
async fn build_minimal_succeeds() {
let agent = AgentBuilder::default()
.set_model("test-model")
.build()
.await
.expect("build should succeed");
assert_eq!(agent.model, "test-model");
assert_eq!(
agent.history.len(),
1,
"history should contain exactly the system prompt"
);
}
#[tokio::test]
async fn custom_system_prompt_and_response_format() {
let json = r#"{"type":"object"}"#;
let agent = AgentBuilder::default()
.set_model("m")
.set_system_prompt("Hello world")
.set_response_format_str(json)
.build()
.await
.unwrap();
assert_eq!(agent.history[0].content.as_ref().unwrap(), "Hello world");
assert!(agent.response_format.is_some());
assert_eq!(
agent
.response_format
.as_ref()
.unwrap()
.get("type")
.unwrap()
.as_str()
.unwrap(),
"object"
);
}
#[tokio::test]
async fn invalid_json_schema_errors() {
let bad = "not json";
let err = AgentBuilder::default()
.set_model("m")
.set_response_format_str(bad)
.build()
.await
.unwrap_err();
assert!(matches!(err, AgentBuildError::InvalidJsonSchema(_)));
}
#[tokio::test]
async fn add_tools() {
let weather_exec: AsyncToolFn = {
Arc::new(move |_model_args_json: Value| {
Box::pin(async move {
Ok(r#"
{
"type":"object",
"properties":{
"windy":{"type":"boolean"},
"temperature":{"type":"integer"},
"description":{"type":"string"}
},
"required":["windy","temperature","description"]
}
"#
.into())
})
})
};
let weather_tool = ToolBuilder::new()
.function_name("get_current_weather")
.function_description("Returns a weather forecast for a given location")
.add_required_property("location", "string", "City name")
.executor(weather_exec)
.build()
.unwrap();
let agent = AgentBuilder::default()
.set_model("x")
.add_tool(weather_tool.clone())
.build()
.await
.unwrap();
assert_eq!(agent.local_tools.unwrap()[0].name(), weather_tool.name());
}
#[tokio::test]
async fn build_with_notification_channel() {
let (agent, mut rx) = AgentBuilder::default()
.set_model("foo")
.build_with_notification()
.await
.unwrap();
agent
.notification_channel
.as_ref()
.unwrap()
.send(Notification::new(
"test".to_string(),
NotificationContent::Done(false, None),
))
.await
.unwrap();
let notified = rx.recv().await.unwrap();
assert!(matches!(
notified.content,
NotificationContent::Done(false, None)
));
}
#[tokio::test]
async fn custom_flow_invocation() {
fn echo_flow<'a>(_agent: &'a mut Agent, prompt: String) -> FlowFuture<'a> {
Box::pin(async move { Ok(Message::system(format!("ECHO: {prompt}"))) })
}
let agent = AgentBuilder::default()
.set_model("m")
.set_flow(echo_flow)
.build()
.await
.unwrap();
let mut a = agent.clone();
let resp = a.invoke_flow("abc").await.unwrap();
assert_eq!(resp.content.unwrap(), "ECHO: abc");
}
}