use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use serde_json::Value;
use crate::error::{AgenticError, Result};
use crate::persistence::session::SessionStore;
use crate::provider::model::Model;
use crate::provider::Provider;
use crate::tools::{SpawnAgentTool, ToolRegistry, Toolable};
use crate::util::generate_agent_name;
use super::event::Event;
use super::output::{AgentOutput, OutputSchema};
use super::queue::CommandQueue;
use super::r#loop::{run_loop, LoopRuntime, LoopState};
use super::spec::{build_context_prompt, AgentSpec};
#[derive(Clone)]
pub struct Agent {
pub(crate) spec: Arc<AgentSpec>,
pub(crate) provider: Option<Arc<dyn Provider>>,
pub(crate) instruction_prompt: String,
pub(crate) template_variables: HashMap<String, Value>,
pub(crate) working_directory: Option<PathBuf>,
pub(crate) event_handler: Option<Arc<dyn Fn(Event) + Send + Sync>>,
pub(crate) cancel_signal: Option<Arc<AtomicBool>>,
pub(crate) command_queue: Option<Arc<CommandQueue>>,
pub(crate) session_dir: Option<PathBuf>,
}
impl Default for Agent {
fn default() -> Self {
Self {
spec: Arc::new(AgentSpec::default()),
provider: None,
instruction_prompt: String::new(),
template_variables: HashMap::new(),
working_directory: None,
event_handler: None,
cancel_signal: None,
command_queue: None,
session_dir: None,
}
}
}
fn load_prompt_file(path: PathBuf) -> String {
std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("failed to read prompt file {}: {e}", path.display()))
}
fn load_json_file(path: PathBuf) -> Value {
let content = load_prompt_file(path.clone());
serde_json::from_str(&content)
.unwrap_or_else(|e| panic!("invalid JSON in {}: {e}", path.display()))
}
impl Agent {
pub const DEFAULT_MAX_REQUEST_RETRIES: u32 = AgentSpec::DEFAULT_MAX_REQUEST_RETRIES;
pub const DEFAULT_BACKOFF_MS: u64 = AgentSpec::DEFAULT_BACKOFF_MS;
pub fn new() -> Self {
Self::default()
}
fn with_spec<F: FnOnce(&mut AgentSpec)>(mut self, f: F) -> Self {
f(Arc::make_mut(&mut self.spec));
self
}
pub(crate) fn interpolate(&self, template: &str) -> String {
let mut result = template.to_string();
for (key, value) in &self.template_variables {
let replacement = match value {
Value::String(s) => s.clone(),
other => other.to_string(),
};
result = result.replace(&format!("{{{key}}}"), &replacement);
}
result
}
pub fn name(self, n: impl Into<String>) -> Self {
self.with_spec(|c| c.name = n.into())
}
pub fn model_name(self, name: impl Into<String>) -> Self {
self.with_spec(|c| c.model = Some(Model::from_name(name)))
}
pub fn model(self, model: Model) -> Self {
self.with_spec(|c| c.model = Some(model))
}
pub fn identity_prompt(self, p: impl Into<String>) -> Self {
self.with_spec(|c| c.identity_prompt = p.into())
}
pub fn identity_prompt_file(self, path: impl Into<PathBuf>) -> Self {
let s = load_prompt_file(path.into());
self.with_spec(|c| c.identity_prompt = s)
}
pub fn max_request_tokens(self, n: u32) -> Self {
self.with_spec(|c| c.max_request_tokens = Some(n))
}
pub fn max_turns(self, n: u32) -> Self {
self.with_spec(|c| c.max_turns = Some(n))
}
pub fn max_input_tokens(self, n: u64) -> Self {
self.with_spec(|c| c.max_input_tokens = Some(n))
}
pub fn max_output_tokens(self, n: u64) -> Self {
self.with_spec(|c| c.max_output_tokens = Some(n))
}
pub fn tool(self, tool: impl Toolable + 'static) -> Self {
self.with_spec(|c| c.tools.register(tool))
}
pub fn output_schema(self, schema: Value) -> Self {
let schema =
OutputSchema::new(schema).unwrap_or_else(|e| panic!("invalid output schema: {e}"));
self.with_spec(|c| c.output_schema = Some(schema))
}
pub fn output_schema_file(self, path: impl Into<PathBuf>) -> Self {
self.output_schema(load_json_file(path.into()))
}
pub fn max_schema_retries(self, n: u32) -> Self {
self.with_spec(|c| c.max_schema_retries = Some(n))
}
pub fn max_request_retries(self, n: u32) -> Self {
self.with_spec(|c| c.max_request_retries = n)
}
pub fn request_retry_delay(self, ms: u64) -> Self {
self.with_spec(|c| c.request_retry_delay = ms)
}
pub fn keep_alive(self) -> Self {
self.with_spec(|c| c.keep_alive = true)
}
pub fn behavior_prompt(self, content: impl Into<String>) -> Self {
let content = content.into();
self.with_spec(|c| c.behavior_prompt = content)
}
pub fn behavior_prompt_file(self, path: impl Into<PathBuf>) -> Self {
let content = load_prompt_file(path.into());
self.with_spec(|c| c.behavior_prompt = content)
}
pub fn context_prompt(self, content: impl Into<String>) -> Self {
let content = content.into();
self.with_spec(|c| c.context_prompts.push(content))
}
pub fn context_prompt_file(self, path: impl Into<PathBuf>) -> Self {
let content = load_prompt_file(path.into());
self.with_spec(|c| c.context_prompts.push(content))
}
pub fn sub_agents(self, agents: impl IntoIterator<Item = Agent>) -> Self {
let agents: Vec<_> = agents.into_iter().collect();
self.with_spec(|c| c.sub_agents.extend(agents))
}
pub fn provider(mut self, p: Arc<dyn Provider>) -> Self {
self.provider = Some(p);
self
}
pub fn provider_from_env(self) -> Result<Self> {
let (provider, model) = crate::provider::from_env()?;
Ok(self.provider(provider).model_name(model))
}
pub fn instruction_prompt(mut self, p: impl Into<String>) -> Self {
self.instruction_prompt = p.into();
self
}
pub fn instruction_prompt_file(mut self, path: impl Into<PathBuf>) -> Self {
self.instruction_prompt = load_prompt_file(path.into());
self
}
pub fn template_variable(mut self, key: impl Into<String>, value: Value) -> Self {
self.template_variables.insert(key.into(), value);
self
}
pub fn working_directory(mut self, d: impl Into<PathBuf>) -> Self {
self.working_directory = Some(d.into());
self
}
pub fn event_handler(mut self, h: Arc<dyn Fn(Event) + Send + Sync>) -> Self {
self.event_handler = Some(h);
self
}
pub fn silent(mut self) -> Self {
self.event_handler = Some(Arc::new(|_| {}));
self
}
pub fn cancel_signal(mut self, s: Arc<AtomicBool>) -> Self {
self.cancel_signal = Some(s);
self
}
pub(crate) fn command_queue(mut self, q: Arc<CommandQueue>) -> Self {
self.command_queue = Some(q);
self
}
pub fn session_dir(mut self, d: impl Into<PathBuf>) -> Self {
self.session_dir = Some(d.into());
self
}
pub fn name_ref(&self) -> &str {
&self.spec.name
}
pub async fn run(&self) -> Result<AgentOutput> {
let (spec, runtime) = self.compile(None)?;
let runtime = Arc::new(runtime);
let instruction = self.interpolate(&self.instruction_prompt);
let context_prompt =
build_context_prompt(&spec.context_prompts, runtime.metadata.as_deref());
let state = LoopState::initial(context_prompt, instruction);
run_loop(runtime, spec, state, None).await
}
pub(crate) async fn run_child(
&self,
parent_spec: &AgentSpec,
parent_runtime: &LoopRuntime,
description: Option<String>,
) -> Result<AgentOutput> {
let (spec, runtime) = self.compile(Some((parent_spec, parent_runtime)))?;
let runtime = Arc::new(runtime);
let instruction = self.interpolate(&self.instruction_prompt);
let context_prompt =
build_context_prompt(&spec.context_prompts, runtime.metadata.as_deref());
let state = LoopState::initial(context_prompt, instruction);
run_loop(runtime, spec, state, description).await
}
pub(crate) fn apply_overrides(mut self, overrides: &Value) -> Self {
if let Some(m) = overrides.get("model").and_then(Value::as_str) {
self = self.model_name(m);
}
if let Some(i) = overrides.get("identity").and_then(Value::as_str) {
self = self.identity_prompt(i);
}
if let Some(t) = overrides.get("max_request_tokens").and_then(Value::as_u64) {
self = self.max_request_tokens(t as u32);
}
if let Some(t) = overrides.get("max_input_tokens").and_then(Value::as_u64) {
self = self.max_input_tokens(t);
}
if let Some(t) = overrides.get("max_output_tokens").and_then(Value::as_u64) {
self = self.max_output_tokens(t);
}
if let Some(mt) = overrides.get("max_turns").and_then(Value::as_u64) {
self = self.max_turns(mt as u32);
}
if let Some(sr) = overrides.get("max_schema_retries").and_then(Value::as_u64) {
self = self.max_schema_retries(sr as u32);
}
if let Some(rr) = overrides.get("max_request_retries").and_then(Value::as_u64) {
self = self.max_request_retries(rr as u32);
}
if let Some(bo) = overrides.get("request_retry_delay").and_then(Value::as_u64) {
self = self.request_retry_delay(bo);
}
if let Some(schema) = overrides.get("output_schema").cloned() {
self = self.output_schema(schema);
}
self
}
pub(crate) fn compile(
&self,
parent: Option<(&AgentSpec, &LoopRuntime)>,
) -> Result<(Arc<AgentSpec>, LoopRuntime)> {
let resolved_model = match (self.spec.model.as_ref(), parent) {
(Some(m), _) => m.clone(),
(None, Some((parent_spec, _))) => parent_spec.model().clone(),
(None, None) => {
return Err(AgenticError::Other(
"root agent requires an explicit .model() / .model_name() (or must be spawned as a child)"
.into(),
));
}
};
let mut spec = Arc::clone(&self.spec);
Arc::make_mut(&mut spec).model = Some(resolved_model);
let runtime = match parent {
Some((_, parent_runtime)) => self.inherit_runtime(parent_runtime, &spec),
None => self.build_runtime(&spec)?,
};
Ok((spec, runtime))
}
fn build_runtime(&self, spec: &AgentSpec) -> Result<LoopRuntime> {
let provider = self
.provider
.clone()
.ok_or_else(|| AgenticError::Other("Agent::run() requires a provider".into()))?;
let working_directory = self
.working_directory
.clone()
.unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
let event_handler: Arc<dyn Fn(Event) + Send + Sync> = self
.event_handler
.clone()
.unwrap_or_else(Event::default_logger);
let cancel_signal = self
.cancel_signal
.clone()
.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
let command_queue = Some(
self.command_queue
.clone()
.unwrap_or_else(|| Arc::new(CommandQueue::new())),
);
let session_store = self.session_dir.as_ref().map(|dir| {
let store = SessionStore::new(dir, &generate_agent_name("session"));
Arc::new(Mutex::new(store))
});
let metadata = Some(LoopRuntime::environment(&working_directory));
Ok(LoopRuntime {
provider,
event_handler,
cancel_signal,
working_directory,
command_queue,
session_store,
metadata,
discovered_tools: Arc::new(Mutex::new(HashSet::new())),
tools: build_tools(spec),
template_variables: self.template_variables.clone(),
})
}
fn inherit_runtime(&self, parent: &LoopRuntime, spec: &AgentSpec) -> LoopRuntime {
LoopRuntime {
provider: self
.provider
.clone()
.unwrap_or_else(|| parent.provider.clone()),
event_handler: self
.event_handler
.clone()
.unwrap_or_else(|| parent.event_handler.clone()),
cancel_signal: self
.cancel_signal
.clone()
.unwrap_or_else(|| parent.cancel_signal.clone()),
working_directory: self
.working_directory
.clone()
.unwrap_or_else(|| parent.working_directory.clone()),
command_queue: parent.command_queue.clone(),
session_store: parent.session_store.clone(),
metadata: parent.metadata.clone(),
discovered_tools: parent.discovered_tools.clone(),
tools: build_tools(spec),
template_variables: self.template_variables.clone(),
}
}
}
fn build_tools(spec: &AgentSpec) -> Arc<ToolRegistry> {
let mut tools = spec.tools.clone();
if !spec.sub_agents.is_empty() && tools.get("spawn_agent").is_none() {
tools.register(SpawnAgentTool);
}
Arc::new(tools)
}
#[cfg(test)]
mod tests {
use super::super::event::EventKind;
use super::super::output::AgentStatus;
use super::*;
use crate::error::AgenticError;
#[test]
fn silent_sets_a_no_op_handler() {
let agent = Agent::new().silent();
let handler = agent
.event_handler
.as_ref()
.expect(".silent() must install a handler")
.clone();
handler(Event::new(
"t",
EventKind::AgentFinished {
turns: 1,
status: AgentStatus::Completed,
},
));
}
#[test]
fn default_logger_is_used_when_no_handler_is_set() {
let agent = Agent::new()
.name("t")
.model_name("mock")
.identity_prompt("")
.provider(std::sync::Arc::new(crate::testutil::MockProvider::text(
"ok",
)));
assert!(agent.event_handler.is_none());
let _ = agent.compile(None).expect("compile with default logger");
}
#[test]
fn identity_prompt_file_loads_content() {
let dir = std::env::temp_dir().join("agentwerk_test_werk_identity");
let path = dir.join("identity.txt");
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(&path, "You are a test agent").unwrap();
let agent = Agent::new().identity_prompt_file(&path);
assert_eq!(agent.spec.identity_prompt, "You are a test agent");
std::fs::remove_file(&path).ok();
std::fs::remove_dir(&dir).ok();
}
#[test]
#[should_panic(expected = "failed to read prompt file")]
fn missing_prompt_file_panics() {
let _ = Agent::new().identity_prompt_file("/nonexistent/xxx.txt");
}
#[test]
fn output_schema_file_loads_valid_schema() {
let dir = std::env::temp_dir().join("agentwerk_test_werk_schema");
let path = dir.join("schema.json");
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(
&path,
r#"{"type":"object","properties":{"answer":{"type":"string"}}}"#,
)
.unwrap();
let agent = Agent::new().output_schema_file(&path);
assert!(agent.spec.output_schema.is_some());
std::fs::remove_file(&path).ok();
std::fs::remove_dir(&dir).ok();
}
#[test]
#[should_panic(expected = "failed to read prompt file")]
fn output_schema_file_missing_file_panics() {
let _ = Agent::new().output_schema_file("/nonexistent/schema.json");
}
#[test]
#[should_panic(expected = "invalid output schema")]
fn invalid_output_schema_panics() {
let _ = Agent::new()
.name("test")
.identity_prompt("")
.output_schema(serde_json::json!({"type": "string"}));
}
#[tokio::test]
async fn apply_overrides_applies_json_fields() {
let base = Agent::new().name("x").model_name("original").max_turns(3);
let applied = base.apply_overrides(&serde_json::json!({
"model": "overridden",
"max_turns": 7,
"max_request_tokens": 256,
"max_input_tokens": 4000,
"max_output_tokens": 5000
}));
assert_eq!(applied.spec.max_turns, Some(7));
assert_eq!(applied.spec.max_request_tokens, Some(256));
assert_eq!(applied.spec.max_input_tokens, Some(4000));
assert_eq!(applied.spec.max_output_tokens, Some(5000));
match &applied.spec.model {
Some(m) => assert_eq!(m.name, "overridden"),
None => panic!("expected a resolved model"),
}
}
#[tokio::test]
async fn missing_provider_fails_run() {
let agent = Agent::new()
.name("test")
.model_name("mock")
.identity_prompt("x")
.instruction_prompt("do");
let err = agent.run().await.unwrap_err();
match err {
AgenticError::Other(msg) => assert!(msg.contains("provider"), "got: {msg}"),
other => panic!("expected Other, got {other:?}"),
}
}
}