use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub id: String,
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(flatten)]
pub agent_type: AgentType,
#[serde(default)]
pub components: ComponentsConfig,
#[serde(default)]
pub capabilities: CapabilitiesConfig,
#[serde(default)]
pub custom: HashMap<String, serde_json::Value>,
#[serde(default)]
pub env_mappings: HashMap<String, String>,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default)]
pub version: Option<String>,
}
fn default_enabled() -> bool {
true
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
id: String::new(),
name: String::new(),
description: None,
agent_type: AgentType::default(),
components: ComponentsConfig::default(),
capabilities: CapabilitiesConfig::default(),
custom: HashMap::new(),
env_mappings: HashMap::new(),
enabled: true,
version: None,
}
}
}
impl AgentConfig {
pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
Self {
id: id.into(),
name: name.into(),
..Default::default()
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_type(mut self, agent_type: AgentType) -> Self {
self.agent_type = agent_type;
self
}
pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.custom.insert(key.into(), value);
self
}
pub fn get_custom<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.custom
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if self.id.is_empty() {
errors.push("Agent ID cannot be empty".to_string());
}
if self.name.is_empty() {
errors.push("Agent name cannot be empty".to_string());
}
if let Err(type_errors) = self.agent_type.validate() {
errors.extend(type_errors);
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentType {
Llm(LlmAgentConfig),
#[serde(rename = "react")]
ReAct(ReActAgentConfig),
Workflow(WorkflowAgentConfig),
Team(TeamAgentConfig),
Custom {
class_path: String,
#[serde(default)]
config: HashMap<String, serde_json::Value>,
},
}
impl Default for AgentType {
fn default() -> Self {
Self::Llm(LlmAgentConfig::default())
}
}
impl AgentType {
pub fn type_name(&self) -> &str {
match self {
Self::Llm(_) => "llm",
Self::ReAct(_) => "react",
Self::Workflow(_) => "workflow",
Self::Team(_) => "team",
Self::Custom { .. } => "custom",
}
}
pub fn validate(&self) -> Result<(), Vec<String>> {
match self {
Self::Llm(config) => config.validate(),
Self::ReAct(config) => config.validate(),
Self::Workflow(config) => config.validate(),
Self::Team(config) => config.validate(),
Self::Custom { class_path, .. } => {
if class_path.is_empty() {
Err(vec!["Custom agent class_path cannot be empty".to_string()])
} else {
Ok(())
}
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmAgentConfig {
pub model: String,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub stop_sequences: Vec<String>,
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub base_url: Option<String>,
#[serde(default)]
pub extra: HashMap<String, serde_json::Value>,
}
fn default_temperature() -> f32 {
0.7
}
impl Default for LlmAgentConfig {
fn default() -> Self {
Self {
model: "gpt-4".to_string(),
system_prompt: None,
temperature: 0.7,
max_tokens: None,
top_p: None,
stop_sequences: Vec::new(),
streaming: false,
api_key_env: None,
base_url: None,
extra: HashMap::new(),
}
}
}
impl LlmAgentConfig {
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if self.model.is_empty() {
errors.push("LLM model cannot be empty".to_string());
}
if self.temperature < 0.0 || self.temperature > 2.0 {
errors.push("Temperature must be between 0.0 and 2.0".to_string());
}
if let Some(top_p) = self.top_p
&& (!(0.0..=1.0).contains(&top_p))
{
errors.push("Top P must be between 0.0 and 1.0".to_string());
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReActAgentConfig {
pub llm: LlmAgentConfig,
#[serde(default = "default_max_steps")]
pub max_steps: usize,
#[serde(default)]
pub tools: Vec<ToolConfig>,
#[serde(default)]
pub parallel_tool_calls: bool,
#[serde(default)]
pub thought_format: Option<String>,
}
fn default_max_steps() -> usize {
10
}
impl Default for ReActAgentConfig {
fn default() -> Self {
Self {
llm: LlmAgentConfig::default(),
max_steps: 10,
tools: Vec::new(),
parallel_tool_calls: false,
thought_format: None,
}
}
}
impl ReActAgentConfig {
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if let Err(llm_errors) = self.llm.validate() {
errors.extend(llm_errors);
}
if self.max_steps == 0 {
errors.push("ReAct max_steps must be greater than 0".to_string());
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolConfig {
pub name: String,
#[serde(default)]
pub tool_type: ToolType,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolType {
#[default]
Builtin,
Mcp,
Custom,
Plugin,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WorkflowAgentConfig {
pub steps: Vec<WorkflowStep>,
#[serde(default)]
pub parallel: bool,
#[serde(default)]
pub error_strategy: ErrorStrategy,
}
impl WorkflowAgentConfig {
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if self.steps.is_empty() {
errors.push("Workflow steps cannot be empty".to_string());
}
for (i, step) in self.steps.iter().enumerate() {
if step.agent_id.is_empty() {
errors.push(format!("Workflow step {} agent_id cannot be empty", i));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStep {
pub id: String,
pub agent_id: String,
#[serde(default)]
pub input_mapping: HashMap<String, String>,
#[serde(default)]
pub output_mapping: HashMap<String, String>,
#[serde(default)]
pub condition: Option<String>,
#[serde(default)]
pub timeout_ms: Option<u64>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorStrategy {
#[default]
FailFast,
Continue,
Retry { max_retries: usize, delay_ms: u64 },
Fallback { fallback_agent_id: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TeamAgentConfig {
pub members: Vec<TeamMember>,
#[serde(default)]
pub coordination: CoordinationMode,
#[serde(default)]
pub leader_id: Option<String>,
#[serde(default)]
pub dispatch_strategy: DispatchStrategy,
}
impl TeamAgentConfig {
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
if self.members.is_empty() {
errors.push("Team members cannot be empty".to_string());
}
if matches!(self.coordination, CoordinationMode::Hierarchical) && self.leader_id.is_none() {
errors.push("Hierarchical coordination requires leader_id".to_string());
}
for member in &self.members {
if member.agent_id.is_empty() {
errors.push("Team member agent_id cannot be empty".to_string());
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TeamMember {
pub agent_id: String,
#[serde(default)]
pub role: Option<String>,
#[serde(default = "default_weight")]
pub weight: f32,
#[serde(default)]
pub optional: bool,
}
fn default_weight() -> f32 {
1.0
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CoordinationMode {
#[default]
Sequential,
Parallel,
Hierarchical,
Consensus,
Voting,
Debate,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DispatchStrategy {
#[default]
Broadcast,
RoundRobin,
Random,
LoadBalanced,
CapabilityBased,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ComponentsConfig {
#[serde(default)]
pub reasoner: Option<ReasonerConfig>,
#[serde(default)]
pub memory: Option<MemoryConfig>,
#[serde(default)]
pub coordinator: Option<CoordinatorConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasonerConfig {
#[serde(default)]
pub strategy: ReasonerStrategy,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasonerStrategy {
#[default]
Direct,
ChainOfThought,
TreeOfThought,
ReAct,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
#[serde(default)]
pub memory_type: MemoryType,
#[serde(default)]
pub max_items: Option<usize>,
#[serde(default)]
pub vector_db: Option<VectorDbConfig>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MemoryType {
#[default]
InMemory,
Redis,
Sqlite,
VectorDb,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorDbConfig {
pub db_type: String,
pub url: String,
#[serde(default)]
pub collection: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoordinatorConfig {
#[serde(default)]
pub pattern: CoordinationMode,
#[serde(default)]
pub timeout_ms: Option<u64>,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CapabilitiesConfig {
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub input_types: Vec<String>,
#[serde(default)]
pub output_types: Vec<String>,
#[serde(default)]
pub supports_streaming: bool,
#[serde(default)]
pub supports_tools: bool,
#[serde(default)]
pub supports_coordination: bool,
#[serde(default)]
pub reasoning_strategies: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_config_validation() {
let config = AgentConfig::new("test-agent", "Test Agent")
.with_type(AgentType::Llm(LlmAgentConfig::default()));
assert!(config.validate().is_ok());
}
#[test]
fn test_empty_config_validation() {
let config = AgentConfig::default();
assert!(config.validate().is_err());
}
#[test]
fn test_llm_config_serialization() {
let config = AgentConfig {
id: "llm-agent".to_string(),
name: "LLM Agent".to_string(),
agent_type: AgentType::Llm(LlmAgentConfig {
model: "gpt-4".to_string(),
temperature: 0.8,
..Default::default()
}),
..Default::default()
};
let json = serde_json::to_string_pretty(&config).unwrap();
assert!(json.contains("gpt-4"));
assert!(json.contains("0.8"));
}
#[test]
fn test_react_config_serialization() {
let config = AgentConfig {
id: "react-agent".to_string(),
name: "ReAct Agent".to_string(),
agent_type: AgentType::ReAct(ReActAgentConfig {
max_steps: 15,
..Default::default()
}),
..Default::default()
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("react"));
assert!(json.contains("15"));
}
#[test]
fn test_team_config_validation() {
let config = TeamAgentConfig {
members: vec![TeamMember {
agent_id: "agent-1".to_string(),
role: Some("worker".to_string()),
weight: 1.0,
optional: false,
}],
coordination: CoordinationMode::Hierarchical,
leader_id: None, dispatch_strategy: DispatchStrategy::Broadcast,
};
assert!(config.validate().is_err());
}
}