use std::sync::Arc;
use tokio::sync::RwLock;
use crate::agent::{Agent, AgentResponse};
use crate::error::Result;
use crate::guardrails::{InputGuardrail, OutputGuardrail};
use crate::hooks::AgentHook;
use crate::memory::Memory;
use crate::middleware::Middleware;
use crate::model::{Model, SharedModel};
use crate::stream::ResponseStream;
use crate::tool::{Tool, ToolRetryPolicy};
pub struct HotSwapAgent {
inner: Arc<RwLock<Agent>>,
}
impl HotSwapAgent {
pub fn new(agent: Agent) -> Self {
Self {
inner: Arc::new(RwLock::new(agent)),
}
}
pub async fn prompt(&self, input: &str) -> Result<AgentResponse> {
let agent = self.inner.read().await;
agent.prompt(input).await
}
pub async fn prompt_stream(&self, input: &str) -> Result<ResponseStream> {
let agent = self.inner.read().await;
agent.prompt_stream(input).await
}
pub async fn swap_model<M: Model + 'static>(&self, model: M) {
let mut agent = self.inner.write().await;
agent.model = Arc::new(model);
}
pub async fn swap_shared_model(&self, model: SharedModel) {
let mut agent = self.inner.write().await;
agent.model = model;
}
pub async fn swap_system_prompt(&self, prompt: Option<String>) {
let mut agent = self.inner.write().await;
agent.system_prompt = prompt;
}
pub async fn add_tool<T: Tool + 'static>(&self, tool: T) -> bool {
let mut agent = self.inner.write().await;
agent.tools.register(tool).is_ok()
}
pub async fn remove_tool(&self, name: &str) -> bool {
let mut agent = self.inner.write().await;
agent.tools.unregister(name)
}
pub async fn swap_memory<M: Memory + 'static>(&self, memory: M) {
let mut agent = self.inner.write().await;
agent.memory = Arc::new(memory);
}
pub async fn swap_hooks<H: AgentHook + 'static>(&self, hooks: H) {
let mut agent = self.inner.write().await;
agent.hooks = Arc::new(hooks);
}
pub async fn swap_middleware(&self, stack: crate::middleware::MiddlewareStack) {
let mut agent = self.inner.write().await;
agent.middleware = stack;
}
pub async fn add_middleware<M: Middleware + 'static>(&self, mw: M) {
let mut agent = self.inner.write().await;
agent.middleware.push(mw);
}
pub async fn add_input_guardrail<G: InputGuardrail + 'static>(&self, guard: G) {
let mut agent = self.inner.write().await;
agent.input_guardrails.push(Arc::new(guard));
}
pub async fn add_output_guardrail<G: OutputGuardrail + 'static>(&self, guard: G) {
let mut agent = self.inner.write().await;
agent.output_guardrails.push(Arc::new(guard));
}
pub async fn clear_input_guardrails(&self) {
let mut agent = self.inner.write().await;
agent.input_guardrails.clear();
}
pub async fn clear_output_guardrails(&self) {
let mut agent = self.inner.write().await;
agent.output_guardrails.clear();
}
pub async fn set_max_iterations(&self, max: usize) {
let mut agent = self.inner.write().await;
agent.max_iterations = max;
}
pub async fn set_temperature(&self, temp: Option<f32>) {
let mut agent = self.inner.write().await;
agent.temperature = temp;
}
pub async fn set_max_tokens(&self, tokens: Option<u32>) {
let mut agent = self.inner.write().await;
agent.max_tokens = tokens;
}
pub async fn set_validate_tool_inputs(&self, enabled: bool) {
let mut agent = self.inner.write().await;
agent.validate_tool_inputs = enabled;
}
pub async fn set_tool_retry_policy(&self, policy: Option<ToolRetryPolicy>) {
let mut agent = self.inner.write().await;
agent.tool_retry_policy = policy;
}
pub async fn replace(&self, agent: Agent) {
let mut inner = self.inner.write().await;
*inner = agent;
}
pub async fn system_prompt(&self) -> Option<String> {
let agent = self.inner.read().await;
agent.system_prompt.clone()
}
pub async fn tool_count(&self) -> usize {
let agent = self.inner.read().await;
agent.tools.len()
}
pub async fn tool_names(&self) -> Vec<String> {
let agent = self.inner.read().await;
agent.tools.list().into_iter().map(String::from).collect()
}
}
impl Clone for HotSwapAgent {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result as DResult;
use crate::model::Model;
use crate::model::types::*;
use crate::stream::ResponseStream;
use crate::tool::ToolOutput;
struct ModelA;
struct ModelB;
impl Model for ModelA {
async fn generate(&self, _request: &ChatRequest) -> DResult<ChatResponse> {
Ok(ChatResponse {
message: Message::assistant("from-A"),
stop_reason: StopReason::EndTurn,
usage: Some(Usage::default()),
})
}
async fn generate_stream(&self, _request: &ChatRequest) -> DResult<ResponseStream> {
Ok(Box::pin(futures::stream::empty()))
}
}
impl Model for ModelB {
async fn generate(&self, _request: &ChatRequest) -> DResult<ChatResponse> {
Ok(ChatResponse {
message: Message::assistant("from-B"),
stop_reason: StopReason::EndTurn,
usage: Some(Usage::default()),
})
}
async fn generate_stream(&self, _request: &ChatRequest) -> DResult<ResponseStream> {
Ok(Box::pin(futures::stream::empty()))
}
}
struct TestTool {
tool_name: String,
}
impl crate::tool::Tool for TestTool {
fn name(&self) -> &str {
&self.tool_name
}
fn description(&self) -> &str {
"test"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(&self, _input: &serde_json::Value) -> DResult<ToolOutput> {
Ok(ToolOutput::text("ok"))
}
}
#[tokio::test]
async fn test_swap_model() {
let agent = Agent::builder().model(ModelA).build().unwrap();
let hot = HotSwapAgent::new(agent);
let r1 = hot.prompt("hi").await.unwrap();
assert_eq!(r1.final_text, "from-A");
hot.swap_model(ModelB).await;
let r2 = hot.prompt("hi").await.unwrap();
assert_eq!(r2.final_text, "from-B");
}
#[tokio::test]
async fn test_swap_system_prompt() {
let agent = Agent::builder()
.model(ModelA)
.system_prompt("original")
.build()
.unwrap();
let hot = HotSwapAgent::new(agent);
assert_eq!(hot.system_prompt().await.as_deref(), Some("original"));
hot.swap_system_prompt(Some("updated".into())).await;
assert_eq!(hot.system_prompt().await.as_deref(), Some("updated"));
hot.swap_system_prompt(None).await;
assert!(hot.system_prompt().await.is_none());
}
#[tokio::test]
async fn test_add_remove_tools() {
let agent = Agent::builder().model(ModelA).build().unwrap();
let hot = HotSwapAgent::new(agent);
assert_eq!(hot.tool_count().await, 0);
hot.add_tool(TestTool {
tool_name: "alpha".into(),
})
.await;
assert_eq!(hot.tool_count().await, 1);
hot.add_tool(TestTool {
tool_name: "beta".into(),
})
.await;
assert_eq!(hot.tool_count().await, 2);
assert!(hot.remove_tool("alpha").await);
assert_eq!(hot.tool_count().await, 1);
assert!(!hot.remove_tool("nonexistent").await);
}
#[tokio::test]
async fn test_set_iterations_and_temp() {
let agent = Agent::builder().model(ModelA).build().unwrap();
let hot = HotSwapAgent::new(agent);
hot.set_max_iterations(10).await;
hot.set_temperature(Some(0.5)).await;
hot.set_max_tokens(Some(100)).await;
let inner = hot.inner.read().await;
assert_eq!(inner.max_iterations, 10);
assert_eq!(inner.temperature, Some(0.5));
assert_eq!(inner.max_tokens, Some(100));
}
#[tokio::test]
async fn test_clone_shares_state() {
let agent = Agent::builder()
.model(ModelA)
.system_prompt("shared")
.build()
.unwrap();
let hot = HotSwapAgent::new(agent);
let clone = hot.clone();
clone
.swap_system_prompt(Some("from-clone".into()))
.await;
assert_eq!(hot.system_prompt().await.as_deref(), Some("from-clone"));
}
#[tokio::test]
async fn test_replace_agent() {
let agent_a = Agent::builder()
.model(ModelA)
.system_prompt("A")
.build()
.unwrap();
let hot = HotSwapAgent::new(agent_a);
assert_eq!(hot.system_prompt().await.as_deref(), Some("A"));
let agent_b = Agent::builder()
.model(ModelB)
.system_prompt("B")
.build()
.unwrap();
hot.replace(agent_b).await;
assert_eq!(hot.system_prompt().await.as_deref(), Some("B"));
let r = hot.prompt("test").await.unwrap();
assert_eq!(r.final_text, "from-B");
}
#[tokio::test]
async fn test_tool_names() {
let agent = Agent::builder().model(ModelA).build().unwrap();
let hot = HotSwapAgent::new(agent);
hot.add_tool(TestTool {
tool_name: "foo".into(),
})
.await;
hot.add_tool(TestTool {
tool_name: "bar".into(),
})
.await;
let mut names = hot.tool_names().await;
names.sort();
assert_eq!(names, vec!["bar", "foo"]);
}
}