use std::fmt;
use serde::{Deserialize, Serialize};
use crate::generation::GenerationOptions;
use crate::tool::ToolChoice;
use crate::types::{ModelId, ProviderId};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct AgentName(String);
impl AgentName {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl fmt::Display for AgentName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentRole {
Primary,
SubAgent,
Internal,
}
impl AgentRole {
pub fn is_primary(&self) -> bool {
matches!(self, Self::Primary)
}
pub fn is_sub_agent(&self) -> bool {
matches!(self, Self::SubAgent)
}
pub fn is_internal(&self) -> bool {
matches!(self, Self::Internal)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentModelRef {
Inherit,
ById {
model_id: ModelId,
provider_id: ProviderId,
},
ByAlias {
alias: String,
},
}
impl AgentModelRef {
pub fn by_id(model_id: ModelId, provider_id: ProviderId) -> Self {
Self::ById {
model_id,
provider_id,
}
}
pub fn by_alias(alias: impl Into<String>) -> Self {
Self::ByAlias {
alias: alias.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolFilter {
#[default]
AllowAll,
AllowList { tools: Vec<String> },
DenyList { tools: Vec<String> },
None,
}
impl ToolFilter {
pub fn allow_list(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::AllowList {
tools: tools.into_iter().map(Into::into).collect(),
}
}
pub fn deny_list(tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::DenyList {
tools: tools.into_iter().map(Into::into).collect(),
}
}
pub fn is_allowed(&self, tool_name: &str) -> bool {
match self {
Self::AllowAll => true,
Self::AllowList { tools } => tools.iter().any(|t| t == tool_name),
Self::DenyList { tools } => !tools.iter().any(|t| t == tool_name),
Self::None => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentDefinition {
pub name: AgentName,
pub role: AgentRole,
#[serde(default)]
pub description: String,
#[serde(default)]
pub system_prompt: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<AgentModelRef>,
#[serde(default)]
pub tool_filter: ToolFilter,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub generation: Option<GenerationOptions>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_steps: Option<u32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub sub_agents: Vec<AgentName>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_options: Option<serde_json::Value>,
}
impl AgentDefinition {
pub fn new(name: impl Into<String>, role: AgentRole) -> Self {
Self {
name: AgentName::new(name),
role,
description: String::new(),
system_prompt: Vec::new(),
model: None,
tool_filter: ToolFilter::default(),
tool_choice: None,
generation: None,
max_steps: None,
sub_agents: Vec::new(),
output_schema: None,
provider_options: None,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = vec![prompt.into()];
self
}
pub fn append_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt.push(prompt.into());
self
}
pub fn with_system_prompts(mut self, prompts: Vec<String>) -> Self {
self.system_prompt = prompts;
self
}
pub fn with_model(mut self, model: AgentModelRef) -> Self {
self.model = Some(model);
self
}
pub fn with_tool_filter(mut self, filter: ToolFilter) -> Self {
self.tool_filter = filter;
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = Some(choice);
self
}
pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
self.generation = Some(generation);
self
}
pub fn with_max_steps(mut self, steps: u32) -> Self {
self.max_steps = Some(steps);
self
}
pub fn with_sub_agents(mut self, agents: Vec<AgentName>) -> Self {
self.sub_agents = agents;
self
}
pub fn add_sub_agent(mut self, agent: impl Into<String>) -> Self {
self.sub_agents.push(AgentName::new(agent));
self
}
pub fn with_output_schema(mut self, schema: serde_json::Value) -> Self {
self.output_schema = Some(schema);
self
}
pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
self.provider_options = Some(options);
self
}
pub fn joined_system_prompt(&self) -> String {
self.system_prompt.join("\n\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_name_new() {
let name = AgentName::new("explore");
assert_eq!(name.as_str(), "explore");
assert_eq!(name.to_string(), "explore");
}
#[test]
fn test_agent_name_serde_roundtrip() {
let name = AgentName::new("build");
let json = serde_json::to_string(&name).unwrap();
let restored: AgentName = serde_json::from_str(&json).unwrap();
assert_eq!(name, restored);
}
#[test]
fn test_agent_role_predicates() {
assert!(AgentRole::Primary.is_primary());
assert!(!AgentRole::Primary.is_sub_agent());
assert!(!AgentRole::Primary.is_internal());
assert!(!AgentRole::SubAgent.is_primary());
assert!(AgentRole::SubAgent.is_sub_agent());
assert!(!AgentRole::SubAgent.is_internal());
assert!(!AgentRole::Internal.is_primary());
assert!(!AgentRole::Internal.is_sub_agent());
assert!(AgentRole::Internal.is_internal());
}
#[test]
fn test_agent_role_serde() {
let json = serde_json::to_string(&AgentRole::SubAgent).unwrap();
assert_eq!(json, r#""sub_agent""#);
let restored: AgentRole = serde_json::from_str(&json).unwrap();
assert_eq!(restored, AgentRole::SubAgent);
}
#[test]
fn test_agent_model_ref_inherit() {
let r = AgentModelRef::Inherit;
let json = serde_json::to_string(&r).unwrap();
assert!(json.contains(r#""type":"inherit""#));
}
#[test]
fn test_agent_model_ref_by_id() {
let r = AgentModelRef::by_id(
ModelId::new("gpt-4o"),
ProviderId::new("openai"),
);
if let AgentModelRef::ById { model_id, provider_id } = &r {
assert_eq!(model_id.as_str(), "gpt-4o");
assert_eq!(provider_id.as_str(), "openai");
} else {
panic!("expected ById");
}
}
#[test]
fn test_agent_model_ref_by_alias() {
let r = AgentModelRef::by_alias("fast");
if let AgentModelRef::ByAlias { alias } = &r {
assert_eq!(alias, "fast");
} else {
panic!("expected ByAlias");
}
}
#[test]
fn test_agent_model_ref_serde_roundtrip() {
let refs = vec![
AgentModelRef::Inherit,
AgentModelRef::by_id(ModelId::new("claude-sonnet-4-20250514"), ProviderId::new("anthropic")),
AgentModelRef::by_alias("cheap"),
];
for r in refs {
let json = serde_json::to_string(&r).unwrap();
let restored: AgentModelRef = serde_json::from_str(&json).unwrap();
assert_eq!(r, restored);
}
}
#[test]
fn test_tool_filter_allow_all() {
let f = ToolFilter::AllowAll;
assert!(f.is_allowed("anything"));
}
#[test]
fn test_tool_filter_allow_list() {
let f = ToolFilter::allow_list(["read_file", "grep"]);
assert!(f.is_allowed("read_file"));
assert!(f.is_allowed("grep"));
assert!(!f.is_allowed("bash"));
}
#[test]
fn test_tool_filter_deny_list() {
let f = ToolFilter::deny_list(["bash", "write_file"]);
assert!(f.is_allowed("read_file"));
assert!(!f.is_allowed("bash"));
assert!(!f.is_allowed("write_file"));
}
#[test]
fn test_tool_filter_none() {
let f = ToolFilter::None;
assert!(!f.is_allowed("anything"));
}
#[test]
fn test_tool_filter_default_is_allow_all() {
assert_eq!(ToolFilter::default(), ToolFilter::AllowAll);
}
#[test]
fn test_tool_filter_serde_roundtrip() {
let filters = vec![
ToolFilter::AllowAll,
ToolFilter::allow_list(["read_file"]),
ToolFilter::deny_list(["bash"]),
ToolFilter::None,
];
for f in filters {
let json = serde_json::to_string(&f).unwrap();
let restored: ToolFilter = serde_json::from_str(&json).unwrap();
assert_eq!(f, restored);
}
}
#[test]
fn test_agent_definition_minimal() {
let agent = AgentDefinition::new("test", AgentRole::Primary);
assert_eq!(agent.name.as_str(), "test");
assert_eq!(agent.role, AgentRole::Primary);
assert!(agent.description.is_empty());
assert!(agent.system_prompt.is_empty());
assert!(agent.model.is_none());
assert_eq!(agent.tool_filter, ToolFilter::AllowAll);
assert!(agent.tool_choice.is_none());
assert!(agent.generation.is_none());
assert!(agent.max_steps.is_none());
assert!(agent.sub_agents.is_empty());
assert!(agent.output_schema.is_none());
assert!(agent.provider_options.is_none());
}
#[test]
fn test_agent_definition_builder() {
use crate::tool::ToolChoice;
use crate::generation::GenerationOptions;
let agent = AgentDefinition::new("explore", AgentRole::SubAgent)
.with_description("Search agent")
.with_system_prompt("You are a search specialist.")
.append_system_prompt("Be thorough.")
.with_model(AgentModelRef::by_alias("fast"))
.with_tool_filter(ToolFilter::allow_list(["read_file", "grep"]))
.with_tool_choice(ToolChoice::Auto)
.with_generation(GenerationOptions::new().with_temperature(0.3))
.with_max_steps(10)
.add_sub_agent("deep_search");
assert_eq!(agent.name.as_str(), "explore");
assert_eq!(agent.role, AgentRole::SubAgent);
assert_eq!(agent.description, "Search agent");
assert_eq!(agent.system_prompt.len(), 2);
assert_eq!(agent.joined_system_prompt(), "You are a search specialist.\n\nBe thorough.");
assert!(agent.model.is_some());
assert!(agent.tool_filter.is_allowed("read_file"));
assert!(!agent.tool_filter.is_allowed("bash"));
assert_eq!(agent.tool_choice, Some(ToolChoice::Auto));
assert_eq!(agent.generation.as_ref().unwrap().temperature, Some(0.3));
assert_eq!(agent.max_steps, Some(10));
assert_eq!(agent.sub_agents.len(), 1);
assert_eq!(agent.sub_agents[0].as_str(), "deep_search");
}
#[test]
fn test_agent_definition_serde_roundtrip() {
use crate::tool::ToolChoice;
use crate::generation::GenerationOptions;
let agent = AgentDefinition::new("build", AgentRole::Primary)
.with_description("Coding agent")
.with_system_prompt("Help with code.")
.with_model(AgentModelRef::by_id(
ModelId::new("gpt-4o"),
ProviderId::new("openai"),
))
.with_tool_filter(ToolFilter::deny_list(["dangerous_tool"]))
.with_tool_choice(ToolChoice::Required)
.with_generation(GenerationOptions::new().with_temperature(0.5).with_max_tokens(4096))
.with_max_steps(50)
.add_sub_agent("explore")
.add_sub_agent("title")
.with_output_schema(serde_json::json!({"type": "object"}))
.with_provider_options(serde_json::json!({"service_tier": "default"}));
let json = serde_json::to_string_pretty(&agent).unwrap();
let restored: AgentDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(agent.name, restored.name);
assert_eq!(agent.role, restored.role);
assert_eq!(agent.description, restored.description);
assert_eq!(agent.system_prompt, restored.system_prompt);
assert_eq!(agent.model, restored.model);
assert_eq!(agent.tool_filter, restored.tool_filter);
assert_eq!(agent.tool_choice, restored.tool_choice);
assert_eq!(agent.generation, restored.generation);
assert_eq!(agent.max_steps, restored.max_steps);
assert_eq!(agent.sub_agents, restored.sub_agents);
assert_eq!(agent.output_schema, restored.output_schema);
assert_eq!(agent.provider_options, restored.provider_options);
}
#[test]
fn test_agent_definition_joined_prompt_empty() {
let agent = AgentDefinition::new("empty", AgentRole::Internal);
assert_eq!(agent.joined_system_prompt(), "");
}
#[test]
fn test_agent_definition_joined_prompt_single() {
let agent = AgentDefinition::new("t", AgentRole::Internal)
.with_system_prompt("Hello");
assert_eq!(agent.joined_system_prompt(), "Hello");
}
}