use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TransportType {
#[default]
Stdio,
Sse,
HttpStream,
WebSocket,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportConfig {
pub transport_type: TransportType,
pub command: Option<String>,
pub args: Vec<String>,
pub url: Option<String>,
pub headers: HashMap<String, String>,
pub env: HashMap<String, String>,
pub timeout: u32,
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
transport_type: TransportType::Stdio,
command: None,
args: Vec::new(),
url: None,
headers: HashMap::new(),
env: HashMap::new(),
timeout: 30,
}
}
}
impl TransportConfig {
pub fn stdio(command: impl Into<String>, args: &[&str]) -> Self {
Self {
transport_type: TransportType::Stdio,
command: Some(command.into()),
args: args.iter().map(|s| s.to_string()).collect(),
..Default::default()
}
}
pub fn http(url: impl Into<String>) -> Self {
Self {
transport_type: TransportType::HttpStream,
url: Some(url.into()),
..Default::default()
}
}
pub fn websocket(url: impl Into<String>) -> Self {
Self {
transport_type: TransportType::WebSocket,
url: Some(url.into()),
..Default::default()
}
}
pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.env.insert(key.into(), value.into());
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn timeout(mut self, timeout: u32) -> Self {
self.timeout = timeout;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub allow_fs: bool,
pub allow_network: bool,
pub allow_env: bool,
pub allowed_hosts: Vec<String>,
pub allowed_paths: Vec<String>,
}
impl Default for SecurityConfig {
fn default() -> Self {
Self {
allow_fs: false,
allow_network: false,
allow_env: false,
allowed_hosts: Vec::new(),
allowed_paths: Vec::new(),
}
}
}
impl SecurityConfig {
pub fn permissive() -> Self {
Self {
allow_fs: true,
allow_network: true,
allow_env: true,
allowed_hosts: vec!["*".to_string()],
allowed_paths: vec!["*".to_string()],
}
}
pub fn restrictive() -> Self {
Self::default()
}
pub fn allow_fs(mut self, allow: bool) -> Self {
self.allow_fs = allow;
self
}
pub fn allow_network(mut self, allow: bool) -> Self {
self.allow_network = allow;
self
}
pub fn allowed_host(mut self, host: impl Into<String>) -> Self {
self.allowed_hosts.push(host.into());
self
}
pub fn allowed_path(mut self, path: impl Into<String>) -> Self {
self.allowed_paths.push(path.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPConfig {
pub name: String,
pub transport: TransportConfig,
pub security: SecurityConfig,
pub auto_reconnect: bool,
pub max_reconnect_attempts: u32,
pub debug: bool,
}
impl Default for MCPConfig {
fn default() -> Self {
Self {
name: "mcp-server".to_string(),
transport: TransportConfig::default(),
security: SecurityConfig::default(),
auto_reconnect: true,
max_reconnect_attempts: 3,
debug: false,
}
}
}
impl MCPConfig {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn transport(mut self, transport: TransportConfig) -> Self {
self.transport = transport;
self
}
pub fn security(mut self, security: SecurityConfig) -> Self {
self.security = security;
self
}
pub fn debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPTool {
pub name: String,
pub description: String,
pub input_schema: serde_json::Value,
}
impl MCPTool {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema: serde_json::json!({"type": "object", "properties": {}}),
}
}
pub fn input_schema(mut self, schema: serde_json::Value) -> Self {
self.input_schema = schema;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPCall {
pub name: String,
pub arguments: serde_json::Value,
pub id: String,
}
impl MCPCall {
pub fn new(name: impl Into<String>, arguments: serde_json::Value) -> Self {
Self {
name: name.into(),
arguments,
id: uuid::Uuid::new_v4().to_string(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPCallResult {
pub id: String,
pub content: Vec<MCPContent>,
pub is_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MCPContent {
Text { text: String },
Image { data: String, mime_type: String },
Resource { uri: String, text: Option<String> },
}
impl MCPContent {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self::Image {
data: data.into(),
mime_type: mime_type.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPResource {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
}
impl MCPResource {
pub fn new(uri: impl Into<String>, name: impl Into<String>) -> Self {
Self {
uri: uri.into(),
name: name.into(),
description: None,
mime_type: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPPrompt {
pub name: String,
pub description: Option<String>,
pub arguments: Vec<MCPPromptArgument>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MCPPromptArgument {
pub name: String,
pub description: Option<String>,
pub required: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ConnectionStatus {
#[default]
Disconnected,
Connecting,
Connected,
Error,
}
#[derive(Debug)]
pub struct MCP {
pub config: MCPConfig,
pub status: ConnectionStatus,
tools: Vec<MCPTool>,
resources: Vec<MCPResource>,
prompts: Vec<MCPPrompt>,
}
impl Default for MCP {
fn default() -> Self {
Self {
config: MCPConfig::default(),
status: ConnectionStatus::Disconnected,
tools: Vec::new(),
resources: Vec::new(),
prompts: Vec::new(),
}
}
}
impl MCP {
pub fn new() -> MCPBuilder {
MCPBuilder::default()
}
pub fn status(&self) -> ConnectionStatus {
self.status
}
pub fn is_connected(&self) -> bool {
self.status == ConnectionStatus::Connected
}
pub async fn connect(&mut self) -> Result<()> {
self.status = ConnectionStatus::Connecting;
self.status = ConnectionStatus::Connected;
Ok(())
}
pub async fn disconnect(&mut self) -> Result<()> {
self.status = ConnectionStatus::Disconnected;
Ok(())
}
pub async fn list_tools(&self) -> Result<Vec<MCPTool>> {
Ok(self.tools.clone())
}
pub async fn list_resources(&self) -> Result<Vec<MCPResource>> {
Ok(self.resources.clone())
}
pub async fn list_prompts(&self) -> Result<Vec<MCPPrompt>> {
Ok(self.prompts.clone())
}
pub async fn call_tool(&self, call: MCPCall) -> Result<MCPCallResult> {
if !self.is_connected() {
return Err(Error::workflow("MCP not connected"));
}
Ok(MCPCallResult {
id: call.id,
content: vec![MCPContent::text(format!(
"Result of calling {} with {:?}",
call.name, call.arguments
))],
is_error: false,
})
}
pub async fn read_resource(&self, uri: &str) -> Result<MCPContent> {
if !self.is_connected() {
return Err(Error::workflow("MCP not connected"));
}
Ok(MCPContent::Resource {
uri: uri.to_string(),
text: Some(format!("Content of resource: {}", uri)),
})
}
pub async fn get_prompt(&self, name: &str, args: HashMap<String, String>) -> Result<String> {
if !self.is_connected() {
return Err(Error::workflow("MCP not connected"));
}
Ok(format!("Prompt '{}' with args: {:?}", name, args))
}
}
#[derive(Debug, Default)]
pub struct MCPBuilder {
config: MCPConfig,
tools: Vec<MCPTool>,
}
impl MCPBuilder {
pub fn name(mut self, name: impl Into<String>) -> Self {
self.config.name = name.into();
self
}
pub fn server(mut self, command: impl Into<String>, args: &[&str]) -> Self {
self.config.transport = TransportConfig::stdio(command, args);
self
}
pub fn http(mut self, url: impl Into<String>) -> Self {
self.config.transport = TransportConfig::http(url);
self
}
pub fn websocket(mut self, url: impl Into<String>) -> Self {
self.config.transport = TransportConfig::websocket(url);
self
}
pub fn config(mut self, config: MCPConfig) -> Self {
self.config = config;
self
}
pub fn security(mut self, security: SecurityConfig) -> Self {
self.config.security = security;
self
}
pub fn tool(mut self, tool: MCPTool) -> Self {
self.tools.push(tool);
self
}
pub fn build(self) -> Result<MCP> {
Ok(MCP {
config: self.config,
status: ConnectionStatus::Disconnected,
tools: self.tools,
resources: Vec::new(),
prompts: Vec::new(),
})
}
}
#[derive(Debug)]
pub struct MCPServer {
pub name: String,
tools: Vec<MCPTool>,
resources: Vec<MCPResource>,
running: bool,
}
impl Default for MCPServer {
fn default() -> Self {
Self {
name: "praisonai-mcp-server".to_string(),
tools: Vec::new(),
resources: Vec::new(),
running: false,
}
}
}
impl MCPServer {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
..Default::default()
}
}
pub fn register_tool(&mut self, tool: MCPTool) {
self.tools.push(tool);
}
pub fn register_resource(&mut self, resource: MCPResource) {
self.resources.push(resource);
}
pub fn tools(&self) -> &[MCPTool] {
&self.tools
}
pub fn resources(&self) -> &[MCPResource] {
&self.resources
}
pub fn is_running(&self) -> bool {
self.running
}
pub async fn start(&mut self) -> Result<()> {
self.running = true;
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
self.running = false;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_config_stdio() {
let config = TransportConfig::stdio("npx", &["-y", "@anthropic/mcp-server"]);
assert_eq!(config.transport_type, TransportType::Stdio);
assert_eq!(config.command, Some("npx".to_string()));
assert_eq!(config.args.len(), 2);
}
#[test]
fn test_transport_config_http() {
let config = TransportConfig::http("http://localhost:8080")
.header("Authorization", "Bearer token")
.timeout(60);
assert_eq!(config.transport_type, TransportType::HttpStream);
assert_eq!(config.url, Some("http://localhost:8080".to_string()));
assert_eq!(config.timeout, 60);
}
#[test]
fn test_security_config() {
let config = SecurityConfig::permissive();
assert!(config.allow_fs);
assert!(config.allow_network);
let restrictive = SecurityConfig::restrictive();
assert!(!restrictive.allow_fs);
assert!(!restrictive.allow_network);
}
#[test]
fn test_mcp_config() {
let config = MCPConfig::new("test-server")
.transport(TransportConfig::stdio("node", &["server.js"]))
.debug(true);
assert_eq!(config.name, "test-server");
assert!(config.debug);
}
#[test]
fn test_mcp_tool() {
let tool = MCPTool::new("search", "Search the web")
.input_schema(serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
}
}));
assert_eq!(tool.name, "search");
assert!(tool.input_schema["properties"]["query"].is_object());
}
#[test]
fn test_mcp_call() {
let call = MCPCall::new("search", serde_json::json!({"query": "test"}));
assert_eq!(call.name, "search");
assert!(!call.id.is_empty());
}
#[test]
fn test_mcp_content() {
let text = MCPContent::text("Hello");
match text {
MCPContent::Text { text } => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_mcp_builder() {
let mcp = MCP::new()
.name("test")
.server("npx", &["-y", "@test/server"])
.build()
.unwrap();
assert_eq!(mcp.config.name, "test");
assert_eq!(mcp.status(), ConnectionStatus::Disconnected);
}
#[test]
fn test_mcp_server() {
let mut server = MCPServer::new("test-server");
server.register_tool(MCPTool::new("tool1", "Description"));
server.register_resource(MCPResource::new("file://test", "test"));
assert_eq!(server.tools().len(), 1);
assert_eq!(server.resources().len(), 1);
assert!(!server.is_running());
}
#[tokio::test]
async fn test_mcp_connect() {
let mut mcp = MCP::new()
.name("test")
.build()
.unwrap();
assert!(!mcp.is_connected());
mcp.connect().await.unwrap();
assert!(mcp.is_connected());
mcp.disconnect().await.unwrap();
assert!(!mcp.is_connected());
}
#[tokio::test]
async fn test_mcp_call_tool() {
let mut mcp = MCP::new()
.tool(MCPTool::new("test_tool", "Test"))
.build()
.unwrap();
mcp.connect().await.unwrap();
let call = MCPCall::new("test_tool", serde_json::json!({}));
let result = mcp.call_tool(call).await.unwrap();
assert!(!result.is_error);
assert!(!result.content.is_empty());
}
#[tokio::test]
async fn test_mcp_server_lifecycle() {
let mut server = MCPServer::new("test");
assert!(!server.is_running());
server.start().await.unwrap();
assert!(server.is_running());
server.stop().await.unwrap();
assert!(!server.is_running());
}
}