chabeau 0.7.3

A full-screen terminal chat interface that connects to various AI APIs for real-time conversations
Documentation
use super::{
    apply_streamable_http_client_post_headers, apply_streamable_http_protocol_version_header,
    client_details_for, protocol, require_http_base_url, McpServerRequestContext,
};
use crate::core::config::data::McpServerConfig;
use crate::mcp::events::McpServerRequest;
use crate::mcp::transport::streamable_http::{
    is_event_stream_content_type, next_sse_server_message, sse_data_payload, SseLineBuffer,
};
use futures_util::StreamExt;
use rust_mcp_schema::schema_utils::{
    ClientMessage, FromMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
    ServerMessage,
};
use rust_mcp_schema::RequestId;
use tokio::sync::mpsc;
use tracing::debug;

fn apply_mcp_http_headers(
    request: reqwest::RequestBuilder,
    headers: Option<&std::collections::HashMap<String, String>>,
) -> reqwest::RequestBuilder {
    let mut request = request;
    if let Some(headers) = headers {
        for (name, value) in headers {
            request = request.header(name, value);
        }
    }
    request
}

pub(crate) trait StreamableHttpContext {
    fn config(&self) -> &McpServerConfig;
    fn auth_header(&self) -> Option<&String>;
    fn session_id(&self) -> Option<&String>;
    fn http_client(&self) -> Option<&reqwest::Client>;
    fn set_session_id(&mut self, session_id: Option<String>);
    fn next_request_id(&mut self) -> i64;
    fn negotiated_protocol_version(&self) -> Option<&str>;
    fn set_negotiated_protocol_version(&mut self, protocol_version: Option<String>);

    fn effective_protocol_version(&self) -> String {
        protocol::effective_protocol_version(self.config(), self.negotiated_protocol_version())
    }
}

pub(crate) struct StreamableHttpListenerConfig {
    pub server_id: String,
    pub base_url: Option<String>,
    pub auth_header: Option<String>,
    pub custom_headers: Option<std::collections::HashMap<String, String>>,
    pub session_id: Option<String>,
    pub protocol_version: Option<String>,
}

pub(crate) async fn ensure_session_context<C: StreamableHttpContext>(
    context: &mut C,
) -> Result<(), String> {
    if context.session_id().is_some() {
        return Ok(());
    }

    let client_details = client_details_for(context.config());
    let response = send_request_with_context(
        context,
        RequestFromClient::InitializeRequest(client_details),
        None,
    )
    .await?;
    let initialize = super::protocol::parse_initialize_result(response)?;
    context.set_negotiated_protocol_version(Some(initialize.protocol_version));

    send_notification(
        context,
        NotificationFromClient::InitializedNotification(None),
    )
    .await
}

pub(crate) async fn send_request_with_context<C: StreamableHttpContext>(
    context: &mut C,
    request: RequestFromClient,
    request_tx: Option<mpsc::UnboundedSender<McpServerRequest>>,
) -> Result<ServerMessage, String> {
    let request_id = context.next_request_id();
    let message = ClientMessage::from_message(
        MessageFromClient::RequestFromClient(request),
        Some(RequestId::Integer(request_id)),
    )
    .map_err(|err| err.to_string())?;
    send_message(context, message, request_tx).await
}

pub(crate) async fn send_server_result_message(
    context: &mut McpServerRequestContext,
    message: ClientMessage,
) -> Result<(), String> {
    send_client_message_with_context(context, message).await
}

pub(crate) fn spawn_streamable_http_listener(
    client: reqwest::Client,
    request_tx: mpsc::UnboundedSender<McpServerRequest>,
    listener: StreamableHttpListenerConfig,
) {
    let StreamableHttpListenerConfig {
        server_id,
        base_url,
        auth_header,
        custom_headers,
        session_id,
        protocol_version,
    } = listener;

    let Some(base_url) = base_url else {
        return;
    };

    tokio::spawn(async move {
        let mut request = apply_streamable_http_protocol_version_header(
            client.get(&base_url).header("Accept", "text/event-stream"),
            protocol_version.as_deref(),
        );

        request = apply_mcp_http_headers(request, custom_headers.as_ref());
        if let Some(auth) = auth_header {
            request = request.header("Authorization", auth);
        }
        if let Some(session_id) = session_id {
            request = request.header("mcp-session-id", session_id);
        }

        let response = match request.send().await {
            Ok(response) => response,
            Err(_) => return,
        };

        if !response.status().is_success() {
            return;
        }

        let content_type = response
            .headers()
            .get(reqwest::header::CONTENT_TYPE)
            .and_then(|value| value.to_str().ok())
            .unwrap_or("");
        if !is_event_stream_content_type(content_type) {
            return;
        }

        let mut stream = response.bytes_stream();
        let mut buffer = SseLineBuffer::default();

        while let Some(chunk) = stream.next().await {
            let chunk = match chunk {
                Ok(chunk) => chunk,
                Err(_) => return,
            };
            for line in buffer.push(&chunk) {
                let Some(payload) = sse_data_payload(&line) else {
                    continue;
                };
                if payload.is_empty() {
                    continue;
                }
                if let Ok(ServerMessage::Request(request)) =
                    serde_json::from_str::<ServerMessage>(payload)
                {
                    let _ = request_tx.send(McpServerRequest {
                        server_id: server_id.clone(),
                        request,
                    });
                }
            }
        }

        for line in buffer.finish() {
            let Some(payload) = sse_data_payload(&line) else {
                continue;
            };
            if payload.is_empty() {
                continue;
            }
            if let Ok(ServerMessage::Request(request)) =
                serde_json::from_str::<ServerMessage>(payload)
            {
                let _ = request_tx.send(McpServerRequest {
                    server_id: server_id.clone(),
                    request,
                });
            }
        }
    });
}

