use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub id: String,
pub name: String,
pub system_prompt: String,
pub model: LlmProvider,
#[serde(default)]
pub tools: Vec<ToolDefinition>,
#[serde(default)]
pub handoff_targets: Vec<String>,
#[serde(default = "default_max_steps")]
pub max_steps: usize,
}
fn default_max_steps() -> usize {
10
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "provider", rename_all = "snake_case")]
pub enum LlmProvider {
Ollama {
model: String,
#[serde(default = "default_ollama_url")]
base_url: String,
},
OpenAI { model: String, api_key: String },
Gemini { model: String, api_key: String },
Anthropic {
model: String,
api_key: String,
#[serde(default = "default_anthropic_max_tokens")]
max_tokens: usize,
},
OpenAICompatible {
model: String,
api_key: String,
base_url: String,
#[serde(default)]
provider_name: Option<String>,
},
}
fn default_ollama_url() -> String {
"http://localhost:11434".to_string()
}
fn default_anthropic_max_tokens() -> usize {
4096
}
impl fmt::Display for LlmProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LlmProvider::Ollama { model, .. } => write!(f, "ollama/{}", model),
LlmProvider::OpenAI { model, .. } => write!(f, "openai/{}", model),
LlmProvider::Gemini { model, .. } => write!(f, "gemini/{}", model),
LlmProvider::Anthropic { model, .. } => write!(f, "anthropic/{}", model),
LlmProvider::OpenAICompatible {
model,
provider_name,
..
} => {
let name = provider_name.as_deref().unwrap_or("custom");
write!(f, "{}/{}", name, model)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters_schema: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
pub timestamp: String,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
impl ChatMessage {
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
timestamp: chrono::Utc::now().to_rfc3339(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Role::System => write!(f, "system"),
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
Role::Tool => write!(f, "tool"),
}
}
}
pub struct AgentRuntime {
pub config: AgentConfig,
pub messages: Vec<ChatMessage>,
pub step: usize,
}
impl AgentRuntime {
pub fn new(config: AgentConfig) -> Self {
let mut sys_prompt = config.system_prompt.clone();
if !config.tools.is_empty() {
sys_prompt.push_str("\n\n# Available Tools\nYou have access to the following tools. To call a tool, output a JSON block wrapped in <tool_call> and </tool_call> tags. Example:\n<tool_call>\n{\"tool_name\": \"search\", \"arguments\": {\"query\": \"rust\"}}\n</tool_call>\n\nTools:\n");
for tool in &config.tools {
let schema =
serde_json::to_string_pretty(&tool.parameters_schema).unwrap_or_default();
sys_prompt.push_str(&format!(
"- **{}**: {}\n Schema: {}\n",
tool.name, tool.description, schema
));
}
}
let system_msg = ChatMessage::new(Role::System, &sys_prompt);
Self {
messages: vec![system_msg],
config,
step: 0,
}
}
pub fn push_message(&mut self, msg: ChatMessage) {
self.messages.push(msg);
}
pub fn advance_step(&mut self) -> bool {
self.step += 1;
self.step >= self.config.max_steps
}
pub fn is_exhausted(&self) -> bool {
self.step >= self.config.max_steps
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_config() -> AgentConfig {
AgentConfig {
id: "security-agent".into(),
name: "Security Reviewer".into(),
system_prompt: "You are a security expert. Analyze code for vulnerabilities.".into(),
model: LlmProvider::Ollama {
model: "llama3".into(),
base_url: default_ollama_url(),
},
tools: vec![],
handoff_targets: vec!["perf-agent".into()],
max_steps: 5,
}
}
#[test]
fn test_agent_config_serialization() {
let config = sample_config();
let json = serde_json::to_string_pretty(&config).unwrap();
let deser: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deser.id, "security-agent");
assert_eq!(deser.name, "Security Reviewer");
assert_eq!(deser.max_steps, 5);
}
#[test]
fn test_llm_provider_display() {
let ollama = LlmProvider::Ollama {
model: "llama3".into(),
base_url: default_ollama_url(),
};
assert_eq!(format!("{}", ollama), "ollama/llama3");
let anthropic = LlmProvider::Anthropic {
model: "claude-sonnet-4-20250514".into(),
api_key: "sk-test".into(),
max_tokens: 4096,
};
assert_eq!(
format!("{}", anthropic),
"anthropic/claude-sonnet-4-20250514"
);
let groq = LlmProvider::OpenAICompatible {
model: "llama-3.3-70b-versatile".into(),
api_key: "gsk-test".into(),
base_url: "https://api.groq.com/openai/v1".into(),
provider_name: Some("groq".into()),
};
assert_eq!(format!("{}", groq), "groq/llama-3.3-70b-versatile");
let custom = LlmProvider::OpenAICompatible {
model: "my-model".into(),
api_key: "key".into(),
base_url: "http://localhost:8080/v1".into(),
provider_name: None,
};
assert_eq!(format!("{}", custom), "custom/my-model");
}
#[test]
fn test_chat_message_builder() {
let msg = ChatMessage::new(Role::User, "Hello")
.with_metadata("agent_id", "sec-01")
.with_metadata("run_id", "run-abc");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.content, "Hello");
assert_eq!(msg.metadata.get("agent_id").unwrap(), "sec-01");
assert!(!msg.timestamp.is_empty());
}
#[test]
fn test_agent_runtime_lifecycle() {
let config = sample_config();
let mut rt = AgentRuntime::new(config);
assert_eq!(rt.step, 0);
assert_eq!(rt.messages.len(), 1); assert_eq!(rt.messages[0].role, Role::System);
rt.push_message(ChatMessage::new(Role::User, "Review this code"));
assert_eq!(rt.messages.len(), 2);
for _ in 0..4 {
assert!(!rt.advance_step());
}
assert!(!rt.is_exhausted());
assert!(rt.advance_step());
assert!(rt.is_exhausted());
}
#[test]
fn test_role_display() {
assert_eq!(format!("{}", Role::System), "system");
assert_eq!(format!("{}", Role::Assistant), "assistant");
assert_eq!(format!("{}", Role::Tool), "tool");
}
}