Skip to main content

mofa_foundation/llm/
agent_workflow.rs

1//! LLM Agent 工作流编排
2//!
3//! 将 LLMAgent 与 WorkflowGraph 结合,提供高级的多 Agent 工作流编排能力
4//!
5//! # 功能特性
6//!
7//! - **Agent 节点**: 将 LLMAgent 封装为工作流节点
8//! - **条件路由**: 基于 LLM 输出进行条件分支
9//! - **并行执行**: 多个 Agent 并行处理
10//! - **流式响应**: 支持工作流中的流式输出
11//! - **会话共享**: 工作流节点间共享会话上下文
12//!
13//! # 示例
14//!
15//! ```rust,ignore
16//! use mofa_foundation::llm::{AgentWorkflow, LLMAgent};
17//! use std::sync::Arc;
18//!
19//! // 创建 Agent 工作流
20//! let workflow = AgentWorkflow::new("content-pipeline")
21//!     .add_agent("researcher", researcher_agent)
22//!     .add_agent("writer", writer_agent)
23//!     .add_agent("editor", editor_agent)
24//!     .chain(["researcher", "writer", "editor"])
25//!     .build();
26//!
27//! // 执行工作流
28//! let result = workflow.run("Write an article about Rust").await?;
29//! ```
30
31use super::agent::LLMAgent;
32use super::types::{LLMError, LLMResult};
33use std::collections::HashMap;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use tokio::sync::RwLock;
38
39/// Agent 工作流节点类型
40#[derive(Debug, Clone)]
41pub enum AgentNodeType {
42    /// 开始节点
43    Start,
44    /// 结束节点
45    End,
46    /// LLM Agent 节点
47    Agent,
48    /// 条件路由节点
49    Router,
50    /// 并行分发节点
51    Parallel,
52    /// 聚合节点
53    Join,
54    /// 转换节点
55    Transform,
56}
57
58/// 节点输入/输出值
59#[derive(Debug, Clone)]
60pub enum AgentValue {
61    /// 空值
62    Null,
63    /// 文本
64    Text(String),
65    /// 多个文本(用于并行结果)
66    Texts(Vec<String>),
67    /// 键值对
68    Map(HashMap<String, String>),
69    /// JSON 值
70    Json(serde_json::Value),
71}
72
73impl AgentValue {
74    pub fn as_text(&self) -> Option<&str> {
75        match self {
76            AgentValue::Text(s) => Some(s),
77            _ => None,
78        }
79    }
80
81    pub fn into_text(self) -> String {
82        match self {
83            AgentValue::Text(s) => s,
84            AgentValue::Texts(v) => v.join("\n"),
85            AgentValue::Map(m) => serde_json::to_string(&m).unwrap_or_default(),
86            AgentValue::Json(j) => j.to_string(),
87            AgentValue::Null => String::new(),
88        }
89    }
90
91    pub fn as_texts(&self) -> Option<&Vec<String>> {
92        match self {
93            AgentValue::Texts(v) => Some(v),
94            _ => None,
95        }
96    }
97
98    pub fn as_map(&self) -> Option<&HashMap<String, String>> {
99        match self {
100            AgentValue::Map(m) => Some(m),
101            _ => None,
102        }
103    }
104}
105
106impl From<String> for AgentValue {
107    fn from(s: String) -> Self {
108        AgentValue::Text(s)
109    }
110}
111
112impl From<&str> for AgentValue {
113    fn from(s: &str) -> Self {
114        AgentValue::Text(s.to_string())
115    }
116}
117
118impl From<Vec<String>> for AgentValue {
119    fn from(v: Vec<String>) -> Self {
120        AgentValue::Texts(v)
121    }
122}
123
124impl From<HashMap<String, String>> for AgentValue {
125    fn from(m: HashMap<String, String>) -> Self {
126        AgentValue::Map(m)
127    }
128}
129
130/// 路由决策函数类型
131pub type RouterFn =
132    Arc<dyn Fn(AgentValue) -> Pin<Box<dyn Future<Output = String> + Send>> + Send + Sync>;
133
134/// 转换函数类型
135pub type TransformFn =
136    Arc<dyn Fn(AgentValue) -> Pin<Box<dyn Future<Output = AgentValue> + Send>> + Send + Sync>;
137
138/// 聚合函数类型
139pub type JoinFn = Arc<
140    dyn Fn(HashMap<String, AgentValue>) -> Pin<Box<dyn Future<Output = AgentValue> + Send>>
141        + Send
142        + Sync,
143>;
144
145/// Agent 工作流节点
146pub struct AgentNode {
147    /// 节点 ID
148    pub id: String,
149    /// 节点名称
150    pub name: String,
151    /// 节点类型
152    pub node_type: AgentNodeType,
153    /// Agent 引用(仅 Agent 节点)
154    agent: Option<Arc<LLMAgent>>,
155    /// 路由函数(仅 Router 节点)
156    router: Option<RouterFn>,
157    /// 转换函数(仅 Transform 节点)
158    transform: Option<TransformFn>,
159    /// 聚合函数(仅 Join 节点)
160    join_fn: Option<JoinFn>,
161    /// 等待的节点列表(仅 Join 节点)
162    wait_for: Vec<String>,
163    /// 提示词模板(Agent 节点使用)
164    prompt_template: Option<String>,
165    /// 会话 ID(用于多轮对话)
166    session_id: Option<String>,
167}
168
169impl AgentNode {
170    /// 创建开始节点
171    pub fn start() -> Self {
172        Self {
173            id: "start".to_string(),
174            name: "Start".to_string(),
175            node_type: AgentNodeType::Start,
176            agent: None,
177            router: None,
178            transform: None,
179            join_fn: None,
180            wait_for: Vec::new(),
181            prompt_template: None,
182            session_id: None,
183        }
184    }
185
186    /// 创建结束节点
187    pub fn end() -> Self {
188        Self {
189            id: "end".to_string(),
190            name: "End".to_string(),
191            node_type: AgentNodeType::End,
192            agent: None,
193            router: None,
194            transform: None,
195            join_fn: None,
196            wait_for: Vec::new(),
197            prompt_template: None,
198            session_id: None,
199        }
200    }
201
202    /// 创建 Agent 节点
203    pub fn agent(id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
204        let id = id.into();
205        Self {
206            name: id.clone(),
207            id,
208            node_type: AgentNodeType::Agent,
209            agent: Some(agent),
210            router: None,
211            transform: None,
212            join_fn: None,
213            wait_for: Vec::new(),
214            prompt_template: None,
215            session_id: None,
216        }
217    }
218
219    /// 创建路由节点
220    pub fn router<F, Fut>(id: impl Into<String>, router_fn: F) -> Self
221    where
222        F: Fn(AgentValue) -> Fut + Send + Sync + 'static,
223        Fut: Future<Output = String> + Send + 'static,
224    {
225        let id = id.into();
226        Self {
227            name: id.clone(),
228            id,
229            node_type: AgentNodeType::Router,
230            agent: None,
231            router: Some(Arc::new(move |input| Box::pin(router_fn(input)))),
232            transform: None,
233            join_fn: None,
234            wait_for: Vec::new(),
235            prompt_template: None,
236            session_id: None,
237        }
238    }
239
240    /// 创建并行节点
241    pub fn parallel(id: impl Into<String>) -> Self {
242        let id = id.into();
243        Self {
244            name: id.clone(),
245            id,
246            node_type: AgentNodeType::Parallel,
247            agent: None,
248            router: None,
249            transform: None,
250            join_fn: None,
251            wait_for: Vec::new(),
252            prompt_template: None,
253            session_id: None,
254        }
255    }
256
257    /// 创建聚合节点
258    pub fn join(id: impl Into<String>, wait_for: Vec<String>) -> Self {
259        let id = id.into();
260        Self {
261            name: id.clone(),
262            id,
263            node_type: AgentNodeType::Join,
264            agent: None,
265            router: None,
266            transform: None,
267            join_fn: None,
268            wait_for,
269            prompt_template: None,
270            session_id: None,
271        }
272    }
273
274    /// 创建带自定义聚合函数的聚合节点
275    pub fn join_with<F, Fut>(id: impl Into<String>, wait_for: Vec<String>, join_fn: F) -> Self
276    where
277        F: Fn(HashMap<String, AgentValue>) -> Fut + Send + Sync + 'static,
278        Fut: Future<Output = AgentValue> + Send + 'static,
279    {
280        let id = id.into();
281        Self {
282            name: id.clone(),
283            id,
284            node_type: AgentNodeType::Join,
285            agent: None,
286            router: None,
287            transform: None,
288            join_fn: Some(Arc::new(move |inputs| Box::pin(join_fn(inputs)))),
289            wait_for,
290            prompt_template: None,
291            session_id: None,
292        }
293    }
294
295    /// 创建转换节点
296    pub fn transform<F, Fut>(id: impl Into<String>, transform_fn: F) -> Self
297    where
298        F: Fn(AgentValue) -> Fut + Send + Sync + 'static,
299        Fut: Future<Output = AgentValue> + Send + 'static,
300    {
301        let id = id.into();
302        Self {
303            name: id.clone(),
304            id,
305            node_type: AgentNodeType::Transform,
306            agent: None,
307            router: None,
308            transform: Some(Arc::new(move |input| Box::pin(transform_fn(input)))),
309            join_fn: None,
310            wait_for: Vec::new(),
311            prompt_template: None,
312            session_id: None,
313        }
314    }
315
316    /// 设置名称
317    pub fn with_name(mut self, name: impl Into<String>) -> Self {
318        self.name = name.into();
319        self
320    }
321
322    /// 设置提示词模板
323    ///
324    /// 模板中可以使用 `{input}` 占位符
325    pub fn with_prompt_template(mut self, template: impl Into<String>) -> Self {
326        self.prompt_template = Some(template.into());
327        self
328    }
329
330    /// 设置会话 ID
331    pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
332        self.session_id = Some(session_id.into());
333        self
334    }
335}
336
337/// Agent 工作流边
338#[derive(Debug, Clone)]
339pub struct AgentEdge {
340    /// 源节点 ID
341    pub from: String,
342    /// 目标节点 ID
343    pub to: String,
344    /// 条件(用于路由)
345    pub condition: Option<String>,
346}
347
348impl AgentEdge {
349    pub fn new(from: impl Into<String>, to: impl Into<String>) -> Self {
350        Self {
351            from: from.into(),
352            to: to.into(),
353            condition: None,
354        }
355    }
356
357    pub fn conditional(
358        from: impl Into<String>,
359        to: impl Into<String>,
360        condition: impl Into<String>,
361    ) -> Self {
362        Self {
363            from: from.into(),
364            to: to.into(),
365            condition: Some(condition.into()),
366        }
367    }
368}
369
370/// Agent 工作流执行上下文
371pub struct AgentWorkflowContext {
372    /// 工作流 ID
373    pub workflow_id: String,
374    /// 执行 ID
375    pub execution_id: String,
376    /// 节点输出
377    node_outputs: Arc<RwLock<HashMap<String, AgentValue>>>,
378    /// 共享会话 ID(用于多 Agent 共享上下文)
379    shared_session_id: Option<String>,
380    /// 变量存储
381    variables: Arc<RwLock<HashMap<String, String>>>,
382}
383
384impl AgentWorkflowContext {
385    pub fn new(workflow_id: impl Into<String>) -> Self {
386        Self {
387            workflow_id: workflow_id.into(),
388            execution_id: uuid::Uuid::now_v7().to_string(),
389            node_outputs: Arc::new(RwLock::new(HashMap::new())),
390            shared_session_id: None,
391            variables: Arc::new(RwLock::new(HashMap::new())),
392        }
393    }
394
395    pub fn with_shared_session(mut self, session_id: impl Into<String>) -> Self {
396        self.shared_session_id = Some(session_id.into());
397        self
398    }
399
400    pub async fn set_output(&self, node_id: &str, value: AgentValue) {
401        let mut outputs = self.node_outputs.write().await;
402        outputs.insert(node_id.to_string(), value);
403    }
404
405    pub async fn get_output(&self, node_id: &str) -> Option<AgentValue> {
406        let outputs = self.node_outputs.read().await;
407        outputs.get(node_id).cloned()
408    }
409
410    pub async fn get_outputs(&self, node_ids: &[String]) -> HashMap<String, AgentValue> {
411        let outputs = self.node_outputs.read().await;
412        node_ids
413            .iter()
414            .filter_map(|id| outputs.get(id).map(|v| (id.clone(), v.clone())))
415            .collect()
416    }
417
418    pub async fn set_variable(&self, key: &str, value: &str) {
419        let mut vars = self.variables.write().await;
420        vars.insert(key.to_string(), value.to_string());
421    }
422
423    pub async fn get_variable(&self, key: &str) -> Option<String> {
424        let vars = self.variables.read().await;
425        vars.get(key).cloned()
426    }
427}
428
429impl Clone for AgentWorkflowContext {
430    fn clone(&self) -> Self {
431        Self {
432            workflow_id: self.workflow_id.clone(),
433            execution_id: self.execution_id.clone(),
434            node_outputs: self.node_outputs.clone(),
435            shared_session_id: self.shared_session_id.clone(),
436            variables: self.variables.clone(),
437        }
438    }
439}
440
441/// Agent 工作流
442pub struct AgentWorkflow {
443    /// 工作流 ID
444    pub id: String,
445    /// 工作流名称
446    pub name: String,
447    /// 节点映射
448    nodes: HashMap<String, AgentNode>,
449    /// 边列表
450    edges: Vec<AgentEdge>,
451    /// 邻接表(源节点 -> 边列表)
452    adjacency: HashMap<String, Vec<AgentEdge>>,
453}
454
455impl AgentWorkflow {
456    /// 创建新的工作流
457    pub fn new(id: impl Into<String>) -> AgentWorkflowBuilder {
458        AgentWorkflowBuilder::new(id)
459    }
460
461    /// 执行工作流
462    pub async fn run(&self, input: impl Into<AgentValue>) -> LLMResult<AgentValue> {
463        let ctx = AgentWorkflowContext::new(&self.id);
464        self.run_with_context(&ctx, input).await
465    }
466
467    /// 使用指定上下文执行工作流
468    pub async fn run_with_context(
469        &self,
470        ctx: &AgentWorkflowContext,
471        input: impl Into<AgentValue>,
472    ) -> LLMResult<AgentValue> {
473        let input = input.into();
474        let mut current_node_id = "start".to_string();
475        let mut current_input = input;
476
477        loop {
478            let node = self
479                .nodes
480                .get(&current_node_id)
481                .ok_or_else(|| LLMError::Other(format!("Node '{}' not found", current_node_id)))?;
482
483            // 执行节点
484            let output = self.execute_node(ctx, node, current_input.clone()).await?;
485
486            // 保存输出
487            ctx.set_output(&current_node_id, output.clone()).await;
488
489            // 确定下一个节点
490            match self.get_next_node(&current_node_id, &output).await {
491                Some(next_id) => {
492                    current_node_id = next_id;
493                    current_input = output;
494                }
495                None => {
496                    // 工作流结束
497                    return Ok(output);
498                }
499            }
500        }
501    }
502
503    /// 执行单个节点
504    async fn execute_node(
505        &self,
506        ctx: &AgentWorkflowContext,
507        node: &AgentNode,
508        input: AgentValue,
509    ) -> LLMResult<AgentValue> {
510        match node.node_type {
511            AgentNodeType::Start | AgentNodeType::End => Ok(input),
512
513            AgentNodeType::Agent => {
514                let agent = node
515                    .agent
516                    .as_ref()
517                    .ok_or_else(|| LLMError::Other("Agent not set".to_string()))?;
518
519                // 构建提示词
520                let prompt = if let Some(ref template) = node.prompt_template {
521                    template.replace("{input}", &input.clone().into_text())
522                } else {
523                    input.clone().into_text()
524                };
525
526                // 确定会话 ID
527                let session_id = node
528                    .session_id
529                    .clone()
530                    .or_else(|| ctx.shared_session_id.clone());
531
532                // 发送消息
533                let response = if let Some(sid) = session_id {
534                    // 确保会话存在
535                    let _ = agent.get_or_create_session(&sid).await;
536                    agent.chat_with_session(&sid, &prompt).await?
537                } else {
538                    agent.ask(&prompt).await?
539                };
540
541                Ok(AgentValue::Text(response))
542            }
543
544            AgentNodeType::Router => {
545                let router = node
546                    .router
547                    .as_ref()
548                    .ok_or_else(|| LLMError::Other("Router function not set".to_string()))?;
549                let _route = router(input.clone()).await;
550                // 路由节点返回原输入,路由决策在 get_next_node 中使用
551                Ok(input)
552            }
553
554            AgentNodeType::Parallel => {
555                // 并行节点直接传递输入,实际并行执行在工作流执行逻辑中处理
556                Ok(input)
557            }
558
559            AgentNodeType::Join => {
560                // 收集所有前置节点的输出
561                let outputs = ctx.get_outputs(&node.wait_for).await;
562
563                if let Some(ref join_fn) = node.join_fn {
564                    Ok(join_fn(outputs).await)
565                } else {
566                    // 默认聚合:合并所有文本输出
567                    let texts: Vec<String> = outputs.into_values().map(|v| v.into_text()).collect();
568                    Ok(AgentValue::Texts(texts))
569                }
570            }
571
572            AgentNodeType::Transform => {
573                let transform = node
574                    .transform
575                    .as_ref()
576                    .ok_or_else(|| LLMError::Other("Transform function not set".to_string()))?;
577                Ok(transform(input).await)
578            }
579        }
580    }
581
582    /// 获取下一个节点
583    async fn get_next_node(&self, current_id: &str, output: &AgentValue) -> Option<String> {
584        let node = self.nodes.get(current_id)?;
585
586        // 结束节点没有后续
587        if matches!(node.node_type, AgentNodeType::End) {
588            return None;
589        }
590
591        let edges = self.adjacency.get(current_id)?;
592
593        // 路由节点:根据路由函数结果选择边
594        if matches!(node.node_type, AgentNodeType::Router) {
595            if let Some(ref router) = node.router {
596                let route = router(output.clone()).await;
597                for edge in edges {
598                    if edge.condition.as_ref() == Some(&route) {
599                        return Some(edge.to.clone());
600                    }
601                }
602            }
603            // 如果没有匹配的条件边,使用默认边(无条件)
604            for edge in edges {
605                if edge.condition.is_none() {
606                    return Some(edge.to.clone());
607                }
608            }
609            return None;
610        }
611
612        // 非路由节点:使用第一条边
613        edges.first().map(|e| e.to.clone())
614    }
615
616    /// 获取节点
617    pub fn get_node(&self, id: &str) -> Option<&AgentNode> {
618        self.nodes.get(id)
619    }
620
621    /// 获取所有节点 ID
622    pub fn node_ids(&self) -> Vec<&str> {
623        self.nodes.keys().map(|s| s.as_str()).collect()
624    }
625}
626
627/// Agent 工作流构建器
628pub struct AgentWorkflowBuilder {
629    id: String,
630    name: String,
631    nodes: HashMap<String, AgentNode>,
632    edges: Vec<AgentEdge>,
633    current_node: Option<String>,
634}
635
636impl AgentWorkflowBuilder {
637    /// 创建新的构建器
638    pub fn new(id: impl Into<String>) -> Self {
639        let id = id.into();
640        let mut builder = Self {
641            name: id.clone(),
642            id,
643            nodes: HashMap::new(),
644            edges: Vec::new(),
645            current_node: None,
646        };
647        // 自动添加 start 节点
648        builder
649            .nodes
650            .insert("start".to_string(), AgentNode::start());
651        builder.current_node = Some("start".to_string());
652        builder
653    }
654
655    /// 设置名称
656    pub fn with_name(mut self, name: impl Into<String>) -> Self {
657        self.name = name.into();
658        self
659    }
660
661    /// 添加 Agent 节点
662    pub fn add_agent(mut self, id: impl Into<String>, agent: Arc<LLMAgent>) -> Self {
663        let id = id.into();
664        let node = AgentNode::agent(&id, agent);
665        self.nodes.insert(id, node);
666        self
667    }
668
669    /// 添加带提示词模板的 Agent 节点
670    pub fn add_agent_with_template(
671        mut self,
672        id: impl Into<String>,
673        agent: Arc<LLMAgent>,
674        template: impl Into<String>,
675    ) -> Self {
676        let id = id.into();
677        let node = AgentNode::agent(&id, agent).with_prompt_template(template);
678        self.nodes.insert(id, node);
679        self
680    }
681
682    /// 添加路由节点
683    pub fn add_router<F, Fut>(mut self, id: impl Into<String>, router_fn: F) -> Self
684    where
685        F: Fn(AgentValue) -> Fut + Send + Sync + 'static,
686        Fut: Future<Output = String> + Send + 'static,
687    {
688        let id = id.into();
689        let node = AgentNode::router(&id, router_fn);
690        self.nodes.insert(id, node);
691        self
692    }
693
694    /// 添加基于 LLM 的智能路由节点
695    pub fn add_llm_router(
696        mut self,
697        id: impl Into<String>,
698        router_agent: Arc<LLMAgent>,
699        routes: Vec<String>,
700    ) -> Self {
701        let id = id.into();
702        let routes_str = routes.join(", ");
703        let prompt = format!(
704            "Based on the following input, choose the most appropriate route. \
705            Available routes: {}. \
706            Respond with ONLY the route name, nothing else.\n\nInput: {{input}}",
707            routes_str
708        );
709
710        let routes_clone = routes.clone();
711        let router_fn = move |input: AgentValue| {
712            let agent = router_agent.clone();
713            let prompt = prompt.replace("{input}", &input.into_text());
714            let valid_routes = routes_clone.clone();
715            async move {
716                match agent.ask(&prompt).await {
717                    Ok(response) => {
718                        let route = response.trim().to_string();
719                        if valid_routes.contains(&route) {
720                            route
721                        } else {
722                            valid_routes.first().cloned().unwrap_or_default()
723                        }
724                    }
725                    Err(_) => valid_routes.first().cloned().unwrap_or_default(),
726                }
727            }
728        };
729
730        let node = AgentNode::router(&id, router_fn);
731        self.nodes.insert(id, node);
732        self
733    }
734
735    /// 添加转换节点
736    pub fn add_transform<F, Fut>(mut self, id: impl Into<String>, transform_fn: F) -> Self
737    where
738        F: Fn(AgentValue) -> Fut + Send + Sync + 'static,
739        Fut: Future<Output = AgentValue> + Send + 'static,
740    {
741        let id = id.into();
742        let node = AgentNode::transform(&id, transform_fn);
743        self.nodes.insert(id, node);
744        self
745    }
746
747    /// 添加并行节点
748    pub fn add_parallel(mut self, id: impl Into<String>) -> Self {
749        let id = id.into();
750        let node = AgentNode::parallel(&id);
751        self.nodes.insert(id, node);
752        self
753    }
754
755    /// 添加聚合节点
756    pub fn add_join(mut self, id: impl Into<String>, wait_for: Vec<&str>) -> Self {
757        let id = id.into();
758        let wait_for: Vec<String> = wait_for.into_iter().map(|s| s.to_string()).collect();
759        let node = AgentNode::join(&id, wait_for);
760        self.nodes.insert(id, node);
761        self
762    }
763
764    /// 添加带自定义函数的聚合节点
765    pub fn add_join_with<F, Fut>(
766        mut self,
767        id: impl Into<String>,
768        wait_for: Vec<&str>,
769        join_fn: F,
770    ) -> Self
771    where
772        F: Fn(HashMap<String, AgentValue>) -> Fut + Send + Sync + 'static,
773        Fut: Future<Output = AgentValue> + Send + 'static,
774    {
775        let id = id.into();
776        let wait_for: Vec<String> = wait_for.into_iter().map(|s| s.to_string()).collect();
777        let node = AgentNode::join_with(&id, wait_for, join_fn);
778        self.nodes.insert(id, node);
779        self
780    }
781
782    /// 添加边
783    pub fn connect(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
784        self.edges.push(AgentEdge::new(from, to));
785        self
786    }
787
788    /// 添加条件边
789    pub fn connect_on(
790        mut self,
791        from: impl Into<String>,
792        to: impl Into<String>,
793        condition: impl Into<String>,
794    ) -> Self {
795        self.edges.push(AgentEdge::conditional(from, to, condition));
796        self
797    }
798
799    /// 链式连接多个节点
800    ///
801    /// 自动将 start 连接到第一个节点,并将最后一个节点连接到 end
802    pub fn chain<S: Into<String> + Clone>(mut self, node_ids: impl IntoIterator<Item = S>) -> Self {
803        let ids: Vec<String> = node_ids.into_iter().map(|s| s.into()).collect();
804
805        if ids.is_empty() {
806            return self;
807        }
808
809        // 连接 start 到第一个节点
810        self.edges.push(AgentEdge::new("start", &ids[0]));
811
812        // 连接中间节点
813        for i in 0..ids.len() - 1 {
814            self.edges.push(AgentEdge::new(&ids[i], &ids[i + 1]));
815        }
816
817        // 添加 end 节点并连接
818        self.nodes.insert("end".to_string(), AgentNode::end());
819        self.edges.push(AgentEdge::new(ids.last().unwrap(), "end"));
820
821        self
822    }
823
824    /// 配置并行执行
825    ///
826    /// 从 parallel_node 分发到多个 Agent,然后在 join_node 聚合
827    pub fn parallel_agents(
828        mut self,
829        parallel_id: impl Into<String>,
830        agent_ids: Vec<&str>,
831        join_id: impl Into<String>,
832    ) -> Self {
833        let parallel_id = parallel_id.into();
834        let join_id = join_id.into();
835
836        // 添加并行节点
837        self.nodes
838            .insert(parallel_id.clone(), AgentNode::parallel(&parallel_id));
839
840        // 添加聚合节点
841        let wait_for: Vec<String> = agent_ids.iter().map(|s| s.to_string()).collect();
842        self.nodes
843            .insert(join_id.clone(), AgentNode::join(&join_id, wait_for));
844
845        // 连接并行节点到各个 Agent
846        for agent_id in &agent_ids {
847            self.edges.push(AgentEdge::new(&parallel_id, *agent_id));
848            self.edges.push(AgentEdge::new(*agent_id, &join_id));
849        }
850
851        self
852    }
853
854    /// 构建工作流
855    pub fn build(self) -> AgentWorkflow {
856        // 构建邻接表
857        let mut adjacency: HashMap<String, Vec<AgentEdge>> = HashMap::new();
858        for edge in &self.edges {
859            adjacency
860                .entry(edge.from.clone())
861                .or_default()
862                .push(edge.clone());
863        }
864
865        AgentWorkflow {
866            id: self.id,
867            name: self.name,
868            nodes: self.nodes,
869            edges: self.edges,
870            adjacency,
871        }
872    }
873}
874
875// ============================================================================
876// 便捷函数
877// ============================================================================
878
879/// 创建简单的顺序 Agent 工作流
880///
881/// # 示例
882///
883/// ```rust,ignore
884/// let workflow = agent_chain("my-pipeline", vec![
885///     ("researcher", researcher_agent),
886///     ("writer", writer_agent),
887/// ]);
888/// ```
889pub fn agent_chain<S: Into<String>>(
890    id: S,
891    agents: Vec<(impl Into<String>, Arc<LLMAgent>)>,
892) -> AgentWorkflow {
893    let mut builder = AgentWorkflowBuilder::new(id);
894    let mut ids = Vec::new();
895
896    for (agent_id, agent) in agents {
897        let agent_id = agent_id.into();
898        ids.push(agent_id.clone());
899        builder = builder.add_agent(agent_id, agent);
900    }
901
902    builder.chain(ids).build()
903}
904
905/// 创建并行 Agent 工作流
906///
907/// 所有 Agent 同时处理输入,结果合并后返回
908pub fn agent_parallel<S: Into<String>>(
909    id: S,
910    agents: Vec<(impl Into<String>, Arc<LLMAgent>)>,
911) -> AgentWorkflow {
912    let mut builder = AgentWorkflowBuilder::new(id);
913    let mut ids = Vec::new();
914
915    for (agent_id, agent) in agents {
916        let agent_id = agent_id.into();
917        ids.push(agent_id.clone());
918        builder = builder.add_agent(agent_id, agent);
919    }
920
921    let ids_ref: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
922
923    builder
924        .parallel_agents("parallel", ids_ref, "join")
925        .connect("start", "parallel")
926        .connect("join", "end")
927        .build()
928}
929
930/// 创建带路由的 Agent 工作流
931///
932/// 根据路由 Agent 的决策选择执行哪个 Agent
933pub fn agent_router<S: Into<String>>(
934    id: S,
935    router_agent: Arc<LLMAgent>,
936    routes: Vec<(impl Into<String>, Arc<LLMAgent>)>,
937) -> AgentWorkflow {
938    let mut builder = AgentWorkflowBuilder::new(id);
939    let mut route_names = Vec::new();
940
941    for (route_id, agent) in routes {
942        let route_id = route_id.into();
943        route_names.push(route_id.clone());
944        builder = builder.add_agent(&route_id, agent);
945    }
946
947    // 添加 LLM 路由器
948    builder = builder.add_llm_router("router", router_agent, route_names.clone());
949
950    // 连接 start -> router
951    builder = builder.connect("start", "router");
952
953    // 添加 end 节点
954    builder.nodes.insert("end".to_string(), AgentNode::end());
955
956    // 连接 router -> 各个 route -> end
957    for route_name in &route_names {
958        builder = builder.connect_on("router", route_name, route_name);
959        builder = builder.connect(route_name, "end");
960    }
961
962    builder.build()
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968
969    #[test]
970    fn test_agent_value_conversions() {
971        let v: AgentValue = "hello".into();
972        assert_eq!(v.as_text(), Some("hello"));
973
974        let v: AgentValue = "world".to_string().into();
975        assert_eq!(v.as_text(), Some("world"));
976
977        let v: AgentValue = vec!["a".to_string(), "b".to_string()].into();
978        assert_eq!(v.as_texts().map(|v| v.len()), Some(2));
979    }
980
981    #[test]
982    fn test_workflow_builder() {
983        let workflow = AgentWorkflowBuilder::new("test")
984            .with_name("Test Workflow")
985            .add_transform("uppercase", |input: AgentValue| async move {
986                AgentValue::Text(input.into_text().to_uppercase())
987            })
988            .chain(["uppercase"])
989            .build();
990
991        assert_eq!(workflow.node_ids().len(), 3); // start, uppercase, end
992    }
993
994    #[test]
995    fn test_chain_builder() {
996        let workflow = AgentWorkflowBuilder::new("chain-test")
997            .add_transform("step1", |input| async move { input })
998            .add_transform("step2", |input| async move { input })
999            .add_transform("step3", |input| async move { input })
1000            .chain(["step1", "step2", "step3"])
1001            .build();
1002
1003        assert!(workflow.get_node("start").is_some());
1004        assert!(workflow.get_node("end").is_some());
1005        assert!(workflow.get_node("step1").is_some());
1006        assert!(workflow.get_node("step2").is_some());
1007        assert!(workflow.get_node("step3").is_some());
1008    }
1009}