1use std::collections::HashMap;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::agent::AgentResult;
11use crate::types::content::ContentBlock;
12use crate::types::errors::Result;
13use crate::types::streaming::{Metrics, Usage};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum Status {
19 Pending,
20 Executing,
21 Completed,
22 Failed,
23 Interrupted,
24}
25
26impl Default for Status {
27 fn default() -> Self {
28 Self::Pending
29 }
30}
31
32use crate::types::interrupt::InterruptResponseContent;
33
34#[derive(Debug, Clone)]
41pub enum MultiAgentInput {
42 Text(String),
44 ContentBlocks(Vec<ContentBlock>),
46 InterruptResponses(Vec<InterruptResponseContent>),
48}
49
50impl From<&str> for MultiAgentInput {
51 fn from(s: &str) -> Self {
52 MultiAgentInput::Text(s.to_string())
53 }
54}
55
56impl From<String> for MultiAgentInput {
57 fn from(s: String) -> Self {
58 MultiAgentInput::Text(s)
59 }
60}
61
62impl From<Vec<ContentBlock>> for MultiAgentInput {
63 fn from(blocks: Vec<ContentBlock>) -> Self {
64 MultiAgentInput::ContentBlocks(blocks)
65 }
66}
67
68impl From<Vec<InterruptResponseContent>> for MultiAgentInput {
69 fn from(responses: Vec<InterruptResponseContent>) -> Self {
70 MultiAgentInput::InterruptResponses(responses)
71 }
72}
73
74impl MultiAgentInput {
75 pub fn as_text(&self) -> Option<&str> {
77 match self {
78 MultiAgentInput::Text(s) => Some(s),
79 _ => None,
80 }
81 }
82
83 pub fn as_content_blocks(&self) -> Option<&[ContentBlock]> {
85 match self {
86 MultiAgentInput::ContentBlocks(blocks) => Some(blocks),
87 _ => None,
88 }
89 }
90
91 pub fn as_interrupt_responses(&self) -> Option<&[InterruptResponseContent]> {
93 match self {
94 MultiAgentInput::InterruptResponses(responses) => Some(responses),
95 _ => None,
96 }
97 }
98
99 pub fn is_interrupt_response(&self) -> bool {
101 matches!(self, MultiAgentInput::InterruptResponses(_))
102 }
103
104 pub fn to_string_lossy(&self) -> String {
106 match self {
107 MultiAgentInput::Text(s) => s.clone(),
108 MultiAgentInput::ContentBlocks(blocks) => blocks
109 .iter()
110 .filter_map(|b| b.text.as_ref())
111 .cloned()
112 .collect::<Vec<_>>()
113 .join("\n"),
114 MultiAgentInput::InterruptResponses(responses) => responses
115 .iter()
116 .map(|r| {
117 format!(
118 "{}:{}",
119 r.interrupt_response.interrupt_id,
120 r.interrupt_response.response
121 )
122 })
123 .collect::<Vec<_>>()
124 .join("; "),
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct Interrupt {
132 pub id: String,
133 pub tool_name: String,
134 pub tool_use_id: String,
135 pub message: Option<String>,
136 #[serde(skip_serializing_if = "Option::is_none")]
137 pub response: Option<serde_json::Value>,
138}
139
140impl Interrupt {
141 pub fn new(id: impl Into<String>, tool_name: impl Into<String>, tool_use_id: impl Into<String>) -> Self {
142 Self {
143 id: id.into(),
144 tool_name: tool_name.into(),
145 tool_use_id: tool_use_id.into(),
146 message: None,
147 response: None,
148 }
149 }
150
151 pub fn with_message(mut self, message: impl Into<String>) -> Self {
152 self.message = Some(message.into());
153 self
154 }
155
156 pub fn with_response(mut self, response: serde_json::Value) -> Self {
157 self.response = Some(response);
158 self
159 }
160
161 pub fn has_response(&self) -> bool {
162 self.response.is_some()
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct NodeResult {
169 pub result: NodeResultValue,
170 pub execution_time_ms: u64,
171 pub status: Status,
172 pub accumulated_usage: Usage,
173 pub accumulated_metrics: Metrics,
174 pub execution_count: u32,
175 pub interrupts: Vec<Interrupt>,
176}
177
178#[derive(Debug, Clone)]
180pub enum NodeResultValue {
181 Agent(AgentResult),
182 MultiAgent(Box<MultiAgentResult>),
183 Error(String),
184}
185
186impl NodeResult {
187 pub fn from_agent(result: AgentResult, execution_time_ms: u64) -> Self {
188 Self {
189 result: NodeResultValue::Agent(result),
190 execution_time_ms,
191 status: Status::Completed,
192 accumulated_usage: Usage::default(),
193 accumulated_metrics: Metrics::default(),
194 execution_count: 1,
195 interrupts: Vec::new(),
196 }
197 }
198
199 pub fn from_error(error: impl Into<String>, execution_time_ms: u64) -> Self {
200 Self {
201 result: NodeResultValue::Error(error.into()),
202 execution_time_ms,
203 status: Status::Failed,
204 accumulated_usage: Usage::default(),
205 accumulated_metrics: Metrics::default(),
206 execution_count: 1,
207 interrupts: Vec::new(),
208 }
209 }
210
211 pub fn get_agent_results(&self) -> Vec<&AgentResult> {
212 match &self.result {
213 NodeResultValue::Agent(r) => vec![r],
214 NodeResultValue::MultiAgent(m) => m
215 .results
216 .values()
217 .flat_map(|nr| nr.get_agent_results())
218 .collect(),
219 NodeResultValue::Error(_) => vec![],
220 }
221 }
222
223 pub fn is_error(&self) -> bool {
224 matches!(self.result, NodeResultValue::Error(_))
225 }
226
227 pub fn is_interrupted(&self) -> bool {
228 self.status == Status::Interrupted
229 }
230}
231
232#[derive(Debug, Clone, Default)]
234pub struct MultiAgentResult {
235 pub status: Status,
236 pub results: HashMap<String, NodeResult>,
237 pub accumulated_usage: Usage,
238 pub accumulated_metrics: Metrics,
239 pub execution_count: u32,
240 pub execution_time_ms: u64,
241 pub interrupts: Vec<Interrupt>,
242}
243
244impl MultiAgentResult {
245 pub fn new() -> Self {
246 Self::default()
247 }
248
249 pub fn with_status(mut self, status: Status) -> Self {
250 self.status = status;
251 self
252 }
253
254 pub fn add_node_result(&mut self, node_id: impl Into<String>, result: NodeResult) {
255 self.accumulated_usage.add(&result.accumulated_usage);
256 self.accumulated_metrics.latency_ms += result.accumulated_metrics.latency_ms;
257 self.execution_count += result.execution_count;
258 self.results.insert(node_id.into(), result);
259 }
260}
261
262#[derive(Debug, Clone)]
264pub enum MultiAgentEvent {
265 NodeStart {
267 node_id: String,
268 node_type: String,
269 },
270 NodeStop {
272 node_id: String,
273 node_result: NodeResult,
274 },
275 Handoff {
277 from_node_ids: Vec<String>,
278 to_node_ids: Vec<String>,
279 message: Option<String>,
280 },
281 NodeStream {
283 node_id: String,
284 event: serde_json::Value,
285 },
286 NodeCancel {
288 node_id: String,
289 message: String,
290 },
291 NodeInterrupt {
293 node_id: String,
294 interrupts: Vec<Interrupt>,
295 },
296 Result(MultiAgentResult),
298}
299
300impl MultiAgentEvent {
301 pub fn node_start(node_id: impl Into<String>, node_type: impl Into<String>) -> Self {
302 Self::NodeStart {
303 node_id: node_id.into(),
304 node_type: node_type.into(),
305 }
306 }
307
308 pub fn node_stop(node_id: impl Into<String>, node_result: NodeResult) -> Self {
309 Self::NodeStop {
310 node_id: node_id.into(),
311 node_result,
312 }
313 }
314
315 pub fn handoff(
316 from_node_ids: Vec<String>,
317 to_node_ids: Vec<String>,
318 message: Option<String>,
319 ) -> Self {
320 Self::Handoff {
321 from_node_ids,
322 to_node_ids,
323 message,
324 }
325 }
326
327 pub fn node_stream(node_id: impl Into<String>, event: serde_json::Value) -> Self {
328 Self::NodeStream {
329 node_id: node_id.into(),
330 event,
331 }
332 }
333
334 pub fn node_cancel(node_id: impl Into<String>, message: impl Into<String>) -> Self {
335 Self::NodeCancel {
336 node_id: node_id.into(),
337 message: message.into(),
338 }
339 }
340
341 pub fn node_interrupt(node_id: impl Into<String>, interrupts: Vec<Interrupt>) -> Self {
342 Self::NodeInterrupt {
343 node_id: node_id.into(),
344 interrupts,
345 }
346 }
347
348 pub fn result(result: MultiAgentResult) -> Self {
349 Self::Result(result)
350 }
351
352 pub fn is_result(&self) -> bool {
353 matches!(self, Self::Result(_))
354 }
355
356 pub fn as_result(&self) -> Option<&MultiAgentResult> {
357 match self {
358 Self::Result(r) => Some(r),
359 _ => None,
360 }
361 }
362}
363
364pub type MultiAgentEventStream<'a> =
366 std::pin::Pin<Box<dyn futures::Stream<Item = MultiAgentEvent> + Send + 'a>>;
367
368#[derive(Debug, Clone, Default)]
370pub struct InvocationState {
371 pub data: HashMap<String, serde_json::Value>,
372}
373
374impl InvocationState {
375 pub fn new() -> Self {
376 Self::default()
377 }
378
379 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
380 self.data.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
381 }
382
383 pub fn set(&mut self, key: impl Into<String>, value: impl serde::Serialize) {
384 if let Ok(v) = serde_json::to_value(value) {
385 self.data.insert(key.into(), v);
386 }
387 }
388}
389
390#[async_trait]
392pub trait MultiAgentBase: Send + Sync {
393 fn id(&self) -> &str;
395
396 async fn invoke_async(
398 &mut self,
399 task: MultiAgentInput,
400 invocation_state: Option<&InvocationState>,
401 ) -> Result<MultiAgentResult>;
402
403 fn stream_async<'a>(
405 &'a mut self,
406 task: MultiAgentInput,
407 invocation_state: Option<&'a InvocationState>,
408 ) -> MultiAgentEventStream<'a>;
409
410 fn serialize_state(&self) -> serde_json::Value;
412
413 fn deserialize_state(&mut self, payload: &serde_json::Value) -> Result<()>;
415}
416
417#[derive(Debug, Clone, Default)]
419pub struct InterruptState {
420 pub activated: bool,
421 pub interrupts: HashMap<String, Interrupt>,
422 pub context: HashMap<String, serde_json::Value>,
423 pub responses: Option<serde_json::Value>,
424}
425
426impl InterruptState {
427 pub fn new() -> Self {
428 Self::default()
429 }
430
431 pub fn activate(&mut self) {
432 self.activated = true;
433 }
434
435 pub fn deactivate(&mut self) {
436 self.activated = false;
437 self.interrupts.clear();
438 self.context.clear();
439 self.responses = None;
440 }
441
442 pub fn resume(&mut self, responses: serde_json::Value) {
443 self.responses = Some(responses);
444 }
445
446 pub fn add(&mut self, interrupt: Interrupt) {
448 self.interrupts.insert(interrupt.id.clone(), interrupt);
449 }
450
451 pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
453 let mut dict = HashMap::new();
454 dict.insert("activated".to_string(), serde_json::json!(self.activated));
455 dict.insert(
456 "interrupts".to_string(),
457 serde_json::json!(self.interrupts
458 .iter()
459 .map(|(k, v)| (k.clone(), serde_json::json!({
460 "id": v.id,
461 "tool_name": v.tool_name,
462 "tool_use_id": v.tool_use_id,
463 "message": v.message,
464 "response": v.response,
465 })))
466 .collect::<HashMap<_, _>>()),
467 );
468 dict.insert("context".to_string(), serde_json::json!(self.context));
469 dict.insert("responses".to_string(), serde_json::json!(self.responses));
470 dict
471 }
472
473 pub fn from_dict(data: HashMap<String, serde_json::Value>) -> Self {
475 let activated = data
476 .get("activated")
477 .and_then(|v| v.as_bool())
478 .unwrap_or(false);
479
480 let interrupts = data
481 .get("interrupts")
482 .and_then(|v| v.as_object())
483 .map(|obj| {
484 obj.iter()
485 .filter_map(|(k, v)| {
486 let id = v.get("id")?.as_str()?.to_string();
487 let tool_name = v.get("tool_name")?.as_str()?.to_string();
488 let tool_use_id = v.get("tool_use_id")?.as_str()?.to_string();
489 let message = v.get("message").and_then(|m| m.as_str().map(|s| s.to_string()));
490 let response = v.get("response").cloned();
491 Some((k.clone(), Interrupt {
492 id,
493 tool_name,
494 tool_use_id,
495 message,
496 response,
497 }))
498 })
499 .collect()
500 })
501 .unwrap_or_default();
502
503 let context = data
504 .get("context")
505 .and_then(|v| v.as_object())
506 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
507 .unwrap_or_default();
508
509 let responses = data.get("responses").cloned();
510
511 Self {
512 activated,
513 interrupts,
514 context,
515 responses,
516 }
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_status_default() {
526 assert_eq!(Status::default(), Status::Pending);
527 }
528
529 #[test]
530 fn test_multi_agent_input_from_str() {
531 let input = MultiAgentInput::from("test task");
532 assert_eq!(input.as_text(), Some("test task"));
533 }
534
535 #[test]
536 fn test_interrupt_creation() {
537 let interrupt = Interrupt::new("int-1", "my_tool", "tu-1")
538 .with_message("Please provide more info");
539 assert_eq!(interrupt.id, "int-1");
540 assert_eq!(interrupt.message, Some("Please provide more info".to_string()));
541 }
542
543 #[test]
544 fn test_multi_agent_event_variants() {
545 let event = MultiAgentEvent::node_start("node1", "agent");
546 assert!(!event.is_result());
547
548 let result = MultiAgentResult::new();
549 let event = MultiAgentEvent::result(result);
550 assert!(event.is_result());
551 }
552}
553