enact-mcp 0.0.2

MCP (Model Context Protocol) client for Enact
Documentation
//! MCP Configuration
//!
//! Configuration loading for the MCP client using the unified config resolution.

use anyhow::Result;
use serde::{Deserialize, Serialize};

/// MCP client configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpClientConfig {
    /// Protocol version for MCP communication
    pub protocol_version: String,
    /// Client name
    pub name: String,
    /// Client version
    pub version: String,
}

impl Default for McpClientConfig {
    fn default() -> Self {
        Self {
            protocol_version: "2024-11-05".to_string(),
            name: "enact-mcp".to_string(),
            version: "0.1.0".to_string(),
        }
    }
}

/// MCP server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
    /// Server name/identifier
    pub name: String,
    /// Transport type: "stdio" (default) or "http"
    #[serde(default = "default_transport")]
    pub transport: String,
    /// Command to run the server
    #[serde(default)]
    pub command: String,
    /// Arguments for the server command
    #[serde(default)]
    pub args: Vec<String>,
    /// HTTP endpoint URL (required when transport=http)
    #[serde(default)]
    pub url: Option<String>,
    /// Environment variables for the server
    #[serde(default)]
    pub env: std::collections::HashMap<String, String>,
}

fn default_transport() -> String {
    "stdio".to_string()
}

impl Default for McpServerConfig {
    fn default() -> Self {
        Self {
            name: String::new(),
            transport: default_transport(),
            command: String::new(),
            args: Vec::new(),
            url: None,
            env: std::collections::HashMap::new(),
        }
    }
}

/// Main MCP configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct McpConfig {
    /// Client configuration
    #[serde(default)]
    pub client: McpClientConfig,
    /// List of MCP servers
    #[serde(default)]
    pub servers: Vec<McpServerConfig>,
}

/// Load the default MCP configuration using unified config resolution.
///
/// Resolution order:
/// 1. `ENACT_MCP_CONFIG_PATH` environment variable
/// 2. `./mcp.yaml` in current working directory
/// 3. `~/.enact/mcp.yaml`
/// 4. Hardcoded defaults if no file found
///
/// # Example
///
/// ```rust,no_run
/// use enact_mcp::config::load_default_mcp_config;
///
/// let config = load_default_mcp_config().expect("Failed to load MCP config");
/// println!("Protocol version: {}", config.client.protocol_version);
/// ```
pub fn load_default_mcp_config() -> Result<McpConfig> {
    match enact_config::resolve_config_file("mcp.yaml", "ENACT_MCP_CONFIG_PATH") {
        Some(path) => {
            let content = std::fs::read_to_string(&path)
                .map_err(|e| anyhow::anyhow!("Failed to read MCP config from {:?}: {}", path, e))?;
            let config: McpConfig = serde_yaml::from_str(&content).map_err(|e| {
                anyhow::anyhow!("Failed to parse MCP config from {:?}: {}", path, e)
            })?;
            Ok(config)
        }
        None => {
            // No config file found, use hardcoded defaults
            Ok(McpConfig::default())
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Write;
    use tempfile::TempDir;

    #[test]
    fn test_default_config() {
        let config = McpConfig::default();
        assert_eq!(config.client.protocol_version, "2024-11-05");
        assert_eq!(config.client.name, "enact-mcp");
        assert_eq!(config.client.version, "0.1.0");
        assert!(config.servers.is_empty());
    }

    #[test]
    fn test_load_default_mcp_config_falls_back_to_defaults() {
        // Ensure no env var is set and we're not in a directory with mcp.yaml
        std::env::remove_var("ENACT_MCP_CONFIG_PATH");
        // This should fall back to defaults (assuming no mcp.yaml in cwd or ~/.enact)
        let config = load_default_mcp_config().unwrap();
        assert_eq!(config.client.name, "enact-mcp");
    }

    #[test]
    fn test_load_config_from_env_var() {
        let temp_dir = TempDir::new().unwrap();
        let config_path = temp_dir.path().join("custom_mcp.yaml");

        let yaml_content = r#"
client:
  protocol_version: "2025-01-01"
  name: "custom-mcp"
  version: "2.0.0"
servers:
  - name: "test-server"
    command: "/usr/bin/test-mcp"
    args: ["--port", "8080"]
"#;

        let mut file = std::fs::File::create(&config_path).unwrap();
        file.write_all(yaml_content.as_bytes()).unwrap();

        std::env::set_var("ENACT_MCP_CONFIG_PATH", config_path.to_str().unwrap());
        let config = load_default_mcp_config().unwrap();
        std::env::remove_var("ENACT_MCP_CONFIG_PATH");

        assert_eq!(config.client.protocol_version, "2025-01-01");
        assert_eq!(config.client.name, "custom-mcp");
        assert_eq!(config.client.version, "2.0.0");
        assert_eq!(config.servers.len(), 1);
        assert_eq!(config.servers[0].name, "test-server");
        assert_eq!(config.servers[0].command, "/usr/bin/test-mcp");
        assert_eq!(config.servers[0].args, vec!["--port", "8080"]);
    }

    #[test]
    fn test_config_serialization() {
        let config = McpConfig {
            client: McpClientConfig {
                protocol_version: "2024-11-05".to_string(),
                name: "test".to_string(),
                version: "1.0.0".to_string(),
            },
            servers: vec![McpServerConfig {
                name: "server1".to_string(),
                transport: "stdio".to_string(),
                command: "mcp-server".to_string(),
                args: vec!["--flag".to_string()],
                url: None,
                env: [("KEY".to_string(), "value".to_string())]
                    .into_iter()
                    .collect(),
            }],
        };

        let yaml = serde_yaml::to_string(&config).unwrap();
        let deserialized: McpConfig = serde_yaml::from_str(&yaml).unwrap();

        assert_eq!(deserialized.client.name, "test");
        assert_eq!(deserialized.servers[0].name, "server1");
    }
}