async fn send_notification<C: StreamableHttpContext>(
    context: &mut C,
    notification: NotificationFromClient,
) -> Result<(), String> {
    let message = ClientMessage::from_message(
        MessageFromClient::NotificationFromClient(notification),
        None,
    )
    .map_err(|err| err.to_string())?;
    send_client_message_with_context(context, message).await
}

async fn send_client_message_with_context<C: StreamableHttpContext>(
    context: &mut C,
    message: ClientMessage,
) -> Result<(), String> {
    let payload = serde_json::to_string(&message).map_err(|err| err.to_string())?;
    let client = context
        .http_client()
        .ok_or_else(|| "MCP HTTP client not connected.".to_string())?;
    let base_url = require_http_base_url(context.config())?;
    let protocol_version = context.effective_protocol_version();
    let mut request = apply_streamable_http_protocol_version_header(
        apply_streamable_http_client_post_headers(client.post(base_url)),
        Some(protocol_version.as_str()),
    )
    .body(payload);

    request = apply_mcp_http_headers(request, context.config().headers.as_ref());
    if let Some(auth) = context.auth_header() {
        request = request.header("Authorization", auth);
    }
    if let Some(session_id) = context.session_id() {
        request = request.header("mcp-session-id", session_id);
    }

    let response = request.send().await.map_err(|err| err.to_string())?;
    if !response.status().is_success() {
        return Err(format!("HTTP error: {}", response.status()));
    }
    if let Some(session_id) = response
        .headers()
        .get("mcp-session-id")
        .and_then(|value| value.to_str().ok())
        .map(|value| value.to_string())
    {
        context.set_session_id(Some(session_id));
    }

    Ok(())
}

async fn send_message<C: StreamableHttpContext>(
    context: &mut C,
    message: ClientMessage,
    request_tx: Option<mpsc::UnboundedSender<McpServerRequest>>,
) -> Result<ServerMessage, String> {
    let payload = serde_json::to_string(&message).map_err(|err| err.to_string())?;
    let client = context
        .http_client()
        .ok_or_else(|| "MCP HTTP client not connected.".to_string())?;
    let base_url = require_http_base_url(context.config())?;
    debug!(server_id = %context.config().id, url = %base_url, "Sending MCP HTTP request");
    let protocol_version = context.effective_protocol_version();
    let mut request = apply_streamable_http_protocol_version_header(
        apply_streamable_http_client_post_headers(client.post(base_url)),
        Some(protocol_version.as_str()),
    )
    .body(payload);

    request = apply_mcp_http_headers(request, context.config().headers.as_ref());
    if let Some(auth) = context.auth_header() {
        request = request.header("Authorization", auth);
    }
    if let Some(session_id) = context.session_id() {
        request = request.header("mcp-session-id", session_id);
    }

    let response = request.send().await.map_err(|err| err.to_string())?;
    if !response.status().is_success() {
        return Err(format!("HTTP error: {}", response.status()));
    }

    let session_id = response
        .headers()
        .get("mcp-session-id")
        .and_then(|value| value.to_str().ok())
        .map(|value| value.to_string());
    let content_type = response
        .headers()
        .get(reqwest::header::CONTENT_TYPE)
        .and_then(|value| value.to_str().ok())
        .unwrap_or("")
        .to_string();

    let server_message = if is_event_stream_content_type(&content_type) {
        let server_id = context.config().id.clone();
        next_sse_server_message(response, move |message| {
            if let ServerMessage::Request(request) = message {
                if let Some(tx) = request_tx.as_ref() {
                    let _ = tx.send(McpServerRequest {
                        server_id: server_id.clone(),
                        request: request.clone(),
                    });
                }
            }
        })
        .await?
    } else {
        let body = response.bytes().await.map_err(|err| err.to_string())?;
        serde_json::from_slice::<ServerMessage>(&body).map_err(|err| err.to_string())?
    };

    if let Some(session_id) = session_id {
        context.set_session_id(Some(session_id));
    }
    Ok(server_message)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn ensure_session_requires_http_client() {
        struct Dummy {
            config: McpServerConfig,
            session: Option<String>,
        }
        impl StreamableHttpContext for Dummy {
            fn config(&self) -> &McpServerConfig {
                &self.config
            }
            fn auth_header(&self) -> Option<&String> {
                None
            }
            fn session_id(&self) -> Option<&String> {
                self.session.as_ref()
            }
            fn http_client(&self) -> Option<&reqwest::Client> {
                None
            }
            fn set_session_id(&mut self, session_id: Option<String>) {
                self.session = session_id;
            }
            fn next_request_id(&mut self) -> i64 {
                0
            }
            fn negotiated_protocol_version(&self) -> Option<&str> {
                None
            }
            fn set_negotiated_protocol_version(&mut self, _protocol_version: Option<String>) {}
        }

        let mut ctx = Dummy {
            config: McpServerConfig {
                id: "alpha".to_string(),
                display_name: "Alpha".to_string(),
                base_url: Some("https://example.com".to_string()),
                command: None,
                args: None,
                env: None,
                headers: None,
                transport: Some("streamable-http".to_string()),
                allowed_tools: None,
                protocol_version: None,
                enabled: Some(true),
                tool_payloads: None,
                tool_payload_window: None,
                yolo: None,
            },
            session: None,
        };
        let rt = tokio::runtime::Runtime::new().expect("runtime");
        let err = rt
            .block_on(ensure_session_context(&mut ctx))
            .expect_err("expected error");
        assert_eq!(err, "MCP HTTP client not connected.");
    }
}