use async_trait::async_trait;
use std::path::PathBuf;
use tracing::{info, warn};
use super::ExecutionEngine;
use crate::context::CapsuleContext;
use crate::error::{CapsuleError, CapsuleResult};
use crate::manifest::{CapsuleManifest, McpServerDef};
use astrid_mcp::SecureMcpClient;
pub struct McpHostEngine {
manifest: CapsuleManifest,
server_def: McpServerDef,
capsule_dir: PathBuf,
mcp_client: SecureMcpClient,
}
impl McpHostEngine {
pub fn new(
manifest: CapsuleManifest,
server_def: McpServerDef,
capsule_dir: PathBuf,
mcp_client: SecureMcpClient,
) -> Self {
Self {
manifest,
server_def,
capsule_dir,
mcp_client,
}
}
}
#[async_trait]
impl ExecutionEngine for McpHostEngine {
async fn load(&mut self, ctx: &CapsuleContext) -> CapsuleResult<()> {
let original_command_str = self
.server_def
.command
.as_ref()
.ok_or_else(|| {
CapsuleError::UnsupportedEntryPoint("MCP server requires a 'command' field".into())
})?
.clone();
let is_granted = self.manifest.capabilities.host_process.iter().any(|cmd| {
original_command_str == *cmd || original_command_str.starts_with(&format!("{cmd} "))
});
if !is_granted {
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"Security Check Failed: host_process capability for '{}' was not declared in the manifest.",
original_command_str
)));
}
let mut command_str = original_command_str.clone();
let is_absolute_system_binary = std::path::Path::new(&command_str).is_absolute();
let local_cmd_path = self.capsule_dir.join(&command_str);
if !is_absolute_system_binary
&& let Ok(canonical_cmd) = local_cmd_path.canonicalize()
&& let Ok(canonical_capsule_dir) = self.capsule_dir.canonicalize()
{
if !canonical_cmd.starts_with(&canonical_capsule_dir) {
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"Path traversal detected: command '{}' escapes the capsule directory.",
command_str
)));
}
if canonical_cmd.is_dir() {
let host_triple = env!("TARGET"); let arch_slice = canonical_cmd.join(host_triple);
if arch_slice.is_file() {
if let Ok(canon_slice) = arch_slice.canonicalize() {
if !canon_slice.starts_with(&canonical_capsule_dir) {
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"Fat binary slice '{}' resolves outside the capsule boundary.",
host_triple
)));
}
info!(
"Fat binary resolved: using {} slice for {}",
host_triple, command_str
);
command_str = canon_slice.to_string_lossy().to_string();
} else {
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"Failed to resolve fat binary slice for the current architecture: {}",
host_triple
)));
}
} else {
return Err(CapsuleError::UnsupportedEntryPoint(format!(
"Fat binary directory '{}' does not contain a valid slice for the current architecture: {}",
command_str, host_triple
)));
}
} else if canonical_cmd.is_file() {
command_str = canonical_cmd.to_string_lossy().to_string();
}
}
info!(
capsule = %self.manifest.package.name,
original_command = %original_command_str,
resolved_command = %command_str,
"Registering legacy MCP host process dynamically (Airlock Override)"
);
let resolved_env = super::resolve_env(&self.manifest, ctx, &[], "mcp_host_engine").await?;
let server_id = format!("capsule:{}", self.manifest.package.name);
let allow_network =
!self.manifest.capabilities.net.is_empty() || self.manifest.capabilities.uplink;
let config = astrid_mcp::ServerConfig {
name: server_id.clone(),
command: Some(command_str),
args: self.server_def.args.clone(),
env: resolved_env,
cwd: Some(self.capsule_dir.clone()),
restart_policy: astrid_mcp::RestartPolicy::Always, allow_network,
..Default::default()
};
self.mcp_client
.connect_dynamic(&server_id, config)
.await
.map_err(|e| {
CapsuleError::UnsupportedEntryPoint(format!(
"Failed to connect MCP host engine: {e}"
))
})?;
Ok(())
}
async fn unload(&mut self) -> CapsuleResult<()> {
info!(
capsule = %self.manifest.package.name,
"Shutting down MCP host process"
);
let server_id = format!("capsule:{}", self.manifest.package.name);
let _ = self.mcp_client.disconnect(&server_id).await;
Ok(())
}
fn check_health(&self) -> crate::capsule::CapsuleState {
let server_id = format!("capsule:{}", self.manifest.package.name);
debug_assert!(
tokio::runtime::Handle::try_current()
.map(|h| h.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread)
.unwrap_or(false),
"check_health() with block_in_place requires multi-threaded tokio runtime"
);
let is_alive = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let health = self
.mcp_client
.inner()
.server_manager()
.health_check()
.await;
health.get(&server_id).copied().unwrap_or(false)
})
});
if is_alive {
crate::capsule::CapsuleState::Ready
} else {
crate::capsule::CapsuleState::Failed(format!(
"MCP server '{server_id}' is no longer running"
))
}
}
async fn invoke_interceptor(
&self,
action: &str,
payload: &[u8],
_caller: Option<&astrid_events::ipc::IpcMessage>,
) -> CapsuleResult<crate::capsule::InterceptResult> {
let server_id = format!("capsule:{}", self.manifest.package.name);
let params: serde_json::Value = serde_json::from_slice(payload).map_err(|e| {
CapsuleError::ExecutionFailed(format!("failed to deserialize interceptor payload: {e}"))
})?;
let tool_args = serde_json::json!({
"hook": action,
"payload": params,
});
let client = self.mcp_client.inner().clone();
let result = client
.call_tool(&server_id, "astrid_hook_intercept", tool_args)
.await;
match result {
Ok(tool_result) => {
let text = tool_result
.content
.iter()
.filter_map(|c| {
if let astrid_mcp::ToolContent::Text { text } = c {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("");
if text.is_empty() || text == "null" {
Ok(crate::capsule::InterceptResult::Continue(Vec::new()))
} else {
Ok(crate::capsule::InterceptResult::Continue(text.into_bytes()))
}
},
Err(e) => {
warn!(
capsule = %self.manifest.package.name,
hook = %action,
error = %e,
"Failed to invoke hook interceptor on MCP capsule"
);
Ok(crate::capsule::InterceptResult::Continue(Vec::new()))
},
}
}
}