#[cfg(feature = "engine")]
use std::sync::Arc;
#[cfg(feature = "engine")]
use std::time::Duration;
#[cfg(feature = "engine")]
use async_trait::async_trait;
#[cfg(feature = "engine")]
use cortexai_agents::{AgentEngine, ExecutionTrace};
#[cfg(feature = "engine")]
use cortexai_core::tool::ToolRegistry;
#[cfg(feature = "engine")]
use cortexai_core::message::{Content, Message};
use cortexai_core::types::{AgentConfig, AgentId, AgentRole, AgentStatus};
#[cfg(feature = "engine")]
use cortexai_providers::LLMBackend;
#[cfg(feature = "engine")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "engine")]
use serde_json::json;
#[cfg(feature = "engine")]
use tracing::{debug, info};
#[cfg(feature = "engine")]
use crate::error::McpError;
#[cfg(feature = "engine")]
use crate::protocol::{CallToolResult, McpTool, ToolContent};
#[cfg(feature = "engine")]
use crate::server::ToolHandler;
#[cfg(feature = "engine")]
const DEFAULT_TIMEOUT_SECS: u64 = 60;
#[cfg(feature = "engine")]
const DEFAULT_MAX_ITERATIONS: usize = 10;
#[cfg(feature = "engine")]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunAgentInput {
pub agent_name: String,
pub system_prompt: String,
pub message: String,
#[serde(default)]
pub tools: Vec<String>,
#[serde(default)]
pub max_iterations: Option<usize>,
#[serde(default)]
pub model: Option<String>,
#[serde(default)]
pub trace: bool,
}
#[cfg(feature = "engine")]
pub struct EngineHandler {
engine: Arc<AgentEngine>,
backend: Arc<dyn LLMBackend>,
tool_registry: Arc<ToolRegistry>,
timeout: Duration,
}
#[cfg(feature = "engine")]
impl EngineHandler {
pub fn new(
engine: Arc<AgentEngine>,
backend: Arc<dyn LLMBackend>,
tool_registry: Arc<ToolRegistry>,
) -> Self {
Self {
engine,
backend,
tool_registry,
timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
fn build_agent_config(input: &RunAgentInput) -> AgentConfig {
let max_iter = input.max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS);
AgentConfig::new(&input.agent_name, AgentRole::Executor)
.with_system_prompt(&input.system_prompt)
.with_max_iterations(max_iter)
.with_timeout(DEFAULT_TIMEOUT_SECS)
}
async fn wait_for_response(
&self,
agent_id: &AgentId,
timeout: Duration,
) -> Result<String, McpError> {
let poll_interval = Duration::from_millis(50);
let start = std::time::Instant::now();
let runtime = self
.engine
.get_agent(agent_id)
.ok_or_else(|| McpError::Internal("Agent disappeared".to_string()))?;
let mut saw_busy = false;
loop {
if start.elapsed() > timeout {
return Err(McpError::Timeout);
}
let state = runtime.state.read().await;
let is_idle = matches!(
state.status,
AgentStatus::Idle | AgentStatus::StoppedByStopWord
);
let is_busy = !is_idle;
drop(state);
if is_busy {
saw_busy = true;
}
if is_idle {
if let Ok(history) = runtime.memory.get_history().await {
for msg in history.iter().rev() {
if msg.from == *agent_id {
if let Content::Text(text) = &msg.content {
return Ok(text.clone());
}
}
}
}
if saw_busy {
return Err(McpError::Internal(
"Agent finished but produced no text response".to_string(),
));
}
}
tokio::time::sleep(poll_interval).await;
}
}
}
#[cfg(feature = "engine")]
#[async_trait]
impl ToolHandler for EngineHandler {
fn definition(&self) -> McpTool {
McpTool {
name: "run_agent".to_string(),
description: Some(
"Run a Cortex agent with a given system prompt and message. \
Creates a temporary agent, executes the task, and returns the result."
.to_string(),
),
input_schema: json!({
"type": "object",
"properties": {
"agent_name": {
"type": "string",
"description": "Name for the agent"
},
"system_prompt": {
"type": "string",
"description": "Agent system prompt / instructions"
},
"message": {
"type": "string",
"description": "The task or query to send to the agent"
},
"tools": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of tool names to enable"
},
"max_iterations": {
"type": "integer",
"description": "Max ReACT loop iterations (default 10)"
},
"model": {
"type": "string",
"description": "Optional model override"
},
"trace": {
"type": "boolean",
"description": "When true, include an execution trace in the response (default false)"
}
},
"required": ["agent_name", "system_prompt", "message"]
}),
}
}
async fn execute(&self, arguments: serde_json::Value) -> Result<CallToolResult, McpError> {
let input: RunAgentInput = serde_json::from_value(arguments)
.map_err(|e| McpError::InvalidParams(format!("Invalid input: {}", e)))?;
info!(agent_name = %input.agent_name, trace = input.trace, "Running agent via MCP");
let trace_collector = if input.trace {
Some(Arc::new(ExecutionTrace::new()))
} else {
None
};
let config = Self::build_agent_config(&input);
let spawned_id = self
.engine
.spawn_agent(config, self.tool_registry.clone(), self.backend.clone())
.await
.map_err(|e| McpError::Internal(format!("Failed to spawn agent: {}", e)))?;
let message = Message::new(
AgentId::new("mcp-caller"),
spawned_id.clone(),
Content::Text(input.message.clone()),
);
self.engine
.send_message(message)
.map_err(|e| McpError::Internal(format!("Failed to send message: {}", e)))?;
let response = self.wait_for_response(&spawned_id, self.timeout).await;
let _ = self.engine.stop_agent(&spawned_id).await;
match response {
Ok(text) => {
debug!(agent_name = %input.agent_name, "Agent completed successfully");
let mut content = vec![ToolContent::text(text.clone())];
if let Some(trace) = trace_collector {
let finalized = trace.finalize(text);
content.push(ToolContent::text(
serde_json::to_string_pretty(&finalized.to_json())
.unwrap_or_else(|_| "{}".to_string()),
));
}
Ok(CallToolResult {
content,
is_error: false,
})
}
Err(e) => Ok(CallToolResult {
content: vec![ToolContent::text(format!("Agent execution error: {}", e))],
is_error: true,
}),
}
}
}
#[cfg(all(test, feature = "engine"))]
mod tests {
use std::sync::Arc;
use cortexai_agents::AgentEngine;
use cortexai_core::tool::ToolRegistry;
use cortexai_providers::{LLMBackend, MockBackend, MockResponse};
use serde_json::json;
use crate::server::ToolHandler;
use super::*;
#[tokio::test]
async fn test_engine_handler_definition() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("Hello")));
let registry = Arc::new(ToolRegistry::new());
let handler = EngineHandler::new(engine, backend, registry);
let def = handler.definition();
assert_eq!(def.name, "run_agent");
assert!(def.description.is_some());
let schema = &def.input_schema;
assert!(schema["properties"]["agent_name"].is_object());
assert!(schema["properties"]["system_prompt"].is_object());
assert!(schema["properties"]["message"].is_object());
}
#[tokio::test]
async fn test_engine_handler_runs_agent_and_returns_response() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("Agent response here")));
let registry = Arc::new(ToolRegistry::new());
let handler = EngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(10));
let result = handler
.execute(json!({
"agent_name": "test-agent",
"system_prompt": "You are a helpful assistant",
"message": "Say hello"
}))
.await
.unwrap();
assert!(!result.is_error);
let text = result.content[0].as_text().unwrap();
assert!(text.contains("Agent response here"));
assert_eq!(engine.agent_count(), 0);
}
#[tokio::test]
async fn test_engine_handler_trace_false_returns_normal_response() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("Normal response")));
let registry = Arc::new(ToolRegistry::new());
let handler = EngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(10));
let result = handler
.execute(json!({
"agent_name": "test-agent",
"system_prompt": "You are helpful",
"message": "Hello",
"trace": false
}))
.await
.unwrap();
assert!(!result.is_error);
assert_eq!(result.content.len(), 1);
let text = result.content[0].as_text().unwrap();
assert!(text.contains("Normal response"));
}
#[tokio::test]
async fn test_engine_handler_trace_true_returns_response_and_trace() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("Traced response")));
let registry = Arc::new(ToolRegistry::new());
let handler = EngineHandler::new(engine.clone(), backend, registry)
.with_timeout(Duration::from_secs(10));
let result = handler
.execute(json!({
"agent_name": "trace-agent",
"system_prompt": "You are helpful",
"message": "Hello",
"trace": true
}))
.await
.unwrap();
assert!(!result.is_error);
assert_eq!(result.content.len(), 2, "Expected 2 content blocks (response + trace)");
let text = result.content[0].as_text().unwrap();
assert!(text.contains("Traced response"));
let trace_text = result.content[1].as_text().unwrap();
let trace_json: serde_json::Value = serde_json::from_str(trace_text)
.expect("trace content should be valid JSON");
assert!(trace_json["trace_id"].is_string());
assert!(trace_json["tool_calls"].is_array());
assert!(trace_json["llm_calls"].is_array());
assert!(trace_json["total_duration_ms"].is_u64());
}
#[tokio::test]
async fn test_engine_handler_missing_required_field() {
let engine = Arc::new(AgentEngine::new());
let backend: Arc<dyn LLMBackend> =
Arc::new(MockBackend::new().with_response(MockResponse::text("Hello")));
let registry = Arc::new(ToolRegistry::new());
let handler = EngineHandler::new(engine, backend, registry);
let result = handler
.execute(json!({
"agent_name": "test"
}))
.await;
assert!(result.is_err());
}
}