use super::context::AgentContext;
use super::step::{AgentStep, StepResult};
use crate::error::ForgeError;
use crate::hooks::{HookContext, HookManager};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
#[derive(Debug)]
pub enum AgentError {
LlmError(ForgeError),
ToolError { tool_name: String, message: String },
MaxStepsExceeded { max_steps: usize },
Stopped,
ConfigError(String),
Other(String),
}
impl fmt::Display for AgentError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AgentError::LlmError(e) => write!(f, "LLM error: {}", e),
AgentError::ToolError { tool_name, message } => {
write!(f, "Tool '{}' error: {}", tool_name, message)
}
AgentError::MaxStepsExceeded { max_steps } => {
write!(f, "Maximum steps ({}) exceeded", max_steps)
}
AgentError::Stopped => write!(f, "Agent was stopped"),
AgentError::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
AgentError::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for AgentError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
AgentError::LlmError(e) => Some(e),
_ => None,
}
}
}
impl From<ForgeError> for AgentError {
fn from(e: ForgeError) -> Self {
AgentError::LlmError(e)
}
}
pub type AgentResult<T> = Result<T, AgentError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_max_steps")]
pub max_steps: usize,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub tools: Vec<String>,
#[serde(default)]
pub metadata: std::collections::HashMap<String, serde_json::Value>,
}
fn default_max_steps() -> usize {
10
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
name: "agent".to_string(),
system_prompt: None,
model: None,
max_steps: default_max_steps(),
temperature: None,
max_tokens: None,
streaming: false,
tools: Vec::new(),
metadata: std::collections::HashMap::new(),
}
}
}
impl AgentConfig {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_streaming(mut self, streaming: bool) -> Self {
self.streaming = streaming;
self
}
pub fn with_tool(mut self, tool_name: impl Into<String>) -> Self {
self.tools.push(tool_name.into());
self
}
pub fn with_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.tools.extend(tools.into_iter().map(|t| t.into()));
self
}
}
#[async_trait]
pub trait Agent: Send + Sync {
fn name(&self) -> &str;
fn config(&self) -> &AgentConfig;
fn context_mut(&mut self) -> &mut AgentContext;
fn context(&self) -> &AgentContext;
fn hooks(&self) -> Option<&Arc<HookManager>> {
None
}
async fn run(&mut self, input: &str) -> AgentResult<String> {
let hooks = self.hooks().cloned();
let agent_name = self.name().to_string();
if let Some(h) = &hooks {
let r = h.run(&HookContext::agent_start(&agent_name));
if r.is_abort() {
return Err(AgentError::Other(
r.error_message().unwrap_or("aborted by hook").to_string(),
));
}
}
self.context_mut()
.memory
.add_message(crate::types::Message::user(input));
let outcome: AgentResult<String> = loop {
if let Some(h) = &hooks {
let r = h.run(&HookContext::before_step(self.context().current_step));
if r.is_abort() {
break Err(AgentError::Other(
r.error_message().unwrap_or("aborted by hook").to_string(),
));
}
}
if !self.context().can_continue() {
if self.context().current_step >= self.context().max_steps {
break Err(AgentError::MaxStepsExceeded {
max_steps: self.context().max_steps,
});
}
break Err(AgentError::Other(
"Agent did not produce a response".to_string(),
));
}
let step = match self.step().await {
Ok(s) => s,
Err(e) => break Err(e),
};
if let Some(h) = &hooks {
let _ = h.run(&HookContext::after_step(
self.context().current_step,
&step.step_type.to_string(),
));
}
match &step.result {
StepResult::Done { response } => {
break Ok(response.clone());
}
StepResult::Error { message } => {
break Err(AgentError::Other(message.clone()));
}
StepResult::WaitForHuman { .. } => {
break Err(AgentError::Other(
"Human input required but not supported".to_string(),
));
}
StepResult::Continue | StepResult::ToolCalls { .. } => {
self.context_mut().increment_step();
}
}
};
if let Some(h) = &hooks {
let final_answer: String = match &outcome {
Ok(s) => s.clone(),
Err(e) => e.to_string(),
};
let total_steps = self.context().current_step;
let _ = h.run(&HookContext::agent_end(
&agent_name,
&final_answer,
total_steps,
));
}
outcome
}
async fn step(&mut self) -> AgentResult<AgentStep>;
fn stop(&mut self) {
self.context_mut().state = super::context::AgentState::Stopped;
}
fn reset(&mut self) {
self.context_mut().reset();
}
async fn run_with_history(
&mut self,
input: &str,
history: Vec<crate::types::Message>,
) -> AgentResult<String> {
for msg in history {
self.context_mut().memory.add_message(msg);
}
self.run(input).await
}
fn load_history(&mut self, history: Vec<crate::types::Message>) {
for msg in history {
self.context_mut().memory.add_message(msg);
}
}
fn conversation_messages(&self) -> Vec<crate::types::Message> {
self.context().memory.short_term.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_config_builder() {
let config = AgentConfig::new("test-agent")
.with_system_prompt("You are helpful")
.with_model("gpt-4")
.with_max_steps(5)
.with_temperature(0.7)
.with_tool("calculator")
.with_tool("search");
assert_eq!(config.name, "test-agent");
assert_eq!(config.system_prompt, Some("You are helpful".to_string()));
assert_eq!(config.model, Some("gpt-4".to_string()));
assert_eq!(config.max_steps, 5);
assert_eq!(config.temperature, Some(0.7));
assert_eq!(config.tools, vec!["calculator", "search"]);
}
#[test]
fn test_agent_error_display() {
let err = AgentError::MaxStepsExceeded { max_steps: 10 };
assert_eq!(err.to_string(), "Maximum steps (10) exceeded");
let err = AgentError::ToolError {
tool_name: "calc".to_string(),
message: "division by zero".to_string(),
};
assert_eq!(err.to_string(), "Tool 'calc' error: division by zero");
}
#[test]
fn test_agent_config_default() {
let config = AgentConfig::default();
assert_eq!(config.max_steps, 10);
assert!(config.tools.is_empty());
assert!(!config.streaming);
}
}