use crate::agent::context::{AgentContext, AgentEvent};
use crate::agent::core::MoFAAgent;
use crate::agent::error::{AgentError, AgentResult};
use crate::agent::plugins::{PluginExecutor, PluginRegistry, SimplePluginRegistry};
use crate::agent::registry::AgentRegistry;
use crate::agent::types::{AgentInput, AgentOutput, AgentState};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::timeout;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionOptions {
#[serde(default)]
pub timeout_ms: Option<u64>,
#[serde(default = "default_tracing")]
pub tracing_enabled: bool,
#[serde(default)]
pub max_retries: usize,
#[serde(default = "default_retry_delay")]
pub retry_delay_ms: u64,
#[serde(default)]
pub custom: HashMap<String, serde_json::Value>,
}
fn default_tracing() -> bool {
true
}
fn default_retry_delay() -> u64 {
3000
}
impl Default for ExecutionOptions {
fn default() -> Self {
Self {
timeout_ms: None,
tracing_enabled: true,
max_retries: 0,
retry_delay_ms: 1000,
custom: HashMap::new(),
}
}
}
impl ExecutionOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
pub fn with_retry(mut self, max_retries: usize, retry_delay_ms: u64) -> Self {
self.max_retries = max_retries;
self.retry_delay_ms = retry_delay_ms;
self
}
pub fn without_tracing(mut self) -> Self {
self.tracing_enabled = false;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutionStatus {
Pending,
Running,
Success,
Failed,
Timeout,
Interrupted,
Retrying { attempt: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionResult {
pub execution_id: String,
pub agent_id: String,
pub status: ExecutionStatus,
pub output: Option<AgentOutput>,
pub error: Option<String>,
pub duration_ms: u64,
pub retries: usize,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ExecutionResult {
pub fn success(
execution_id: String,
agent_id: String,
output: AgentOutput,
duration_ms: u64,
) -> Self {
Self {
execution_id,
agent_id,
status: ExecutionStatus::Success,
output: Some(output),
error: None,
duration_ms,
retries: 0,
metadata: HashMap::new(),
}
}
pub fn failure(
execution_id: String,
agent_id: String,
error: String,
duration_ms: u64,
) -> Self {
Self {
execution_id,
agent_id,
status: ExecutionStatus::Failed,
output: None,
error: Some(error),
duration_ms,
retries: 0,
metadata: HashMap::new(),
}
}
pub fn is_success(&self) -> bool {
self.status == ExecutionStatus::Success
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
pub struct ExecutionEngine {
registry: Arc<AgentRegistry>,
plugin_executor: PluginExecutor,
}
impl ExecutionEngine {
pub fn new(registry: Arc<AgentRegistry>) -> Self {
Self {
registry,
plugin_executor: PluginExecutor::new(Arc::new(SimplePluginRegistry::new())),
}
}
pub fn with_plugin_registry(
registry: Arc<AgentRegistry>,
plugin_registry: Arc<dyn PluginRegistry>,
) -> Self {
Self {
registry,
plugin_executor: PluginExecutor::new(plugin_registry),
}
}
pub async fn execute(
&self,
agent_id: &str,
input: AgentInput,
options: ExecutionOptions,
) -> AgentResult<ExecutionResult> {
let execution_id = uuid::Uuid::now_v7().to_string();
let start_time = std::time::Instant::now();
let agent = self
.registry
.get(agent_id)
.await
.ok_or_else(|| AgentError::NotFound(format!("Agent not found: {}", agent_id)))?;
let ctx = AgentContext::new(&execution_id);
if options.tracing_enabled {
ctx.emit_event(AgentEvent::new(
"execution_started",
serde_json::json!({
"agent_id": agent_id,
"execution_id": execution_id,
}),
))
.await;
}
let processed_input = self
.plugin_executor
.execute_pre_request(input, &ctx)
.await?;
self.plugin_executor
.execute_stage(crate::agent::plugins::PluginStage::PreContext, &ctx)
.await?;
let result = self
.execute_with_options(&agent, processed_input, &ctx, &options)
.await;
let duration_ms = start_time.elapsed().as_millis() as u64;
let execution_result = match result {
Ok(output) => {
let processed_output = self
.plugin_executor
.execute_post_response(output, &ctx)
.await?;
self.plugin_executor
.execute_stage(crate::agent::plugins::PluginStage::PostProcess, &ctx)
.await?;
if options.tracing_enabled {
ctx.emit_event(AgentEvent::new(
"execution_completed",
serde_json::json!({
"agent_id": agent_id,
"execution_id": execution_id,
"duration_ms": duration_ms,
}),
))
.await;
}
ExecutionResult::success(
execution_id,
agent_id.to_string(),
processed_output,
duration_ms,
)
}
Err(e) => {
let status = match &e {
AgentError::Timeout { .. } => ExecutionStatus::Timeout,
AgentError::Interrupted => ExecutionStatus::Interrupted,
_ => ExecutionStatus::Failed,
};
if options.tracing_enabled {
ctx.emit_event(AgentEvent::new(
"execution_failed",
serde_json::json!({
"agent_id": agent_id,
"execution_id": execution_id,
"error": e.to_string(),
"duration_ms": duration_ms,
}),
))
.await;
}
ExecutionResult {
execution_id,
agent_id: agent_id.to_string(),
status,
output: None,
error: Some(e.to_string()),
duration_ms,
retries: 0,
metadata: HashMap::new(),
}
}
};
Ok(execution_result)
}
async fn execute_with_options(
&self,
agent: &Arc<RwLock<dyn MoFAAgent>>,
input: AgentInput,
ctx: &AgentContext,
options: &ExecutionOptions,
) -> AgentResult<AgentOutput> {
let mut last_error = None;
let max_attempts = options.max_retries + 1;
for attempt in 0..max_attempts {
if attempt > 0 {
tokio::time::sleep(Duration::from_millis(options.retry_delay_ms)).await;
}
let result = self.execute_once(agent, input.clone(), ctx, options).await;
match result {
Ok(output) => return Ok(output),
Err(e) => {
if matches!(e, AgentError::Interrupted | AgentError::ConfigError(_)) {
return Err(e);
}
last_error = Some(e);
}
}
}
Err(last_error.unwrap_or_else(|| AgentError::ExecutionFailed("Unknown error".to_string())))
}
async fn execute_once(
&self,
agent: &Arc<RwLock<dyn MoFAAgent>>,
input: AgentInput,
ctx: &AgentContext,
options: &ExecutionOptions,
) -> AgentResult<AgentOutput> {
let mut agent_guard = agent.write().await;
if agent_guard.state() == AgentState::Created {
agent_guard.initialize(ctx).await?;
}
if agent_guard.state() != AgentState::Ready {
return Err(AgentError::invalid_state_transition(
agent_guard.state(),
&AgentState::Executing,
));
}
if let Some(timeout_ms) = options.timeout_ms {
let duration = Duration::from_millis(timeout_ms);
match timeout(duration, agent_guard.execute(input, ctx)).await {
Ok(result) => result,
Err(_) => Err(AgentError::timeout(timeout_ms)),
}
} else {
agent_guard.execute(input, ctx).await
}
}
pub async fn execute_batch(
&self,
executions: Vec<(String, AgentInput)>,
options: ExecutionOptions,
) -> Vec<AgentResult<ExecutionResult>> {
let mut results = Vec::new();
for (agent_id, input) in executions {
let result = self.execute(&agent_id, input, options.clone()).await;
results.push(result);
}
results
}
pub async fn execute_parallel(
&self,
executions: Vec<(String, AgentInput)>,
options: ExecutionOptions,
) -> Vec<AgentResult<ExecutionResult>> {
let mut handles = Vec::new();
for (agent_id, input) in executions {
let registry = self.registry.clone();
let opts = options.clone();
let handle = tokio::spawn(async move {
let engine = ExecutionEngine::new(registry);
engine.execute(&agent_id, input, opts).await
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(e) => results.push(Err(AgentError::ExecutionFailed(e.to_string()))),
}
}
results
}
pub async fn interrupt(&self, agent_id: &str) -> AgentResult<()> {
let agent = self
.registry
.get(agent_id)
.await
.ok_or_else(|| AgentError::NotFound(format!("Agent not found: {}", agent_id)))?;
let mut agent_guard = agent.write().await;
agent_guard.interrupt().await?;
Ok(())
}
pub async fn interrupt_all(&self) -> AgentResult<Vec<String>> {
let executing = self.registry.find_by_state(AgentState::Executing).await;
let mut interrupted = Vec::new();
for metadata in executing {
if self.interrupt(&metadata.id).await.is_ok() {
interrupted.push(metadata.id);
}
}
Ok(interrupted)
}
pub fn register_plugin(
&self,
plugin: Arc<dyn crate::agent::plugins::Plugin>,
) -> AgentResult<()> {
self.plugin_executor.registry.register(plugin)
}
pub fn unregister_plugin(&self, name: &str) -> AgentResult<bool> {
self.plugin_executor.registry.unregister(name)
}
pub fn list_plugins(&self) -> Vec<Arc<dyn crate::agent::plugins::Plugin>> {
self.plugin_executor.registry.list()
}
pub fn plugin_count(&self) -> usize {
self.plugin_executor.registry.count()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStep {
pub id: String,
pub agent_id: String,
#[serde(default)]
pub input_transform: Option<String>,
#[serde(default)]
pub depends_on: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Workflow {
pub id: String,
pub name: String,
pub steps: Vec<WorkflowStep>,
}
impl ExecutionEngine {
pub async fn execute_workflow(
&self,
workflow: &Workflow,
initial_input: AgentInput,
options: ExecutionOptions,
) -> AgentResult<HashMap<String, ExecutionResult>> {
let mut results: HashMap<String, ExecutionResult> = HashMap::new();
let mut completed: Vec<String> = Vec::new();
while completed.len() < workflow.steps.len() {
let mut executed_any = false;
for step in &workflow.steps {
if completed.contains(&step.id) {
continue;
}
let deps_satisfied = step.depends_on.iter().all(|dep| completed.contains(dep));
if !deps_satisfied {
continue;
}
let input = if step.depends_on.is_empty() {
initial_input.clone()
} else {
let prev_step = step.depends_on.last().unwrap();
if let Some(prev_result) = results.get(prev_step) {
if let Some(output) = &prev_result.output {
AgentInput::text(output.to_text())
} else {
initial_input.clone()
}
} else {
initial_input.clone()
}
};
let result = self.execute(&step.agent_id, input, options.clone()).await?;
results.insert(step.id.clone(), result);
completed.push(step.id.clone());
executed_any = true;
}
if !executed_any && completed.len() < workflow.steps.len() {
return Err(AgentError::ExecutionFailed(
"Workflow has circular dependencies".to_string(),
));
}
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::capabilities::AgentCapabilities;
use crate::agent::context::AgentContext;
use crate::agent::core::MoFAAgent;
use crate::agent::types::AgentState;
struct TestAgent {
id: String,
response: String,
capabilities: AgentCapabilities,
state: AgentState,
}
impl TestAgent {
fn new(id: &str, response: &str) -> Self {
Self {
id: id.to_string(),
response: response.to_string(),
capabilities: AgentCapabilities::default(),
state: AgentState::Created,
}
}
}
#[async_trait::async_trait]
impl MoFAAgent for TestAgent {
fn id(&self) -> &str {
&self.id
}
fn name(&self) -> &str {
&self.id
}
fn capabilities(&self) -> &AgentCapabilities {
&self.capabilities
}
fn state(&self) -> AgentState {
self.state.clone()
}
async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
self.state = AgentState::Ready;
Ok(())
}
async fn execute(
&mut self,
_input: AgentInput,
_ctx: &AgentContext,
) -> AgentResult<AgentOutput> {
Ok(AgentOutput::text(&self.response))
}
async fn shutdown(&mut self) -> AgentResult<()> {
self.state = AgentState::Shutdown;
Ok(())
}
}
#[tokio::test]
async fn test_execution_engine_basic() {
let registry = Arc::new(AgentRegistry::new());
let agent = Arc::new(RwLock::new(TestAgent::new("test-agent", "Hello, World!")));
registry.register(agent).await.unwrap();
let engine = ExecutionEngine::new(registry);
let result = engine
.execute(
"test-agent",
AgentInput::text("input"),
ExecutionOptions::default(),
)
.await
.unwrap();
assert!(result.is_success());
assert_eq!(result.output.unwrap().to_text(), "Hello, World!");
}
#[tokio::test]
async fn test_execution_timeout() {
let registry = Arc::new(AgentRegistry::new());
let agent = Arc::new(RwLock::new(TestAgent::new("slow-agent", "response")));
registry.register(agent).await.unwrap();
let engine = ExecutionEngine::new(registry);
let result = engine
.execute(
"slow-agent",
AgentInput::text("input"),
ExecutionOptions::default().with_timeout(1), )
.await
.unwrap();
assert!(
result.status == ExecutionStatus::Success || result.status == ExecutionStatus::Timeout
);
}
#[test]
fn test_execution_options() {
let options = ExecutionOptions::new()
.with_timeout(5000)
.with_retry(3, 500)
.without_tracing();
assert_eq!(options.timeout_ms, Some(5000));
assert_eq!(options.max_retries, 3);
assert_eq!(options.retry_delay_ms, 500);
assert!(!options.tracing_enabled);
}
}