Skip to main content

mofa_foundation/react/
patterns.rs

1//! Agent 执行模式
2//!
3//! 提供 Chain(链式)和 Parallel(并行)模式的 Agent 执行支持
4//!
5//! # 架构
6//!
7//! ```text
8//! ┌─────────────────────────────────────────────────────────────────────────┐
9//! │                        Agent 执行模式                                    │
10//! ├─────────────────────────────────────────────────────────────────────────┤
11//! │                                                                         │
12//! │  Chain (链式模式)                                                        │
13//! │  ┌─────┐    ┌─────┐    ┌─────┐    ┌─────┐                              │
14//! │  │Agent│───▶│Agent│───▶│Agent│───▶│Agent│                              │
15//! │  │  1  │    │  2  │    │  3  │    │  N  │                              │
16//! │  └─────┘    └─────┘    └─────┘    └─────┘                              │
17//! │    input     output     output     output                               │
18//! │              =input     =input     =final                               │
19//! │                                                                         │
20//! │  Parallel (并行模式)                                                     │
21//! │              ┌─────┐                                                    │
22//! │           ┌─▶│Agent│──┐                                                │
23//! │           │  │  1  │  │                                                │
24//! │           │  └─────┘  │                                                │
25//! │  ┌─────┐  │  ┌─────┐  │  ┌──────────┐    ┌─────┐                       │
26//! │  │Input│──┼─▶│Agent│──┼─▶│Aggregator│───▶│Output│                      │
27//! │  └─────┘  │  │  2  │  │  └──────────┘    └─────┘                       │
28//! │           │  └─────┘  │                                                │
29//! │           │  ┌─────┐  │                                                │
30//! │           └─▶│Agent│──┘                                                │
31//! │              │  N  │                                                    │
32//! │              └─────┘                                                    │
33//! │                                                                         │
34//! └─────────────────────────────────────────────────────────────────────────┘
35//! ```
36//!
37//! # 示例
38//!
39//! ## Chain 模式
40//!
41//! ```rust,ignore
42//! use mofa_foundation::react::{ChainAgent, ReActAgent};
43//!
44//! // 创建链式 Agent
45//! let chain = ChainAgent::new()
46//!     .add("researcher", researcher_agent)
47//!     .add("writer", writer_agent)
48//!     .add("editor", editor_agent)
49//!     .with_transform(|prev_output, _next_name| {
50//!         format!("Based on this: {}\n\nPlease continue.", prev_output)
51//!     });
52//!
53//! let result = chain.run("Write an article about Rust").await?;
54//! ```
55//!
56//! ## Parallel 模式
57//!
58//! ```rust,ignore
59//! use mofa_foundation::react::{ParallelAgent, AggregationStrategy};
60//!
61//! // 创建并行 Agent
62//! let parallel = ParallelAgent::new()
63//!     .add("analyst1", analyst1_agent)
64//!     .add("analyst2", analyst2_agent)
65//!     .add("analyst3", analyst3_agent)
66//!     .with_aggregation(AggregationStrategy::LLMSummarize(summarizer_agent));
67//!
68//! let result = parallel.run("Analyze market trends").await?;
69//! ```
70
71use super::core::ReActResult;
72use crate::llm::{LLMAgent, LLMError, LLMResult};
73use serde::{Deserialize, Serialize};
74use std::collections::HashMap;
75use std::sync::Arc;
76
77// ============================================================================
78// 通用类型
79// ============================================================================
80
81/// Agent 执行单元
82///
83/// 包装 ReActAgent 或 LLMAgent,提供统一的执行接口
84#[derive(Clone)]
85pub enum AgentUnit {
86    /// ReAct Agent
87    ReAct(Arc<super::ReActAgent>),
88    /// LLM Agent (简单问答)
89    LLM(Arc<LLMAgent>),
90}
91
92impl AgentUnit {
93    /// 从 ReActAgent 创建
94    pub fn react(agent: Arc<super::ReActAgent>) -> Self {
95        Self::ReAct(agent)
96    }
97
98    /// 从 LLMAgent 创建
99    pub fn llm(agent: Arc<LLMAgent>) -> Self {
100        Self::LLM(agent)
101    }
102
103    /// 执行任务
104    pub async fn run(&self, task: impl Into<String>) -> LLMResult<AgentOutput> {
105        let task = task.into();
106        let start = std::time::Instant::now();
107
108        match self {
109            AgentUnit::ReAct(agent) => {
110                let result = agent.run(&task).await?;
111                Ok(AgentOutput {
112                    content: result.answer.clone(),
113                    task,
114                    success: result.success,
115                    error: result.error.clone(),
116                    duration_ms: result.duration_ms,
117                    metadata: Some(AgentOutputMetadata::ReAct(result)),
118                })
119            }
120            AgentUnit::LLM(agent) => {
121                let response = agent.ask(&task).await?;
122                Ok(AgentOutput {
123                    content: response,
124                    task,
125                    success: true,
126                    error: None,
127                    duration_ms: start.elapsed().as_millis() as u64,
128                    metadata: None,
129                })
130            }
131        }
132    }
133}
134
135/// Agent 输出
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct AgentOutput {
138    /// 输出内容
139    pub content: String,
140    /// 原始任务
141    pub task: String,
142    /// 是否成功
143    pub success: bool,
144    /// 错误信息
145    #[serde(skip_serializing_if = "Option::is_none")]
146    pub error: Option<String>,
147    /// 执行耗时 (毫秒)
148    pub duration_ms: u64,
149    /// 额外元数据
150    #[serde(skip)]
151    pub metadata: Option<AgentOutputMetadata>,
152}
153
154/// Agent 输出元数据
155#[derive(Debug, Clone)]
156pub enum AgentOutputMetadata {
157    /// ReAct 执行结果
158    ReAct(ReActResult),
159}
160
161// ============================================================================
162// Chain Agent (链式模式)
163// ============================================================================
164
165/// 链式 Agent 执行模式
166///
167/// 多个 Agent 串行执行,前一个的输出作为后一个的输入
168///
169/// # 示例
170///
171/// ```rust,ignore
172/// let chain = ChainAgent::new()
173///     .add("step1", agent1)
174///     .add("step2", agent2)
175///     .add("step3", agent3);
176///
177/// let result = chain.run("Initial task").await?;
178/// ```
179pub struct ChainAgent {
180    /// Agent 列表 (保持插入顺序)
181    agents: Vec<(String, AgentUnit)>,
182    /// 输入转换函数
183    transform: Option<TransformFn>,
184    /// 是否在失败时继续
185    continue_on_error: bool,
186    /// 是否详细输出
187    verbose: bool,
188}
189
190/// 输入转换函数类型
191type TransformFn = Arc<dyn Fn(&str, &str) -> String + Send + Sync>;
192
193impl ChainAgent {
194    /// 创建新的链式 Agent
195    pub fn new() -> Self {
196        Self {
197            agents: Vec::new(),
198            transform: None,
199            continue_on_error: false,
200            verbose: true,
201        }
202    }
203
204    /// 添加 ReAct Agent 到链中
205    pub fn add(mut self, name: impl Into<String>, agent: Arc<super::ReActAgent>) -> Self {
206        self.agents.push((name.into(), AgentUnit::react(agent)));
207        self
208    }
209
210    /// 添加 LLM Agent 到链中
211    pub fn add_llm(mut self, name: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
212        self.agents.push((name.into(), AgentUnit::llm(agent)));
213        self
214    }
215
216    /// 添加通用 AgentUnit
217    pub fn add_unit(mut self, name: impl Into<String>, unit: AgentUnit) -> Self {
218        self.agents.push((name.into(), unit));
219        self
220    }
221
222    /// 设置输入转换函数
223    ///
224    /// 转换函数接收前一个 Agent 的输出和下一个 Agent 的名称,返回转换后的输入
225    ///
226    /// # 示例
227    ///
228    /// ```rust,ignore
229    /// chain.with_transform(|prev_output, next_name| {
230    ///     format!("Previous result: {}\n\nTask for {}: continue the analysis", prev_output, next_name)
231    /// })
232    /// ```
233    pub fn with_transform<F>(mut self, f: F) -> Self
234    where
235        F: Fn(&str, &str) -> String + Send + Sync + 'static,
236    {
237        self.transform = Some(Arc::new(f));
238        self
239    }
240
241    /// 设置是否在失败时继续执行
242    pub fn with_continue_on_error(mut self, continue_on_error: bool) -> Self {
243        self.continue_on_error = continue_on_error;
244        self
245    }
246
247    /// 设置是否详细输出
248    pub fn with_verbose(mut self, verbose: bool) -> Self {
249        self.verbose = verbose;
250        self
251    }
252
253    /// 执行链式 Agent
254    pub async fn run(&self, initial_task: impl Into<String>) -> LLMResult<ChainResult> {
255        let initial_task = initial_task.into();
256        let start_time = std::time::Instant::now();
257        let chain_id = uuid::Uuid::now_v7().to_string();
258
259        let mut step_results = Vec::new();
260        let mut current_input = initial_task.clone();
261        let mut final_output = String::new();
262        let mut all_success = true;
263
264        for (idx, (name, agent)) in self.agents.iter().enumerate() {
265            if self.verbose {
266                tracing::info!("[Chain] Step {}: {} - Starting", idx + 1, name);
267            }
268
269            // 执行 Agent
270            let result = agent.run(&current_input).await;
271
272            match result {
273                Ok(output) => {
274                    if self.verbose {
275                        tracing::info!(
276                            "[Chain] Step {}: {} - Completed in {}ms",
277                            idx + 1,
278                            name,
279                            output.duration_ms
280                        );
281                    }
282
283                    step_results.push(ChainStepResult {
284                        step: idx + 1,
285                        agent_name: name.clone(),
286                        input: current_input.clone(),
287                        output: output.clone(),
288                        success: output.success,
289                    });
290
291                    if !output.success {
292                        all_success = false;
293                        if !self.continue_on_error {
294                            return Ok(ChainResult {
295                                chain_id,
296                                initial_task,
297                                final_output: output.content.clone(),
298                                steps: step_results,
299                                success: false,
300                                error: output.error,
301                                total_duration_ms: start_time.elapsed().as_millis() as u64,
302                            });
303                        }
304                    }
305
306                    final_output = output.content.clone();
307
308                    // 转换输入给下一个 Agent
309                    if idx < self.agents.len() - 1 {
310                        let next_name = &self.agents[idx + 1].0;
311                        current_input = if let Some(ref transform) = self.transform {
312                            transform(&output.content, next_name)
313                        } else {
314                            output.content.clone()
315                        };
316                    }
317                }
318                Err(e) => {
319                    all_success = false;
320                    step_results.push(ChainStepResult {
321                        step: idx + 1,
322                        agent_name: name.clone(),
323                        input: current_input.clone(),
324                        output: AgentOutput {
325                            content: String::new(),
326                            task: current_input.clone(),
327                            success: false,
328                            error: Some(e.to_string()),
329                            duration_ms: 0,
330                            metadata: None,
331                        },
332                        success: false,
333                    });
334
335                    if !self.continue_on_error {
336                        return Ok(ChainResult {
337                            chain_id,
338                            initial_task,
339                            final_output: String::new(),
340                            steps: step_results,
341                            success: false,
342                            error: Some(e.to_string()),
343                            total_duration_ms: start_time.elapsed().as_millis() as u64,
344                        });
345                    }
346                }
347            }
348        }
349
350        Ok(ChainResult {
351            chain_id,
352            initial_task,
353            final_output,
354            steps: step_results,
355            success: all_success,
356            error: None,
357            total_duration_ms: start_time.elapsed().as_millis() as u64,
358        })
359    }
360
361    /// 获取链中的 Agent 数量
362    pub fn len(&self) -> usize {
363        self.agents.len()
364    }
365
366    /// 检查链是否为空
367    pub fn is_empty(&self) -> bool {
368        self.agents.is_empty()
369    }
370}
371
372impl Default for ChainAgent {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378/// 链式执行结果
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub struct ChainResult {
381    /// 链 ID
382    pub chain_id: String,
383    /// 初始任务
384    pub initial_task: String,
385    /// 最终输出
386    pub final_output: String,
387    /// 各步骤结果
388    pub steps: Vec<ChainStepResult>,
389    /// 是否全部成功
390    pub success: bool,
391    /// 错误信息
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub error: Option<String>,
394    /// 总耗时 (毫秒)
395    pub total_duration_ms: u64,
396}
397
398impl ChainResult {
399    /// 获取指定步骤的结果
400    pub fn get_step(&self, step: usize) -> Option<&ChainStepResult> {
401        self.steps.get(step.saturating_sub(1))
402    }
403
404    /// 获取指定 Agent 的结果
405    pub fn get_by_name(&self, name: &str) -> Option<&ChainStepResult> {
406        self.steps.iter().find(|s| s.agent_name == name)
407    }
408}
409
410/// 链式执行步骤结果
411#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct ChainStepResult {
413    /// 步骤序号
414    pub step: usize,
415    /// Agent 名称
416    pub agent_name: String,
417    /// 输入
418    pub input: String,
419    /// 输出
420    pub output: AgentOutput,
421    /// 是否成功
422    pub success: bool,
423}
424
425// ============================================================================
426// Parallel Agent (并行模式)
427// ============================================================================
428
429/// 并行 Agent 执行模式
430///
431/// 多个 Agent 并行执行同一任务,然后聚合结果
432///
433/// # 示例
434///
435/// ```rust,ignore
436/// let parallel = ParallelAgent::new()
437///     .add("expert1", agent1)
438///     .add("expert2", agent2)
439///     .add("expert3", agent3)
440///     .with_aggregation(AggregationStrategy::Concatenate);
441///
442/// let result = parallel.run("Analyze this problem").await?;
443/// ```
444pub struct ParallelAgent {
445    /// Agent 列表
446    agents: Vec<(String, AgentUnit)>,
447    /// 聚合策略
448    aggregation: AggregationStrategy,
449    /// 是否在有失败时仍继续聚合
450    aggregate_on_partial_failure: bool,
451    /// 超时时间 (毫秒)
452    timeout_ms: Option<u64>,
453    /// 是否详细输出
454    verbose: bool,
455    /// 任务模板 (可为不同 Agent 定制任务)
456    task_templates: HashMap<String, String>,
457}
458
459/// 聚合策略
460#[derive(Clone)]
461pub enum AggregationStrategy {
462    /// 简单拼接所有输出
463    Concatenate,
464    /// 使用分隔符拼接
465    ConcatenateWithSeparator(String),
466    /// 返回第一个成功的结果
467    FirstSuccess,
468    /// 返回所有结果 (JSON 格式)
469    CollectAll,
470    /// 投票选择 (适用于分类任务)
471    Vote,
472    /// 使用 LLM 总结聚合
473    LLMSummarize(Arc<LLMAgent>),
474    /// 自定义聚合函数
475    Custom(Arc<dyn Fn(Vec<ParallelStepResult>) -> String + Send + Sync>),
476}
477
478impl ParallelAgent {
479    /// 创建新的并行 Agent
480    pub fn new() -> Self {
481        Self {
482            agents: Vec::new(),
483            aggregation: AggregationStrategy::Concatenate,
484            aggregate_on_partial_failure: true,
485            timeout_ms: None,
486            verbose: true,
487            task_templates: HashMap::new(),
488        }
489    }
490
491    /// 添加 ReAct Agent
492    pub fn add(mut self, name: impl Into<String>, agent: Arc<super::ReActAgent>) -> Self {
493        self.agents.push((name.into(), AgentUnit::react(agent)));
494        self
495    }
496
497    /// 添加 LLM Agent
498    pub fn add_llm(mut self, name: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
499        self.agents.push((name.into(), AgentUnit::llm(agent)));
500        self
501    }
502
503    /// 添加通用 AgentUnit
504    pub fn add_unit(mut self, name: impl Into<String>, unit: AgentUnit) -> Self {
505        self.agents.push((name.into(), unit));
506        self
507    }
508
509    /// 设置聚合策略
510    pub fn with_aggregation(mut self, strategy: AggregationStrategy) -> Self {
511        self.aggregation = strategy;
512        self
513    }
514
515    /// 设置是否在部分失败时仍聚合
516    pub fn with_aggregate_on_partial_failure(mut self, enabled: bool) -> Self {
517        self.aggregate_on_partial_failure = enabled;
518        self
519    }
520
521    /// 设置超时时间
522    pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
523        self.timeout_ms = Some(timeout_ms);
524        self
525    }
526
527    /// 设置是否详细输出
528    pub fn with_verbose(mut self, verbose: bool) -> Self {
529        self.verbose = verbose;
530        self
531    }
532
533    /// 设置特定 Agent 的任务模板
534    ///
535    /// 模板中可使用 `{task}` 占位符表示原始任务
536    ///
537    /// # 示例
538    ///
539    /// ```rust,ignore
540    /// parallel.with_task_template("analyst", "As a financial analyst, {task}");
541    /// ```
542    pub fn with_task_template(
543        mut self,
544        agent_name: impl Into<String>,
545        template: impl Into<String>,
546    ) -> Self {
547        self.task_templates
548            .insert(agent_name.into(), template.into());
549        self
550    }
551
552    /// 执行并行 Agent
553    pub async fn run(&self, task: impl Into<String>) -> LLMResult<ParallelResult> {
554        let task = task.into();
555        let start_time = std::time::Instant::now();
556        let parallel_id = uuid::Uuid::now_v7().to_string();
557
558        if self.verbose {
559            tracing::info!("[Parallel] Starting {} agents for task", self.agents.len());
560        }
561
562        // 准备所有任务
563        let mut handles = Vec::new();
564
565        for (name, agent) in &self.agents {
566            let name = name.clone();
567            let agent = agent.clone();
568            let task_input = self.prepare_task(&name, &task);
569            let verbose = self.verbose;
570
571            let handle = tokio::spawn(async move {
572                if verbose {
573                    tracing::info!("[Parallel] Agent '{}' starting", name);
574                }
575
576                let result = agent.run(&task_input).await;
577
578                if verbose {
579                    match &result {
580                        Ok(output) => {
581                            tracing::info!(
582                                "[Parallel] Agent '{}' completed in {}ms",
583                                name,
584                                output.duration_ms
585                            );
586                        }
587                        Err(e) => {
588                            tracing::warn!("[Parallel] Agent '{}' failed: {}", name, e);
589                        }
590                    }
591                }
592
593                (name, task_input, result)
594            });
595
596            handles.push(handle);
597        }
598
599        // 等待所有任务完成
600        let mut step_results = Vec::new();
601        let mut all_success = true;
602
603        for handle in handles {
604            match handle.await {
605                Ok((name, input, result)) => match result {
606                    Ok(output) => {
607                        if !output.success {
608                            all_success = false;
609                        }
610                        step_results.push(ParallelStepResult {
611                            agent_name: name,
612                            input,
613                            output,
614                            success: true,
615                        });
616                    }
617                    Err(e) => {
618                        all_success = false;
619                        step_results.push(ParallelStepResult {
620                            agent_name: name,
621                            input,
622                            output: AgentOutput {
623                                content: String::new(),
624                                task: task.clone(),
625                                success: false,
626                                error: Some(e.to_string()),
627                                duration_ms: 0,
628                                metadata: None,
629                            },
630                            success: false,
631                        });
632                    }
633                },
634                Err(e) => {
635                    all_success = false;
636                    step_results.push(ParallelStepResult {
637                        agent_name: "unknown".to_string(),
638                        input: task.clone(),
639                        output: AgentOutput {
640                            content: String::new(),
641                            task: task.clone(),
642                            success: false,
643                            error: Some(format!("Task join error: {}", e)),
644                            duration_ms: 0,
645                            metadata: None,
646                        },
647                        success: false,
648                    });
649                }
650            }
651        }
652
653        // 聚合结果
654        let aggregated_output = if all_success || self.aggregate_on_partial_failure {
655            self.aggregate(&step_results).await?
656        } else {
657            String::new()
658        };
659
660        Ok(ParallelResult {
661            parallel_id,
662            task,
663            aggregated_output,
664            individual_results: step_results,
665            success: all_success,
666            total_duration_ms: start_time.elapsed().as_millis() as u64,
667        })
668    }
669
670    /// 准备任务输入
671    fn prepare_task(&self, agent_name: &str, task: &str) -> String {
672        if let Some(template) = self.task_templates.get(agent_name) {
673            template.replace("{task}", task)
674        } else {
675            task.to_string()
676        }
677    }
678
679    /// 聚合结果
680    async fn aggregate(&self, results: &[ParallelStepResult]) -> LLMResult<String> {
681        let successful_results: Vec<&ParallelStepResult> =
682            results.iter().filter(|r| r.success).collect();
683
684        match &self.aggregation {
685            AggregationStrategy::Concatenate => {
686                let outputs: Vec<String> = successful_results
687                    .iter()
688                    .map(|r| format!("[{}]\n{}", r.agent_name, r.output.content))
689                    .collect();
690                Ok(outputs.join("\n\n"))
691            }
692
693            AggregationStrategy::ConcatenateWithSeparator(sep) => {
694                let outputs: Vec<String> = successful_results
695                    .iter()
696                    .map(|r| format!("[{}]\n{}", r.agent_name, r.output.content))
697                    .collect();
698                Ok(outputs.join(sep))
699            }
700
701            AggregationStrategy::FirstSuccess => Ok(successful_results
702                .first()
703                .map(|r| r.output.content.clone())
704                .unwrap_or_default()),
705
706            AggregationStrategy::CollectAll => {
707                let collected: Vec<serde_json::Value> = results
708                    .iter()
709                    .map(|r| {
710                        serde_json::json!({
711                            "agent": r.agent_name,
712                            "success": r.success,
713                            "output": r.output.content,
714                            "error": r.output.error,
715                        })
716                    })
717                    .collect();
718                Ok(serde_json::to_string_pretty(&collected).unwrap_or_else(|_| "[]".to_string()))
719            }
720
721            AggregationStrategy::Vote => {
722                // 简单投票:统计相同输出的数量
723                let mut votes: HashMap<String, usize> = HashMap::new();
724                for result in &successful_results {
725                    let content = result.output.content.trim().to_lowercase();
726                    *votes.entry(content).or_insert(0) += 1;
727                }
728
729                let winner = votes
730                    .into_iter()
731                    .max_by_key(|(_, count)| *count)
732                    .map(|(content, _)| content)
733                    .unwrap_or_default();
734
735                // 返回原始大小写版本
736                Ok(successful_results
737                    .iter()
738                    .find(|r| r.output.content.trim().to_lowercase() == winner)
739                    .map(|r| r.output.content.clone())
740                    .unwrap_or(winner))
741            }
742
743            AggregationStrategy::LLMSummarize(llm) => {
744                let outputs: Vec<String> = successful_results
745                    .iter()
746                    .map(|r| format!("Expert '{}' says:\n{}", r.agent_name, r.output.content))
747                    .collect();
748
749                let prompt = format!(
750                    r#"You are tasked with synthesizing multiple expert opinions into a coherent summary.
751
752Here are the expert opinions:
753
754{}
755
756Please provide a comprehensive synthesis that:
7571. Identifies common themes and agreements
7582. Notes any significant disagreements
7593. Provides a balanced conclusion
760
761Synthesized Summary:"#,
762                    outputs.join("\n\n---\n\n")
763                );
764
765                llm.ask(&prompt).await
766            }
767
768            AggregationStrategy::Custom(f) => Ok(f(results.to_vec())),
769        }
770    }
771
772    /// 获取 Agent 数量
773    pub fn len(&self) -> usize {
774        self.agents.len()
775    }
776
777    /// 检查是否为空
778    pub fn is_empty(&self) -> bool {
779        self.agents.is_empty()
780    }
781}
782
783impl Default for ParallelAgent {
784    fn default() -> Self {
785        Self::new()
786    }
787}
788
789/// 并行执行结果
790#[derive(Debug, Clone, Serialize, Deserialize)]
791pub struct ParallelResult {
792    /// 并行执行 ID
793    pub parallel_id: String,
794    /// 原始任务
795    pub task: String,
796    /// 聚合后的输出
797    pub aggregated_output: String,
798    /// 各 Agent 的单独结果
799    pub individual_results: Vec<ParallelStepResult>,
800    /// 是否全部成功
801    pub success: bool,
802    /// 总耗时 (毫秒)
803    pub total_duration_ms: u64,
804}
805
806impl ParallelResult {
807    /// 获取成功的结果数量
808    pub fn success_count(&self) -> usize {
809        self.individual_results.iter().filter(|r| r.success).count()
810    }
811
812    /// 获取失败的结果数量
813    pub fn failure_count(&self) -> usize {
814        self.individual_results
815            .iter()
816            .filter(|r| !r.success)
817            .count()
818    }
819
820    /// 获取指定 Agent 的结果
821    pub fn get_by_name(&self, name: &str) -> Option<&ParallelStepResult> {
822        self.individual_results
823            .iter()
824            .find(|r| r.agent_name == name)
825    }
826}
827
828/// 并行执行步骤结果
829#[derive(Debug, Clone, Serialize, Deserialize)]
830pub struct ParallelStepResult {
831    /// Agent 名称
832    pub agent_name: String,
833    /// 输入任务
834    pub input: String,
835    /// 输出结果
836    pub output: AgentOutput,
837    /// 是否成功
838    pub success: bool,
839}
840
841// ============================================================================
842// 便捷构建函数
843// ============================================================================
844
845/// 创建简单的链式 Agent
846///
847/// # 示例
848///
849/// ```rust,ignore
850/// let chain = chain_agents(vec![
851///     ("researcher", researcher_agent),
852///     ("writer", writer_agent),
853///     ("editor", editor_agent),
854/// ]);
855/// ```
856pub fn chain_agents(agents: Vec<(&str, Arc<super::ReActAgent>)>) -> ChainAgent {
857    let mut chain = ChainAgent::new();
858    for (name, agent) in agents {
859        chain = chain.add(name, agent);
860    }
861    chain
862}
863
864/// 创建简单的并行 Agent
865///
866/// # 示例
867///
868/// ```rust,ignore
869/// let parallel = parallel_agents(vec![
870///     ("analyst1", analyst1_agent),
871///     ("analyst2", analyst2_agent),
872/// ]);
873/// ```
874pub fn parallel_agents(agents: Vec<(&str, Arc<super::ReActAgent>)>) -> ParallelAgent {
875    let mut parallel = ParallelAgent::new();
876    for (name, agent) in agents {
877        parallel = parallel.add(name, agent);
878    }
879    parallel
880}
881
882/// 创建带 LLM 聚合的并行 Agent
883///
884/// # 示例
885///
886/// ```rust,ignore
887/// let parallel = parallel_agents_with_summarizer(
888///     vec![
889///         ("expert1", expert1_agent),
890///         ("expert2", expert2_agent),
891///     ],
892///     summarizer_llm,
893/// );
894/// ```
895pub fn parallel_agents_with_summarizer(
896    agents: Vec<(&str, Arc<super::ReActAgent>)>,
897    summarizer: Arc<LLMAgent>,
898) -> ParallelAgent {
899    parallel_agents(agents).with_aggregation(AggregationStrategy::LLMSummarize(summarizer))
900}
901
902// ============================================================================
903// MapReduce 模式
904// ============================================================================
905
906/// MapReduce Agent
907///
908/// 将任务拆分、并行处理、然后归约结果
909///
910/// # 示例
911///
912/// ```rust,ignore
913/// let map_reduce = MapReduceAgent::new()
914///     .with_mapper(|task| {
915///         // 拆分任务为多个子任务
916///         task.split('\n').map(|s| s.to_string()).collect()
917///     })
918///     .with_worker(worker_agent)
919///     .with_reducer(reducer_agent);
920///
921/// let result = map_reduce.run("line1\nline2\nline3").await?;
922/// ```
923pub struct MapReduceAgent {
924    /// Map 函数 - 将输入拆分为多个子任务
925    mapper: Option<Arc<dyn Fn(&str) -> Vec<String> + Send + Sync>>,
926    /// 工作 Agent (处理子任务)
927    worker: Option<AgentUnit>,
928    /// Reduce Agent (聚合结果)
929    reducer: Option<AgentUnit>,
930    /// 并行度限制
931    concurrency_limit: Option<usize>,
932    /// 是否详细输出
933    verbose: bool,
934}
935
936impl MapReduceAgent {
937    /// 创建新的 MapReduce Agent
938    pub fn new() -> Self {
939        Self {
940            mapper: None,
941            worker: None,
942            reducer: None,
943            concurrency_limit: None,
944            verbose: true,
945        }
946    }
947
948    /// 设置 Map 函数
949    pub fn with_mapper<F>(mut self, f: F) -> Self
950    where
951        F: Fn(&str) -> Vec<String> + Send + Sync + 'static,
952    {
953        self.mapper = Some(Arc::new(f));
954        self
955    }
956
957    /// 设置工作 Agent (ReAct)
958    pub fn with_worker(mut self, agent: Arc<super::ReActAgent>) -> Self {
959        self.worker = Some(AgentUnit::react(agent));
960        self
961    }
962
963    /// 设置工作 Agent (LLM)
964    pub fn with_worker_llm(mut self, agent: Arc<LLMAgent>) -> Self {
965        self.worker = Some(AgentUnit::llm(agent));
966        self
967    }
968
969    /// 设置 Reduce Agent (ReAct)
970    pub fn with_reducer(mut self, agent: Arc<super::ReActAgent>) -> Self {
971        self.reducer = Some(AgentUnit::react(agent));
972        self
973    }
974
975    /// 设置 Reduce Agent (LLM)
976    pub fn with_reducer_llm(mut self, agent: Arc<LLMAgent>) -> Self {
977        self.reducer = Some(AgentUnit::llm(agent));
978        self
979    }
980
981    /// 设置并行度限制
982    pub fn with_concurrency_limit(mut self, limit: usize) -> Self {
983        self.concurrency_limit = Some(limit);
984        self
985    }
986
987    /// 设置是否详细输出
988    pub fn with_verbose(mut self, verbose: bool) -> Self {
989        self.verbose = verbose;
990        self
991    }
992
993    /// 执行 MapReduce
994    pub async fn run(&self, input: impl Into<String>) -> LLMResult<MapReduceResult> {
995        let input = input.into();
996        let start_time = std::time::Instant::now();
997        let mr_id = uuid::Uuid::now_v7().to_string();
998
999        // Map 阶段
1000        let mapper = self
1001            .mapper
1002            .as_ref()
1003            .ok_or_else(|| LLMError::ConfigError("Mapper not set".to_string()))?;
1004
1005        let sub_tasks = mapper(&input);
1006
1007        if self.verbose {
1008            tracing::info!("[MapReduce] Mapped to {} sub-tasks", sub_tasks.len());
1009        }
1010
1011        // 并行处理阶段
1012        let worker = self
1013            .worker
1014            .as_ref()
1015            .ok_or_else(|| LLMError::ConfigError("Worker not set".to_string()))?;
1016
1017        let mut handles = Vec::new();
1018        let semaphore = self
1019            .concurrency_limit
1020            .map(|limit| Arc::new(tokio::sync::Semaphore::new(limit)));
1021
1022        for (idx, sub_task) in sub_tasks.into_iter().enumerate() {
1023            let worker = worker.clone();
1024            let semaphore = semaphore.clone();
1025            let verbose = self.verbose;
1026
1027            let handle = tokio::spawn(async move {
1028                let _permit = if let Some(ref sem) = semaphore {
1029                    Some(sem.acquire().await)
1030                } else {
1031                    None
1032                };
1033
1034                if verbose {
1035                    tracing::info!("[MapReduce] Processing sub-task {}", idx + 1);
1036                }
1037
1038                let result = worker.run(&sub_task).await;
1039
1040                if verbose {
1041                    match &result {
1042                        Ok(_) => tracing::info!("[MapReduce] Sub-task {} completed", idx + 1),
1043                        Err(e) => {
1044                            tracing::warn!("[MapReduce] Sub-task {} failed: {}", idx + 1, e)
1045                        }
1046                    }
1047                }
1048
1049                (idx, sub_task, result)
1050            });
1051
1052            handles.push(handle);
1053        }
1054
1055        // 收集结果
1056        let mut map_results = Vec::new();
1057        for handle in handles {
1058            match handle.await {
1059                Ok((idx, sub_task, result)) => {
1060                    map_results.push(MapStepResult {
1061                        index: idx,
1062                        input: sub_task,
1063                        output: result.ok(),
1064                    });
1065                }
1066                Err(e) => {
1067                    map_results.push(MapStepResult {
1068                        index: map_results.len(),
1069                        input: String::new(),
1070                        output: None,
1071                    });
1072                    tracing::error!("[MapReduce] Task join error: {}", e);
1073                }
1074            }
1075        }
1076
1077        // 按索引排序
1078        map_results.sort_by_key(|r| r.index);
1079
1080        // Reduce 阶段
1081        let reducer = self
1082            .reducer
1083            .as_ref()
1084            .ok_or_else(|| LLMError::ConfigError("Reducer not set".to_string()))?;
1085
1086        let map_outputs: Vec<String> = map_results
1087            .iter()
1088            .filter_map(|r| r.output.as_ref().map(|o| o.content.clone()))
1089            .collect();
1090
1091        let reduce_input = format!(
1092            "Please synthesize the following {} results:\n\n{}",
1093            map_outputs.len(),
1094            map_outputs
1095                .iter()
1096                .enumerate()
1097                .map(|(i, o)| format!("[Result {}]\n{}", i + 1, o))
1098                .collect::<Vec<_>>()
1099                .join("\n\n---\n\n")
1100        );
1101
1102        if self.verbose {
1103            tracing::info!("[MapReduce] Starting reduce phase");
1104        }
1105
1106        let reduce_output = reducer.run(&reduce_input).await?;
1107
1108        Ok(MapReduceResult {
1109            mr_id,
1110            input,
1111            map_results,
1112            reduce_output,
1113            total_duration_ms: start_time.elapsed().as_millis() as u64,
1114        })
1115    }
1116}
1117
1118impl Default for MapReduceAgent {
1119    fn default() -> Self {
1120        Self::new()
1121    }
1122}
1123
1124/// MapReduce 执行结果
1125#[derive(Debug, Clone, Serialize, Deserialize)]
1126pub struct MapReduceResult {
1127    /// MapReduce ID
1128    pub mr_id: String,
1129    /// 原始输入
1130    pub input: String,
1131    /// Map 阶段结果
1132    pub map_results: Vec<MapStepResult>,
1133    /// Reduce 阶段输出
1134    pub reduce_output: AgentOutput,
1135    /// 总耗时 (毫秒)
1136    pub total_duration_ms: u64,
1137}
1138
1139/// Map 步骤结果
1140#[derive(Debug, Clone, Serialize, Deserialize)]
1141pub struct MapStepResult {
1142    /// 索引
1143    pub index: usize,
1144    /// 输入
1145    pub input: String,
1146    /// 输出
1147    pub output: Option<AgentOutput>,
1148}
1149
1150// ============================================================================
1151// 测试
1152// ============================================================================
1153
1154#[cfg(test)]
1155mod tests {
1156    use super::*;
1157
1158    #[test]
1159    fn test_chain_agent_builder() {
1160        let chain = ChainAgent::new()
1161            .with_continue_on_error(true)
1162            .with_verbose(false);
1163
1164        assert!(chain.is_empty());
1165        assert!(!chain.verbose);
1166        assert!(chain.continue_on_error);
1167    }
1168
1169    #[test]
1170    fn test_parallel_agent_builder() {
1171        let parallel = ParallelAgent::new()
1172            .with_aggregation(AggregationStrategy::Concatenate)
1173            .with_timeout_ms(5000)
1174            .with_verbose(false);
1175
1176        assert!(parallel.is_empty());
1177        assert!(!parallel.verbose);
1178    }
1179
1180    #[test]
1181    fn test_map_reduce_builder() {
1182        let mr = MapReduceAgent::new()
1183            .with_mapper(|s| s.lines().map(|l| l.to_string()).collect())
1184            .with_concurrency_limit(4);
1185
1186        assert!(mr.mapper.is_some());
1187        assert_eq!(mr.concurrency_limit, Some(4));
1188    }
1189}