Skip to main content

cortexai_crew/
handoff.rs

1//! Agent Handoffs
2//!
3//! LangGraph-style agent-to-agent handoffs with full context passing,
4//! supporting multi-hop chains, returns, and conditional routing.
5//!
6//! ## Features
7//!
8//! - Handoff with full conversation context
9//! - Return to caller after completion
10//! - Multi-agent handoff chains
11//! - Conditional handoff routing
12//! - Handoff history tracking
13//! - Parallel handoffs to multiple agents
14//!
15//! ## Example
16//!
17//! ```rust,ignore
18//! use cortexai_crew::handoff::{HandoffRouter, AgentNode, HandoffTrigger};
19//!
20//! // Create agent nodes
21//! let triage = AgentNode::new("triage", triage_agent)
22//!     .description("Routes requests to appropriate specialists");
23//!
24//! let sales = AgentNode::new("sales", sales_agent)
25//!     .description("Handles sales inquiries")
26//!     .can_return(); // Can return to caller
27//!
28//! let support = AgentNode::new("support", support_agent)
29//!     .description("Handles technical support");
30//!
31//! // Build handoff router
32//! let router = HandoffRouter::new()
33//!     .add_agent(triage)
34//!     .add_agent(sales)
35//!     .add_agent(support)
36//!     .add_handoff("triage", "sales", HandoffTrigger::keyword("buy"))
37//!     .add_handoff("triage", "support", HandoffTrigger::keyword("help"))
38//!     .set_entry("triage");
39//!
40//! // Execute conversation with automatic handoffs
41//! let result = router.run(conversation).await?;
42//! ```
43
44use chrono::{DateTime, Utc};
45use cortexai_core::errors::CrewError;
46use serde::{Deserialize, Serialize};
47use std::collections::HashMap;
48use std::sync::Arc;
49
50/// Conversation message in a handoff context
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct HandoffMessage {
53    /// Message role
54    pub role: MessageRole,
55    /// Message content
56    pub content: String,
57    /// Agent that produced this message (if assistant)
58    pub agent_id: Option<String>,
59    /// Timestamp
60    pub timestamp: DateTime<Utc>,
61    /// Additional metadata
62    pub metadata: HashMap<String, serde_json::Value>,
63}
64
65/// Message role
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67#[serde(rename_all = "lowercase")]
68pub enum MessageRole {
69    System,
70    User,
71    Assistant,
72    Tool,
73}
74
75impl HandoffMessage {
76    /// Create a user message
77    pub fn user(content: impl Into<String>) -> Self {
78        Self {
79            role: MessageRole::User,
80            content: content.into(),
81            agent_id: None,
82            timestamp: Utc::now(),
83            metadata: HashMap::new(),
84        }
85    }
86
87    /// Create an assistant message
88    pub fn assistant(agent_id: impl Into<String>, content: impl Into<String>) -> Self {
89        Self {
90            role: MessageRole::Assistant,
91            content: content.into(),
92            agent_id: Some(agent_id.into()),
93            timestamp: Utc::now(),
94            metadata: HashMap::new(),
95        }
96    }
97
98    /// Create a system message
99    pub fn system(content: impl Into<String>) -> Self {
100        Self {
101            role: MessageRole::System,
102            content: content.into(),
103            agent_id: None,
104            timestamp: Utc::now(),
105            metadata: HashMap::new(),
106        }
107    }
108
109    /// Add metadata
110    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
111        self.metadata.insert(key.into(), value);
112        self
113    }
114}
115
116/// Conversation context passed during handoffs
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct HandoffContext {
119    /// Conversation ID
120    pub conversation_id: String,
121    /// Full message history
122    pub messages: Vec<HandoffMessage>,
123    /// Current agent handling the conversation
124    pub current_agent: String,
125    /// Stack of agents (for return handoffs)
126    pub agent_stack: Vec<String>,
127    /// Handoff history
128    pub handoff_history: Vec<HandoffRecord>,
129    /// Custom context data
130    pub data: HashMap<String, serde_json::Value>,
131    /// Created timestamp
132    pub created_at: DateTime<Utc>,
133    /// Last updated timestamp
134    pub updated_at: DateTime<Utc>,
135}
136
137impl HandoffContext {
138    /// Create a new context
139    pub fn new(conversation_id: impl Into<String>, entry_agent: impl Into<String>) -> Self {
140        let entry = entry_agent.into();
141        Self {
142            conversation_id: conversation_id.into(),
143            messages: Vec::new(),
144            current_agent: entry.clone(),
145            agent_stack: vec![entry],
146            handoff_history: Vec::new(),
147            data: HashMap::new(),
148            created_at: Utc::now(),
149            updated_at: Utc::now(),
150        }
151    }
152
153    /// Add a message to the conversation
154    pub fn add_message(&mut self, message: HandoffMessage) {
155        self.messages.push(message);
156        self.updated_at = Utc::now();
157    }
158
159    /// Add a user message
160    pub fn user_message(&mut self, content: impl Into<String>) {
161        self.add_message(HandoffMessage::user(content));
162    }
163
164    /// Add an assistant message from current agent
165    pub fn agent_message(&mut self, content: impl Into<String>) {
166        self.add_message(HandoffMessage::assistant(&self.current_agent, content));
167    }
168
169    /// Get last N messages
170    pub fn last_messages(&self, n: usize) -> &[HandoffMessage] {
171        let start = self.messages.len().saturating_sub(n);
172        &self.messages[start..]
173    }
174
175    /// Get messages from a specific agent
176    pub fn messages_from(&self, agent_id: &str) -> Vec<&HandoffMessage> {
177        self.messages
178            .iter()
179            .filter(|m| m.agent_id.as_deref() == Some(agent_id))
180            .collect()
181    }
182
183    /// Set custom data
184    pub fn set_data(&mut self, key: impl Into<String>, value: serde_json::Value) {
185        self.data.insert(key.into(), value);
186        self.updated_at = Utc::now();
187    }
188
189    /// Get custom data
190    pub fn get_data<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
191        self.data
192            .get(key)
193            .and_then(|v| serde_json::from_value(v.clone()).ok())
194    }
195
196    /// Get conversation summary
197    pub fn summary(&self) -> String {
198        format!(
199            "Conversation {} with {} messages, current agent: {}, {} handoffs",
200            self.conversation_id,
201            self.messages.len(),
202            self.current_agent,
203            self.handoff_history.len()
204        )
205    }
206}
207
208/// Record of a handoff event
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct HandoffRecord {
211    /// Source agent
212    pub from_agent: String,
213    /// Target agent
214    pub to_agent: String,
215    /// Reason for handoff
216    pub reason: String,
217    /// Whether this is a return handoff
218    pub is_return: bool,
219    /// Timestamp
220    pub timestamp: DateTime<Utc>,
221    /// Message index when handoff occurred
222    pub message_index: usize,
223}
224
225/// Trigger condition for a handoff
226#[derive(Clone)]
227pub enum HandoffTrigger {
228    /// Trigger on keyword in message
229    Keyword(String),
230    /// Trigger on multiple keywords (any match)
231    Keywords(Vec<String>),
232    /// Trigger on regex pattern
233    Pattern(String),
234    /// Trigger on custom condition
235    Custom(Arc<dyn Fn(&HandoffContext, &str) -> bool + Send + Sync>),
236    /// Always trigger (for explicit handoffs)
237    Always,
238    /// Trigger based on agent decision
239    AgentDecision,
240}
241
242impl std::fmt::Debug for HandoffTrigger {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        match self {
245            Self::Keyword(k) => write!(f, "Keyword({})", k),
246            Self::Keywords(ks) => write!(f, "Keywords({:?})", ks),
247            Self::Pattern(p) => write!(f, "Pattern({})", p),
248            Self::Custom(_) => write!(f, "Custom(<fn>)"),
249            Self::Always => write!(f, "Always"),
250            Self::AgentDecision => write!(f, "AgentDecision"),
251        }
252    }
253}
254
255impl HandoffTrigger {
256    /// Create a keyword trigger
257    pub fn keyword(keyword: impl Into<String>) -> Self {
258        Self::Keyword(keyword.into().to_lowercase())
259    }
260
261    /// Create a multi-keyword trigger
262    pub fn keywords(keywords: Vec<String>) -> Self {
263        Self::Keywords(keywords.into_iter().map(|k| k.to_lowercase()).collect())
264    }
265
266    /// Create a pattern trigger
267    pub fn pattern(pattern: impl Into<String>) -> Self {
268        Self::Pattern(pattern.into())
269    }
270
271    /// Create a custom trigger
272    pub fn custom<F>(f: F) -> Self
273    where
274        F: Fn(&HandoffContext, &str) -> bool + Send + Sync + 'static,
275    {
276        Self::Custom(Arc::new(f))
277    }
278
279    /// Check if trigger matches
280    pub fn matches(&self, context: &HandoffContext, message: &str) -> bool {
281        let lower = message.to_lowercase();
282        match self {
283            Self::Keyword(k) => lower.contains(k),
284            Self::Keywords(ks) => ks.iter().any(|k| lower.contains(k)),
285            Self::Pattern(p) => regex::Regex::new(p)
286                .map(|re| re.is_match(&lower))
287                .unwrap_or(false),
288            Self::Custom(f) => f(context, message),
289            Self::Always => true,
290            Self::AgentDecision => false, // Handled separately
291        }
292    }
293}
294
295/// Handoff instruction from an agent
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct HandoffInstruction {
298    /// Target agent ID
299    pub target_agent: String,
300    /// Reason for handoff
301    pub reason: String,
302    /// Whether to return after target completes
303    pub should_return: bool,
304    /// Additional data to pass
305    pub data: HashMap<String, serde_json::Value>,
306}
307
308impl HandoffInstruction {
309    /// Create a new handoff instruction
310    pub fn to(agent: impl Into<String>) -> Self {
311        Self {
312            target_agent: agent.into(),
313            reason: String::new(),
314            should_return: false,
315            data: HashMap::new(),
316        }
317    }
318
319    /// Set reason
320    pub fn because(mut self, reason: impl Into<String>) -> Self {
321        self.reason = reason.into();
322        self
323    }
324
325    /// Request return after completion
326    pub fn and_return(mut self) -> Self {
327        self.should_return = true;
328        self
329    }
330
331    /// Add data to pass
332    pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
333        self.data.insert(key.into(), value);
334        self
335    }
336}
337
338/// Agent response with optional handoff
339#[derive(Debug, Clone)]
340pub enum AgentResponse {
341    /// Continue conversation with this message
342    Message(String),
343    /// Hand off to another agent
344    Handoff(HandoffInstruction),
345    /// Return to previous agent
346    Return(String),
347    /// End the conversation
348    End(String),
349}
350
351/// Agent executor function type
352pub type AgentExecutor = Arc<
353    dyn Fn(HandoffContext) -> futures::future::BoxFuture<'static, Result<AgentResponse, CrewError>>
354        + Send
355        + Sync,
356>;
357
358/// An agent node in the handoff graph
359pub struct AgentNode {
360    /// Agent ID
361    pub id: String,
362    /// Agent description
363    pub description: String,
364    /// Agent executor
365    executor: AgentExecutor,
366    /// System prompt for this agent
367    pub system_prompt: Option<String>,
368    /// Whether this agent can return to caller
369    pub can_return: bool,
370    /// Whether this agent can hand off to others
371    pub can_handoff: bool,
372    /// Allowed handoff targets (empty = all)
373    pub allowed_targets: Vec<String>,
374}
375
376impl std::fmt::Debug for AgentNode {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        f.debug_struct("AgentNode")
379            .field("id", &self.id)
380            .field("description", &self.description)
381            .field("can_return", &self.can_return)
382            .field("can_handoff", &self.can_handoff)
383            .finish()
384    }
385}
386
387impl AgentNode {
388    /// Create a new agent node
389    pub fn new<F, Fut>(id: impl Into<String>, executor: F) -> Self
390    where
391        F: Fn(HandoffContext) -> Fut + Send + Sync + 'static,
392        Fut: std::future::Future<Output = Result<AgentResponse, CrewError>> + Send + 'static,
393    {
394        Self {
395            id: id.into(),
396            description: String::new(),
397            executor: Arc::new(move |ctx| Box::pin(executor(ctx))),
398            system_prompt: None,
399            can_return: false,
400            can_handoff: true,
401            allowed_targets: Vec::new(),
402        }
403    }
404
405    /// Set description
406    pub fn description(mut self, desc: impl Into<String>) -> Self {
407        self.description = desc.into();
408        self
409    }
410
411    /// Set system prompt
412    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
413        self.system_prompt = Some(prompt.into());
414        self
415    }
416
417    /// Allow this agent to return to caller
418    pub fn can_return(mut self) -> Self {
419        self.can_return = true;
420        self
421    }
422
423    /// Prevent this agent from handing off
424    pub fn no_handoff(mut self) -> Self {
425        self.can_handoff = false;
426        self
427    }
428
429    /// Restrict handoff targets
430    pub fn allowed_targets(mut self, targets: Vec<String>) -> Self {
431        self.allowed_targets = targets;
432        self
433    }
434
435    /// Execute the agent
436    pub async fn execute(&self, context: HandoffContext) -> Result<AgentResponse, CrewError> {
437        (self.executor)(context).await
438    }
439}
440
441/// Handoff rule between agents
442#[derive(Debug)]
443pub struct HandoffRule {
444    /// Source agent
445    pub from: String,
446    /// Target agent
447    pub to: String,
448    /// Trigger condition
449    pub trigger: HandoffTrigger,
450    /// Priority (higher = checked first)
451    pub priority: i32,
452}
453
454/// Handoff router - manages agent network and routes conversations
455pub struct HandoffRouter {
456    /// Registered agents
457    agents: HashMap<String, AgentNode>,
458    /// Handoff rules
459    rules: Vec<HandoffRule>,
460    /// Entry agent
461    entry_agent: Option<String>,
462    /// Maximum handoffs per conversation
463    max_handoffs: u32,
464    /// Maximum conversation turns
465    max_turns: u32,
466}
467
468impl Default for HandoffRouter {
469    fn default() -> Self {
470        Self::new()
471    }
472}
473
474impl HandoffRouter {
475    /// Create a new handoff router
476    pub fn new() -> Self {
477        Self {
478            agents: HashMap::new(),
479            rules: Vec::new(),
480            entry_agent: None,
481            max_handoffs: 10,
482            max_turns: 50,
483        }
484    }
485
486    /// Add an agent
487    pub fn add_agent(mut self, agent: AgentNode) -> Self {
488        self.agents.insert(agent.id.clone(), agent);
489        self
490    }
491
492    /// Add a handoff rule
493    pub fn add_handoff(
494        mut self,
495        from: impl Into<String>,
496        to: impl Into<String>,
497        trigger: HandoffTrigger,
498    ) -> Self {
499        self.rules.push(HandoffRule {
500            from: from.into(),
501            to: to.into(),
502            trigger,
503            priority: 0,
504        });
505        self
506    }
507
508    /// Add a prioritized handoff rule
509    pub fn add_handoff_priority(
510        mut self,
511        from: impl Into<String>,
512        to: impl Into<String>,
513        trigger: HandoffTrigger,
514        priority: i32,
515    ) -> Self {
516        self.rules.push(HandoffRule {
517            from: from.into(),
518            to: to.into(),
519            trigger,
520            priority,
521        });
522        self
523    }
524
525    /// Set entry agent
526    pub fn set_entry(mut self, agent_id: impl Into<String>) -> Self {
527        self.entry_agent = Some(agent_id.into());
528        self
529    }
530
531    /// Set maximum handoffs
532    pub fn max_handoffs(mut self, max: u32) -> Self {
533        self.max_handoffs = max;
534        self
535    }
536
537    /// Set maximum turns
538    pub fn max_turns(mut self, max: u32) -> Self {
539        self.max_turns = max;
540        self
541    }
542
543    /// Get agent by ID
544    pub fn get_agent(&self, id: &str) -> Option<&AgentNode> {
545        self.agents.get(id)
546    }
547
548    /// List all agents
549    pub fn list_agents(&self) -> Vec<&AgentNode> {
550        self.agents.values().collect()
551    }
552
553    /// Find matching handoff rules for current context
554    fn find_matching_rules(&self, context: &HandoffContext, message: &str) -> Vec<&HandoffRule> {
555        let mut matches: Vec<_> = self
556            .rules
557            .iter()
558            .filter(|r| r.from == context.current_agent && r.trigger.matches(context, message))
559            .collect();
560
561        // Sort by priority (descending)
562        matches.sort_by(|a, b| b.priority.cmp(&a.priority));
563        matches
564    }
565
566    /// Execute handoff to a new agent
567    fn execute_handoff(
568        &self,
569        context: &mut HandoffContext,
570        to: &str,
571        reason: &str,
572        is_return: bool,
573    ) {
574        let record = HandoffRecord {
575            from_agent: context.current_agent.clone(),
576            to_agent: to.to_string(),
577            reason: reason.to_string(),
578            is_return,
579            timestamp: Utc::now(),
580            message_index: context.messages.len(),
581        };
582
583        context.handoff_history.push(record);
584
585        if !is_return {
586            context.agent_stack.push(to.to_string());
587        } else {
588            context.agent_stack.pop();
589        }
590
591        context.current_agent = to.to_string();
592        context.updated_at = Utc::now();
593    }
594
595    /// Run a conversation with automatic handoffs
596    pub async fn run(&self, mut context: HandoffContext) -> Result<HandoffResult, CrewError> {
597        let entry = self.entry_agent.as_ref().ok_or_else(|| {
598            CrewError::InvalidConfiguration("No entry agent specified".to_string())
599        })?;
600
601        context.current_agent = entry.clone();
602        if context.agent_stack.is_empty() {
603            context.agent_stack.push(entry.clone());
604        }
605
606        let mut turns = 0;
607        let mut handoffs = 0;
608
609        loop {
610            // Check limits
611            if turns >= self.max_turns {
612                return Ok(HandoffResult {
613                    context,
614                    status: HandoffStatus::MaxTurnsReached,
615                    final_message: None,
616                });
617            }
618
619            if handoffs >= self.max_handoffs {
620                return Ok(HandoffResult {
621                    context,
622                    status: HandoffStatus::MaxHandoffsReached,
623                    final_message: None,
624                });
625            }
626
627            // Get current agent
628            let agent = self.agents.get(&context.current_agent).ok_or_else(|| {
629                CrewError::TaskNotFound(format!("Agent '{}' not found", context.current_agent))
630            })?;
631
632            // Execute agent
633            let response = agent.execute(context.clone()).await?;
634            turns += 1;
635
636            match response {
637                AgentResponse::Message(msg) => {
638                    context.agent_message(&msg);
639
640                    // Check for automatic handoff triggers
641                    let rules = self.find_matching_rules(&context, &msg);
642                    if let Some(rule) = rules.first() {
643                        self.execute_handoff(
644                            &mut context,
645                            &rule.to,
646                            &format!("Triggered by rule: {:?}", rule.trigger),
647                            false,
648                        );
649                        handoffs += 1;
650                    }
651                    // Continue with current agent if no trigger
652                }
653
654                AgentResponse::Handoff(instruction) => {
655                    // Validate handoff
656                    if !agent.can_handoff {
657                        return Err(CrewError::ExecutionFailed(format!(
658                            "Agent '{}' is not allowed to hand off",
659                            context.current_agent
660                        )));
661                    }
662
663                    if !agent.allowed_targets.is_empty()
664                        && !agent.allowed_targets.contains(&instruction.target_agent)
665                    {
666                        return Err(CrewError::ExecutionFailed(format!(
667                            "Agent '{}' cannot hand off to '{}'",
668                            context.current_agent, instruction.target_agent
669                        )));
670                    }
671
672                    if !self.agents.contains_key(&instruction.target_agent) {
673                        return Err(CrewError::TaskNotFound(format!(
674                            "Target agent '{}' not found",
675                            instruction.target_agent
676                        )));
677                    }
678
679                    // Merge any data from instruction
680                    for (key, value) in instruction.data {
681                        context.set_data(key, value);
682                    }
683
684                    self.execute_handoff(
685                        &mut context,
686                        &instruction.target_agent,
687                        &instruction.reason,
688                        false,
689                    );
690                    handoffs += 1;
691                }
692
693                AgentResponse::Return(msg) => {
694                    if !agent.can_return {
695                        return Err(CrewError::ExecutionFailed(format!(
696                            "Agent '{}' is not allowed to return",
697                            context.current_agent
698                        )));
699                    }
700
701                    context.agent_message(&msg);
702
703                    // Return to previous agent in stack
704                    if context.agent_stack.len() > 1 {
705                        let prev = context.agent_stack[context.agent_stack.len() - 2].clone();
706                        self.execute_handoff(&mut context, &prev, "Return to caller", true);
707                        handoffs += 1;
708                    } else {
709                        // No one to return to - end conversation
710                        return Ok(HandoffResult {
711                            context,
712                            status: HandoffStatus::Completed,
713                            final_message: Some(msg),
714                        });
715                    }
716                }
717
718                AgentResponse::End(msg) => {
719                    context.agent_message(&msg);
720                    return Ok(HandoffResult {
721                        context,
722                        status: HandoffStatus::Completed,
723                        final_message: Some(msg),
724                    });
725                }
726            }
727        }
728    }
729
730    /// Run with a new conversation starting from user message
731    pub async fn start(&self, user_message: impl Into<String>) -> Result<HandoffResult, CrewError> {
732        let mut context = HandoffContext::new(
733            uuid::Uuid::new_v4().to_string(),
734            self.entry_agent.as_deref().unwrap_or("default"),
735        );
736        context.user_message(user_message);
737        self.run(context).await
738    }
739}
740
741/// Result of a handoff conversation
742#[derive(Debug, Clone)]
743pub struct HandoffResult {
744    /// Final conversation context
745    pub context: HandoffContext,
746    /// Completion status
747    pub status: HandoffStatus,
748    /// Final message (if completed normally)
749    pub final_message: Option<String>,
750}
751
752/// Status of handoff completion
753#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
754pub enum HandoffStatus {
755    /// Conversation completed normally
756    Completed,
757    /// Hit maximum turns limit
758    MaxTurnsReached,
759    /// Hit maximum handoffs limit
760    MaxHandoffsReached,
761    /// Conversation was interrupted
762    Interrupted,
763}
764
765impl HandoffResult {
766    /// Get summary of the conversation
767    pub fn summary(&self) -> String {
768        format!(
769            "Status: {:?}, Agents: {:?}, Messages: {}, Handoffs: {}",
770            self.status,
771            self.context.agent_stack,
772            self.context.messages.len(),
773            self.context.handoff_history.len()
774        )
775    }
776
777    /// Get all handoffs that occurred
778    pub fn handoffs(&self) -> &[HandoffRecord] {
779        &self.context.handoff_history
780    }
781
782    /// Get messages from a specific agent
783    pub fn messages_from(&self, agent_id: &str) -> Vec<&HandoffMessage> {
784        self.context.messages_from(agent_id)
785    }
786}
787
788/// Parallel handoff - hand off to multiple agents simultaneously
789pub struct ParallelHandoff {
790    /// Target agents
791    targets: Vec<String>,
792    /// State mapping for each target
793    #[allow(dead_code)]
794    mappings: HashMap<String, HashMap<String, String>>,
795}
796
797impl ParallelHandoff {
798    /// Create a new parallel handoff
799    pub fn new() -> Self {
800        Self {
801            targets: Vec::new(),
802            mappings: HashMap::new(),
803        }
804    }
805
806    /// Add a target agent
807    pub fn to(mut self, agent: impl Into<String>) -> Self {
808        self.targets.push(agent.into());
809        self
810    }
811
812    /// Add multiple targets
813    pub fn to_all(mut self, agents: Vec<String>) -> Self {
814        self.targets.extend(agents);
815        self
816    }
817}
818
819impl Default for ParallelHandoff {
820    fn default() -> Self {
821        Self::new()
822    }
823}
824
825/// Handoff chain - sequence of agents to pass through
826pub struct HandoffChain {
827    /// Ordered list of agents
828    agents: Vec<String>,
829    /// Whether to return to origin after chain completes
830    return_to_origin: bool,
831}
832
833impl HandoffChain {
834    /// Create a new chain
835    pub fn new() -> Self {
836        Self {
837            agents: Vec::new(),
838            return_to_origin: false,
839        }
840    }
841
842    /// Add next agent in chain
843    pub fn then(mut self, agent: impl Into<String>) -> Self {
844        self.agents.push(agent.into());
845        self
846    }
847
848    /// Return to origin after chain
849    pub fn and_return(mut self) -> Self {
850        self.return_to_origin = true;
851        self
852    }
853}
854
855impl Default for HandoffChain {
856    fn default() -> Self {
857        Self::new()
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864
865    #[allow(dead_code)]
866    fn create_test_agent(id: &str, response: &'static str) -> AgentNode {
867        AgentNode::new(id, move |_ctx| async move {
868            Ok(AgentResponse::Message(response.to_string()))
869        })
870        .description(format!("Test agent {}", id))
871    }
872
873    fn create_handoff_agent(id: &str, target: &'static str) -> AgentNode {
874        AgentNode::new(id, move |_ctx| async move {
875            Ok(AgentResponse::Handoff(
876                HandoffInstruction::to(target).because("Need specialist"),
877            ))
878        })
879    }
880
881    fn create_ending_agent(id: &str, response: &'static str) -> AgentNode {
882        AgentNode::new(id, move |_ctx| async move {
883            Ok(AgentResponse::End(response.to_string()))
884        })
885    }
886
887    #[allow(dead_code)]
888    fn create_returning_agent(id: &str, response: &'static str) -> AgentNode {
889        AgentNode::new(id, move |_ctx| async move {
890            Ok(AgentResponse::Return(response.to_string()))
891        })
892        .can_return()
893    }
894
895    #[tokio::test]
896    async fn test_simple_conversation() {
897        let router = HandoffRouter::new()
898            .add_agent(create_ending_agent("greeter", "Hello! How can I help?"))
899            .set_entry("greeter");
900
901        let result = router.start("Hi there").await.unwrap();
902
903        assert_eq!(result.status, HandoffStatus::Completed);
904        assert_eq!(
905            result.final_message.as_deref(),
906            Some("Hello! How can I help?")
907        );
908        assert_eq!(result.context.messages.len(), 2); // user + agent
909    }
910
911    #[tokio::test]
912    async fn test_agent_handoff() {
913        let router = HandoffRouter::new()
914            .add_agent(create_handoff_agent("triage", "sales"))
915            .add_agent(create_ending_agent(
916                "sales",
917                "I can help with your purchase!",
918            ))
919            .set_entry("triage");
920
921        let result = router.start("I want to buy something").await.unwrap();
922
923        assert_eq!(result.status, HandoffStatus::Completed);
924        assert_eq!(result.context.handoff_history.len(), 1);
925        assert_eq!(result.context.handoff_history[0].from_agent, "triage");
926        assert_eq!(result.context.handoff_history[0].to_agent, "sales");
927    }
928
929    #[tokio::test]
930    async fn test_trigger_based_handoff() {
931        // Agent that echoes the user's intent, triggering the keyword
932        let triage = AgentNode::new("triage", |_ctx| async move {
933            Ok(AgentResponse::Message(
934                "I understand you want to buy something. Let me help with that.".to_string(),
935            ))
936        });
937
938        let router = HandoffRouter::new()
939            .add_agent(triage)
940            .add_agent(create_ending_agent("sales", "Sales here!"))
941            .add_handoff("triage", "sales", HandoffTrigger::keyword("buy"))
942            .set_entry("triage");
943
944        // This should trigger handoff to sales based on agent's response
945        let mut context = HandoffContext::new("test", "triage");
946        context.user_message("I want to purchase something");
947
948        let result = router.run(context).await.unwrap();
949
950        assert_eq!(result.context.handoff_history.len(), 1);
951        assert_eq!(result.context.handoff_history[0].to_agent, "sales");
952    }
953
954    #[tokio::test]
955    async fn test_return_handoff() {
956        let triage = AgentNode::new("triage", |ctx| async move {
957            // First time: hand off to specialist
958            // After return: end conversation
959            if ctx.handoff_history.is_empty() {
960                Ok(AgentResponse::Handoff(
961                    HandoffInstruction::to("specialist").because("Need expert"),
962                ))
963            } else {
964                Ok(AgentResponse::End("Thanks for your patience!".to_string()))
965            }
966        });
967
968        let specialist = AgentNode::new("specialist", |_ctx| async move {
969            Ok(AgentResponse::Return("Here's my analysis".to_string()))
970        })
971        .can_return();
972
973        let router = HandoffRouter::new()
974            .add_agent(triage)
975            .add_agent(specialist)
976            .set_entry("triage");
977
978        let result = router.start("Need help").await.unwrap();
979
980        assert_eq!(result.status, HandoffStatus::Completed);
981        assert_eq!(result.context.handoff_history.len(), 2); // forward + return
982        assert!(result.context.handoff_history[1].is_return);
983    }
984
985    #[tokio::test]
986    async fn test_max_handoffs_limit() {
987        // Create circular handoff
988        let agent_a = create_handoff_agent("a", "b");
989        let agent_b = create_handoff_agent("b", "a");
990
991        let router = HandoffRouter::new()
992            .add_agent(agent_a)
993            .add_agent(agent_b)
994            .set_entry("a")
995            .max_handoffs(5);
996
997        let result = router.start("Start").await.unwrap();
998
999        assert_eq!(result.status, HandoffStatus::MaxHandoffsReached);
1000        assert!(result.context.handoff_history.len() <= 5);
1001    }
1002
1003    #[tokio::test]
1004    async fn test_handoff_context_data() {
1005        let mut context = HandoffContext::new("test-123", "agent1");
1006
1007        context.set_data("user_id", serde_json::json!("user-456"));
1008        context.set_data("priority", serde_json::json!(5));
1009
1010        assert_eq!(
1011            context.get_data::<String>("user_id"),
1012            Some("user-456".to_string())
1013        );
1014        assert_eq!(context.get_data::<i32>("priority"), Some(5));
1015    }
1016
1017    #[tokio::test]
1018    async fn test_handoff_trigger_keywords() {
1019        let context = HandoffContext::new("test", "agent");
1020
1021        let trigger = HandoffTrigger::keywords(vec!["buy".to_string(), "purchase".to_string()]);
1022
1023        assert!(trigger.matches(&context, "I want to BUY something"));
1024        assert!(trigger.matches(&context, "Can I purchase this?"));
1025        assert!(!trigger.matches(&context, "Just browsing"));
1026    }
1027
1028    #[tokio::test]
1029    async fn test_handoff_trigger_pattern() {
1030        let context = HandoffContext::new("test", "agent");
1031
1032        let trigger = HandoffTrigger::pattern(r"order\s*#?\d+");
1033
1034        assert!(trigger.matches(&context, "Check order #12345"));
1035        assert!(trigger.matches(&context, "order 67890 status"));
1036        assert!(!trigger.matches(&context, "I want to order"));
1037    }
1038
1039    #[tokio::test]
1040    async fn test_handoff_instruction_builder() {
1041        let instruction = HandoffInstruction::to("support")
1042            .because("Technical issue")
1043            .and_return()
1044            .with_data("ticket_id", serde_json::json!("TKT-123"));
1045
1046        assert_eq!(instruction.target_agent, "support");
1047        assert_eq!(instruction.reason, "Technical issue");
1048        assert!(instruction.should_return);
1049        assert_eq!(
1050            instruction.data.get("ticket_id"),
1051            Some(&serde_json::json!("TKT-123"))
1052        );
1053    }
1054
1055    #[tokio::test]
1056    async fn test_handoff_history_tracking() {
1057        let router = HandoffRouter::new()
1058            .add_agent(create_handoff_agent("a", "b"))
1059            .add_agent(create_handoff_agent("b", "c"))
1060            .add_agent(create_ending_agent("c", "Done!"))
1061            .set_entry("a");
1062
1063        let result = router.start("Go").await.unwrap();
1064
1065        assert_eq!(result.context.handoff_history.len(), 2);
1066        assert_eq!(result.context.handoff_history[0].from_agent, "a");
1067        assert_eq!(result.context.handoff_history[0].to_agent, "b");
1068        assert_eq!(result.context.handoff_history[1].from_agent, "b");
1069        assert_eq!(result.context.handoff_history[1].to_agent, "c");
1070    }
1071
1072    #[tokio::test]
1073    async fn test_message_history() {
1074        let router = HandoffRouter::new()
1075            .add_agent(create_ending_agent("agent", "Response"))
1076            .set_entry("agent");
1077
1078        let result = router.start("User message").await.unwrap();
1079
1080        assert_eq!(result.context.messages.len(), 2);
1081        assert_eq!(result.context.messages[0].role, MessageRole::User);
1082        assert_eq!(result.context.messages[0].content, "User message");
1083        assert_eq!(result.context.messages[1].role, MessageRole::Assistant);
1084        assert_eq!(
1085            result.context.messages[1].agent_id,
1086            Some("agent".to_string())
1087        );
1088    }
1089}