Skip to main content

astrid_plugins/
mcp_plugin.rs

1//! MCP-backed plugin implementation.
2//!
3//! [`McpPlugin`] wraps an MCP server child process as a [`Plugin`], exposing
4//! the server's tools as [`PluginTool`] instances.  The child process is
5//! spawned and connected during [`Plugin::load()`], and gracefully shut down
6//! during [`Plugin::unload()`].
7//!
8//! Security is enforced at the runtime layer: the
9//! [`SecurityInterceptor`](astrid_approval::SecurityInterceptor) checks
10//! capability tokens for `SensitiveAction::McpToolCall` before the tool's
11//! `execute()` is ever called. OS-level sandboxing is handled by
12//! [`SandboxProfile`](crate::sandbox::SandboxProfile).
13//!
14//! # Hook Event Forwarding
15//!
16//! The runtime can push hook events to the plugin's MCP server via
17//! [`McpPlugin::send_hook_event()`], which sends a custom notification
18//! (`notifications/astrid.hookEvent`) over the MCP connection.
19
20use std::borrow::Cow;
21use std::sync::Arc;
22use std::time::Duration;
23
24use async_trait::async_trait;
25use rmcp::ServiceExt;
26use rmcp::model::{CallToolRequestParams, ClientNotification, CustomNotification};
27use rmcp::service::{Peer, RoleClient, RunningService};
28use rmcp::transport::TokioChildProcess;
29use serde_json::Value;
30use tracing::{debug, error, info, warn};
31
32use astrid_core::HookEvent;
33use astrid_mcp::{AstridClientHandler, CapabilitiesHandler, McpClient, ToolResult};
34
35use crate::context::{PluginContext, PluginToolContext};
36use crate::error::{PluginError, PluginResult};
37use crate::manifest::{PluginEntryPoint, PluginManifest};
38use crate::plugin::{Plugin, PluginId, PluginState};
39use crate::sandbox::SandboxProfile;
40use crate::tool::PluginTool;
41
42/// Timeout for graceful shutdown of plugin MCP servers.
43const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
44
45/// Type alias for the running MCP service backing a plugin.
46type PluginMcpService = RunningService<RoleClient, AstridClientHandler>;
47
48/// A plugin backed by an MCP server child process.
49///
50/// The MCP server is spawned during `load()` and shut down during
51/// `unload()`.  Tool calls are forwarded over the MCP connection via
52/// the stored [`Peer`] handle.
53///
54/// # Dependency Injection
55///
56/// `McpPlugin::new()` receives an [`McpClient`] at construction time.
57/// The caller (plugin manager / runtime) decides which `Plugin` impl to
58/// create based on the manifest's [`PluginEntryPoint`].
59pub struct McpPlugin {
60    id: PluginId,
61    manifest: PluginManifest,
62    state: PluginState,
63    tools: Vec<Box<dyn PluginTool>>,
64    /// MCP server name (format: `"plugin:{plugin_id}"`).
65    server_name: String,
66    /// Injected at construction — used for hook forwarding and lifecycle.
67    mcp_client: McpClient,
68    /// Running MCP service (owns the child process).
69    service: Option<PluginMcpService>,
70    /// Lightweight, cloneable RPC handle for tool calls + notifications.
71    peer: Option<Peer<RoleClient>>,
72    /// Optional sandbox profile applied to the child process.
73    sandbox: Option<SandboxProfile>,
74}
75
76impl McpPlugin {
77    /// Create a new MCP plugin.
78    ///
79    /// The plugin starts in `Unloaded` state. Call [`Plugin::load()`] to
80    /// spawn the MCP server and discover tools.
81    #[must_use]
82    pub fn new(manifest: PluginManifest, mcp_client: McpClient) -> Self {
83        let id = manifest.id.clone();
84        let server_name = format!("plugin:{id}");
85        Self {
86            id,
87            manifest,
88            state: PluginState::Unloaded,
89            tools: Vec::new(),
90            server_name,
91            mcp_client,
92            service: None,
93            peer: None,
94            sandbox: None,
95        }
96    }
97
98    /// Set an OS sandbox profile for the child process.
99    #[must_use]
100    pub fn with_sandbox(mut self, profile: SandboxProfile) -> Self {
101        self.sandbox = Some(profile);
102        self
103    }
104
105    /// Send a hook event notification to the MCP server.
106    ///
107    /// Sends a custom MCP notification with method
108    /// `notifications/astrid.hookEvent`. This is fire-and-forget;
109    /// errors are logged but do not propagate.
110    pub async fn send_hook_event(&self, event: HookEvent, data: Value) {
111        let Some(peer) = &self.peer else {
112            debug!(
113                plugin_id = %self.id,
114                "Cannot send hook event: no peer connection"
115            );
116            return;
117        };
118
119        let notification = CustomNotification::new(
120            "notifications/astrid.hookEvent",
121            Some(serde_json::json!({
122                "event": event.to_string(),
123                "data": data,
124            })),
125        );
126
127        if let Err(e) = peer
128            .send_notification(ClientNotification::CustomNotification(notification))
129            .await
130        {
131            warn!(
132                plugin_id = %self.id,
133                event = %event,
134                error = %e,
135                "Failed to send hook event to plugin MCP server"
136            );
137        }
138    }
139
140    /// Get the MCP server name for this plugin.
141    #[must_use]
142    pub fn server_name(&self) -> &str {
143        &self.server_name
144    }
145
146    /// Get a reference to the injected [`McpClient`].
147    #[must_use]
148    pub fn mcp_client(&self) -> &McpClient {
149        &self.mcp_client
150    }
151
152    /// Check if the MCP server child process is still running.
153    ///
154    /// Returns `true` if the plugin is loaded and the underlying MCP service
155    /// reports it is still alive. If the process has crashed, transitions the
156    /// plugin state to `Failed` and returns `false`.
157    pub fn check_health(&mut self) -> bool {
158        if !matches!(self.state, PluginState::Ready) {
159            return false;
160        }
161
162        let alive = self.service.as_ref().is_some_and(|s| !s.is_closed());
163
164        if !alive {
165            let msg = "MCP server process exited unexpectedly".to_string();
166            warn!(plugin_id = %self.id, "{msg}");
167            self.state = PluginState::Failed(msg);
168            self.peer = None;
169            self.tools.clear();
170        }
171
172        alive
173    }
174
175    /// Build the `tokio::process::Command` from the manifest entry point,
176    /// optionally applying sandbox wrapping.
177    ///
178    /// On Linux with a sandbox profile, this applies Landlock rules via a
179    /// `pre_exec` hook. The `unsafe` is required by POSIX: `pre_exec` runs
180    /// between `fork()` and `exec()` where only async-signal-safe operations
181    /// are permitted. The Landlock syscalls used here are safe in practice
182    /// (they are simple kernel calls that don't allocate or lock).
183    #[allow(unsafe_code)]
184    fn build_command(&self) -> PluginResult<tokio::process::Command> {
185        let PluginEntryPoint::Mcp {
186            command,
187            args,
188            env,
189            binary_hash: _,
190        } = &self.manifest.entry_point
191        else {
192            return Err(PluginError::UnsupportedEntryPoint(
193                "expected Mcp entry point".into(),
194            ));
195        };
196
197        // Optionally wrap with sandbox
198        let (final_cmd, final_args) = if let Some(sandbox) = &self.sandbox {
199            sandbox.wrap_command(command, args)?
200        } else {
201            (command.clone(), args.clone())
202        };
203
204        let mut cmd = tokio::process::Command::new(&final_cmd);
205        cmd.args(&final_args);
206
207        for (key, value) in env {
208            cmd.env(key, value);
209        }
210
211        // On Linux, apply Landlock rules via pre_exec hook.
212        // PathFds are opened HERE (before fork) where heap allocation is safe.
213        // Only raw Landlock syscalls run inside the pre_exec closure.
214        #[cfg(target_os = "linux")]
215        if let Some(sandbox) = &self.sandbox {
216            let prepared = prepare_landlock_rules(&sandbox.landlock_rules());
217            let mut prepared = Some(prepared);
218            // SAFETY: pre_exec runs between fork() and exec(). The closure
219            // only invokes Landlock syscalls (landlock_create_ruleset,
220            // landlock_add_rule, landlock_restrict_self) using pre-opened
221            // file descriptors. No heap allocation occurs inside the closure.
222            unsafe {
223                cmd.pre_exec(move || {
224                    let rules = prepared.take().ok_or_else(|| {
225                        std::io::Error::other("Landlock pre_exec called more than once")
226                    })?;
227                    enforce_landlock_rules(rules).map_err(|e| {
228                        std::io::Error::new(std::io::ErrorKind::PermissionDenied, e.clone())
229                    })
230                });
231            }
232        }
233
234        Ok(cmd)
235    }
236
237    /// Verify the binary hash if configured in the manifest.
238    fn verify_binary_hash(&self) -> PluginResult<()> {
239        let PluginEntryPoint::Mcp {
240            command,
241            binary_hash: Some(expected),
242            ..
243        } = &self.manifest.entry_point
244        else {
245            return Ok(());
246        };
247
248        // Use the same verification logic as ServerConfig
249        let binary_path = which::which(command).map_err(|e| PluginError::McpServerFailed {
250            plugin_id: self.id.clone(),
251            message: format!("Cannot find binary {command}: {e}"),
252        })?;
253
254        let binary_data = std::fs::read(&binary_path)?;
255        let actual_hash = astrid_crypto::ContentHash::hash(&binary_data);
256        let actual_str = format!("sha256:{}", actual_hash.to_hex());
257
258        if expected != &actual_str {
259            return Err(PluginError::McpServerFailed {
260                plugin_id: self.id.clone(),
261                message: format!("Binary hash mismatch: expected {expected}, got {actual_str}"),
262            });
263        }
264
265        Ok(())
266    }
267}
268
269#[async_trait]
270impl Plugin for McpPlugin {
271    fn id(&self) -> &PluginId {
272        &self.id
273    }
274
275    fn manifest(&self) -> &PluginManifest {
276        &self.manifest
277    }
278
279    fn state(&self) -> PluginState {
280        self.state.clone()
281    }
282
283    async fn load(&mut self, _ctx: &PluginContext) -> PluginResult<()> {
284        self.state = PluginState::Loading;
285
286        // 1. Verify binary hash if configured
287        if let Err(e) = self.verify_binary_hash() {
288            self.state = PluginState::Failed(e.to_string());
289            return Err(e);
290        }
291
292        // 2. Build the command
293        let cmd = match self.build_command() {
294            Ok(cmd) => cmd,
295            Err(e) => {
296                self.state = PluginState::Failed(e.to_string());
297                return Err(e);
298            },
299        };
300
301        // 3. Create transport (spawns the child process)
302        let transport = TokioChildProcess::new(cmd).map_err(|e| {
303            let err = PluginError::McpServerFailed {
304                plugin_id: self.id.clone(),
305                message: format!("Failed to spawn MCP server process: {e}"),
306            };
307            self.state = PluginState::Failed(err.to_string());
308            err
309        })?;
310
311        // 4. MCP handshake
312        let handler = Arc::new(CapabilitiesHandler::new());
313        let client_handler = AstridClientHandler::new(&self.server_name, handler);
314
315        let service = client_handler.serve(transport).await.map_err(|e| {
316            let err = PluginError::McpServerFailed {
317                plugin_id: self.id.clone(),
318                message: format!("MCP handshake failed: {e}"),
319            };
320            self.state = PluginState::Failed(err.to_string());
321            err
322        })?;
323
324        // 5. Discover tools
325        let rmcp_tools = service.list_all_tools().await.map_err(|e| {
326            let err = PluginError::McpServerFailed {
327                plugin_id: self.id.clone(),
328                message: format!("Failed to list tools: {e}"),
329            };
330            self.state = PluginState::Failed(err.to_string());
331            err
332        })?;
333
334        // 6. Extract peer handle
335        let peer = service.peer().clone();
336
337        // 7. Create McpPluginTool wrappers
338        let tools: Vec<Box<dyn PluginTool>> = rmcp_tools
339            .iter()
340            .map(|t| {
341                let tool: Box<dyn PluginTool> = Box::new(McpPluginTool {
342                    name: t.name.to_string(),
343                    description: t.description.as_deref().unwrap_or("").to_string(),
344                    input_schema: serde_json::to_value(&*t.input_schema)
345                        .unwrap_or_else(|_| serde_json::json!({"type": "object"})),
346                    server_name: self.server_name.clone(),
347                    peer: peer.clone(),
348                });
349                tool
350            })
351            .collect();
352
353        info!(
354            plugin_id = %self.id,
355            server_name = %self.server_name,
356            tool_count = tools.len(),
357            "MCP plugin loaded successfully"
358        );
359
360        self.service = Some(service);
361        self.peer = Some(peer);
362        self.tools = tools;
363        self.state = PluginState::Ready;
364
365        Ok(())
366    }
367
368    async fn unload(&mut self) -> PluginResult<()> {
369        self.state = PluginState::Unloading;
370
371        // Drop the peer handle first
372        self.peer = None;
373        self.tools.clear();
374
375        // Gracefully close the MCP session
376        if let Some(ref mut service) = self.service {
377            match service.close_with_timeout(SHUTDOWN_TIMEOUT).await {
378                Ok(Some(reason)) => {
379                    info!(
380                        plugin_id = %self.id,
381                        ?reason,
382                        "Plugin MCP session closed gracefully"
383                    );
384                },
385                Ok(None) => {
386                    warn!(
387                        plugin_id = %self.id,
388                        "Plugin MCP session close timed out; dropping"
389                    );
390                },
391                Err(e) => {
392                    error!(
393                        plugin_id = %self.id,
394                        error = %e,
395                        "Plugin MCP session close join error"
396                    );
397                },
398            }
399        }
400
401        self.service = None;
402        self.state = PluginState::Unloaded;
403
404        info!(plugin_id = %self.id, "MCP plugin unloaded");
405
406        Ok(())
407    }
408
409    fn tools(&self) -> &[Box<dyn PluginTool>] {
410        &self.tools
411    }
412}
413
414/// A tool provided by an MCP server, wrapped as a [`PluginTool`].
415///
416/// Tool calls are forwarded directly to the MCP server via the stored
417/// [`Peer`] handle. Security is enforced at the runtime layer (before
418/// `execute()` is called), not here.
419struct McpPluginTool {
420    name: String,
421    description: String,
422    input_schema: Value,
423    #[allow(dead_code)]
424    server_name: String,
425    peer: Peer<RoleClient>,
426}
427
428#[async_trait]
429impl PluginTool for McpPluginTool {
430    fn name(&self) -> &str {
431        &self.name
432    }
433
434    fn description(&self) -> &str {
435        &self.description
436    }
437
438    fn input_schema(&self) -> Value {
439        self.input_schema.clone()
440    }
441
442    async fn execute(&self, args: Value, _ctx: &PluginToolContext) -> PluginResult<String> {
443        let arguments = match args {
444            Value::Object(map) => Some(map),
445            Value::Null => None,
446            other => {
447                let mut map = serde_json::Map::new();
448                map.insert("value".to_string(), other);
449                Some(map)
450            },
451        };
452
453        let params = CallToolRequestParams {
454            meta: None,
455            name: Cow::Owned(self.name.clone()),
456            arguments,
457            task: None,
458        };
459
460        let result = self
461            .peer
462            .call_tool(params)
463            .await
464            .map_err(|e| PluginError::ExecutionFailed(format!("MCP tool call failed: {e}")))?;
465
466        // Convert to our ToolResult and extract text content
467        let tool_result = ToolResult::from(result);
468        if tool_result.is_error {
469            return Err(PluginError::ExecutionFailed(
470                tool_result
471                    .error
472                    .unwrap_or_else(|| "Unknown MCP tool error".into()),
473            ));
474        }
475
476        Ok(tool_result.text_content())
477    }
478}
479
480/// Create a plugin from a manifest, choosing the appropriate implementation
481/// based on the entry point type.
482///
483/// # Errors
484///
485/// - [`PluginError::McpClientRequired`] if the entry point is `Mcp` but
486///   no `McpClient` was provided.
487/// - [`PluginError::UnsupportedEntryPoint`] if the entry point type is
488///   not supported (e.g. `Wasm` — handled by a different subsystem).
489pub fn create_plugin(
490    manifest: PluginManifest,
491    mcp_client: Option<McpClient>,
492) -> PluginResult<Box<dyn Plugin>> {
493    match &manifest.entry_point {
494        PluginEntryPoint::Wasm { .. } => Err(PluginError::UnsupportedEntryPoint("wasm".into())),
495        PluginEntryPoint::Mcp { .. } => {
496            let client = mcp_client.ok_or(PluginError::McpClientRequired)?;
497            Ok(Box::new(McpPlugin::new(manifest, client)))
498        },
499    }
500}
501
502/// A pre-opened Landlock rule ready for enforcement inside `pre_exec`.
503///
504/// File descriptors are opened in the parent process (where allocation is
505/// safe) and consumed inside the `pre_exec` closure (where only
506/// async-signal-safe operations are permitted).
507#[cfg(target_os = "linux")]
508struct PreparedLandlockRules {
509    /// Pre-opened `(PathFd, read, write)` tuples.
510    rules: Vec<(landlock::PathFd, bool, bool)>,
511}
512
513/// Phase 1 (parent process): open file descriptors and compute access flags.
514///
515/// This runs before `fork()`, so heap allocation and filesystem access are
516/// safe. Paths that don't exist are silently skipped.
517#[cfg(target_os = "linux")]
518fn prepare_landlock_rules(rules: &[crate::sandbox::LandlockPathRule]) -> PreparedLandlockRules {
519    use landlock::PathFd;
520
521    let mut prepared = Vec::with_capacity(rules.len());
522
523    for rule in rules {
524        if !rule.read && !rule.write {
525            continue;
526        }
527
528        // Open the path FD now (heap allocation happens here, safely)
529        if let Ok(fd) = PathFd::new(&rule.path) {
530            prepared.push((fd, rule.read, rule.write));
531        }
532    }
533
534    PreparedLandlockRules { rules: prepared }
535}
536
537/// Phase 2 (child process, inside `pre_exec`): create ruleset and enforce.
538///
539/// Only Landlock syscalls are invoked here — no heap allocation, no
540/// filesystem access. All file descriptors were pre-opened in phase 1.
541#[cfg(target_os = "linux")]
542fn enforce_landlock_rules(prepared: PreparedLandlockRules) -> Result<(), String> {
543    use landlock::{
544        ABI, Access, AccessFs, CompatLevel, Compatible, PathBeneath, Ruleset, RulesetAttr,
545        RulesetCreatedAttr, RulesetStatus,
546    };
547
548    let abi = ABI::V5;
549
550    let mut ruleset = Ruleset::default()
551        .set_compatibility(CompatLevel::BestEffort)
552        .handle_access(AccessFs::from_all(abi))
553        .map_err(|e| format!("failed to create Landlock ruleset: {e}"))?
554        .create()
555        .map_err(|e| format!("failed to create Landlock ruleset: {e}"))?;
556
557    for (fd, read, write) in prepared.rules {
558        let access = match (read, write) {
559            (true, true) => AccessFs::from_all(abi),
560            (true, false) => AccessFs::from_read(abi),
561            (false, true) => AccessFs::from_write(abi),
562            (false, false) => continue,
563        };
564        let path_beneath = PathBeneath::new(fd, access);
565        ruleset = ruleset
566            .add_rule(path_beneath)
567            .map_err(|e| format!("failed to add Landlock rule: {e}"))?;
568    }
569
570    let status = ruleset
571        .restrict_self()
572        .map_err(|e| format!("failed to enforce Landlock ruleset: {e}"))?;
573
574    match status.ruleset {
575        RulesetStatus::FullyEnforced
576        | RulesetStatus::PartiallyEnforced
577        | RulesetStatus::NotEnforced => {
578            // NotEnforced: kernel doesn't support Landlock — not a fatal error
579        },
580    }
581
582    Ok(())
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use std::collections::HashMap;
589
590    fn mcp_manifest(id: &str) -> PluginManifest {
591        PluginManifest {
592            id: PluginId::from_static(id),
593            name: format!("Test MCP Plugin {id}"),
594            version: "0.1.0".into(),
595            description: Some("Test MCP plugin".into()),
596            author: None,
597            entry_point: PluginEntryPoint::Mcp {
598                command: "node".into(),
599                args: vec!["dist/index.js".into()],
600                env: HashMap::new(),
601                binary_hash: None,
602            },
603            capabilities: vec![],
604            config: HashMap::new(),
605        }
606    }
607
608    fn wasm_manifest(id: &str) -> PluginManifest {
609        PluginManifest {
610            id: PluginId::from_static(id),
611            name: format!("Test WASM Plugin {id}"),
612            version: "0.1.0".into(),
613            description: None,
614            author: None,
615            entry_point: PluginEntryPoint::Wasm {
616                path: "plugin.wasm".into(),
617                hash: None,
618            },
619            capabilities: vec![],
620            config: HashMap::new(),
621        }
622    }
623
624    fn test_mcp_client() -> McpClient {
625        McpClient::with_config(astrid_mcp::ServersConfig::default())
626    }
627
628    #[tokio::test]
629    async fn test_mcp_plugin_creation() {
630        let manifest = mcp_manifest("test-mcp");
631        let client = test_mcp_client();
632        let plugin = McpPlugin::new(manifest, client);
633
634        assert_eq!(plugin.id().as_str(), "test-mcp");
635        assert_eq!(plugin.state(), PluginState::Unloaded);
636        assert!(plugin.tools().is_empty());
637        assert_eq!(plugin.server_name(), "plugin:test-mcp");
638    }
639
640    #[tokio::test]
641    async fn test_mcp_plugin_with_sandbox() {
642        let manifest = mcp_manifest("test-mcp");
643        let client = test_mcp_client();
644        let sandbox = SandboxProfile::new("/workspace".into(), "/plugins/test".into());
645        let plugin = McpPlugin::new(manifest, client).with_sandbox(sandbox);
646
647        assert!(plugin.sandbox.is_some());
648    }
649
650    #[tokio::test]
651    async fn test_create_plugin_mcp() {
652        let manifest = mcp_manifest("test-mcp");
653        let client = test_mcp_client();
654        let plugin = create_plugin(manifest, Some(client));
655        assert!(plugin.is_ok());
656    }
657
658    #[test]
659    fn test_create_plugin_mcp_requires_client() {
660        let manifest = mcp_manifest("test-mcp");
661        let result = create_plugin(manifest, None);
662        assert!(result.is_err());
663        assert!(matches!(
664            result.unwrap_err(),
665            PluginError::McpClientRequired
666        ));
667    }
668
669    #[test]
670    fn test_create_plugin_wasm_unsupported() {
671        let manifest = wasm_manifest("test-wasm");
672        let result = create_plugin(manifest, None);
673        assert!(result.is_err());
674        assert!(matches!(
675            result.unwrap_err(),
676            PluginError::UnsupportedEntryPoint(_)
677        ));
678    }
679
680    #[tokio::test]
681    async fn test_server_name_format() {
682        let manifest = mcp_manifest("my-cool-plugin");
683        let client = test_mcp_client();
684        let plugin = McpPlugin::new(manifest, client);
685        assert_eq!(plugin.server_name(), "plugin:my-cool-plugin");
686    }
687
688    #[tokio::test]
689    async fn test_health_check_unloaded_returns_false() {
690        let manifest = mcp_manifest("test-health");
691        let client = test_mcp_client();
692        let mut plugin = McpPlugin::new(manifest, client);
693        assert!(!plugin.check_health());
694    }
695}