use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ContextPolicy {
Full,
#[default]
Summary,
None,
LastN,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffConfig {
#[serde(default)]
pub context_policy: ContextPolicy,
#[serde(default = "default_max_context_tokens")]
pub max_context_tokens: usize,
#[serde(default = "default_max_context_messages")]
pub max_context_messages: usize,
#[serde(default = "default_true")]
pub preserve_system: bool,
#[serde(default = "default_timeout")]
pub timeout_seconds: f64,
#[serde(default = "default_max_concurrent")]
pub max_concurrent: usize,
#[serde(default = "default_true")]
pub detect_cycles: bool,
#[serde(default = "default_max_depth")]
pub max_depth: usize,
#[serde(default)]
pub async_mode: bool,
}
fn default_max_context_tokens() -> usize {
4000
}
fn default_max_context_messages() -> usize {
10
}
fn default_true() -> bool {
true
}
fn default_timeout() -> f64 {
300.0
}
fn default_max_concurrent() -> usize {
3
}
fn default_max_depth() -> usize {
10
}
impl Default for HandoffConfig {
fn default() -> Self {
Self {
context_policy: ContextPolicy::Summary,
max_context_tokens: 4000,
max_context_messages: 10,
preserve_system: true,
timeout_seconds: 300.0,
max_concurrent: 3,
detect_cycles: true,
max_depth: 10,
async_mode: false,
}
}
}
impl HandoffConfig {
pub fn new() -> Self {
Self::default()
}
pub fn context_policy(mut self, policy: ContextPolicy) -> Self {
self.context_policy = policy;
self
}
pub fn max_context_tokens(mut self, tokens: usize) -> Self {
self.max_context_tokens = tokens;
self
}
pub fn max_context_messages(mut self, messages: usize) -> Self {
self.max_context_messages = messages;
self
}
pub fn preserve_system(mut self, preserve: bool) -> Self {
self.preserve_system = preserve;
self
}
pub fn timeout_seconds(mut self, timeout: f64) -> Self {
self.timeout_seconds = timeout;
self
}
pub fn max_concurrent(mut self, max: usize) -> Self {
self.max_concurrent = max;
self
}
pub fn detect_cycles(mut self, detect: bool) -> Self {
self.detect_cycles = detect;
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn async_mode(mut self) -> Self {
self.async_mode = true;
self
}
pub fn to_map(&self) -> HashMap<String, serde_json::Value> {
let mut map = HashMap::new();
map.insert(
"context_policy".to_string(),
serde_json::to_value(&self.context_policy).unwrap_or_default(),
);
map.insert(
"max_context_tokens".to_string(),
serde_json::Value::Number(self.max_context_tokens.into()),
);
map.insert(
"max_context_messages".to_string(),
serde_json::Value::Number(self.max_context_messages.into()),
);
map.insert(
"preserve_system".to_string(),
serde_json::Value::Bool(self.preserve_system),
);
map.insert(
"timeout_seconds".to_string(),
serde_json::json!(self.timeout_seconds),
);
map.insert(
"max_concurrent".to_string(),
serde_json::Value::Number(self.max_concurrent.into()),
);
map.insert(
"detect_cycles".to_string(),
serde_json::Value::Bool(self.detect_cycles),
);
map.insert(
"max_depth".to_string(),
serde_json::Value::Number(self.max_depth.into()),
);
map.insert(
"async_mode".to_string(),
serde_json::Value::Bool(self.async_mode),
);
map
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HandoffInputData {
pub messages: Vec<serde_json::Value>,
pub context: HashMap<String, serde_json::Value>,
pub source_agent: Option<String>,
pub handoff_depth: usize,
pub handoff_chain: Vec<String>,
}
impl HandoffInputData {
pub fn new() -> Self {
Self::default()
}
pub fn messages(mut self, messages: Vec<serde_json::Value>) -> Self {
self.messages = messages;
self
}
pub fn context(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.context.insert(key.into(), value);
self
}
pub fn source_agent(mut self, agent: impl Into<String>) -> Self {
self.source_agent = Some(agent.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffResult {
pub success: bool,
pub response: Option<String>,
pub target_agent: Option<String>,
pub source_agent: Option<String>,
pub duration_seconds: f64,
pub error: Option<String>,
pub handoff_depth: usize,
}
impl HandoffResult {
pub fn success(response: impl Into<String>) -> Self {
Self {
success: true,
response: Some(response.into()),
target_agent: None,
source_agent: None,
duration_seconds: 0.0,
error: None,
handoff_depth: 0,
}
}
pub fn failure(error: impl Into<String>) -> Self {
Self {
success: false,
response: None,
target_agent: None,
source_agent: None,
duration_seconds: 0.0,
error: Some(error.into()),
handoff_depth: 0,
}
}
pub fn with_target(mut self, agent: impl Into<String>) -> Self {
self.target_agent = Some(agent.into());
self
}
pub fn with_source(mut self, agent: impl Into<String>) -> Self {
self.source_agent = Some(agent.into());
self
}
pub fn with_duration(mut self, seconds: f64) -> Self {
self.duration_seconds = seconds;
self
}
pub fn with_depth(mut self, depth: usize) -> Self {
self.handoff_depth = depth;
self
}
}
impl Default for HandoffResult {
fn default() -> Self {
Self {
success: false,
response: None,
target_agent: None,
source_agent: None,
duration_seconds: 0.0,
error: None,
handoff_depth: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct HandoffCycleError {
pub chain: Vec<String>,
}
impl std::fmt::Display for HandoffCycleError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Handoff cycle detected: {}", self.chain.join(" -> "))
}
}
impl std::error::Error for HandoffCycleError {}
#[derive(Debug, Clone)]
pub struct HandoffDepthError {
pub depth: usize,
pub max_depth: usize,
}
impl std::fmt::Display for HandoffDepthError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Max handoff depth exceeded: {} > {}",
self.depth, self.max_depth
)
}
}
impl std::error::Error for HandoffDepthError {}
#[derive(Debug, Clone)]
pub struct HandoffTimeoutError {
pub timeout: f64,
pub agent_name: String,
}
impl std::fmt::Display for HandoffTimeoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Handoff to {} timed out after {}s",
self.agent_name, self.timeout
)
}
}
impl std::error::Error for HandoffTimeoutError {}
#[derive(Debug, Default)]
pub struct HandoffChain {
chain: std::sync::RwLock<Vec<String>>,
}
impl HandoffChain {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self) -> Vec<String> {
self.chain.read().unwrap().clone()
}
pub fn depth(&self) -> usize {
self.chain.read().unwrap().len()
}
pub fn push(&self, agent_name: impl Into<String>) {
self.chain.write().unwrap().push(agent_name.into());
}
pub fn pop(&self) -> Option<String> {
self.chain.write().unwrap().pop()
}
pub fn contains(&self, agent_name: &str) -> bool {
self.chain.read().unwrap().iter().any(|a| a == agent_name)
}
pub fn clear(&self) {
self.chain.write().unwrap().clear();
}
}
#[derive(Debug, Clone)]
pub struct Handoff {
pub target_agent_name: String,
pub tool_name_override: Option<String>,
pub tool_description_override: Option<String>,
pub config: HandoffConfig,
}
impl Handoff {
pub fn new(target_agent_name: impl Into<String>) -> Self {
Self {
target_agent_name: target_agent_name.into(),
tool_name_override: None,
tool_description_override: None,
config: HandoffConfig::default(),
}
}
pub fn tool_name(mut self, name: impl Into<String>) -> Self {
self.tool_name_override = Some(name.into());
self
}
pub fn tool_description(mut self, description: impl Into<String>) -> Self {
self.tool_description_override = Some(description.into());
self
}
pub fn config(mut self, config: HandoffConfig) -> Self {
self.config = config;
self
}
pub fn get_tool_name(&self) -> String {
if let Some(ref name) = self.tool_name_override {
name.clone()
} else {
self.default_tool_name()
}
}
pub fn get_tool_description(&self) -> String {
if let Some(ref desc) = self.tool_description_override {
desc.clone()
} else {
self.default_tool_description()
}
}
fn default_tool_name(&self) -> String {
let agent_name = self
.target_agent_name
.to_lowercase()
.replace(' ', "_")
.replace('-', "_");
format!("transfer_to_{}", agent_name)
}
fn default_tool_description(&self) -> String {
format!("Transfer task to {}", self.target_agent_name)
}
pub fn check_safety(
&self,
_source_agent_name: &str,
chain: &HandoffChain,
) -> Result<()> {
if self.config.detect_cycles && chain.contains(&self.target_agent_name) {
let mut cycle_chain = chain.get();
cycle_chain.push(self.target_agent_name.clone());
return Err(Error::handoff(format!(
"Cycle detected: {}",
cycle_chain.join(" -> ")
)));
}
let current_depth = chain.depth();
if current_depth >= self.config.max_depth {
return Err(Error::handoff(format!(
"Max depth exceeded: {} > {}",
current_depth + 1,
self.config.max_depth
)));
}
Ok(())
}
pub fn prepare_context(
&self,
messages: Vec<serde_json::Value>,
source_agent: &str,
chain: &HandoffChain,
extra_context: HashMap<String, serde_json::Value>,
) -> HandoffInputData {
let filtered_messages = match self.config.context_policy {
ContextPolicy::None => vec![],
ContextPolicy::LastN => {
let n = self.config.max_context_messages;
if self.config.preserve_system {
let system_msgs: Vec<_> = messages
.iter()
.filter(|m| {
m.get("role")
.and_then(|r| r.as_str())
.map(|r| r == "system")
.unwrap_or(false)
})
.cloned()
.collect();
let other_msgs: Vec<_> = messages
.iter()
.filter(|m| {
m.get("role")
.and_then(|r| r.as_str())
.map(|r| r != "system")
.unwrap_or(true)
})
.cloned()
.collect();
let mut result = system_msgs;
result.extend(other_msgs.into_iter().rev().take(n).rev());
result
} else {
messages.into_iter().rev().take(n).rev().collect()
}
}
ContextPolicy::Summary => {
if self.config.preserve_system {
let system_msgs: Vec<_> = messages
.iter()
.filter(|m| {
m.get("role")
.and_then(|r| r.as_str())
.map(|r| r == "system")
.unwrap_or(false)
})
.cloned()
.collect();
let other_msgs: Vec<_> = messages
.iter()
.filter(|m| {
m.get("role")
.and_then(|r| r.as_str())
.map(|r| r != "system")
.unwrap_or(true)
})
.cloned()
.collect();
let mut result = system_msgs;
result.extend(other_msgs.into_iter().rev().take(3).rev());
result
} else {
messages.into_iter().rev().take(3).rev().collect()
}
}
ContextPolicy::Full => messages,
};
let mut context = extra_context;
context.insert(
"source_agent".to_string(),
serde_json::Value::String(source_agent.to_string()),
);
HandoffInputData {
messages: filtered_messages,
context,
source_agent: Some(source_agent.to_string()),
handoff_depth: chain.depth(),
handoff_chain: chain.get(),
}
}
}
pub struct HandoffFilters;
impl HandoffFilters {
pub fn remove_all_tools(mut data: HandoffInputData) -> HandoffInputData {
data.messages.retain(|msg| {
let has_tool_calls = msg.get("tool_calls").is_some();
let is_tool_role = msg
.get("role")
.and_then(|r| r.as_str())
.map(|r| r == "tool")
.unwrap_or(false);
!has_tool_calls && !is_tool_role
});
data
}
pub fn keep_last_n(n: usize) -> impl Fn(HandoffInputData) -> HandoffInputData {
move |mut data: HandoffInputData| {
let len = data.messages.len();
if len > n {
data.messages = data.messages.into_iter().skip(len - n).collect();
}
data
}
}
pub fn remove_system_messages(mut data: HandoffInputData) -> HandoffInputData {
data.messages.retain(|msg| {
msg.get("role")
.and_then(|r| r.as_str())
.map(|r| r != "system")
.unwrap_or(true)
});
data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handoff_config_defaults() {
let config = HandoffConfig::new();
assert_eq!(config.context_policy, ContextPolicy::Summary);
assert_eq!(config.max_context_tokens, 4000);
assert_eq!(config.max_depth, 10);
assert!(config.detect_cycles);
}
#[test]
fn test_handoff_config_builder() {
let config = HandoffConfig::new()
.context_policy(ContextPolicy::Full)
.timeout_seconds(60.0)
.max_depth(5)
.detect_cycles(false);
assert_eq!(config.context_policy, ContextPolicy::Full);
assert_eq!(config.timeout_seconds, 60.0);
assert_eq!(config.max_depth, 5);
assert!(!config.detect_cycles);
}
#[test]
fn test_handoff_result_success() {
let result = HandoffResult::success("Task completed")
.with_target("billing_agent")
.with_source("triage_agent")
.with_duration(1.5);
assert!(result.success);
assert_eq!(result.response, Some("Task completed".to_string()));
assert_eq!(result.target_agent, Some("billing_agent".to_string()));
assert_eq!(result.source_agent, Some("triage_agent".to_string()));
assert_eq!(result.duration_seconds, 1.5);
}
#[test]
fn test_handoff_result_failure() {
let result = HandoffResult::failure("Connection timeout")
.with_target("billing_agent");
assert!(!result.success);
assert_eq!(result.error, Some("Connection timeout".to_string()));
assert!(result.response.is_none());
}
#[test]
fn test_handoff_tool_name() {
let handoff = Handoff::new("Billing Agent");
assert_eq!(handoff.get_tool_name(), "transfer_to_billing_agent");
let handoff_custom = Handoff::new("Billing Agent").tool_name("custom_transfer");
assert_eq!(handoff_custom.get_tool_name(), "custom_transfer");
}
#[test]
fn test_handoff_chain() {
let chain = HandoffChain::new();
assert_eq!(chain.depth(), 0);
chain.push("agent_a");
chain.push("agent_b");
assert_eq!(chain.depth(), 2);
assert!(chain.contains("agent_a"));
assert!(chain.contains("agent_b"));
assert!(!chain.contains("agent_c"));
let popped = chain.pop();
assert_eq!(popped, Some("agent_b".to_string()));
assert_eq!(chain.depth(), 1);
}
#[test]
fn test_handoff_cycle_detection() {
let handoff = Handoff::new("agent_a").config(HandoffConfig::new().detect_cycles(true));
let chain = HandoffChain::new();
chain.push("agent_a");
let result = handoff.check_safety("agent_b", &chain);
assert!(result.is_err());
}
#[test]
fn test_handoff_depth_check() {
let handoff = Handoff::new("agent_c").config(HandoffConfig::new().max_depth(2));
let chain = HandoffChain::new();
chain.push("agent_a");
chain.push("agent_b");
let result = handoff.check_safety("agent_b", &chain);
assert!(result.is_err());
}
#[test]
fn test_context_policy_none() {
let handoff = Handoff::new("target").config(
HandoffConfig::new().context_policy(ContextPolicy::None),
);
let messages = vec![
serde_json::json!({"role": "system", "content": "You are helpful"}),
serde_json::json!({"role": "user", "content": "Hello"}),
];
let chain = HandoffChain::new();
let data = handoff.prepare_context(messages, "source", &chain, HashMap::new());
assert!(data.messages.is_empty());
}
#[test]
fn test_context_policy_last_n() {
let handoff = Handoff::new("target").config(
HandoffConfig::new()
.context_policy(ContextPolicy::LastN)
.max_context_messages(2)
.preserve_system(false),
);
let messages = vec![
serde_json::json!({"role": "user", "content": "msg1"}),
serde_json::json!({"role": "assistant", "content": "msg2"}),
serde_json::json!({"role": "user", "content": "msg3"}),
serde_json::json!({"role": "assistant", "content": "msg4"}),
];
let chain = HandoffChain::new();
let data = handoff.prepare_context(messages, "source", &chain, HashMap::new());
assert_eq!(data.messages.len(), 2);
}
#[test]
fn test_handoff_filters_remove_tools() {
let data = HandoffInputData {
messages: vec![
serde_json::json!({"role": "user", "content": "Hello"}),
serde_json::json!({"role": "assistant", "tool_calls": []}),
serde_json::json!({"role": "tool", "content": "result"}),
serde_json::json!({"role": "assistant", "content": "Done"}),
],
..Default::default()
};
let filtered = HandoffFilters::remove_all_tools(data);
assert_eq!(filtered.messages.len(), 2);
}
#[test]
fn test_handoff_input_data_builder() {
let data = HandoffInputData::new()
.source_agent("source_agent")
.context("key", serde_json::json!("value"));
assert_eq!(data.source_agent, Some("source_agent".to_string()));
assert!(data.context.contains_key("key"));
}
}