mcpway 0.2.0

Run MCP stdio servers over SSE, WebSocket, Streamable HTTP, and gRPC transports.
Documentation
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use std::time::Duration;

use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::sync::Mutex;

use crate::types::HeadersMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
struct WarmHint {
    transport: String,
    last_success_utc: u64,
}

#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct WarmHintStore {
    records: BTreeMap<String, WarmHint>,
}

#[derive(Default)]
pub struct TransportPool {
    http_clients: Mutex<HashMap<String, Arc<reqwest::Client>>>,
    warm_hints: Mutex<WarmHintStore>,
}

static GLOBAL_POOL: OnceLock<Arc<TransportPool>> = OnceLock::new();

pub fn global_pool() -> Arc<TransportPool> {
    GLOBAL_POOL
        .get_or_init(|| Arc::new(TransportPool::new()))
        .clone()
}

impl TransportPool {
    fn new() -> Self {
        Self {
            http_clients: Mutex::new(HashMap::new()),
            warm_hints: Mutex::new(load_warm_hint_store().unwrap_or_default()),
        }
    }

    pub async fn http_client(
        &self,
        key: &str,
        connect_timeout: Duration,
        request_timeout: Option<Duration>,
    ) -> Result<Arc<reqwest::Client>, String> {
        {
            let clients = self.http_clients.lock().await;
            if let Some(existing) = clients.get(key) {
                return Ok(existing.clone());
            }
        }

        {
            let hints = self.warm_hints.lock().await;
            if hints.records.contains_key(key) {
                tracing::debug!("Using warm transport hint for key={key}");
            }
        }

        let mut builder = reqwest::Client::builder().connect_timeout(connect_timeout);
        if let Some(timeout) = request_timeout {
            builder = builder.timeout(timeout);
        }

        let client = builder
            .build()
            .map_err(|err| format!("Failed to build HTTP client: {err}"))?;
        let client = Arc::new(client);

        let mut clients = self.http_clients.lock().await;
        clients.insert(key.to_string(), client.clone());
        Ok(client)
    }

    pub async fn mark_success(&self, key: &str, transport: &str) {
        let mut hints = self.warm_hints.lock().await;
        hints.records.insert(
            key.to_string(),
            WarmHint {
                transport: transport.to_string(),
                last_success_utc: unix_timestamp_secs(),
            },
        );

        if let Err(err) = persist_warm_hint_store(&hints) {
            tracing::warn!("Failed to persist transport warm cache: {err}");
        }
    }
}

pub fn transport_fingerprint(
    transport: &str,
    endpoint_or_command: &str,
    headers: &HeadersMap,
    protocol_version: &str,
) -> String {
    let mut pairs = headers
        .iter()
        .map(|(k, v)| format!("{}={}", k.to_ascii_lowercase(), v))
        .collect::<Vec<_>>();
    pairs.sort();

    let mut payload = String::new();
    payload.push_str(transport);
    payload.push('|');
    payload.push_str(endpoint_or_command);
    payload.push('|');
    payload.push_str(protocol_version);
    payload.push('|');
    payload.push_str(&pairs.join(";"));

    sha256_hex(payload.as_bytes())
}

fn warm_cache_path() -> PathBuf {
    if let Some(path) = std::env::var_os("MCPWAY_WARM_CACHE_PATH") {
        return PathBuf::from(path);
    }

    if let Some(home) = crate::discovery::user_home_dir() {
        return home.join(".mcpway").join("transport-warm-cache.json");
    }

    PathBuf::from(".mcpway/transport-warm-cache.json")
}

fn load_warm_hint_store() -> Result<WarmHintStore, String> {
    let path = warm_cache_path();
    if !path.exists() {
        return Ok(WarmHintStore::default());
    }

    let body = std::fs::read_to_string(&path)
        .map_err(|err| format!("Failed to read {}: {err}", path.display()))?;
    serde_json::from_str::<WarmHintStore>(&body)
        .map_err(|err| format!("Invalid warm cache JSON in {}: {err}", path.display()))
}

fn persist_warm_hint_store(store: &WarmHintStore) -> Result<(), String> {
    let path = warm_cache_path();
    let parent = path
        .parent()
        .ok_or_else(|| format!("Invalid warm cache path: {}", path.display()))?;

    std::fs::create_dir_all(parent)
        .map_err(|err| format!("Failed to create {}: {err}", parent.display()))?;

    let body = serde_json::to_string_pretty(store)
        .map_err(|err| format!("Failed to serialize warm cache: {err}"))?;
    std::fs::write(&path, body).map_err(|err| format!("Failed to write {}: {err}", path.display()))
}

fn unix_timestamp_secs() -> u64 {
    std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .map(|duration| duration.as_secs())
        .unwrap_or(0)
}

fn sha256_hex(content: &[u8]) -> String {
    let mut hasher = Sha256::new();
    hasher.update(content);
    let digest = hasher.finalize();
    format!("{digest:x}")
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn fingerprint_is_stable_and_sensitive_to_inputs() {
        let mut headers = HeadersMap::new();
        headers.insert("Authorization".to_string(), "Bearer token-a".to_string());

        let first = transport_fingerprint(
            "streamable-http",
            "https://example.com/mcp",
            &headers,
            "2024-11-05",
        );
        let second = transport_fingerprint(
            "streamable-http",
            "https://example.com/mcp",
            &headers,
            "2024-11-05",
        );
        assert_eq!(first, second);

        headers.insert("X-Test".to_string(), "changed".to_string());
        let third = transport_fingerprint(
            "streamable-http",
            "https://example.com/mcp",
            &headers,
            "2024-11-05",
        );
        assert_ne!(first, third);
    }

    #[tokio::test]
    async fn warm_cache_persists_only_hashed_keys() {
        let tmp_path =
            std::env::temp_dir().join(format!("mcpway-warm-cache-{}.json", std::process::id()));
        std::env::set_var("MCPWAY_WARM_CACHE_PATH", &tmp_path);

        let pool = TransportPool::new();
        let key = "deadbeef";
        pool.mark_success(key, "streamable-http").await;

        let body = std::fs::read_to_string(&tmp_path).expect("warm cache file missing");
        assert!(body.contains("deadbeef"));
        assert!(!body.contains("https://example.com"));

        let _ = std::fs::remove_file(&tmp_path);
        std::env::remove_var("MCPWAY_WARM_CACHE_PATH");
    }
}