use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffConfig {
pub max_handoff_depth: u32,
pub context_transfer: ContextTransferMode,
pub allow_handback: bool,
#[serde(with = "duration_secs")]
pub timeout: Duration,
}
impl Default for HandoffConfig {
fn default() -> Self {
Self {
max_handoff_depth: 5,
context_transfer: ContextTransferMode::Summary,
allow_handback: true,
timeout: Duration::from_secs(60),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContextTransferMode {
Full,
Summary,
Selective,
Minimal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffRequest {
pub from_agent: String,
pub to_agent: String,
pub reason: String,
pub task: String,
pub context: HandoffContext,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HandoffContext {
pub messages: Vec<ContextMessage>,
pub tool_results: Vec<ToolResultSummary>,
pub accumulated_tokens: usize,
pub parent_chain: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMessage {
pub role: String,
pub content: String,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultSummary {
pub tool_name: String,
pub success: bool,
pub summary: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffResult {
pub from_agent: String,
pub to_agent: String,
pub response: String,
pub tokens_used: usize,
#[serde(with = "duration_secs")]
pub duration: Duration,
pub handoff_chain: Vec<String>,
pub status: HandoffStatus,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum HandoffStatus {
Completed,
HandedBack {
reason: String,
},
TimedOut,
DepthExceeded,
Failed {
reason: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffRecord {
pub request: HandoffRequest,
pub result: Option<HandoffResult>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HandoffProtocol {
config: HandoffConfig,
history: Vec<HandoffRecord>,
}
impl HandoffProtocol {
pub fn new(config: HandoffConfig) -> Self {
Self {
config,
history: Vec::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(HandoffConfig::default())
}
pub fn config(&self) -> &HandoffConfig {
&self.config
}
pub fn history(&self) -> &[HandoffRecord] {
&self.history
}
pub fn current_depth(&self) -> u32 {
self.history.len() as u32
}
pub fn initiate_handoff(&mut self, request: HandoffRequest) -> Result<usize, HandoffError> {
if self.current_depth() >= self.config.max_handoff_depth {
return Err(HandoffError::DepthExceeded {
max: self.config.max_handoff_depth,
current: self.current_depth(),
});
}
if !self.config.allow_handback && self.is_circular(&request.to_agent) {
return Err(HandoffError::CircularHandoff {
agent: request.to_agent.clone(),
chain: self.current_chain(),
});
}
if request.from_agent == request.to_agent {
return Err(HandoffError::SelfHandoff {
agent: request.from_agent.clone(),
});
}
let record = HandoffRecord {
request,
result: None,
timestamp: Utc::now(),
};
self.history.push(record);
Ok(self.history.len() - 1)
}
pub fn accept_handoff(&self, record_index: usize) -> Result<&HandoffRequest, HandoffError> {
let record = self
.history
.get(record_index)
.ok_or(HandoffError::RecordNotFound {
index: record_index,
})?;
if record.result.is_some() {
return Err(HandoffError::AlreadyCompleted {
index: record_index,
});
}
Ok(&record.request)
}
pub fn complete_handoff(
&mut self,
record_index: usize,
response: String,
tokens_used: usize,
duration: Duration,
) -> Result<&HandoffResult, HandoffError> {
let chain = self.current_chain();
let record = self
.history
.get_mut(record_index)
.ok_or(HandoffError::RecordNotFound {
index: record_index,
})?;
if record.result.is_some() {
return Err(HandoffError::AlreadyCompleted {
index: record_index,
});
}
let result = HandoffResult {
from_agent: record.request.from_agent.clone(),
to_agent: record.request.to_agent.clone(),
response,
tokens_used,
duration,
handoff_chain: chain,
status: HandoffStatus::Completed,
};
record.result = Some(result);
#[allow(clippy::expect_used)]
Ok(record.result.as_ref().expect("just inserted"))
}
pub fn handback(
&mut self,
record_index: usize,
reason: String,
tokens_used: usize,
duration: Duration,
) -> Result<&HandoffResult, HandoffError> {
if !self.config.allow_handback {
return Err(HandoffError::HandbackNotAllowed);
}
let chain = self.current_chain();
let record = self
.history
.get_mut(record_index)
.ok_or(HandoffError::RecordNotFound {
index: record_index,
})?;
if record.result.is_some() {
return Err(HandoffError::AlreadyCompleted {
index: record_index,
});
}
let result = HandoffResult {
from_agent: record.request.from_agent.clone(),
to_agent: record.request.to_agent.clone(),
response: String::new(),
tokens_used,
duration,
handoff_chain: chain,
status: HandoffStatus::HandedBack {
reason: reason.clone(),
},
};
record.result = Some(result);
#[allow(clippy::expect_used)]
Ok(record.result.as_ref().expect("just inserted"))
}
pub fn mark_timeout(
&mut self,
record_index: usize,
duration: Duration,
) -> Result<(), HandoffError> {
let chain = self.current_chain();
let record = self
.history
.get_mut(record_index)
.ok_or(HandoffError::RecordNotFound {
index: record_index,
})?;
if record.result.is_some() {
return Err(HandoffError::AlreadyCompleted {
index: record_index,
});
}
record.result = Some(HandoffResult {
from_agent: record.request.from_agent.clone(),
to_agent: record.request.to_agent.clone(),
response: String::new(),
tokens_used: 0,
duration,
handoff_chain: chain,
status: HandoffStatus::TimedOut,
});
Ok(())
}
pub fn mark_failed(
&mut self,
record_index: usize,
reason: String,
duration: Duration,
) -> Result<(), HandoffError> {
let chain = self.current_chain();
let record = self
.history
.get_mut(record_index)
.ok_or(HandoffError::RecordNotFound {
index: record_index,
})?;
if record.result.is_some() {
return Err(HandoffError::AlreadyCompleted {
index: record_index,
});
}
record.result = Some(HandoffResult {
from_agent: record.request.from_agent.clone(),
to_agent: record.request.to_agent.clone(),
response: String::new(),
tokens_used: 0,
duration,
handoff_chain: chain,
status: HandoffStatus::Failed {
reason: reason.clone(),
},
});
Ok(())
}
pub fn current_chain(&self) -> Vec<String> {
let mut chain = Vec::new();
for record in &self.history {
if chain
.last()
.map_or(true, |last: &String| *last != record.request.from_agent)
{
chain.push(record.request.from_agent.clone());
}
chain.push(record.request.to_agent.clone());
}
chain.dedup();
chain
}
pub fn is_circular(&self, agent: &str) -> bool {
self.history
.iter()
.any(|r| r.request.from_agent == agent || r.request.to_agent == agent)
}
pub fn last_result(&self) -> Option<&HandoffResult> {
self.history.iter().rev().find_map(|r| r.result.as_ref())
}
pub fn completed_count(&self) -> usize {
self.history.iter().filter(|r| r.result.is_some()).count()
}
pub fn pending_count(&self) -> usize {
self.history.iter().filter(|r| r.result.is_none()).count()
}
pub fn total_tokens(&self) -> usize {
self.history
.iter()
.filter_map(|r| r.result.as_ref())
.map(|r| r.tokens_used)
.sum()
}
pub fn reset(&mut self) {
self.history.clear();
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum HandoffError {
DepthExceeded {
max: u32,
current: u32,
},
CircularHandoff {
agent: String,
chain: Vec<String>,
},
SelfHandoff {
agent: String,
},
RecordNotFound {
index: usize,
},
AlreadyCompleted {
index: usize,
},
HandbackNotAllowed,
}
impl std::fmt::Display for HandoffError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DepthExceeded { max, current } => {
write!(f, "Handoff depth exceeded: current {current}, max {max}")
}
Self::CircularHandoff { agent, chain } => {
write!(
f,
"Circular handoff to '{agent}' detected in chain: {chain:?}"
)
}
Self::SelfHandoff { agent } => {
write!(f, "Agent '{agent}' cannot hand off to itself")
}
Self::RecordNotFound { index } => {
write!(f, "Handoff record not found at index {index}")
}
Self::AlreadyCompleted { index } => {
write!(f, "Handoff at index {index} is already completed")
}
Self::HandbackNotAllowed => {
write!(f, "Handback is not allowed by configuration")
}
}
}
}
impl std::error::Error for HandoffError {}
mod duration_secs {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S: Serializer>(d: &Duration, s: S) -> Result<S::Ok, S::Error> {
d.as_secs().serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
let secs = u64::deserialize(d)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn make_request(from: &str, to: &str) -> HandoffRequest {
HandoffRequest {
from_agent: from.to_string(),
to_agent: to.to_string(),
reason: format!("{from} needs {to}"),
task: format!("Task for {to}"),
context: HandoffContext::default(),
metadata: HashMap::new(),
}
}
fn make_request_with_context(from: &str, to: &str, messages: usize) -> HandoffRequest {
let msgs: Vec<ContextMessage> = (0..messages)
.map(|i| ContextMessage {
role: "user".to_string(),
content: format!("message {i}"),
timestamp: Utc::now(),
})
.collect();
HandoffRequest {
from_agent: from.to_string(),
to_agent: to.to_string(),
reason: "needs help".to_string(),
task: "do something".to_string(),
context: HandoffContext {
messages: msgs,
tool_results: vec![],
accumulated_tokens: 500,
parent_chain: vec![from.to_string()],
},
metadata: HashMap::new(),
}
}
#[test]
fn test_default_config() {
let cfg = HandoffConfig::default();
assert_eq!(cfg.max_handoff_depth, 5);
assert_eq!(cfg.context_transfer, ContextTransferMode::Summary);
assert!(cfg.allow_handback);
assert_eq!(cfg.timeout, Duration::from_secs(60));
}
#[test]
fn test_with_defaults() {
let proto = HandoffProtocol::with_defaults();
assert_eq!(proto.current_depth(), 0);
assert!(proto.history().is_empty());
}
#[test]
fn test_initiate_handoff() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
assert_eq!(idx, 0);
assert_eq!(proto.current_depth(), 1);
assert_eq!(proto.pending_count(), 1);
}
#[test]
fn test_accept_handoff() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
let req = proto.accept_handoff(idx).unwrap();
assert_eq!(req.from_agent, "A");
assert_eq!(req.to_agent, "B");
}
#[test]
fn test_complete_handoff() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
let result = proto
.complete_handoff(idx, "done".to_string(), 100, Duration::from_secs(5))
.unwrap();
assert_eq!(result.status, HandoffStatus::Completed);
assert_eq!(result.tokens_used, 100);
assert_eq!(result.response, "done");
assert_eq!(proto.completed_count(), 1);
assert_eq!(proto.pending_count(), 0);
}
#[test]
fn test_handback() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
let result = proto
.handback(
idx,
"need more info".to_string(),
50,
Duration::from_secs(2),
)
.unwrap();
assert_eq!(
result.status,
HandoffStatus::HandedBack {
reason: "need more info".to_string()
}
);
}
#[test]
fn test_handback_not_allowed() {
let cfg = HandoffConfig {
allow_handback: false,
..Default::default()
};
let mut proto = HandoffProtocol::new(cfg);
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
let err = proto
.handback(idx, "reason".to_string(), 0, Duration::from_secs(1))
.unwrap_err();
assert_eq!(err, HandoffError::HandbackNotAllowed);
}
#[test]
fn test_depth_exceeded() {
let cfg = HandoffConfig {
max_handoff_depth: 2,
..Default::default()
};
let mut proto = HandoffProtocol::new(cfg);
proto.initiate_handoff(make_request("A", "B")).unwrap();
proto.initiate_handoff(make_request("B", "C")).unwrap();
let err = proto.initiate_handoff(make_request("C", "D")).unwrap_err();
assert!(matches!(err, HandoffError::DepthExceeded { max: 2, .. }));
}
#[test]
fn test_self_handoff_rejected() {
let mut proto = HandoffProtocol::with_defaults();
let err = proto.initiate_handoff(make_request("A", "A")).unwrap_err();
assert!(matches!(err, HandoffError::SelfHandoff { .. }));
}
#[test]
fn test_circular_handoff_detected() {
let cfg = HandoffConfig {
allow_handback: false,
..Default::default()
};
let mut proto = HandoffProtocol::new(cfg);
proto.initiate_handoff(make_request("A", "B")).unwrap();
let err = proto.initiate_handoff(make_request("B", "A")).unwrap_err();
assert!(matches!(err, HandoffError::CircularHandoff { .. }));
}
#[test]
fn test_circular_allowed_with_handback() {
let mut proto = HandoffProtocol::with_defaults(); proto.initiate_handoff(make_request("A", "B")).unwrap();
let result = proto.initiate_handoff(make_request("B", "A"));
assert!(result.is_ok());
}
#[test]
fn test_record_not_found() {
let proto = HandoffProtocol::with_defaults();
let err = proto.accept_handoff(99).unwrap_err();
assert!(matches!(err, HandoffError::RecordNotFound { index: 99 }));
}
#[test]
fn test_already_completed_on_complete() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "ok".to_string(), 10, Duration::from_secs(1))
.unwrap();
let err = proto
.complete_handoff(idx, "again".to_string(), 10, Duration::from_secs(1))
.unwrap_err();
assert!(matches!(err, HandoffError::AlreadyCompleted { .. }));
}
#[test]
fn test_already_completed_on_accept() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "ok".to_string(), 10, Duration::from_secs(1))
.unwrap();
let err = proto.accept_handoff(idx).unwrap_err();
assert!(matches!(err, HandoffError::AlreadyCompleted { .. }));
}
#[test]
fn test_mark_timeout() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto.mark_timeout(idx, Duration::from_secs(60)).unwrap();
let result = proto.history()[idx].result.as_ref().unwrap();
assert_eq!(result.status, HandoffStatus::TimedOut);
}
#[test]
fn test_mark_failed() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.mark_failed(idx, "crash".to_string(), Duration::from_secs(3))
.unwrap();
let result = proto.history()[idx].result.as_ref().unwrap();
assert_eq!(
result.status,
HandoffStatus::Failed {
reason: "crash".to_string()
}
);
}
#[test]
fn test_chain_tracking() {
let mut proto = HandoffProtocol::with_defaults();
proto.initiate_handoff(make_request("A", "B")).unwrap();
proto.initiate_handoff(make_request("B", "C")).unwrap();
proto.initiate_handoff(make_request("C", "D")).unwrap();
let chain = proto.current_chain();
assert_eq!(chain, vec!["A", "B", "C", "D"]);
}
#[test]
fn test_is_circular() {
let mut proto = HandoffProtocol::with_defaults();
proto.initiate_handoff(make_request("A", "B")).unwrap();
assert!(proto.is_circular("A"));
assert!(proto.is_circular("B"));
assert!(!proto.is_circular("C"));
}
#[test]
fn test_total_tokens() {
let mut proto = HandoffProtocol::with_defaults();
let idx0 = proto.initiate_handoff(make_request("A", "B")).unwrap();
let idx1 = proto.initiate_handoff(make_request("B", "C")).unwrap();
proto
.complete_handoff(idx0, "r1".to_string(), 100, Duration::from_secs(1))
.unwrap();
proto
.complete_handoff(idx1, "r2".to_string(), 200, Duration::from_secs(2))
.unwrap();
assert_eq!(proto.total_tokens(), 300);
}
#[test]
fn test_last_result() {
let mut proto = HandoffProtocol::with_defaults();
assert!(proto.last_result().is_none());
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "final".to_string(), 50, Duration::from_secs(1))
.unwrap();
let last = proto.last_result().unwrap();
assert_eq!(last.response, "final");
}
#[test]
fn test_reset() {
let mut proto = HandoffProtocol::with_defaults();
proto.initiate_handoff(make_request("A", "B")).unwrap();
proto.reset();
assert_eq!(proto.current_depth(), 0);
assert!(proto.history().is_empty());
}
#[test]
fn test_config_serialization() {
let cfg = HandoffConfig::default();
let json = serde_json::to_string(&cfg).unwrap();
let restored: HandoffConfig = serde_json::from_str(&json).unwrap();
assert_eq!(restored.max_handoff_depth, 5);
assert!(restored.allow_handback);
}
#[test]
fn test_request_serialization() {
let req = make_request_with_context("A", "B", 3);
let json = serde_json::to_string(&req).unwrap();
let restored: HandoffRequest = serde_json::from_str(&json).unwrap();
assert_eq!(restored.from_agent, "A");
assert_eq!(restored.context.messages.len(), 3);
assert_eq!(restored.context.accumulated_tokens, 500);
}
#[test]
fn test_result_serialization() {
let result = HandoffResult {
from_agent: "A".to_string(),
to_agent: "B".to_string(),
response: "done".to_string(),
tokens_used: 42,
duration: Duration::from_secs(10),
handoff_chain: vec!["A".to_string(), "B".to_string()],
status: HandoffStatus::Completed,
};
let json = serde_json::to_string(&result).unwrap();
let restored: HandoffResult = serde_json::from_str(&json).unwrap();
assert_eq!(restored.tokens_used, 42);
assert_eq!(restored.status, HandoffStatus::Completed);
}
#[test]
fn test_status_serialization_variants() {
let statuses = vec![
HandoffStatus::Completed,
HandoffStatus::HandedBack {
reason: "oops".to_string(),
},
HandoffStatus::TimedOut,
HandoffStatus::DepthExceeded,
HandoffStatus::Failed {
reason: "boom".to_string(),
},
];
for status in statuses {
let json = serde_json::to_string(&status).unwrap();
let restored: HandoffStatus = serde_json::from_str(&json).unwrap();
assert_eq!(restored, status);
}
}
#[test]
fn test_context_message_fields() {
let msg = ContextMessage {
role: "assistant".to_string(),
content: "Hello there".to_string(),
timestamp: Utc::now(),
};
let json = serde_json::to_string(&msg).unwrap();
let restored: ContextMessage = serde_json::from_str(&json).unwrap();
assert_eq!(restored.role, "assistant");
assert_eq!(restored.content, "Hello there");
}
#[test]
fn test_tool_result_summary() {
let summary = ToolResultSummary {
tool_name: "file_read".to_string(),
success: true,
summary: "Read 42 lines".to_string(),
};
let json = serde_json::to_string(&summary).unwrap();
let restored: ToolResultSummary = serde_json::from_str(&json).unwrap();
assert!(restored.success);
assert_eq!(restored.tool_name, "file_read");
}
#[test]
fn test_handoff_with_metadata() {
let mut req = make_request("A", "B");
req.metadata
.insert("priority".to_string(), serde_json::json!("high"));
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(req).unwrap();
let stored = proto.accept_handoff(idx).unwrap();
assert_eq!(stored.metadata["priority"], "high");
}
#[test]
fn test_error_display() {
let errors = vec![
HandoffError::DepthExceeded { max: 5, current: 5 },
HandoffError::CircularHandoff {
agent: "B".to_string(),
chain: vec!["A".to_string(), "B".to_string()],
},
HandoffError::SelfHandoff {
agent: "A".to_string(),
},
HandoffError::RecordNotFound { index: 42 },
HandoffError::AlreadyCompleted { index: 0 },
HandoffError::HandbackNotAllowed,
];
for err in errors {
let display = format!("{err}");
assert!(!display.is_empty());
}
}
#[test]
fn test_sequential_handoffs() {
let mut proto = HandoffProtocol::with_defaults();
let idx0 = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx0, "B done".to_string(), 100, Duration::from_secs(5))
.unwrap();
let idx1 = proto.initiate_handoff(make_request("B", "C")).unwrap();
proto
.complete_handoff(idx1, "C done".to_string(), 200, Duration::from_secs(3))
.unwrap();
let idx2 = proto.initiate_handoff(make_request("C", "D")).unwrap();
proto
.complete_handoff(idx2, "D done".to_string(), 150, Duration::from_secs(4))
.unwrap();
assert_eq!(proto.completed_count(), 3);
assert_eq!(proto.total_tokens(), 450);
assert_eq!(proto.last_result().unwrap().response, "D done");
}
#[test]
fn test_protocol_serialization() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "ok".to_string(), 10, Duration::from_secs(1))
.unwrap();
let json = serde_json::to_string(&proto).unwrap();
let restored: HandoffProtocol = serde_json::from_str(&json).unwrap();
assert_eq!(restored.history().len(), 1);
assert_eq!(restored.completed_count(), 1);
}
#[test]
fn test_timeout_on_completed_fails() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "ok".to_string(), 10, Duration::from_secs(1))
.unwrap();
let err = proto
.mark_timeout(idx, Duration::from_secs(60))
.unwrap_err();
assert!(matches!(err, HandoffError::AlreadyCompleted { .. }));
}
#[test]
fn test_failed_on_completed_fails() {
let mut proto = HandoffProtocol::with_defaults();
let idx = proto.initiate_handoff(make_request("A", "B")).unwrap();
proto
.complete_handoff(idx, "ok".to_string(), 10, Duration::from_secs(1))
.unwrap();
let err = proto
.mark_failed(idx, "crash".to_string(), Duration::from_secs(1))
.unwrap_err();
assert!(matches!(err, HandoffError::AlreadyCompleted { .. }));
}
}