codineer-runtime 0.6.8

Core runtime engine for Codineer: session, config, MCP, prompt, sandbox
Documentation
use std::collections::BTreeMap;
use std::io;

use futures_util::{SinkExt, StreamExt};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, CONTENT_TYPE};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio_tungstenite::tungstenite::Message as WsMessage;

use crate::mcp_client::{
    McpClientAuth, McpClientBootstrap, McpClientTransport, McpRemoteTransport,
};
use crate::mcp_stdio::types::{
    JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeParams, McpInitializeResult,
    McpListToolsParams, McpListToolsResult, McpToolCallParams, McpToolCallResult,
};

#[derive(Debug)]
enum RemoteTransport {
    Http {
        client: reqwest::Client,
        url: String,
        headers: HeaderMap,
    },
    WebSocket {
        ws: Box<
            tokio_tungstenite::WebSocketStream<
                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
            >,
        >,
    },
}

#[derive(Debug)]
pub struct McpRemoteClient {
    transport: RemoteTransport,
}

impl McpRemoteClient {
    pub async fn connect(bootstrap: &McpClientBootstrap) -> io::Result<Self> {
        match &bootstrap.transport {
            McpClientTransport::Sse(remote) | McpClientTransport::Http(remote) => {
                Self::connect_http(remote).await
            }
            McpClientTransport::WebSocket(remote) => Self::connect_ws(remote).await,
            other => Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!(
                    "MCP bootstrap transport for {} is not remote: {other:?}",
                    bootstrap.server_name
                ),
            )),
        }
    }

    async fn connect_http(remote: &McpRemoteTransport) -> io::Result<Self> {
        let mut headers = build_headers(&remote.headers);
        if let McpClientAuth::OAuth(ref oauth) = remote.auth {
            if let Some(ref client_id) = oauth.client_id {
                let _ = headers.insert(
                    HeaderName::from_static("x-oauth-client-id"),
                    HeaderValue::from_str(client_id)
                        .unwrap_or_else(|_| HeaderValue::from_static("")),
                );
            }
        }

        let client = reqwest::Client::builder()
            .timeout(std::time::Duration::from_secs(120))
            .build()
            .map_err(io::Error::other)?;

        Ok(Self {
            transport: RemoteTransport::Http {
                client,
                url: remote.url.clone(),
                headers,
            },
        })
    }

    async fn connect_ws(remote: &McpRemoteTransport) -> io::Result<Self> {
        let mut request = tokio_tungstenite::tungstenite::http::Request::builder()
            .uri(&remote.url)
            .header("Sec-WebSocket-Protocol", "mcp");

        for (key, value) in &remote.headers {
            request = request.header(key.as_str(), value.as_str());
        }

        let request = request
            .body(())
            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;

        let (ws, _response) = tokio_tungstenite::connect_async(request)
            .await
            .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;

        Ok(Self {
            transport: RemoteTransport::WebSocket { ws: Box::new(ws) },
        })
    }

    async fn request<TParams: Serialize, TResult: DeserializeOwned>(
        &mut self,
        id: JsonRpcId,
        method: impl Into<String>,
        params: Option<TParams>,
    ) -> io::Result<JsonRpcResponse<TResult>> {
        let method = method.into();
        let request = JsonRpcRequest::new(id.clone(), method.clone(), params);

        match &mut self.transport {
            RemoteTransport::Http {
                client,
                url,
                headers,
            } => {
                let body = serde_json::to_vec(&request)
                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;

                let response = client
                    .post(url.as_str())
                    .headers(headers.clone())
                    .header(CONTENT_TYPE, "application/json")
                    .header(ACCEPT, "application/json, text/event-stream")
                    .body(body)
                    .send()
                    .await
                    .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;

                if !response.status().is_success() {
                    return Err(io::Error::other(format!(
                        "MCP HTTP request {method} failed with status {}",
                        response.status()
                    )));
                }

                let content_type = response
                    .headers()
                    .get(CONTENT_TYPE)
                    .and_then(|v| v.to_str().ok())
                    .unwrap_or("")
                    .to_string();

                if content_type.contains("text/event-stream") {
                    let text = response
                        .text()
                        .await
                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
                    parse_sse_jsonrpc(&text, &id)
                } else {
                    let bytes = response
                        .bytes()
                        .await
                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
                    serde_json::from_slice(&bytes)
                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
                }
            }
            RemoteTransport::WebSocket { ws } => {
                let body = serde_json::to_string(&request)
                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
                ws.send(WsMessage::Text(body.into()))
                    .await
                    .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;

                let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(120);
                loop {
                    let msg = tokio::time::timeout_at(deadline, ws.next())
                        .await
                        .map_err(|_| {
                            io::Error::new(
                                io::ErrorKind::TimedOut,
                                "WebSocket RPC response timed out after 120s",
                            )
                        })?
                        .ok_or_else(|| {
                            io::Error::new(
                                io::ErrorKind::UnexpectedEof,
                                "WebSocket stream closed while waiting for response",
                            )
                        })?
                        .map_err(|e| io::Error::new(io::ErrorKind::BrokenPipe, e))?;

                    match msg {
                        WsMessage::Text(text) => {
                            let response: JsonRpcResponse<TResult> = serde_json::from_str(&text)
                                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
                            if response.id == id {
                                return Ok(response);
                            }
                        }
                        WsMessage::Ping(data) => {
                            let _ = ws.send(WsMessage::Pong(data)).await;
                        }
                        WsMessage::Close(_) => {
                            return Err(io::Error::new(
                                io::ErrorKind::ConnectionAborted,
                                "WebSocket connection closed by server",
                            ));
                        }
                        _ => {}
                    }
                }
            }
        }
    }

    pub async fn initialize(
        &mut self,
        id: JsonRpcId,
        params: McpInitializeParams,
    ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
        self.request(id, "initialize", Some(params)).await
    }

    pub async fn list_tools(
        &mut self,
        id: JsonRpcId,
        params: Option<McpListToolsParams>,
    ) -> io::Result<JsonRpcResponse<McpListToolsResult>> {
        self.request(id, "tools/list", params).await
    }

    pub async fn call_tool(
        &mut self,
        id: JsonRpcId,
        params: McpToolCallParams,
    ) -> io::Result<JsonRpcResponse<McpToolCallResult>> {
        self.request(id, "tools/call", Some(params)).await
    }

    pub async fn shutdown(&mut self) -> io::Result<()> {
        match &mut self.transport {
            RemoteTransport::Http { .. } => Ok(()),
            RemoteTransport::WebSocket { ws } => {
                let _ = ws.close(None).await;
                Ok(())
            }
        }
    }
}

