1use crate::error::AgentError;
10use crate::mcp::McpConnection;
11use crate::types::{McpConnectionStatus, McpServerConfig, McpSseConfig, McpStdioConfig};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::path::Path;
15use std::process::Stdio;
16use std::sync::Arc;
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::process::Command;
19use tokio::sync::RwLock;
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23#[serde(rename_all = "lowercase")]
24pub enum PluginMcpTransport {
25 Stdio,
26 Sse,
27 Http,
28 #[serde(other)]
29 Unknown,
30}
31
32impl Default for PluginMcpTransport {
33 fn default() -> Self {
34 PluginMcpTransport::Stdio
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "lowercase")]
41pub enum PluginMcpServerStatus {
42 Stopped,
44 Starting,
46 Running,
48 Error,
50 Disabled,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56#[serde(rename_all = "camelCase")]
57pub struct PluginMcpServerConfig {
58 pub transport_type: Option<PluginMcpTransport>,
60 pub command: Option<String>,
62 pub args: Option<Vec<String>>,
64 pub env: Option<HashMap<String, String>>,
66 pub url: Option<String>,
68 pub headers: Option<HashMap<String, String>>,
70 pub scope: Option<String>,
72 pub plugin_source: Option<String>,
74}
75
76impl PluginMcpServerConfig {
77 pub fn to_mcp_config(&self) -> Option<McpServerConfig> {
79 let transport = self
80 .transport_type
81 .as_ref()
82 .unwrap_or(&PluginMcpTransport::Stdio);
83
84 match transport {
85 PluginMcpTransport::Stdio => {
86 let command = self.command.as_ref()?;
87 Some(McpServerConfig::Stdio(McpStdioConfig {
88 transport_type: Some("stdio".to_string()),
89 command: command.clone(),
90 args: self.args.clone(),
91 env: self.env.clone(),
92 }))
93 }
94 PluginMcpTransport::Sse => {
95 let url = self.url.as_ref()?;
96 Some(McpServerConfig::Sse(McpSseConfig {
97 transport_type: "sse".to_string(),
98 url: url.clone(),
99 headers: self.headers.clone(),
100 }))
101 }
102 PluginMcpTransport::Http => {
103 let url = self.url.as_ref()?;
104 Some(McpServerConfig::Http(crate::types::McpHttpConfig {
105 transport_type: "http".to_string(),
106 url: url.clone(),
107 headers: self.headers.clone(),
108 }))
109 }
110 PluginMcpTransport::Unknown => None,
111 }
112 }
113}
114
115#[derive(Debug)]
117pub struct PluginMcpServer {
118 pub name: String,
120 pub config: PluginMcpServerConfig,
122 pub status: PluginMcpServerStatus,
124 child: Option<tokio::process::Child>,
126 connection: Option<McpConnection>,
128 plugin_path: String,
130 _plugin_source: String,
132}
133
134impl PluginMcpServer {
135 pub fn new(
137 name: String,
138 config: PluginMcpServerConfig,
139 plugin_path: String,
140 plugin_source: String,
141 ) -> Self {
142 Self {
143 name,
144 config,
145 status: PluginMcpServerStatus::Stopped,
146 child: None,
147 connection: None,
148 plugin_path,
149 _plugin_source: plugin_source,
150 }
151 }
152
153 pub async fn start(&mut self) -> Result<(), AgentError> {
155 if self.status == PluginMcpServerStatus::Running {
156 return Ok(());
157 }
158
159 self.status = PluginMcpServerStatus::Starting;
160
161 let mcp_config = self.config.to_mcp_config().ok_or_else(|| {
162 AgentError::Mcp(format!("Invalid MCP config for server {}", self.name))
163 })?;
164
165 let resolved_config = self.resolve_environment(&mcp_config);
167
168 match resolved_config {
169 McpServerConfig::Stdio(stdio_config) => {
170 self.start_stdio(stdio_config).await?;
171 }
172 McpServerConfig::Sse(sse_config) => {
173 self.start_sse(sse_config).await?;
174 }
175 McpServerConfig::Http(http_config) => {
176 self.start_http(http_config).await?;
177 }
178 }
179
180 self.status = PluginMcpServerStatus::Running;
181 Ok(())
182 }
183
184 async fn start_stdio(&mut self, config: McpStdioConfig) -> Result<(), AgentError> {
186 let mut env_vars: HashMap<String, String> = std::env::vars().collect();
187
188 env_vars.insert("AI_PLUGIN_ROOT".to_string(), self.plugin_path.clone());
190
191 if let Some(custom_env) = &config.env {
193 for (key, value) in custom_env {
194 env_vars.insert(key.clone(), value.clone());
195 }
196 }
197
198 let command = config.command.clone();
199 let args = config.args.unwrap_or_default();
200
201 let mut child = Command::new(&command)
202 .args(&args)
203 .envs(&env_vars)
204 .kill_on_drop(true)
205 .stdout(Stdio::piped())
206 .stderr(Stdio::piped())
207 .stdin(Stdio::piped())
208 .spawn()
209 .map_err(|e| {
210 AgentError::Mcp(format!("Failed to spawn MCP server '{}': {}", command, e))
211 })?;
212
213 let stdout = child
214 .stdout
215 .take()
216 .ok_or_else(|| AgentError::Mcp("Failed to take stdout from MCP server".to_string()))?;
217
218 let mut stdin = child
219 .stdin
220 .take()
221 .ok_or_else(|| AgentError::Mcp("Failed to take stdin from MCP server".to_string()))?;
222
223 let mut stdout_reader = BufReader::new(stdout).lines();
224
225 let initialize_request = serde_json::json!({
227 "jsonrpc": "2.0",
228 "id": 1,
229 "method": "initialize",
230 "params": {
231 "protocolVersion": "2024-11-05",
232 "capabilities": {},
233 "clientInfo": {
234 "name": format!("agent-sdk-plugin-{}", self.name),
235 "version": "1.0.0"
236 }
237 }
238 });
239
240 stdin
241 .write_all(format!("{initialize_request}\n").as_bytes())
242 .await
243 .map_err(|e| {
244 AgentError::Io(std::io::Error::new(
245 std::io::ErrorKind::Other,
246 e.to_string(),
247 ))
248 })?;
249 stdin.flush().await.map_err(|e| {
250 AgentError::Io(std::io::Error::new(
251 std::io::ErrorKind::Other,
252 e.to_string(),
253 ))
254 })?;
255
256 let _ = stdout_reader.next_line().await;
258
259 let list_tools_request = serde_json::json!({
261 "jsonrpc": "2.0",
262 "id": 2,
263 "method": "tools/list"
264 });
265
266 stdin
267 .write_all(format!("{list_tools_request}\n").as_bytes())
268 .await
269 .map_err(|e| {
270 AgentError::Io(std::io::Error::new(
271 std::io::ErrorKind::Other,
272 e.to_string(),
273 ))
274 })?;
275 stdin.flush().await.map_err(|e| {
276 AgentError::Io(std::io::Error::new(
277 std::io::ErrorKind::Other,
278 e.to_string(),
279 ))
280 })?;
281
282 let mut tools = vec![];
284 if let Ok(Some(response)) = stdout_reader.next_line().await {
285 if let Ok(resp) = serde_json::from_str::<serde_json::Value>(&response) {
286 if let Some(result) = resp.get("result") {
287 if let Some(tools_array) = result.get("tools").and_then(|t| t.as_array()) {
288 for tool_val in tools_array {
289 if let Ok(mcp_tool) =
290 serde_json::from_value::<crate::types::McpTool>(tool_val.clone())
291 {
292 let tool_def = create_mcp_tool_definition(&self.name, &mcp_tool);
293 tools.push(tool_def);
294 }
295 }
296 }
297 }
298 }
299 }
300
301 drop(stdin);
303
304 self.child = Some(child);
305 self.connection = Some(McpConnection {
306 name: self.name.clone(),
307 status: McpConnectionStatus::Connected,
308 tools,
309 });
310
311 Ok(())
312 }
313
314 async fn start_sse(&mut self, _config: McpSseConfig) -> Result<(), AgentError> {
316 self.connection = Some(McpConnection {
319 name: self.name.clone(),
320 status: McpConnectionStatus::Connected,
321 tools: vec![],
322 });
323 Ok(())
324 }
325
326 async fn start_http(&mut self, _config: crate::types::McpHttpConfig) -> Result<(), AgentError> {
328 self.connection = Some(McpConnection {
331 name: self.name.clone(),
332 status: McpConnectionStatus::Connected,
333 tools: vec![],
334 });
335 Ok(())
336 }
337
338 pub async fn stop(&mut self) -> Result<(), AgentError> {
340 if self.status == PluginMcpServerStatus::Stopped {
341 return Ok(());
342 }
343
344 if let Some(mut conn) = self.connection.take() {
346 conn.close().await;
347 }
348
349 if let Some(mut child) = self.child.take() {
351 let _ = child.kill().await;
352 }
353
354 self.status = PluginMcpServerStatus::Stopped;
355 Ok(())
356 }
357
358 pub fn is_running(&self) -> bool {
360 self.status == PluginMcpServerStatus::Running
361 }
362
363 pub fn get_status(&self) -> &PluginMcpServerStatus {
365 &self.status
366 }
367
368 pub fn get_connection(&self) -> Option<&McpConnection> {
370 self.connection.as_ref()
371 }
372
373 fn resolve_environment(&self, config: &McpServerConfig) -> McpServerConfig {
375 match config {
376 McpServerConfig::Stdio(stdio_config) => {
377 let mut resolved_env = std::env::vars().collect::<HashMap<_, _>>();
378
379 resolved_env.insert("AI_PLUGIN_ROOT".to_string(), self.plugin_path.clone());
381
382 if let Some(custom_env) = &stdio_config.env {
383 for (key, value) in custom_env {
384 let resolved = self.substitute_variables(value);
385 resolved_env.insert(key.clone(), resolved);
386 }
387 }
388
389 McpServerConfig::Stdio(McpStdioConfig {
390 transport_type: stdio_config.transport_type.clone(),
391 command: self.substitute_variables(&stdio_config.command),
392 args: stdio_config
393 .args
394 .as_ref()
395 .map(|args| args.iter().map(|a| self.substitute_variables(a)).collect()),
396 env: Some(resolved_env),
397 })
398 }
399 McpServerConfig::Sse(sse_config) => {
400 let resolved_url = self.substitute_variables(&sse_config.url);
401 let resolved_headers = sse_config.headers.as_ref().map(|headers| {
402 headers
403 .iter()
404 .map(|(k, v)| (k.clone(), self.substitute_variables(v)))
405 .collect()
406 });
407
408 McpServerConfig::Sse(McpSseConfig {
409 transport_type: sse_config.transport_type.clone(),
410 url: resolved_url,
411 headers: resolved_headers,
412 })
413 }
414 McpServerConfig::Http(http_config) => {
415 let resolved_url = self.substitute_variables(&http_config.url);
416 let resolved_headers = http_config.headers.as_ref().map(|headers| {
417 headers
418 .iter()
419 .map(|(k, v)| (k.clone(), self.substitute_variables(v)))
420 .collect()
421 });
422
423 McpServerConfig::Http(crate::types::McpHttpConfig {
424 transport_type: http_config.transport_type.clone(),
425 url: resolved_url,
426 headers: resolved_headers,
427 })
428 }
429 }
430 }
431
432 fn substitute_variables(&self, value: &str) -> String {
434 let mut result = value.to_string();
435
436 result = result.replace("${AI_PLUGIN_ROOT}", &self.plugin_path);
438 result = result.replace("$AI_PLUGIN_ROOT", &self.plugin_path);
439
440 for (key, val) in std::env::vars() {
442 let pattern = format!("${{{}}}", key);
443 let pattern_dollar = format!("${}", key);
444 result = result.replace(&pattern, &val);
445 result = result.replace(&pattern_dollar, &val);
446 }
447
448 result
449 }
450}
451
452fn create_mcp_tool_definition(
454 server_name: &str,
455 mcp_tool: &crate::types::McpTool,
456) -> crate::types::ToolDefinition {
457 let tool_name = format!("mcp__{}__{}", server_name, mcp_tool.name);
458
459 let input_schema = mcp_tool.input_schema.clone().unwrap_or_else(|| {
460 serde_json::json!({
461 "type": "object",
462 "properties": {}
463 })
464 });
465
466 crate::types::ToolDefinition {
467 name: tool_name,
468 description: mcp_tool
469 .description
470 .clone()
471 .unwrap_or_else(|| format!("MCP tool: {}", mcp_tool.name)),
472 input_schema: crate::types::ToolInputSchema {
473 schema_type: input_schema
474 .get("type")
475 .and_then(|t| t.as_str())
476 .unwrap_or("object")
477 .to_string(),
478 properties: input_schema
479 .get("properties")
480 .cloned()
481 .unwrap_or(serde_json::json!({})),
482 required: input_schema
483 .get("required")
484 .and_then(|r| r.as_array())
485 .map(|arr| {
486 arr.iter()
487 .filter_map(|s| s.as_str().map(String::from))
488 .collect()
489 }),
490 },
491 annotations: None,
492 should_defer: None,
493 always_load: None,
494 is_mcp: None,
495 search_hint: None,
496 aliases: None,
497 user_facing_name: None,
498 interrupt_behavior: None,
499 }
500}
501
502pub struct PluginMcpServerManager {
504 servers: RwLock<HashMap<String, Arc<RwLock<PluginMcpServer>>>>,
506}
507
508impl Default for PluginMcpServerManager {
509 fn default() -> Self {
510 Self::new()
511 }
512}
513
514impl PluginMcpServerManager {
515 pub fn new() -> Self {
517 Self {
518 servers: RwLock::new(HashMap::new()),
519 }
520 }
521
522 pub async fn add_server(&self, server: PluginMcpServer) {
524 let name = server.name.clone();
525 let server = Arc::new(RwLock::new(server));
526 self.servers.write().await.insert(name, server);
527 }
528
529 pub async fn get_server(&self, name: &str) -> Option<Arc<RwLock<PluginMcpServer>>> {
531 self.servers.read().await.get(name).cloned()
532 }
533
534 pub async fn remove_server(&self, name: &str) {
536 if let Some(server) = self.servers.write().await.remove(name) {
537 let mut server = server.write().await;
538 let _ = server.stop().await;
539 }
540 }
541
542 pub async fn start_server(&self, name: &str) -> Result<(), AgentError> {
544 if let Some(server) = self.servers.read().await.get(name) {
545 let mut server = server.write().await;
546 server.start().await
547 } else {
548 Err(AgentError::Mcp(format!("Server '{}' not found", name)))
549 }
550 }
551
552 pub async fn stop_server(&self, name: &str) -> Result<(), AgentError> {
554 if let Some(server) = self.servers.read().await.get(name) {
555 let mut server = server.write().await;
556 server.stop().await
557 } else {
558 Err(AgentError::Mcp(format!("Server '{}' not found", name)))
559 }
560 }
561
562 pub async fn start_all(&self) -> Vec<(String, Result<(), AgentError>)> {
564 let mut results = Vec::new();
565 let servers = self.servers.read().await;
566
567 for (name, server) in servers.iter() {
568 let mut server = server.write().await;
569 results.push((name.clone(), server.start().await));
570 }
571
572 results
573 }
574
575 pub async fn stop_all(&self) {
577 let mut servers = self.servers.write().await;
578
579 for (_, server) in servers.iter() {
580 let mut server = server.write().await;
581 let _ = server.stop().await;
582 }
583
584 servers.clear();
585 }
586
587 pub async fn list_servers(&self) -> Vec<String> {
589 self.servers.read().await.keys().cloned().collect()
590 }
591
592 pub async fn get_all_status(&self) -> HashMap<String, PluginMcpServerStatus> {
594 let servers = self.servers.read().await;
595 let mut result = HashMap::new();
596
597 for (name, server) in servers.iter() {
598 let status = server.read().await.status.clone();
599 result.insert(name.clone(), status);
600 }
601
602 result
603 }
604}
605
606pub async fn load_mcp_servers_from_file(
608 plugin_path: &str,
609 filename: &str,
610) -> Result<HashMap<String, PluginMcpServerConfig>, AgentError> {
611 let path = Path::new(plugin_path).join(filename);
612
613 if !path.exists() {
614 return Ok(HashMap::new());
615 }
616
617 let content = tokio::fs::read_to_string(&path).await.map_err(|e| {
618 AgentError::Io(std::io::Error::new(
619 std::io::ErrorKind::Other,
620 format!("Failed to read MCP config from {}: {}", path.display(), e),
621 ))
622 })?;
623
624 let parsed: serde_json::Value = serde_json::from_str(&content)
625 .map_err(|e| AgentError::Mcp(format!("Failed to parse MCP config: {}", e)))?;
626
627 let mcp_servers = if let Some(servers) = parsed.get("mcpServers") {
629 servers.clone()
630 } else {
631 parsed
632 };
633
634 let mut configs = HashMap::new();
635
636 if let Some(obj) = mcp_servers.as_object() {
637 for (name, config_val) in obj {
638 let config = parse_mcp_server_config(config_val);
639 if config.is_some() {
640 configs.insert(name.clone(), config.unwrap());
641 }
642 }
643 }
644
645 Ok(configs)
646}
647
648fn parse_mcp_server_config(value: &serde_json::Value) -> Option<PluginMcpServerConfig> {
650 let obj = value.as_object()?;
651
652 let transport_type = obj
654 .get("type")
655 .and_then(|t| t.as_str())
656 .map(|t| match t {
657 "stdio" => PluginMcpTransport::Stdio,
658 "sse" => PluginMcpTransport::Sse,
659 "http" => PluginMcpTransport::Http,
660 _ => PluginMcpTransport::Unknown,
661 })
662 .unwrap_or(PluginMcpTransport::Stdio);
663
664 let command = obj
666 .get("command")
667 .and_then(|v| v.as_str())
668 .map(String::from);
669 let args = obj.get("args").and_then(|v| v.as_array()).map(|arr| {
670 arr.iter()
671 .filter_map(|s| s.as_str().map(String::from))
672 .collect()
673 });
674
675 let env = obj.get("env").and_then(|v| v.as_object()).map(|obj| {
676 obj.iter()
677 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
678 .collect()
679 });
680
681 let url = obj.get("url").and_then(|v| v.as_str()).map(String::from);
683 let headers = obj.get("headers").and_then(|v| v.as_object()).map(|obj| {
684 obj.iter()
685 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
686 .collect()
687 });
688
689 Some(PluginMcpServerConfig {
690 transport_type: Some(transport_type),
691 command,
692 args,
693 env,
694 url,
695 headers,
696 scope: None,
697 plugin_source: None,
698 })
699}
700
701pub async fn load_plugin_mcp_servers(
703 plugin_path: &str,
704 mcp_servers_spec: &serde_json::Value,
705) -> Result<HashMap<String, PluginMcpServerConfig>, AgentError> {
706 let mut servers = HashMap::new();
707
708 match mcp_servers_spec {
709 serde_json::Value::String(path) => {
711 if path.ends_with(".mcpb") {
712 eprintln!("MCPB file loading not yet implemented: {}", path);
715 } else {
716 let loaded = load_mcp_servers_from_file(plugin_path, path).await?;
718 servers.extend(loaded);
719 }
720 }
721 serde_json::Value::Array(arr) => {
723 for spec in arr {
724 match spec {
725 serde_json::Value::String(path) => {
726 if path.ends_with(".mcpb") {
727 eprintln!("MCPB file loading not yet implemented: {}", path);
728 } else {
729 let loaded = load_mcp_servers_from_file(plugin_path, path).await?;
730 servers.extend(loaded);
731 }
732 }
733 _ => {
734 if let Some(config) = parse_mcp_server_config(spec) {
736 let name = format!("inline_{}", servers.len());
738 servers.insert(name, config);
739 }
740 }
741 }
742 }
743 }
744 serde_json::Value::Object(_) => {
746 if let Some(config) = parse_mcp_server_config(mcp_servers_spec) {
747 let name = format!("inline_{}", servers.len());
748 servers.insert(name, config);
749 }
750 }
751 _ => {}
752 }
753
754 Ok(servers)
755}
756
757pub fn add_plugin_scope_to_servers(
759 servers: HashMap<String, PluginMcpServerConfig>,
760 plugin_name: &str,
761 plugin_source: &str,
762) -> HashMap<String, PluginMcpServerConfig> {
763 servers
764 .into_iter()
765 .map(|(name, mut config)| {
766 let scoped_name = format!("plugin:{}:{}", plugin_name, name);
767 config.plugin_source = Some(plugin_source.to_string());
768 (scoped_name, config)
769 })
770 .collect()
771}
772
773#[cfg(test)]
774mod tests {
775 use super::*;
776
777 #[test]
778 fn test_transport_type_parsing() {
779 let json = serde_json::json!({
780 "type": "stdio",
781 "command": "npx",
782 "args": ["-y", "some-server"]
783 });
784
785 let config = parse_mcp_server_config(&json).unwrap();
786 assert_eq!(config.transport_type, Some(PluginMcpTransport::Stdio));
787 assert_eq!(config.command, Some("npx".to_string()));
788 }
789
790 #[test]
791 fn test_sse_config_parsing() {
792 let json = serde_json::json!({
793 "type": "sse",
794 "url": "http://localhost:3000/sse"
795 });
796
797 let config = parse_mcp_server_config(&json).unwrap();
798 assert_eq!(config.transport_type, Some(PluginMcpTransport::Sse));
799 assert_eq!(config.url, Some("http://localhost:3000/sse".to_string()));
800 }
801
802 #[test]
803 fn test_server_status() {
804 let server = PluginMcpServer::new(
805 "test".to_string(),
806 PluginMcpServerConfig {
807 transport_type: Some(PluginMcpTransport::Stdio),
808 command: Some("echo".to_string()),
809 args: None,
810 env: None,
811 url: None,
812 headers: None,
813 scope: None,
814 plugin_source: None,
815 },
816 "/tmp/plugin".to_string(),
817 "test-plugin".to_string(),
818 );
819
820 assert_eq!(server.get_status(), &PluginMcpServerStatus::Stopped);
821 assert!(!server.is_running());
822 }
823
824 #[test]
825 fn test_manager() {
826 let manager = PluginMcpServerManager::new();
827
828 let server = PluginMcpServer::new(
829 "test".to_string(),
830 PluginMcpServerConfig {
831 transport_type: Some(PluginMcpTransport::Stdio),
832 command: Some("echo".to_string()),
833 args: None,
834 env: None,
835 url: None,
836 headers: None,
837 scope: None,
838 plugin_source: None,
839 },
840 "/tmp/plugin".to_string(),
841 "test-plugin".to_string(),
842 );
843
844 let runtime = tokio::runtime::Runtime::new().unwrap();
845 runtime.block_on(async {
846 manager.add_server(server).await;
847 let servers = manager.list_servers().await;
848 assert_eq!(servers.len(), 1);
849 assert!(servers.contains(&"test".to_string()));
850 });
851 }
852}