Skip to main content

mofa_runtime/agent/
execution.rs

1//! 执行引擎
2//!
3//! 提供 Agent 执行、工作流编排、错误处理等功能
4
5use crate::agent::context::{AgentContext, AgentEvent};
6use crate::agent::core::MoFAAgent;
7use crate::agent::error::{AgentError, AgentResult};
8use crate::agent::plugins::{PluginExecutor, PluginRegistry, SimplePluginRegistry};
9use crate::agent::registry::AgentRegistry;
10use crate::agent::types::{AgentInput, AgentOutput, AgentState};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::RwLock;
16use tokio::time::timeout;
17
18/// 执行选项
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ExecutionOptions {
21    /// 超时时间 (毫秒)
22    #[serde(default)]
23    pub timeout_ms: Option<u64>,
24
25    /// 是否启用追踪
26    #[serde(default = "default_tracing")]
27    pub tracing_enabled: bool,
28
29    /// 重试次数
30    #[serde(default)]
31    pub max_retries: usize,
32
33    /// 重试延迟 (毫秒)
34    #[serde(default = "default_retry_delay")]
35    pub retry_delay_ms: u64,
36
37    /// 自定义参数
38    #[serde(default)]
39    pub custom: HashMap<String, serde_json::Value>,
40}
41
42fn default_tracing() -> bool {
43    true
44}
45
46fn default_retry_delay() -> u64 {
47    3000
48}
49
50impl Default for ExecutionOptions {
51    fn default() -> Self {
52        Self {
53            timeout_ms: None,
54            tracing_enabled: true,
55            max_retries: 0,
56            retry_delay_ms: 1000,
57            custom: HashMap::new(),
58        }
59    }
60}
61
62impl ExecutionOptions {
63    /// 创建新的执行选项
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// 设置超时
69    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
70        self.timeout_ms = Some(timeout_ms);
71        self
72    }
73
74    /// 设置重试
75    pub fn with_retry(mut self, max_retries: usize, retry_delay_ms: u64) -> Self {
76        self.max_retries = max_retries;
77        self.retry_delay_ms = retry_delay_ms;
78        self
79    }
80
81    /// 禁用追踪
82    pub fn without_tracing(mut self) -> Self {
83        self.tracing_enabled = false;
84        self
85    }
86}
87
88/// 执行状态
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub enum ExecutionStatus {
91    /// 待执行
92    Pending,
93    /// 执行中
94    Running,
95    /// 成功
96    Success,
97    /// 失败
98    Failed,
99    /// 超时
100    Timeout,
101    /// 中断
102    Interrupted,
103    /// 重试中
104    Retrying { attempt: usize },
105}
106
107/// 执行结果
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct ExecutionResult {
110    /// 执行 ID
111    pub execution_id: String,
112    /// Agent ID
113    pub agent_id: String,
114    /// 状态
115    pub status: ExecutionStatus,
116    /// 输出
117    pub output: Option<AgentOutput>,
118    /// 错误信息
119    pub error: Option<String>,
120    /// 执行时间 (毫秒)
121    pub duration_ms: u64,
122    /// 重试次数
123    pub retries: usize,
124    /// 元数据
125    pub metadata: HashMap<String, serde_json::Value>,
126}
127
128impl ExecutionResult {
129    /// 创建成功结果
130    pub fn success(
131        execution_id: String,
132        agent_id: String,
133        output: AgentOutput,
134        duration_ms: u64,
135    ) -> Self {
136        Self {
137            execution_id,
138            agent_id,
139            status: ExecutionStatus::Success,
140            output: Some(output),
141            error: None,
142            duration_ms,
143            retries: 0,
144            metadata: HashMap::new(),
145        }
146    }
147
148    /// 创建失败结果
149    pub fn failure(
150        execution_id: String,
151        agent_id: String,
152        error: String,
153        duration_ms: u64,
154    ) -> Self {
155        Self {
156            execution_id,
157            agent_id,
158            status: ExecutionStatus::Failed,
159            output: None,
160            error: Some(error),
161            duration_ms,
162            retries: 0,
163            metadata: HashMap::new(),
164        }
165    }
166
167    /// 是否成功
168    pub fn is_success(&self) -> bool {
169        self.status == ExecutionStatus::Success
170    }
171
172    /// 添加元数据
173    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
174        self.metadata.insert(key.into(), value);
175        self
176    }
177}
178
179/// 执行引擎
180///
181/// 提供 Agent 执行、工作流编排等功能
182///
183/// # 示例
184///
185/// ```rust,ignore
186/// use mofa_runtime::agent::execution::{ExecutionEngine, ExecutionOptions};
187///
188/// let registry = AgentRegistry::new();
189/// // ... 注册 Agent ...
190///
191/// let engine = ExecutionEngine::new(registry);
192///
193/// let result = engine.execute(
194///     "my-agent",
195///     AgentInput::text("Hello"),
196///     ExecutionOptions::default(),
197/// ).await?;
198///
199/// if result.is_success() {
200///     info!("Output: {:?}", result.output);
201/// }
202/// ```
203pub struct ExecutionEngine {
204    /// Agent 注册中心
205    registry: Arc<AgentRegistry>,
206    /// 插件执行器
207    plugin_executor: PluginExecutor,
208}
209
210impl ExecutionEngine {
211    /// 创建新的执行引擎
212    pub fn new(registry: Arc<AgentRegistry>) -> Self {
213        Self {
214            registry,
215            plugin_executor: PluginExecutor::new(Arc::new(SimplePluginRegistry::new())),
216        }
217    }
218
219    /// 创建带有自定义插件注册中心的执行引擎
220    pub fn with_plugin_registry(
221        registry: Arc<AgentRegistry>,
222        plugin_registry: Arc<dyn PluginRegistry>,
223    ) -> Self {
224        Self {
225            registry,
226            plugin_executor: PluginExecutor::new(plugin_registry),
227        }
228    }
229
230    /// 执行 Agent
231    pub async fn execute(
232        &self,
233        agent_id: &str,
234        input: AgentInput,
235        options: ExecutionOptions,
236    ) -> AgentResult<ExecutionResult> {
237        let execution_id = uuid::Uuid::now_v7().to_string();
238        let start_time = std::time::Instant::now();
239
240        // 获取 Agent
241        let agent = self
242            .registry
243            .get(agent_id)
244            .await
245            .ok_or_else(|| AgentError::NotFound(format!("Agent not found: {}", agent_id)))?;
246
247        // 创建上下文
248        let ctx = AgentContext::new(&execution_id);
249
250        // 发送开始事件
251        if options.tracing_enabled {
252            ctx.emit_event(AgentEvent::new(
253                "execution_started",
254                serde_json::json!({
255                    "agent_id": agent_id,
256                    "execution_id": execution_id,
257                }),
258            ))
259            .await;
260        }
261
262        // 插件执行阶段1: 请求处理前 - 数据处理
263        let processed_input = self
264            .plugin_executor
265            .execute_pre_request(input, &ctx)
266            .await?;
267
268        // 插件执行阶段2: 上下文组装前
269        self.plugin_executor
270            .execute_stage(crate::agent::plugins::PluginStage::PreContext, &ctx)
271            .await?;
272
273        // 执行 (带超时和重试)
274        let result = self
275            .execute_with_options(&agent, processed_input, &ctx, &options)
276            .await;
277
278        let duration_ms = start_time.elapsed().as_millis() as u64;
279
280        // 构建结果
281        let execution_result = match result {
282            Ok(output) => {
283                // 插件执行阶段3: LLM响应后
284                let processed_output = self
285                    .plugin_executor
286                    .execute_post_response(output, &ctx)
287                    .await?;
288
289                // 插件执行阶段4: 整个流程完成后
290                self.plugin_executor
291                    .execute_stage(crate::agent::plugins::PluginStage::PostProcess, &ctx)
292                    .await?;
293
294                if options.tracing_enabled {
295                    ctx.emit_event(AgentEvent::new(
296                        "execution_completed",
297                        serde_json::json!({
298                            "agent_id": agent_id,
299                            "execution_id": execution_id,
300                            "duration_ms": duration_ms,
301                        }),
302                    ))
303                    .await;
304                }
305
306                ExecutionResult::success(
307                    execution_id,
308                    agent_id.to_string(),
309                    processed_output,
310                    duration_ms,
311                )
312            }
313            Err(e) => {
314                let status = match &e {
315                    AgentError::Timeout { .. } => ExecutionStatus::Timeout,
316                    AgentError::Interrupted => ExecutionStatus::Interrupted,
317                    _ => ExecutionStatus::Failed,
318                };
319
320                if options.tracing_enabled {
321                    ctx.emit_event(AgentEvent::new(
322                        "execution_failed",
323                        serde_json::json!({
324                            "agent_id": agent_id,
325                            "execution_id": execution_id,
326                            "error": e.to_string(),
327                            "duration_ms": duration_ms,
328                        }),
329                    ))
330                    .await;
331                }
332
333                ExecutionResult {
334                    execution_id,
335                    agent_id: agent_id.to_string(),
336                    status,
337                    output: None,
338                    error: Some(e.to_string()),
339                    duration_ms,
340                    retries: 0,
341                    metadata: HashMap::new(),
342                }
343            }
344        };
345
346        Ok(execution_result)
347    }
348
349    /// 带选项执行
350    async fn execute_with_options(
351        &self,
352        agent: &Arc<RwLock<dyn MoFAAgent>>,
353        input: AgentInput,
354        ctx: &AgentContext,
355        options: &ExecutionOptions,
356    ) -> AgentResult<AgentOutput> {
357        let mut last_error = None;
358        let max_attempts = options.max_retries + 1;
359
360        for attempt in 0..max_attempts {
361            if attempt > 0 {
362                // 重试延迟
363                tokio::time::sleep(Duration::from_millis(options.retry_delay_ms)).await;
364            }
365
366            let result = self.execute_once(agent, input.clone(), ctx, options).await;
367
368            match result {
369                Ok(output) => return Ok(output),
370                Err(e) => {
371                    // 某些错误不应该重试
372                    if matches!(e, AgentError::Interrupted | AgentError::ConfigError(_)) {
373                        return Err(e);
374                    }
375                    last_error = Some(e);
376                }
377            }
378        }
379
380        Err(last_error.unwrap_or_else(|| AgentError::ExecutionFailed("Unknown error".to_string())))
381    }
382
383    /// 单次执行
384    async fn execute_once(
385        &self,
386        agent: &Arc<RwLock<dyn MoFAAgent>>,
387        input: AgentInput,
388        ctx: &AgentContext,
389        options: &ExecutionOptions,
390    ) -> AgentResult<AgentOutput> {
391        let mut agent_guard = agent.write().await;
392
393        // 确保 Agent 已初始化
394        if agent_guard.state() == AgentState::Created {
395            agent_guard.initialize(ctx).await?;
396        }
397
398        // 检查状态
399        if agent_guard.state() != AgentState::Ready {
400            return Err(AgentError::invalid_state_transition(
401                agent_guard.state(),
402                &AgentState::Executing,
403            ));
404        }
405
406        // 执行 (带超时)
407        if let Some(timeout_ms) = options.timeout_ms {
408            let duration = Duration::from_millis(timeout_ms);
409            match timeout(duration, agent_guard.execute(input, ctx)).await {
410                Ok(result) => result,
411                Err(_) => Err(AgentError::timeout(timeout_ms)),
412            }
413        } else {
414            agent_guard.execute(input, ctx).await
415        }
416    }
417
418    /// 批量执行
419    pub async fn execute_batch(
420        &self,
421        executions: Vec<(String, AgentInput)>,
422        options: ExecutionOptions,
423    ) -> Vec<AgentResult<ExecutionResult>> {
424        let mut results = Vec::new();
425
426        for (agent_id, input) in executions {
427            let result = self.execute(&agent_id, input, options.clone()).await;
428            results.push(result);
429        }
430
431        results
432    }
433
434    /// 并行执行多个 Agent
435    pub async fn execute_parallel(
436        &self,
437        executions: Vec<(String, AgentInput)>,
438        options: ExecutionOptions,
439    ) -> Vec<AgentResult<ExecutionResult>> {
440        let mut handles = Vec::new();
441
442        for (agent_id, input) in executions {
443            let registry = self.registry.clone();
444            let opts = options.clone();
445
446            let handle = tokio::spawn(async move {
447                let engine = ExecutionEngine::new(registry);
448                engine.execute(&agent_id, input, opts).await
449            });
450
451            handles.push(handle);
452        }
453
454        let mut results = Vec::new();
455        for handle in handles {
456            match handle.await {
457                Ok(result) => results.push(result),
458                Err(e) => results.push(Err(AgentError::ExecutionFailed(e.to_string()))),
459            }
460        }
461
462        results
463    }
464
465    /// 中断执行
466    pub async fn interrupt(&self, agent_id: &str) -> AgentResult<()> {
467        let agent = self
468            .registry
469            .get(agent_id)
470            .await
471            .ok_or_else(|| AgentError::NotFound(format!("Agent not found: {}", agent_id)))?;
472
473        let mut agent_guard = agent.write().await;
474        agent_guard.interrupt().await?;
475
476        Ok(())
477    }
478
479    /// 中断所有执行中的 Agent
480    pub async fn interrupt_all(&self) -> AgentResult<Vec<String>> {
481        let executing = self.registry.find_by_state(AgentState::Executing).await;
482
483        let mut interrupted = Vec::new();
484        for metadata in executing {
485            if self.interrupt(&metadata.id).await.is_ok() {
486                interrupted.push(metadata.id);
487            }
488        }
489
490        Ok(interrupted)
491    }
492
493    /// 注册插件
494    pub fn register_plugin(
495        &self,
496        plugin: Arc<dyn crate::agent::plugins::Plugin>,
497    ) -> AgentResult<()> {
498        // 现在 PluginRegistry 支持 &self 注册,因为使用了内部可变性
499        self.plugin_executor.registry.register(plugin)
500    }
501
502    /// 移除插件
503    pub fn unregister_plugin(&self, name: &str) -> AgentResult<bool> {
504        // 现在 PluginRegistry 支持 &self 注销,因为使用了内部可变性
505        self.plugin_executor.registry.unregister(name)
506    }
507
508    /// 列出所有插件
509    pub fn list_plugins(&self) -> Vec<Arc<dyn crate::agent::plugins::Plugin>> {
510        self.plugin_executor.registry.list()
511    }
512
513    /// 插件数量
514    pub fn plugin_count(&self) -> usize {
515        self.plugin_executor.registry.count()
516    }
517}
518
519// ============================================================================
520// 工作流执行
521// ============================================================================
522
523/// 工作流步骤
524#[derive(Debug, Clone, Serialize, Deserialize)]
525pub struct WorkflowStep {
526    /// 步骤 ID
527    pub id: String,
528    /// Agent ID
529    pub agent_id: String,
530    /// 输入转换
531    #[serde(default)]
532    pub input_transform: Option<String>,
533    /// 依赖的步骤
534    #[serde(default)]
535    pub depends_on: Vec<String>,
536}
537
538/// 工作流定义
539#[derive(Debug, Clone, Serialize, Deserialize)]
540pub struct Workflow {
541    /// 工作流 ID
542    pub id: String,
543    /// 工作流名称
544    pub name: String,
545    /// 步骤列表
546    pub steps: Vec<WorkflowStep>,
547}
548
549impl ExecutionEngine {
550    /// 执行工作流
551    pub async fn execute_workflow(
552        &self,
553        workflow: &Workflow,
554        initial_input: AgentInput,
555        options: ExecutionOptions,
556    ) -> AgentResult<HashMap<String, ExecutionResult>> {
557        let mut results: HashMap<String, ExecutionResult> = HashMap::new();
558        let mut completed: Vec<String> = Vec::new();
559
560        // 简单的拓扑排序执行
561        while completed.len() < workflow.steps.len() {
562            let mut executed_any = false;
563
564            for step in &workflow.steps {
565                // 跳过已完成的步骤
566                if completed.contains(&step.id) {
567                    continue;
568                }
569
570                // 检查依赖
571                let deps_satisfied = step.depends_on.iter().all(|dep| completed.contains(dep));
572                if !deps_satisfied {
573                    continue;
574                }
575
576                // 准备输入
577                let input = if step.depends_on.is_empty() {
578                    initial_input.clone()
579                } else {
580                    // 使用前一个步骤的输出作为输入
581                    let prev_step = step.depends_on.last().unwrap();
582                    if let Some(prev_result) = results.get(prev_step) {
583                        if let Some(output) = &prev_result.output {
584                            AgentInput::text(output.to_text())
585                        } else {
586                            initial_input.clone()
587                        }
588                    } else {
589                        initial_input.clone()
590                    }
591                };
592
593                // 执行步骤
594                let result = self.execute(&step.agent_id, input, options.clone()).await?;
595                results.insert(step.id.clone(), result);
596                completed.push(step.id.clone());
597                executed_any = true;
598            }
599
600            if !executed_any && completed.len() < workflow.steps.len() {
601                return Err(AgentError::ExecutionFailed(
602                    "Workflow has circular dependencies".to_string(),
603                ));
604            }
605        }
606
607        Ok(results)
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use crate::agent::capabilities::AgentCapabilities;
615    use crate::agent::context::AgentContext;
616    use crate::agent::core::MoFAAgent;
617    use crate::agent::types::AgentState;
618
619    // 测试用 Agent (内联实现,不依赖 BaseAgent)
620    struct TestAgent {
621        id: String,
622        response: String,
623        capabilities: AgentCapabilities,
624        state: AgentState,
625    }
626
627    impl TestAgent {
628        fn new(id: &str, response: &str) -> Self {
629            Self {
630                id: id.to_string(),
631                response: response.to_string(),
632                capabilities: AgentCapabilities::default(),
633                state: AgentState::Created,
634            }
635        }
636    }
637
638    #[async_trait::async_trait]
639    impl MoFAAgent for TestAgent {
640        fn id(&self) -> &str {
641            &self.id
642        }
643
644        fn name(&self) -> &str {
645            &self.id
646        }
647
648        fn capabilities(&self) -> &AgentCapabilities {
649            &self.capabilities
650        }
651
652        fn state(&self) -> AgentState {
653            self.state.clone()
654        }
655
656        async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
657            self.state = AgentState::Ready;
658            Ok(())
659        }
660
661        async fn execute(
662            &mut self,
663            _input: AgentInput,
664            _ctx: &AgentContext,
665        ) -> AgentResult<AgentOutput> {
666            Ok(AgentOutput::text(&self.response))
667        }
668
669        async fn shutdown(&mut self) -> AgentResult<()> {
670            self.state = AgentState::Shutdown;
671            Ok(())
672        }
673    }
674
675    #[tokio::test]
676    async fn test_execution_engine_basic() {
677        let registry = Arc::new(AgentRegistry::new());
678
679        // 注册测试 Agent
680        let agent = Arc::new(RwLock::new(TestAgent::new("test-agent", "Hello, World!")));
681        registry.register(agent).await.unwrap();
682
683        // 创建引擎并执行
684        let engine = ExecutionEngine::new(registry);
685        let result = engine
686            .execute(
687                "test-agent",
688                AgentInput::text("input"),
689                ExecutionOptions::default(),
690            )
691            .await
692            .unwrap();
693
694        assert!(result.is_success());
695        assert_eq!(result.output.unwrap().to_text(), "Hello, World!");
696    }
697
698    #[tokio::test]
699    async fn test_execution_timeout() {
700        let registry = Arc::new(AgentRegistry::new());
701        let agent = Arc::new(RwLock::new(TestAgent::new("slow-agent", "response")));
702        registry.register(agent).await.unwrap();
703
704        let engine = ExecutionEngine::new(registry);
705        let result = engine
706            .execute(
707                "slow-agent",
708                AgentInput::text("input"),
709                ExecutionOptions::default().with_timeout(1), // 1ms timeout
710            )
711            .await
712            .unwrap();
713
714        // 可能成功也可能超时,取决于执行速度
715        assert!(
716            result.status == ExecutionStatus::Success || result.status == ExecutionStatus::Timeout
717        );
718    }
719
720    #[test]
721    fn test_execution_options() {
722        let options = ExecutionOptions::new()
723            .with_timeout(5000)
724            .with_retry(3, 500)
725            .without_tracing();
726
727        assert_eq!(options.timeout_ms, Some(5000));
728        assert_eq!(options.max_retries, 3);
729        assert_eq!(options.retry_delay_ms, 500);
730        assert!(!options.tracing_enabled);
731    }
732}