fn build_headers(headers: &BTreeMap<String, String>) -> HeaderMap {
    let mut header_map = HeaderMap::new();
    for (key, value) in headers {
        if let (Ok(name), Ok(val)) = (
            HeaderName::from_bytes(key.as_bytes()),
            HeaderValue::from_str(value),
        ) {
            header_map.insert(name, val);
        }
    }
    header_map
}

fn parse_sse_jsonrpc<T: DeserializeOwned>(
    sse_text: &str,
    expected_id: &JsonRpcId,
) -> io::Result<JsonRpcResponse<T>> {
    for line in sse_text.lines() {
        let data = match line.strip_prefix("data: ") {
            Some(d) if !d.is_empty() => d.trim(),
            _ => continue,
        };
        if let Ok(response) = serde_json::from_str::<JsonRpcResponse<T>>(data) {
            if &response.id == expected_id {
                return Ok(response);
            }
        }
    }
    Err(io::Error::new(
        io::ErrorKind::InvalidData,
        "no matching JSON-RPC response found in SSE stream",
    ))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::Value as JsonValue;

    #[test]
    fn build_headers_creates_header_map() {
        let mut input = BTreeMap::new();
        input.insert("X-Custom".to_string(), "value1".to_string());
        input.insert("Authorization".to_string(), "Bearer tok".to_string());

        let headers = build_headers(&input);
        assert_eq!(headers.get("x-custom").unwrap(), "value1");
        assert_eq!(headers.get("authorization").unwrap(), "Bearer tok");
    }

    #[test]
    fn parse_sse_extracts_matching_response() {
        let sse_text =
            "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n";
        let result: JsonRpcResponse<JsonValue> =
            parse_sse_jsonrpc(sse_text, &JsonRpcId::Number(1)).unwrap();
        assert_eq!(result.id, JsonRpcId::Number(1));
        assert!(result.result.is_some());
    }

    #[test]
    fn parse_sse_rejects_when_no_match() {
        let sse_text = "data: {\"jsonrpc\":\"2.0\",\"id\":99,\"result\":null}\n\n";
        let result: io::Result<JsonRpcResponse<JsonValue>> =
            parse_sse_jsonrpc(sse_text, &JsonRpcId::Number(1));
        assert!(result.is_err());
    }

    #[test]
    fn parse_sse_skips_non_data_lines() {
        let sse_text = ": comment\nevent: ping\nretry: 3000\ndata: {\"jsonrpc\":\"2.0\",\"id\":5,\"result\":42}\n\n";
        let result: JsonRpcResponse<JsonValue> =
            parse_sse_jsonrpc(sse_text, &JsonRpcId::Number(5)).unwrap();
        assert_eq!(result.id, JsonRpcId::Number(5));
    }
}