systemprompt-mcp 0.1.3

Core MCP (Model Context Protocol) functionality for systemprompt.io OS
Documentation
use anyhow::Result;
use rmcp::model::{ClientCapabilities, ClientInfo, Implementation, ProtocolVersion};
use rmcp::transport::streamable_http_client::StreamableHttpClientTransport;
use rmcp::ServiceExt;
use std::time::Duration;
use tokio::time::timeout;

use super::types::{McpConnectionResult, McpProtocolInfo, ValidationResult};

pub async fn validate_connection(
    service_name: &str,
    host: &str,
    port: u16,
) -> Result<McpConnectionResult> {
    let connection_start = std::time::Instant::now();
    let url = format!("http://{host}:{port}/mcp");

    let connection_result = timeout(
        Duration::from_secs(15),
        connect_and_validate(&url, service_name),
    )
    .await;

    let connection_time = connection_start.elapsed().as_millis() as u32;

    match connection_result {
        Ok(Ok((server_info, validation_result))) => Ok(McpConnectionResult {
            service_name: service_name.to_string(),
            success: validation_result.success,
            error_message: validation_result.error_message,
            connection_time_ms: connection_time,
            server_info: Some(server_info),
            tools_count: validation_result.tools_count,
            validation_type: validation_result.validation_type,
        }),
        Ok(Err(e)) => Ok(McpConnectionResult {
            service_name: service_name.to_string(),
            success: false,
            error_message: Some(e.to_string()),
            connection_time_ms: connection_time,
            server_info: None,
            tools_count: 0,
            validation_type: "connection_failed".to_string(),
        }),
        Err(_) => Ok(McpConnectionResult {
            service_name: service_name.to_string(),
            success: false,
            error_message: Some("Connection timeout".to_string()),
            connection_time_ms: connection_time,
            server_info: None,
            tools_count: 0,
            validation_type: "timeout".to_string(),
        }),
    }
}

pub async fn validate_connection_with_auth(
    service_name: &str,
    host: &str,
    port: u16,
    requires_oauth: bool,
) -> Result<McpConnectionResult> {
    if requires_oauth {
        Ok(validate_oauth_service(service_name, host, port))
    } else {
        validate_connection(service_name, host, port).await
    }
}

fn validate_oauth_service(service_name: &str, host: &str, port: u16) -> McpConnectionResult {
    let connection_start = std::time::Instant::now();

    let port_check = std::net::TcpStream::connect(format!("{host}:{port}"));
    let connection_time = connection_start.elapsed().as_millis() as u32;

    match port_check {
        Ok(_) => McpConnectionResult {
            service_name: service_name.to_string(),
            success: true,
            error_message: None,
            connection_time_ms: connection_time,
            server_info: Some(McpProtocolInfo {
                server_name: service_name.to_string(),
                version: "unknown".to_string(),
                protocol_version: "unknown".to_string(),
            }),
            tools_count: 0,
            validation_type: "auth_required".to_string(),
        },
        Err(e) => McpConnectionResult {
            service_name: service_name.to_string(),
            success: false,
            error_message: Some(format!("Port not responding: {e}")),
            connection_time_ms: connection_time,
            server_info: None,
            tools_count: 0,
            validation_type: "port_unavailable".to_string(),
        },
    }
}

async fn connect_and_validate(
    url: &str,
    service_name: &str,
) -> Result<(McpProtocolInfo, ValidationResult)> {
    let transport = StreamableHttpClientTransport::from_uri(url);

    let client_info = ClientInfo {
        meta: None,
        protocol_version: ProtocolVersion::default(),
        capabilities: ClientCapabilities::default(),
        client_info: Implementation {
            name: format!("systemprompt.io MCP Validator for {service_name}"),
            title: None,
            version: "1.0.0".to_string(),
            website_url: None,
            icons: None,
        },
    };

    let client = client_info.serve(transport).await?;

    let peer_info = client
        .peer_info()
        .ok_or_else(|| anyhow::anyhow!("Failed to get peer info from MCP client"))?;

    let server_info = McpProtocolInfo {
        server_name: if peer_info.server_info.name.is_empty() {
            service_name.to_string()
        } else {
            peer_info.server_info.name.clone()
        },
        version: if peer_info.server_info.version.is_empty() {
            "1.0.0".to_string()
        } else {
            peer_info.server_info.version.clone()
        },
        protocol_version: peer_info.protocol_version.to_string(),
    };

    let validation_result = match client.list_tools(None).await {
        Ok(tools_response) => {
            let tools_count = tools_response.tools.len();

            if tools_count > 0 {
                ValidationResult {
                    success: true,
                    error_message: None,
                    tools_count,
                    validation_type: "mcp_validated".to_string(),
                }
            } else {
                ValidationResult {
                    success: false,
                    error_message: Some(
                        "No tools returned - service may require authentication".to_string(),
                    ),
                    tools_count: 0,
                    validation_type: "no_tools".to_string(),
                }
            }
        },
        Err(e) => ValidationResult {
            success: false,
            error_message: Some(format!("Tools request failed: {e}")),
            tools_count: 0,
            validation_type: "tools_request_failed".to_string(),
        },
    };

    client.cancel().await?;
    Ok((server_info, validation_result))
}

pub fn rewrite_url_for_internal_use(url: &str) -> String {
    use systemprompt_models::Config;

    let Ok(config) = Config::get() else {
        return url.to_string();
    };
    let external_url = &config.api_external_url;
    let internal_url = &config.api_server_url;

    if url.starts_with(external_url) {
        url.replace(external_url, internal_url)
    } else {
        url.to_string()
    }
}