strands_agents/multiagent/
swarm.rs

1//! Swarm-based multi-agent orchestration.
2//!
3//! Provides a collaborative agent orchestration system where agents work
4//! together as a team to solve complex tasks, with shared context and
5//! autonomous coordination.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use futures::StreamExt;
13use serde::{Deserialize, Serialize};
14use serde_json::json;
15use tokio::sync::RwLock;
16
17use super::base::{
18    Interrupt, InterruptState, InvocationState, MultiAgentBase, MultiAgentEvent,
19    MultiAgentEventStream, MultiAgentInput, MultiAgentResult, NodeResult, NodeResultValue, Status,
20};
21use crate::agent::Agent;
22use crate::hooks::{
23    AfterInvocationEvent, AfterToolCallEvent, BeforeInvocationEvent, BeforeToolCallEvent,
24    HookEvent, HookRegistry,
25};
26use crate::types::tools::{ToolResult as ToolResultType, ToolUse};
27use crate::tools::{AgentTool, ToolContext, ToolResult2};
28use crate::types::tools::ToolSpec;
29use crate::types::errors::{Result, StrandsError};
30use crate::types::streaming::{Metrics, Usage};
31use crate::types::tools::ToolResultStatus;
32
33/// Shared context between swarm nodes.
34#[derive(Debug, Clone, Default, Serialize, Deserialize)]
35pub struct SharedContext {
36    context: HashMap<String, HashMap<String, serde_json::Value>>,
37}
38
39impl SharedContext {
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Adds context for a node.
45    pub fn add_context(
46        &mut self,
47        node_id: &str,
48        key: impl Into<String>,
49        value: impl Serialize,
50    ) -> Result<()> {
51        let key = key.into();
52        if key.is_empty() {
53            return Err(StrandsError::ConfigurationError {
54                message: "Key cannot be empty".to_string(),
55            });
56        }
57
58        let value = serde_json::to_value(value).map_err(|e| StrandsError::ConfigurationError {
59            message: format!("Value is not JSON serializable: {e}"),
60        })?;
61
62        self.context
63            .entry(node_id.to_string())
64            .or_default()
65            .insert(key, value);
66
67        Ok(())
68    }
69
70    /// Gets context for a node.
71    pub fn get_context(&self, node_id: &str) -> Option<&HashMap<String, serde_json::Value>> {
72        self.context.get(node_id)
73    }
74
75    /// Gets all context.
76    pub fn all(&self) -> &HashMap<String, HashMap<String, serde_json::Value>> {
77        &self.context
78    }
79}
80
81/// Represents a node (agent) in the swarm.
82pub struct SwarmNode {
83    pub node_id: String,
84    pub agent: Agent,
85    initial_messages: Vec<crate::types::content::Message>,
86}
87
88impl SwarmNode {
89    pub fn new(node_id: impl Into<String>, agent: Agent) -> Self {
90        let initial_messages = agent.messages().to_vec();
91        Self {
92            node_id: node_id.into(),
93            agent,
94            initial_messages,
95        }
96    }
97
98    /// Resets the node state to initial state.
99    pub fn reset(&mut self) {
100        self.agent.clear_messages();
101        for msg in &self.initial_messages {
102            self.agent.add_message(msg.clone());
103        }
104    }
105}
106
107impl std::hash::Hash for SwarmNode {
108    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
109        self.node_id.hash(state);
110    }
111}
112
113impl PartialEq for SwarmNode {
114    fn eq(&self, other: &Self) -> bool {
115        self.node_id == other.node_id
116    }
117}
118
119impl Eq for SwarmNode {}
120
121/// Result from a swarm node execution.
122#[derive(Debug, Clone)]
123pub struct SwarmNodeResult {
124    pub node_id: String,
125    pub result: NodeResult,
126}
127
128/// Current state of swarm execution.
129pub struct SwarmState {
130    pub current_node_id: Option<String>,
131    pub task: String,
132    pub status: Status,
133    pub shared_context: SharedContext,
134    pub node_history: Vec<String>,
135    pub start_time: Instant,
136    pub results: HashMap<String, NodeResult>,
137    pub accumulated_usage: Usage,
138    pub accumulated_metrics: Metrics,
139    pub execution_time_ms: u64,
140    pub handoff_node_id: Option<String>,
141    pub handoff_message: Option<String>,
142}
143
144impl Default for SwarmState {
145    fn default() -> Self {
146        Self {
147            current_node_id: None,
148            task: String::new(),
149            status: Status::Pending,
150            shared_context: SharedContext::new(),
151            node_history: Vec::new(),
152            start_time: Instant::now(),
153            results: HashMap::new(),
154            accumulated_usage: Usage::default(),
155            accumulated_metrics: Metrics::default(),
156            execution_time_ms: 0,
157            handoff_node_id: None,
158            handoff_message: None,
159        }
160    }
161}
162
163impl SwarmState {
164    /// Check if the swarm should continue execution.
165    pub fn should_continue(&self, config: &SwarmConfig) -> (bool, &'static str) {
166        if self.node_history.len() >= config.max_handoffs {
167            return (false, "Max handoffs reached");
168        }
169
170        if self.node_history.len() >= config.max_iterations {
171            return (false, "Max iterations reached");
172        }
173
174        if let Some(timeout) = config.execution_timeout {
175            if self.start_time.elapsed() > timeout {
176                return (false, "Execution timed out");
177            }
178        }
179
180        if config.repetitive_handoff_detection_window > 0
181            && self.node_history.len() >= config.repetitive_handoff_detection_window
182        {
183            let recent: Vec<_> = self
184                .node_history
185                .iter()
186                .rev()
187                .take(config.repetitive_handoff_detection_window)
188                .collect();
189            let unique: std::collections::HashSet<_> = recent.iter().collect();
190            if unique.len() < config.repetitive_handoff_min_unique_agents {
191                return (false, "Repetitive handoff detected");
192            }
193        }
194
195        (true, "Continuing")
196    }
197}
198
199/// Result from swarm execution.
200#[derive(Debug, Clone)]
201pub struct SwarmResult {
202    pub status: Status,
203    pub results: HashMap<String, NodeResult>,
204    pub node_history: Vec<String>,
205    pub accumulated_usage: Usage,
206    pub accumulated_metrics: Metrics,
207    pub execution_time_ms: u64,
208    pub interrupts: Vec<Interrupt>,
209}
210
211impl From<SwarmResult> for MultiAgentResult {
212    fn from(sr: SwarmResult) -> Self {
213        MultiAgentResult {
214            status: sr.status,
215            results: sr.results,
216            accumulated_usage: sr.accumulated_usage,
217            accumulated_metrics: sr.accumulated_metrics,
218            execution_count: sr.node_history.len() as u32,
219            execution_time_ms: sr.execution_time_ms,
220            interrupts: sr.interrupts,
221        }
222    }
223}
224
225/// Configuration options for swarm execution.
226#[derive(Debug, Clone)]
227pub struct SwarmConfig {
228    pub max_handoffs: usize,
229    pub max_iterations: usize,
230    pub execution_timeout: Option<Duration>,
231    pub node_timeout: Option<Duration>,
232    pub repetitive_handoff_detection_window: usize,
233    pub repetitive_handoff_min_unique_agents: usize,
234}
235
236impl Default for SwarmConfig {
237    fn default() -> Self {
238        Self {
239            max_handoffs: 20,
240            max_iterations: 20,
241            execution_timeout: Some(Duration::from_secs(900)),
242            node_timeout: Some(Duration::from_secs(300)),
243            repetitive_handoff_detection_window: 0,
244            repetitive_handoff_min_unique_agents: 0,
245        }
246    }
247}
248
249/// Handoff tool for coordinating agent-to-agent transfers.
250struct HandoffTool {
251    swarm_state: Arc<RwLock<HandoffState>>,
252    available_agents: Vec<String>,
253}
254
255#[derive(Default)]
256struct HandoffState {
257    target_node_id: Option<String>,
258    message: Option<String>,
259    context: HashMap<String, serde_json::Value>,
260}
261
262#[async_trait]
263impl AgentTool for HandoffTool {
264    fn name(&self) -> &str {
265        "handoff_to_agent"
266    }
267
268    fn description(&self) -> &str {
269        "Transfer control to another agent in the swarm for specialized help"
270    }
271
272    fn tool_spec(&self) -> ToolSpec {
273        ToolSpec::new(
274            "handoff_to_agent",
275            "Transfer control to another agent in the swarm for specialized help",
276        ).with_input_schema(json!({
277            "type": "object",
278            "properties": {
279                "agent_name": {
280                    "type": "string",
281                    "description": "Name of the agent to hand off to"
282                },
283                "message": {
284                    "type": "string",
285                    "description": "Message explaining what needs to be done and why you're handing off"
286                },
287                "context": {
288                    "type": "object",
289                    "description": "Additional context to share with the next agent",
290                    "additionalProperties": true
291                }
292            },
293            "required": ["agent_name", "message"]
294        }))
295    }
296
297    async fn invoke(
298        &self,
299        input: serde_json::Value,
300        _context: &ToolContext,
301    ) -> std::result::Result<ToolResult2, String> {
302        let agent_name = input
303            .get("agent_name")
304            .and_then(|v| v.as_str())
305            .ok_or("Missing agent_name")?;
306        let message = input
307            .get("message")
308            .and_then(|v| v.as_str())
309            .ok_or("Missing message")?;
310        let context = input
311            .get("context")
312            .and_then(|v| v.as_object())
313            .cloned()
314            .unwrap_or_default();
315
316        if !self.available_agents.contains(&agent_name.to_string()) {
317            return Ok(ToolResult2 {
318                status: ToolResultStatus::Error,
319                content: vec![crate::types::tools::ToolResultContent::text(format!(
320                    "Error: Agent '{}' not found in swarm. Available agents: {:?}",
321                    agent_name, self.available_agents
322                ))],
323            });
324        }
325
326        let mut state = self.swarm_state.write().await;
327        state.target_node_id = Some(agent_name.to_string());
328        state.message = Some(message.to_string());
329        state.context = context.into_iter().collect();
330
331        Ok(ToolResult2 {
332            status: ToolResultStatus::Success,
333            content: vec![crate::types::tools::ToolResultContent::text(format!(
334                "Handing off to {}: {}",
335                agent_name, message
336            ))],
337        })
338    }
339}
340
341/// Self-organizing collaborative agent teams with shared working memory.
342pub struct Swarm {
343    id: String,
344    nodes: HashMap<String, SwarmNode>,
345    entry_point_id: Option<String>,
346    config: SwarmConfig,
347    state: SwarmState,
348    hooks: HookRegistry,
349    interrupt_state: InterruptState,
350    handoff_state: Arc<RwLock<HandoffState>>,
351    resume_from_session: bool,
352}
353
354impl Swarm {
355    /// Creates a new swarm with the given agents.
356    pub fn new(
357        agents: Vec<Agent>,
358        entry_point: Option<&str>,
359        config: SwarmConfig,
360    ) -> Result<Self> {
361        if agents.is_empty() {
362            return Err(StrandsError::ConfigurationError {
363                message: "Swarm must have at least one agent".to_string(),
364            });
365        }
366
367        let mut nodes = HashMap::new();
368        let mut node_names: Vec<String> = Vec::new();
369
370        for (i, agent) in agents.into_iter().enumerate() {
371            let node_id = agent.name().cloned().unwrap_or_else(|| format!("node_{i}"));
372
373            if nodes.contains_key(&node_id) {
374                return Err(StrandsError::ConfigurationError {
375                    message: format!("Duplicate node ID: {node_id}"),
376                });
377            }
378
379            node_names.push(node_id.clone());
380            nodes.insert(node_id.clone(), SwarmNode::new(node_id, agent));
381        }
382
383        let entry_point_id = entry_point.map(|s| s.to_string()).or_else(|| {
384            nodes.keys().next().cloned()
385        });
386
387        if let Some(ref ep) = entry_point_id {
388            if !nodes.contains_key(ep) {
389                return Err(StrandsError::ConfigurationError {
390                    message: format!("Entry point '{ep}' not found in swarm nodes"),
391                });
392            }
393        }
394
395        let handoff_state = Arc::new(RwLock::new(HandoffState::default()));
396
397        let mut swarm = Self {
398            id: "default_swarm".to_string(),
399            nodes,
400            entry_point_id,
401            config,
402            state: SwarmState::default(),
403            hooks: HookRegistry::new(),
404            interrupt_state: InterruptState::new(),
405            handoff_state,
406            resume_from_session: false,
407        };
408
409        for node in swarm.nodes.values_mut() {
410            let tool = HandoffTool {
411                swarm_state: Arc::clone(&swarm.handoff_state),
412                available_agents: node_names.iter().filter(|n| *n != &node.node_id).cloned().collect(),
413            };
414            node.agent.tool_registry_mut().register(Box::new(tool));
415        }
416
417        Ok(swarm)
418    }
419
420    /// Sets the swarm ID.
421    pub fn with_id(mut self, id: impl Into<String>) -> Self {
422        self.id = id.into();
423        self
424    }
425
426    /// Sets the hook registry.
427    pub fn with_hooks(mut self, hooks: HookRegistry) -> Self {
428        self.hooks = hooks;
429        self
430    }
431
432    /// Returns the swarm ID.
433    pub fn swarm_id(&self) -> &str {
434        &self.id
435    }
436
437    /// Returns the current state.
438    pub fn state(&self) -> &SwarmState {
439        &self.state
440    }
441
442    /// Returns an iterator over node IDs.
443    pub fn node_ids(&self) -> impl Iterator<Item = &str> {
444        self.nodes.keys().map(|s| s.as_str())
445    }
446
447    /// Returns a reference to the interrupt state.
448    pub fn interrupt_state(&self) -> &InterruptState {
449        &self.interrupt_state
450    }
451
452    /// Returns a mutable reference to the interrupt state.
453    pub fn interrupt_state_mut(&mut self) -> &mut InterruptState {
454        &mut self.interrupt_state
455    }
456
457    /// Activates interrupt state for a node.
458    fn activate_interrupt(
459        &mut self,
460        node_id: &str,
461        interrupts: Vec<Interrupt>,
462    ) -> MultiAgentEvent {
463        tracing::debug!("node=<{}> | node interrupted", node_id);
464        self.state.status = Status::Interrupted;
465
466        self.interrupt_state.context.insert(
467            node_id.to_string(),
468            serde_json::json!({
469                "activated": true,
470            }),
471        );
472
473        for interrupt in &interrupts {
474            self.interrupt_state.add(interrupt.clone());
475        }
476
477        self.interrupt_state.activate();
478
479        MultiAgentEvent::node_interrupt(node_id, interrupts)
480    }
481
482    /// Invokes the swarm synchronously.
483    pub fn call(&mut self, task: impl Into<MultiAgentInput>) -> Result<SwarmResult> {
484        tokio::task::block_in_place(|| {
485            tokio::runtime::Handle::current().block_on(self.invoke_async(task.into(), None))
486        })
487    }
488
489    /// Invokes the swarm asynchronously and returns the result.
490    pub async fn invoke_async(
491        &mut self,
492        task: MultiAgentInput,
493        invocation_state: Option<&InvocationState>,
494    ) -> Result<SwarmResult> {
495        let mut stream = self.stream_async(task, invocation_state);
496        let mut final_result = None;
497
498        while let Some(event) = stream.next().await {
499            if let MultiAgentEvent::Result(result) = event {
500                final_result = Some(result);
501            }
502        }
503
504        drop(stream);
505
506        final_result
507            .map(|r| SwarmResult {
508                status: r.status,
509                results: r.results,
510                node_history: self.state.node_history.clone(),
511                accumulated_usage: r.accumulated_usage,
512                accumulated_metrics: r.accumulated_metrics,
513                execution_time_ms: r.execution_time_ms,
514                interrupts: r.interrupts,
515            })
516            .ok_or_else(|| StrandsError::MultiAgentError {
517                message: "Swarm execution completed without result".to_string(),
518            })
519    }
520
521    /// Streams events during swarm execution.
522    pub fn stream_async<'a>(
523        &'a mut self,
524        task: MultiAgentInput,
525        _invocation_state: Option<&'a InvocationState>,
526    ) -> MultiAgentEventStream<'a> {
527        let task_str = task.to_string_lossy();
528
529        Box::pin(async_stream::stream! {
530            self.hooks.invoke(&HookEvent::BeforeInvocation(BeforeInvocationEvent)).await;
531
532            if self.resume_from_session || self.interrupt_state.activated {
533                self.state.status = Status::Executing;
534                self.state.start_time = Instant::now();
535            } else {
536                self.state = SwarmState {
537                    current_node_id: self.entry_point_id.clone(),
538                    task: task_str.clone(),
539                    status: Status::Executing,
540                    start_time: Instant::now(),
541                    ..Default::default()
542                };
543
544                {
545                    let mut handoff = self.handoff_state.write().await;
546                    *handoff = HandoffState::default();
547                }
548            }
549
550            while self.state.status == Status::Executing {
551                let (should_continue, reason) = self.state.should_continue(&self.config);
552                if !should_continue {
553                    tracing::warn!("Swarm execution stopped: {reason}");
554                    self.state.status = Status::Failed;
555                    break;
556                }
557
558                let current_node_id = match &self.state.current_node_id {
559                    Some(id) => id.clone(),
560                    None => {
561                        self.state.status = Status::Failed;
562                        break;
563                    }
564                };
565
566                if !self.nodes.contains_key(&current_node_id) {
567                    tracing::error!("Node '{}' not found", current_node_id);
568                    self.state.status = Status::Failed;
569                    break;
570                }
571
572                yield MultiAgentEvent::node_start(&current_node_id, "agent");
573
574                self.hooks.invoke(&HookEvent::BeforeToolCall(BeforeToolCallEvent::new(
575                    ToolUse::new(&current_node_id, &current_node_id, serde_json::json!({}))
576                ))).await;
577
578                let result = self.execute_node(&current_node_id, &task_str).await;
579
580                match result {
581                    Ok(node_result) => {
582                        self.state.node_history.push(current_node_id.clone());
583                        self.state.accumulated_usage.add(&node_result.accumulated_usage);
584                        self.state.accumulated_metrics.latency_ms += node_result.accumulated_metrics.latency_ms;
585
586                        yield MultiAgentEvent::node_stop(&current_node_id, node_result.clone());
587
588                        if node_result.status == Status::Interrupted {
589                            let interrupt_event = self.activate_interrupt(&current_node_id, node_result.interrupts.clone());
590                            yield interrupt_event;
591                            break;
592                        }
593
594                        self.interrupt_state.deactivate();
595
596                        self.state.results.insert(current_node_id.clone(), node_result);
597
598                        let handoff = {
599                            let state = self.handoff_state.read().await;
600                            (state.target_node_id.clone(), state.message.clone())
601                        };
602
603                        if let (Some(target_id), message) = handoff {
604                            {
605                                let mut state = self.handoff_state.write().await;
606                                *state = HandoffState::default();
607                            }
608
609                            yield MultiAgentEvent::handoff(
610                                vec![current_node_id.clone()],
611                                vec![target_id.clone()],
612                                message,
613                            );
614
615                            self.state.current_node_id = Some(target_id);
616                        } else {
617                            self.state.status = Status::Completed;
618                        }
619                    }
620                    Err(e) => {
621                        tracing::error!("Node '{}' failed: {}", current_node_id, e);
622                        let error_result = NodeResult::from_error(e.to_string(), 0);
623                        yield MultiAgentEvent::node_stop(&current_node_id, error_result);
624                        self.state.status = Status::Failed;
625                    }
626                }
627
628                self.hooks.invoke(&HookEvent::AfterToolCall(AfterToolCallEvent::new(
629                    ToolUse::new(&current_node_id, &current_node_id, serde_json::json!({})),
630                    ToolResultType::success(&current_node_id, "completed")
631                ))).await;
632            }
633
634            self.state.execution_time_ms = self.state.start_time.elapsed().as_millis() as u64;
635
636            self.hooks.invoke(&HookEvent::AfterInvocation(AfterInvocationEvent::new(None))).await;
637
638            let result = MultiAgentResult {
639                status: self.state.status,
640                results: self.state.results.clone(),
641                accumulated_usage: self.state.accumulated_usage.clone(),
642                accumulated_metrics: self.state.accumulated_metrics.clone(),
643                execution_count: self.state.node_history.len() as u32,
644                execution_time_ms: self.state.execution_time_ms,
645                interrupts: Vec::new(),
646            };
647
648            yield MultiAgentEvent::result(result);
649        })
650    }
651
652    async fn execute_node(&mut self, node_id: &str, task: &str) -> Result<NodeResult> {
653        let start = Instant::now();
654
655        let input = self.build_node_input(node_id, task);
656
657        let node = self.nodes.get_mut(node_id).ok_or_else(|| StrandsError::InternalError {
658            message: format!("Node '{node_id}' not found"),
659        })?;
660
661        let agent_result = node.agent.invoke_async(input.as_str()).await?;
662        let execution_time_ms = start.elapsed().as_millis() as u64;
663
664        let usage = agent_result.usage.clone();
665
666        Ok(NodeResult {
667            result: NodeResultValue::Agent(agent_result),
668            execution_time_ms,
669            status: Status::Completed,
670            accumulated_usage: usage,
671            accumulated_metrics: Metrics { latency_ms: execution_time_ms, time_to_first_byte_ms: 0 },
672            execution_count: 1,
673            interrupts: Vec::new(),
674        })
675    }
676
677    fn build_node_input(&self, target_node_id: &str, task: &str) -> String {
678        let mut input = String::new();
679
680        if let Some(ref message) = self.state.handoff_message {
681            input.push_str(&format!("Handoff Message: {message}\n\n"));
682        }
683
684        input.push_str(&format!("User Request: {task}\n\n"));
685
686        if !self.state.node_history.is_empty() {
687            input.push_str(&format!(
688                "Previous agents who worked on this: {}\n\n",
689                self.state.node_history.join(" → ")
690            ));
691        }
692
693        if !self.state.shared_context.context.is_empty() {
694            input.push_str("Shared knowledge from previous agents:\n");
695            for (node_name, context) in &self.state.shared_context.context {
696                if !context.is_empty() {
697                    input.push_str(&format!("• {node_name}: {:?}\n", context));
698                }
699            }
700            input.push('\n');
701        }
702
703        let other_nodes: Vec<_> = self.nodes.keys()
704            .filter(|id| *id != target_node_id)
705            .collect();
706
707        if !other_nodes.is_empty() {
708            input.push_str("Other agents available for collaboration:\n");
709            for node_id in other_nodes {
710                input.push_str(&format!("Agent name: {node_id}.\n"));
711            }
712            input.push('\n');
713        }
714
715        input.push_str(
716            "You have access to swarm coordination tools if you need help from other agents. \
717             If you don't hand off to another agent, the swarm will consider the task complete."
718        );
719
720        input
721    }
722}
723
724#[async_trait]
725impl MultiAgentBase for Swarm {
726    fn id(&self) -> &str {
727        &self.id
728    }
729
730    async fn invoke_async(
731        &mut self,
732        task: MultiAgentInput,
733        invocation_state: Option<&InvocationState>,
734    ) -> Result<MultiAgentResult> {
735        self.invoke_async(task, invocation_state).await.map(Into::into)
736    }
737
738    fn stream_async<'a>(
739        &'a mut self,
740        task: MultiAgentInput,
741        invocation_state: Option<&'a InvocationState>,
742    ) -> MultiAgentEventStream<'a> {
743        self.stream_async(task, invocation_state)
744    }
745
746    fn serialize_state(&self) -> serde_json::Value {
747        json!({
748            "type": "swarm",
749            "id": self.id,
750            "status": format!("{:?}", self.state.status).to_lowercase(),
751            "node_history": self.state.node_history,
752            "current_node": self.state.current_node_id,
753            "current_task": self.state.task,
754            "shared_context": self.state.shared_context.context,
755            "interrupt_state": self.interrupt_state.to_dict(),
756        })
757    }
758
759    fn deserialize_state(&mut self, payload: &serde_json::Value) -> Result<()> {
760        if let Some(status_str) = payload.get("status").and_then(|v| v.as_str()) {
761            self.state.status = match status_str {
762                "pending" => Status::Pending,
763                "executing" => Status::Executing,
764                "completed" => Status::Completed,
765                "failed" => Status::Failed,
766                "interrupted" => Status::Interrupted,
767                _ => Status::Pending,
768            };
769        }
770
771        if let Some(history) = payload.get("node_history").and_then(|v| v.as_array()) {
772            self.state.node_history = history
773                .iter()
774                .filter_map(|v| v.as_str().map(|s| s.to_string()))
775                .collect();
776        }
777
778        if let Some(current) = payload.get("current_node").and_then(|v| v.as_str()) {
779            self.state.current_node_id = Some(current.to_string());
780        }
781
782        if let Some(task) = payload.get("current_task").and_then(|v| v.as_str()) {
783            self.state.task = task.to_string();
784        }
785
786        if let Some(interrupt_obj) = payload.get("interrupt_state").and_then(|v| v.as_object()) {
787            let interrupt_map: std::collections::HashMap<String, serde_json::Value> = interrupt_obj
788                .iter()
789                .map(|(k, v)| (k.clone(), v.clone()))
790                .collect();
791            self.interrupt_state = InterruptState::from_dict(interrupt_map);
792            self.resume_from_session = true;
793        }
794
795        Ok(())
796    }
797}
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    #[test]
804    fn test_shared_context() {
805        let mut ctx = SharedContext::new();
806        ctx.add_context("node1", "key1", "value1").unwrap();
807        assert!(ctx.get_context("node1").is_some());
808    }
809
810    #[test]
811    fn test_shared_context_empty_key() {
812        let mut ctx = SharedContext::new();
813        let result = ctx.add_context("node1", "", "value");
814        assert!(result.is_err());
815    }
816
817    #[test]
818    fn test_swarm_state_should_continue() {
819        let config = SwarmConfig::default();
820        let state = SwarmState::default();
821        let (should_continue, _) = state.should_continue(&config);
822        assert!(should_continue);
823    }
824
825    #[test]
826    fn test_swarm_state_max_handoffs() {
827        let config = SwarmConfig {
828            max_handoffs: 2,
829            ..Default::default()
830        };
831        let mut state = SwarmState::default();
832        state.node_history = vec!["a".to_string(), "b".to_string()];
833        let (should_continue, reason) = state.should_continue(&config);
834        assert!(!should_continue);
835        assert_eq!(reason, "Max handoffs reached");
836    }
837}