use std::sync::Arc;
use crate::agent::Agent;
use crate::checkpoint::ErasedCheckpoint;
use crate::cost::CostTracker;
use crate::error::{DaimonError, Result};
use crate::guardrails::{InputGuardrail, OutputGuardrail};
use crate::hooks::{AgentHook, ErasedAgentHook};
use crate::memory::{Memory, SharedMemory, SlidingWindowMemory};
use crate::middleware::{Middleware, MiddlewareStack};
use crate::model::{Model, SharedModel};
use crate::tool::{Tool, ToolRegistry, ToolRetryPolicy};
impl Agent {
pub fn fork(&self) -> Agent {
let cost_tracker = self.cost_tracker.as_ref().map(|t| {
CostTracker::new(Arc::clone(&t.cost_model))
});
Agent {
model: self.model.clone(),
system_prompt: self.system_prompt.clone(),
tools: self.tools.clone(),
memory: Arc::new(SlidingWindowMemory::default()),
hooks: self.hooks.clone(),
middleware: self.middleware.clone(),
input_guardrails: self.input_guardrails.clone(),
output_guardrails: self.output_guardrails.clone(),
max_iterations: self.max_iterations,
temperature: self.temperature,
max_tokens: self.max_tokens,
validate_tool_inputs: self.validate_tool_inputs,
cost_tracker,
max_budget: self.max_budget,
tool_retry_policy: self.tool_retry_policy.clone(),
}
}
pub async fn fork_from_checkpoint(
&self,
run_id: &str,
checkpoint: &Arc<dyn ErasedCheckpoint>,
) -> Result<Agent> {
let state = checkpoint
.load_erased(run_id)
.await?
.ok_or_else(|| DaimonError::Other(format!("no checkpoint for run '{run_id}'")))?;
let memory = SlidingWindowMemory::new(state.messages.len() + 50);
for msg in &state.messages {
memory.add_message(msg.clone()).await?;
}
let cost_tracker = self.cost_tracker.as_ref().map(|t| {
CostTracker::new(Arc::clone(&t.cost_model))
});
Ok(Agent {
model: self.model.clone(),
system_prompt: self.system_prompt.clone(),
tools: self.tools.clone(),
memory: Arc::new(memory),
hooks: self.hooks.clone(),
middleware: self.middleware.clone(),
input_guardrails: self.input_guardrails.clone(),
output_guardrails: self.output_guardrails.clone(),
max_iterations: self.max_iterations,
temperature: self.temperature,
max_tokens: self.max_tokens,
validate_tool_inputs: self.validate_tool_inputs,
cost_tracker,
max_budget: self.max_budget,
tool_retry_policy: self.tool_retry_policy.clone(),
})
}
pub fn fork_with_memory<M: Memory + 'static>(&self, memory: M) -> Agent {
let cost_tracker = self.cost_tracker.as_ref().map(|t| {
CostTracker::new(Arc::clone(&t.cost_model))
});
Agent {
model: self.model.clone(),
system_prompt: self.system_prompt.clone(),
tools: self.tools.clone(),
memory: Arc::new(memory),
hooks: self.hooks.clone(),
middleware: self.middleware.clone(),
input_guardrails: self.input_guardrails.clone(),
output_guardrails: self.output_guardrails.clone(),
max_iterations: self.max_iterations,
temperature: self.temperature,
max_tokens: self.max_tokens,
validate_tool_inputs: self.validate_tool_inputs,
cost_tracker,
max_budget: self.max_budget,
tool_retry_policy: self.tool_retry_policy.clone(),
}
}
pub fn fork_builder(&self) -> ForkBuilder {
ForkBuilder {
model: self.model.clone(),
system_prompt: self.system_prompt.clone(),
tools: self.tools.clone(),
memory: None,
hooks: Some(self.hooks.clone()),
middleware: self.middleware.clone(),
input_guardrails: self.input_guardrails.clone(),
output_guardrails: self.output_guardrails.clone(),
max_iterations: self.max_iterations,
temperature: self.temperature,
max_tokens: self.max_tokens,
validate_tool_inputs: self.validate_tool_inputs,
cost_model: self.cost_tracker.as_ref().map(|t| Arc::clone(&t.cost_model)),
max_budget: self.max_budget,
tool_retry_policy: self.tool_retry_policy.clone(),
}
}
}
pub struct ForkBuilder {
model: SharedModel,
system_prompt: Option<String>,
tools: ToolRegistry,
memory: Option<SharedMemory>,
hooks: Option<Arc<dyn ErasedAgentHook>>,
middleware: MiddlewareStack,
input_guardrails: Vec<Arc<dyn crate::guardrails::ErasedInputGuardrail>>,
output_guardrails: Vec<Arc<dyn crate::guardrails::ErasedOutputGuardrail>>,
max_iterations: usize,
temperature: Option<f32>,
max_tokens: Option<u32>,
validate_tool_inputs: bool,
cost_model: Option<Arc<dyn crate::cost::CostModel>>,
max_budget: Option<f64>,
tool_retry_policy: Option<ToolRetryPolicy>,
}
impl ForkBuilder {
pub fn model<M: Model + 'static>(mut self, model: M) -> Self {
self.model = Arc::new(model);
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn no_system_prompt(mut self) -> Self {
self.system_prompt = None;
self
}
pub fn tool<T: Tool + 'static>(mut self, tool: T) -> Self {
let _ = self.tools.register(tool);
self
}
pub fn remove_tool(mut self, name: &str) -> Self {
self.tools.unregister(name);
self
}
pub fn memory<M: Memory + 'static>(mut self, memory: M) -> Self {
self.memory = Some(Arc::new(memory));
self
}
pub fn hooks<H: AgentHook + 'static>(mut self, hooks: H) -> Self {
self.hooks = Some(Arc::new(hooks));
self
}
pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
self.middleware.push(mw);
self
}
pub fn input_guardrail<G: InputGuardrail + 'static>(mut self, guard: G) -> Self {
self.input_guardrails.push(Arc::new(guard));
self
}
pub fn output_guardrail<G: OutputGuardrail + 'static>(mut self, guard: G) -> Self {
self.output_guardrails.push(Arc::new(guard));
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.max_tokens = Some(tokens);
self
}
pub fn validate_tool_inputs(mut self, enabled: bool) -> Self {
self.validate_tool_inputs = enabled;
self
}
pub fn tool_retry_policy(mut self, policy: ToolRetryPolicy) -> Self {
self.tool_retry_policy = Some(policy);
self
}
pub fn build(mut self) -> Agent {
let memory = self
.memory
.unwrap_or_else(|| Arc::new(SlidingWindowMemory::default()));
let hooks = self
.hooks
.unwrap_or_else(|| Arc::new(crate::hooks::NoOpHook));
self.tools.warm_cache();
let cost_tracker = self.cost_model.map(CostTracker::new);
Agent {
model: self.model,
system_prompt: self.system_prompt,
tools: self.tools,
memory,
hooks,
middleware: self.middleware,
input_guardrails: self.input_guardrails,
output_guardrails: self.output_guardrails,
max_iterations: self.max_iterations,
temperature: self.temperature,
max_tokens: self.max_tokens,
validate_tool_inputs: self.validate_tool_inputs,
cost_tracker,
max_budget: self.max_budget,
tool_retry_policy: self.tool_retry_policy,
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use crate::agent::Agent;
use crate::checkpoint::{Checkpoint, CheckpointState, InMemoryCheckpoint};
use crate::error::Result;
use crate::memory::SlidingWindowMemory;
use crate::model::Model;
use crate::model::types::*;
use crate::stream::ResponseStream;
use crate::tool::{Tool, ToolOutput};
struct EchoModel;
impl Model for EchoModel {
async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
let last = request
.messages
.last()
.and_then(|m| m.content.as_deref())
.unwrap_or("none");
Ok(ChatResponse {
message: Message::assistant(format!("Echo: {last}")),
stop_reason: StopReason::EndTurn,
usage: Some(Usage::default()),
})
}
async fn generate_stream(&self, _request: &ChatRequest) -> Result<ResponseStream> {
Ok(Box::pin(futures::stream::empty()))
}
}
struct DummyTool {
tool_name: &'static str,
}
impl Tool for DummyTool {
fn name(&self) -> &str {
self.tool_name
}
fn description(&self) -> &str {
"A dummy tool"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(&self, _input: &serde_json::Value) -> Result<ToolOutput> {
Ok(ToolOutput::text("ok"))
}
}
#[tokio::test]
async fn test_fork_has_independent_memory() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
agent.prompt("hello").await.unwrap();
let forked = agent.fork();
let original_msgs = agent.memory.get_messages_erased().await.unwrap();
let forked_msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(original_msgs.len(), 2);
assert_eq!(forked_msgs.len(), 0);
}
#[tokio::test]
async fn test_fork_preserves_config() {
let agent = Agent::builder()
.model(EchoModel)
.system_prompt("Be helpful")
.max_iterations(10)
.temperature(0.5)
.build()
.unwrap();
let forked = agent.fork();
assert_eq!(forked.system_prompt.as_deref(), Some("Be helpful"));
assert_eq!(forked.max_iterations, 10);
assert_eq!(forked.temperature, Some(0.5));
}
#[tokio::test]
async fn test_fork_from_checkpoint() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
let cp = Arc::new(InMemoryCheckpoint::new());
let state = CheckpointState::new(
"run-1",
vec![Message::user("hi"), Message::assistant("hello")],
1,
);
cp.save(&state).await.unwrap();
let forked = agent.fork_from_checkpoint("run-1", &(cp as Arc<_>)).await.unwrap();
let msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(msgs.len(), 2);
}
#[tokio::test]
async fn test_fork_from_checkpoint_missing_run() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
let cp: Arc<dyn crate::checkpoint::ErasedCheckpoint> =
Arc::new(InMemoryCheckpoint::new());
let result = agent.fork_from_checkpoint("missing", &cp).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_fork_with_memory() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
agent.prompt("original").await.unwrap();
let custom_mem = SlidingWindowMemory::new(5);
let forked = agent.fork_with_memory(custom_mem);
let original_msgs = agent.memory.get_messages_erased().await.unwrap();
let forked_msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(original_msgs.len(), 2);
assert_eq!(forked_msgs.len(), 0);
}
#[tokio::test]
async fn test_forked_agents_run_independently() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
let forked = agent.fork();
agent.prompt("msg1").await.unwrap();
forked.prompt("msg2").await.unwrap();
let a_msgs = agent.memory.get_messages_erased().await.unwrap();
let f_msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(a_msgs.len(), 2);
assert_eq!(f_msgs.len(), 2);
assert!(a_msgs[0].content.as_deref().unwrap().contains("msg1"));
assert!(f_msgs[0].content.as_deref().unwrap().contains("msg2"));
}
#[tokio::test]
async fn test_fork_builder_changes_system_prompt() {
let agent = Agent::builder()
.model(EchoModel)
.system_prompt("Original prompt")
.build()
.unwrap();
let forked = agent
.fork_builder()
.system_prompt("New prompt")
.build();
assert_eq!(forked.system_prompt.as_deref(), Some("New prompt"));
}
#[tokio::test]
async fn test_fork_builder_clears_system_prompt() {
let agent = Agent::builder()
.model(EchoModel)
.system_prompt("Original")
.build()
.unwrap();
let forked = agent.fork_builder().no_system_prompt().build();
assert!(forked.system_prompt.is_none());
}
#[tokio::test]
async fn test_fork_builder_adds_and_removes_tools() {
let agent = Agent::builder()
.model(EchoModel)
.tool(DummyTool { tool_name: "alpha" })
.tool(DummyTool { tool_name: "beta" })
.build()
.unwrap();
assert_eq!(agent.tools.len(), 2);
let forked = agent
.fork_builder()
.remove_tool("alpha")
.tool(DummyTool { tool_name: "gamma" })
.build();
assert!(forked.tools.get("alpha").is_none());
assert!(forked.tools.get("beta").is_some());
assert!(forked.tools.get("gamma").is_some());
assert_eq!(forked.tools.len(), 2);
}
#[tokio::test]
async fn test_fork_builder_overrides_iterations_and_temp() {
let agent = Agent::builder()
.model(EchoModel)
.max_iterations(5)
.temperature(0.7)
.build()
.unwrap();
let forked = agent
.fork_builder()
.max_iterations(20)
.temperature(0.1)
.build();
assert_eq!(forked.max_iterations, 20);
assert_eq!(forked.temperature, Some(0.1));
}
#[tokio::test]
async fn test_fork_builder_preserves_unchanged() {
let agent = Agent::builder()
.model(EchoModel)
.system_prompt("Keep me")
.max_iterations(8)
.build()
.unwrap();
let forked = agent.fork_builder().build();
assert_eq!(forked.system_prompt.as_deref(), Some("Keep me"));
assert_eq!(forked.max_iterations, 8);
}
#[tokio::test]
async fn test_fork_builder_independent_memory() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
agent.prompt("hello").await.unwrap();
let forked = agent.fork_builder().build();
let original_msgs = agent.memory.get_messages_erased().await.unwrap();
let forked_msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(original_msgs.len(), 2);
assert_eq!(forked_msgs.len(), 0);
}
#[tokio::test]
async fn test_fork_builder_custom_memory() {
let agent = Agent::builder().model(EchoModel).build().unwrap();
let mem = SlidingWindowMemory::new(3);
let forked = agent.fork_builder().memory(mem).build();
forked.prompt("a").await.unwrap();
let msgs = forked.memory.get_messages_erased().await.unwrap();
assert_eq!(msgs.len(), 2);
}
#[tokio::test]
async fn test_fork_builder_replace_model() {
struct AltModel;
impl Model for AltModel {
async fn generate(&self, _req: &ChatRequest) -> Result<ChatResponse> {
Ok(ChatResponse {
message: Message::assistant("alt".to_string()),
stop_reason: StopReason::EndTurn,
usage: Some(Usage::default()),
})
}
async fn generate_stream(&self, _req: &ChatRequest) -> Result<ResponseStream> {
Ok(Box::pin(futures::stream::empty()))
}
}
let agent = Agent::builder().model(EchoModel).build().unwrap();
let forked = agent.fork_builder().model(AltModel).build();
let resp = forked.prompt("test").await.unwrap();
assert_eq!(resp.text(), "alt");
}
}