mcpway 0.2.0

Run MCP stdio servers over SSE, WebSocket, Streamable HTTP, and gRPC transports.
Documentation
use std::collections::HashMap;
use std::time::Duration;

use crate::tool_api::client::ToolClient;
use crate::tool_api::error::ToolCallError;
use crate::tool_api::transport::{Transport, TransportClient, TransportOptions};

const DEFAULT_PROTOCOL_VERSION: &str = "2024-11-05";
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_millis(1_500);
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);

pub struct ToolClientBuilder {
    endpoint: String,
    transport: Transport,
    protocol_version: String,
    headers: HashMap<String, String>,
    connect_timeout: Duration,
    request_timeout: Option<Duration>,
    stdio_command: Option<String>,
    stdio_args: Vec<String>,
    stdio_env: HashMap<String, String>,
    stdio_cwd: Option<String>,
}

impl ToolClientBuilder {
    pub fn new(endpoint: impl Into<String>, transport: Transport) -> Self {
        Self {
            endpoint: endpoint.into(),
            transport,
            protocol_version: DEFAULT_PROTOCOL_VERSION.to_string(),
            headers: HashMap::new(),
            connect_timeout: DEFAULT_CONNECT_TIMEOUT,
            request_timeout: Some(DEFAULT_REQUEST_TIMEOUT),
            stdio_command: None,
            stdio_args: Vec::new(),
            stdio_env: HashMap::new(),
            stdio_cwd: None,
        }
    }

    pub fn protocol_version(mut self, protocol_version: impl Into<String>) -> Self {
        self.protocol_version = protocol_version.into();
        self
    }

    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
        self.headers.insert(key.into(), value.into());
        self
    }

    pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
        self.headers.extend(headers);
        self
    }

    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
        self.connect_timeout = timeout;
        self
    }

    pub fn request_timeout(mut self, timeout: Option<Duration>) -> Self {
        self.request_timeout = timeout;
        self
    }

    pub fn stdio_command(mut self, command: impl Into<String>) -> Self {
        self.stdio_command = Some(command.into());
        self
    }

    pub fn stdio_args(mut self, args: Vec<String>) -> Self {
        self.stdio_args = args;
        self
    }

    pub fn stdio_env(mut self, env: HashMap<String, String>) -> Self {
        self.stdio_env = env;
        self
    }

    pub fn stdio_cwd(mut self, cwd: Option<String>) -> Self {
        self.stdio_cwd = cwd;
        self
    }

    pub fn build(self) -> Result<ToolClient, ToolCallError> {
        match self.transport {
            Transport::Stdio => {
                let command = self
                    .stdio_command
                    .as_deref()
                    .map(str::trim)
                    .unwrap_or_default();
                if command.is_empty() {
                    return Err(ToolCallError::InvalidArguments(
                        "stdio command cannot be empty".to_string(),
                    ));
                }
            }
            Transport::StreamableHttp | Transport::Sse | Transport::Ws | Transport::Grpc => {
                if self.endpoint.trim().is_empty() {
                    return Err(ToolCallError::InvalidEndpoint(
                        "endpoint cannot be empty".to_string(),
                    ));
                }
            }
        }

        let transport = TransportClient::new(
            self.transport,
            TransportOptions {
                endpoint: self.endpoint,
                headers: self.headers,
                connect_timeout: self.connect_timeout,
                request_timeout: self.request_timeout,
                stdio_command: self.stdio_command,
                stdio_args: self.stdio_args,
                stdio_env: self.stdio_env,
                stdio_cwd: self.stdio_cwd,
            },
        )?;

        Ok(ToolClient::from_parts(transport, self.protocol_version))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn build_rejects_empty_endpoint() {
        let result = ToolClientBuilder::new("  ", Transport::StreamableHttp).build();
        assert!(matches!(result, Err(ToolCallError::InvalidEndpoint(_))));
    }

    #[test]
    fn build_rejects_stdio_without_command() {
        let result = ToolClientBuilder::new("", Transport::Stdio).build();
        assert!(matches!(result, Err(ToolCallError::InvalidArguments(_))));
    }

    #[test]
    fn build_accepts_stdio_with_command() {
        let result = ToolClientBuilder::new("", Transport::Stdio)
            .stdio_command("node")
            .stdio_args(vec!["--version".to_string()])
            .build();
        assert!(result.is_ok());
    }
}