1use std::borrow::Cow;
21use std::sync::Arc;
22use std::time::Duration;
23
24use async_trait::async_trait;
25use rmcp::ServiceExt;
26use rmcp::model::{CallToolRequestParams, ClientNotification, CustomNotification};
27use rmcp::service::{Peer, RoleClient, RunningService};
28use rmcp::transport::TokioChildProcess;
29use serde_json::Value;
30use tracing::{debug, error, info, warn};
31
32use astrid_core::HookEvent;
33use astrid_mcp::{AstridClientHandler, CapabilitiesHandler, McpClient, ToolResult};
34
35use crate::context::{PluginContext, PluginToolContext};
36use crate::error::{PluginError, PluginResult};
37use crate::manifest::{PluginEntryPoint, PluginManifest};
38use crate::plugin::{Plugin, PluginId, PluginState};
39use crate::sandbox::SandboxProfile;
40use crate::tool::PluginTool;
41
42const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
44
45type PluginMcpService = RunningService<RoleClient, AstridClientHandler>;
47
48pub struct McpPlugin {
60 id: PluginId,
61 manifest: PluginManifest,
62 state: PluginState,
63 tools: Vec<Box<dyn PluginTool>>,
64 server_name: String,
66 mcp_client: McpClient,
68 service: Option<PluginMcpService>,
70 peer: Option<Peer<RoleClient>>,
72 sandbox: Option<SandboxProfile>,
74}
75
76impl McpPlugin {
77 #[must_use]
82 pub fn new(manifest: PluginManifest, mcp_client: McpClient) -> Self {
83 let id = manifest.id.clone();
84 let server_name = format!("plugin:{id}");
85 Self {
86 id,
87 manifest,
88 state: PluginState::Unloaded,
89 tools: Vec::new(),
90 server_name,
91 mcp_client,
92 service: None,
93 peer: None,
94 sandbox: None,
95 }
96 }
97
98 #[must_use]
100 pub fn with_sandbox(mut self, profile: SandboxProfile) -> Self {
101 self.sandbox = Some(profile);
102 self
103 }
104
105 pub async fn send_hook_event(&self, event: HookEvent, data: Value) {
111 let Some(peer) = &self.peer else {
112 debug!(
113 plugin_id = %self.id,
114 "Cannot send hook event: no peer connection"
115 );
116 return;
117 };
118
119 let notification = CustomNotification::new(
120 "notifications/astrid.hookEvent",
121 Some(serde_json::json!({
122 "event": event.to_string(),
123 "data": data,
124 })),
125 );
126
127 if let Err(e) = peer
128 .send_notification(ClientNotification::CustomNotification(notification))
129 .await
130 {
131 warn!(
132 plugin_id = %self.id,
133 event = %event,
134 error = %e,
135 "Failed to send hook event to plugin MCP server"
136 );
137 }
138 }
139
140 #[must_use]
142 pub fn server_name(&self) -> &str {
143 &self.server_name
144 }
145
146 #[must_use]
148 pub fn mcp_client(&self) -> &McpClient {
149 &self.mcp_client
150 }
151
152 pub fn check_health(&mut self) -> bool {
158 if !matches!(self.state, PluginState::Ready) {
159 return false;
160 }
161
162 let alive = self.service.as_ref().is_some_and(|s| !s.is_closed());
163
164 if !alive {
165 let msg = "MCP server process exited unexpectedly".to_string();
166 warn!(plugin_id = %self.id, "{msg}");
167 self.state = PluginState::Failed(msg);
168 self.peer = None;
169 self.tools.clear();
170 }
171
172 alive
173 }
174
175 #[allow(unsafe_code)]
184 fn build_command(&self) -> PluginResult<tokio::process::Command> {
185 let PluginEntryPoint::Mcp {
186 command,
187 args,
188 env,
189 binary_hash: _,
190 } = &self.manifest.entry_point
191 else {
192 return Err(PluginError::UnsupportedEntryPoint(
193 "expected Mcp entry point".into(),
194 ));
195 };
196
197 let (final_cmd, final_args) = if let Some(sandbox) = &self.sandbox {
199 sandbox.wrap_command(command, args)?
200 } else {
201 (command.clone(), args.clone())
202 };
203
204 let mut cmd = tokio::process::Command::new(&final_cmd);
205 cmd.args(&final_args);
206
207 for (key, value) in env {
208 cmd.env(key, value);
209 }
210
211 #[cfg(target_os = "linux")]
215 if let Some(sandbox) = &self.sandbox {
216 let prepared = prepare_landlock_rules(&sandbox.landlock_rules());
217 let mut prepared = Some(prepared);
218 unsafe {
223 cmd.pre_exec(move || {
224 let rules = prepared.take().ok_or_else(|| {
225 std::io::Error::other("Landlock pre_exec called more than once")
226 })?;
227 enforce_landlock_rules(rules).map_err(|e| {
228 std::io::Error::new(std::io::ErrorKind::PermissionDenied, e.clone())
229 })
230 });
231 }
232 }
233
234 Ok(cmd)
235 }
236
237 fn verify_binary_hash(&self) -> PluginResult<()> {
239 let PluginEntryPoint::Mcp {
240 command,
241 binary_hash: Some(expected),
242 ..
243 } = &self.manifest.entry_point
244 else {
245 return Ok(());
246 };
247
248 let binary_path = which::which(command).map_err(|e| PluginError::McpServerFailed {
250 plugin_id: self.id.clone(),
251 message: format!("Cannot find binary {command}: {e}"),
252 })?;
253
254 let binary_data = std::fs::read(&binary_path)?;
255 let actual_hash = astrid_crypto::ContentHash::hash(&binary_data);
256 let actual_str = format!("sha256:{}", actual_hash.to_hex());
257
258 if expected != &actual_str {
259 return Err(PluginError::McpServerFailed {
260 plugin_id: self.id.clone(),
261 message: format!("Binary hash mismatch: expected {expected}, got {actual_str}"),
262 });
263 }
264
265 Ok(())
266 }
267}
268
269#[async_trait]
270impl Plugin for McpPlugin {
271 fn id(&self) -> &PluginId {
272 &self.id
273 }
274
275 fn manifest(&self) -> &PluginManifest {
276 &self.manifest
277 }
278
279 fn state(&self) -> PluginState {
280 self.state.clone()
281 }
282
283 async fn load(&mut self, _ctx: &PluginContext) -> PluginResult<()> {
284 self.state = PluginState::Loading;
285
286 if let Err(e) = self.verify_binary_hash() {
288 self.state = PluginState::Failed(e.to_string());
289 return Err(e);
290 }
291
292 let cmd = match self.build_command() {
294 Ok(cmd) => cmd,
295 Err(e) => {
296 self.state = PluginState::Failed(e.to_string());
297 return Err(e);
298 },
299 };
300
301 let transport = TokioChildProcess::new(cmd).map_err(|e| {
303 let err = PluginError::McpServerFailed {
304 plugin_id: self.id.clone(),
305 message: format!("Failed to spawn MCP server process: {e}"),
306 };
307 self.state = PluginState::Failed(err.to_string());
308 err
309 })?;
310
311 let handler = Arc::new(CapabilitiesHandler::new());
313 let client_handler = AstridClientHandler::new(&self.server_name, handler);
314
315 let service = client_handler.serve(transport).await.map_err(|e| {
316 let err = PluginError::McpServerFailed {
317 plugin_id: self.id.clone(),
318 message: format!("MCP handshake failed: {e}"),
319 };
320 self.state = PluginState::Failed(err.to_string());
321 err
322 })?;
323
324 let rmcp_tools = service.list_all_tools().await.map_err(|e| {
326 let err = PluginError::McpServerFailed {
327 plugin_id: self.id.clone(),
328 message: format!("Failed to list tools: {e}"),
329 };
330 self.state = PluginState::Failed(err.to_string());
331 err
332 })?;
333
334 let peer = service.peer().clone();
336
337 let tools: Vec<Box<dyn PluginTool>> = rmcp_tools
339 .iter()
340 .map(|t| {
341 let tool: Box<dyn PluginTool> = Box::new(McpPluginTool {
342 name: t.name.to_string(),
343 description: t.description.as_deref().unwrap_or("").to_string(),
344 input_schema: serde_json::to_value(&*t.input_schema)
345 .unwrap_or_else(|_| serde_json::json!({"type": "object"})),
346 server_name: self.server_name.clone(),
347 peer: peer.clone(),
348 });
349 tool
350 })
351 .collect();
352
353 info!(
354 plugin_id = %self.id,
355 server_name = %self.server_name,
356 tool_count = tools.len(),
357 "MCP plugin loaded successfully"
358 );
359
360 self.service = Some(service);
361 self.peer = Some(peer);
362 self.tools = tools;
363 self.state = PluginState::Ready;
364
365 Ok(())
366 }
367
368 async fn unload(&mut self) -> PluginResult<()> {
369 self.state = PluginState::Unloading;
370
371 self.peer = None;
373 self.tools.clear();
374
375 if let Some(ref mut service) = self.service {
377 match service.close_with_timeout(SHUTDOWN_TIMEOUT).await {
378 Ok(Some(reason)) => {
379 info!(
380 plugin_id = %self.id,
381 ?reason,
382 "Plugin MCP session closed gracefully"
383 );
384 },
385 Ok(None) => {
386 warn!(
387 plugin_id = %self.id,
388 "Plugin MCP session close timed out; dropping"
389 );
390 },
391 Err(e) => {
392 error!(
393 plugin_id = %self.id,
394 error = %e,
395 "Plugin MCP session close join error"
396 );
397 },
398 }
399 }
400
401 self.service = None;
402 self.state = PluginState::Unloaded;
403
404 info!(plugin_id = %self.id, "MCP plugin unloaded");
405
406 Ok(())
407 }
408
409 fn tools(&self) -> &[Box<dyn PluginTool>] {
410 &self.tools
411 }
412}
413
414struct McpPluginTool {
420 name: String,
421 description: String,
422 input_schema: Value,
423 #[allow(dead_code)]
424 server_name: String,
425 peer: Peer<RoleClient>,
426}
427
428#[async_trait]
429impl PluginTool for McpPluginTool {
430 fn name(&self) -> &str {
431 &self.name
432 }
433
434 fn description(&self) -> &str {
435 &self.description
436 }
437
438 fn input_schema(&self) -> Value {
439 self.input_schema.clone()
440 }
441
442 async fn execute(&self, args: Value, _ctx: &PluginToolContext) -> PluginResult<String> {
443 let arguments = match args {
444 Value::Object(map) => Some(map),
445 Value::Null => None,
446 other => {
447 let mut map = serde_json::Map::new();
448 map.insert("value".to_string(), other);
449 Some(map)
450 },
451 };
452
453 let params = CallToolRequestParams {
454 meta: None,
455 name: Cow::Owned(self.name.clone()),
456 arguments,
457 task: None,
458 };
459
460 let result = self
461 .peer
462 .call_tool(params)
463 .await
464 .map_err(|e| PluginError::ExecutionFailed(format!("MCP tool call failed: {e}")))?;
465
466 let tool_result = ToolResult::from(result);
468 if tool_result.is_error {
469 return Err(PluginError::ExecutionFailed(
470 tool_result
471 .error
472 .unwrap_or_else(|| "Unknown MCP tool error".into()),
473 ));
474 }
475
476 Ok(tool_result.text_content())
477 }
478}
479
480pub fn create_plugin(
490 manifest: PluginManifest,
491 mcp_client: Option<McpClient>,
492) -> PluginResult<Box<dyn Plugin>> {
493 match &manifest.entry_point {
494 PluginEntryPoint::Wasm { .. } => Err(PluginError::UnsupportedEntryPoint("wasm".into())),
495 PluginEntryPoint::Mcp { .. } => {
496 let client = mcp_client.ok_or(PluginError::McpClientRequired)?;
497 Ok(Box::new(McpPlugin::new(manifest, client)))
498 },
499 }
500}
501
502#[cfg(target_os = "linux")]
508struct PreparedLandlockRules {
509 rules: Vec<(landlock::PathFd, bool, bool)>,
511}
512
513#[cfg(target_os = "linux")]
518fn prepare_landlock_rules(rules: &[crate::sandbox::LandlockPathRule]) -> PreparedLandlockRules {
519 use landlock::PathFd;
520
521 let mut prepared = Vec::with_capacity(rules.len());
522
523 for rule in rules {
524 if !rule.read && !rule.write {
525 continue;
526 }
527
528 if let Ok(fd) = PathFd::new(&rule.path) {
530 prepared.push((fd, rule.read, rule.write));
531 }
532 }
533
534 PreparedLandlockRules { rules: prepared }
535}
536
537#[cfg(target_os = "linux")]
542fn enforce_landlock_rules(prepared: PreparedLandlockRules) -> Result<(), String> {
543 use landlock::{
544 ABI, Access, AccessFs, CompatLevel, Compatible, PathBeneath, Ruleset, RulesetAttr,
545 RulesetCreatedAttr, RulesetStatus,
546 };
547
548 let abi = ABI::V5;
549
550 let mut ruleset = Ruleset::default()
551 .set_compatibility(CompatLevel::BestEffort)
552 .handle_access(AccessFs::from_all(abi))
553 .map_err(|e| format!("failed to create Landlock ruleset: {e}"))?
554 .create()
555 .map_err(|e| format!("failed to create Landlock ruleset: {e}"))?;
556
557 for (fd, read, write) in prepared.rules {
558 let access = match (read, write) {
559 (true, true) => AccessFs::from_all(abi),
560 (true, false) => AccessFs::from_read(abi),
561 (false, true) => AccessFs::from_write(abi),
562 (false, false) => continue,
563 };
564 let path_beneath = PathBeneath::new(fd, access);
565 ruleset = ruleset
566 .add_rule(path_beneath)
567 .map_err(|e| format!("failed to add Landlock rule: {e}"))?;
568 }
569
570 let status = ruleset
571 .restrict_self()
572 .map_err(|e| format!("failed to enforce Landlock ruleset: {e}"))?;
573
574 match status.ruleset {
575 RulesetStatus::FullyEnforced
576 | RulesetStatus::PartiallyEnforced
577 | RulesetStatus::NotEnforced => {
578 },
580 }
581
582 Ok(())
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588 use std::collections::HashMap;
589
590 fn mcp_manifest(id: &str) -> PluginManifest {
591 PluginManifest {
592 id: PluginId::from_static(id),
593 name: format!("Test MCP Plugin {id}"),
594 version: "0.1.0".into(),
595 description: Some("Test MCP plugin".into()),
596 author: None,
597 entry_point: PluginEntryPoint::Mcp {
598 command: "node".into(),
599 args: vec!["dist/index.js".into()],
600 env: HashMap::new(),
601 binary_hash: None,
602 },
603 capabilities: vec![],
604 config: HashMap::new(),
605 }
606 }
607
608 fn wasm_manifest(id: &str) -> PluginManifest {
609 PluginManifest {
610 id: PluginId::from_static(id),
611 name: format!("Test WASM Plugin {id}"),
612 version: "0.1.0".into(),
613 description: None,
614 author: None,
615 entry_point: PluginEntryPoint::Wasm {
616 path: "plugin.wasm".into(),
617 hash: None,
618 },
619 capabilities: vec![],
620 config: HashMap::new(),
621 }
622 }
623
624 fn test_mcp_client() -> McpClient {
625 McpClient::with_config(astrid_mcp::ServersConfig::default())
626 }
627
628 #[tokio::test]
629 async fn test_mcp_plugin_creation() {
630 let manifest = mcp_manifest("test-mcp");
631 let client = test_mcp_client();
632 let plugin = McpPlugin::new(manifest, client);
633
634 assert_eq!(plugin.id().as_str(), "test-mcp");
635 assert_eq!(plugin.state(), PluginState::Unloaded);
636 assert!(plugin.tools().is_empty());
637 assert_eq!(plugin.server_name(), "plugin:test-mcp");
638 }
639
640 #[tokio::test]
641 async fn test_mcp_plugin_with_sandbox() {
642 let manifest = mcp_manifest("test-mcp");
643 let client = test_mcp_client();
644 let sandbox = SandboxProfile::new("/workspace".into(), "/plugins/test".into());
645 let plugin = McpPlugin::new(manifest, client).with_sandbox(sandbox);
646
647 assert!(plugin.sandbox.is_some());
648 }
649
650 #[tokio::test]
651 async fn test_create_plugin_mcp() {
652 let manifest = mcp_manifest("test-mcp");
653 let client = test_mcp_client();
654 let plugin = create_plugin(manifest, Some(client));
655 assert!(plugin.is_ok());
656 }
657
658 #[test]
659 fn test_create_plugin_mcp_requires_client() {
660 let manifest = mcp_manifest("test-mcp");
661 let result = create_plugin(manifest, None);
662 assert!(result.is_err());
663 assert!(matches!(
664 result.unwrap_err(),
665 PluginError::McpClientRequired
666 ));
667 }
668
669 #[test]
670 fn test_create_plugin_wasm_unsupported() {
671 let manifest = wasm_manifest("test-wasm");
672 let result = create_plugin(manifest, None);
673 assert!(result.is_err());
674 assert!(matches!(
675 result.unwrap_err(),
676 PluginError::UnsupportedEntryPoint(_)
677 ));
678 }
679
680 #[tokio::test]
681 async fn test_server_name_format() {
682 let manifest = mcp_manifest("my-cool-plugin");
683 let client = test_mcp_client();
684 let plugin = McpPlugin::new(manifest, client);
685 assert_eq!(plugin.server_name(), "plugin:my-cool-plugin");
686 }
687
688 #[tokio::test]
689 async fn test_health_check_unloaded_returns_false() {
690 let manifest = mcp_manifest("test-health");
691 let client = test_mcp_client();
692 let mut plugin = McpPlugin::new(manifest, client);
693 assert!(!plugin.check_health());
694 }
695}