chabeau 0.7.1

A full-screen terminal chat interface that connects to various AI APIs for real-time conversations
Documentation
use super::{
    protocol, require_stdio_command, stdio_args, stdio_env, STDIO_REQUEST_TIMEOUT_SECONDS,
    STDIO_SAMPLING_TIMEOUT_MULTIPLIER,
};
use crate::core::config::data::McpServerConfig;
use crate::mcp::events::McpServerRequest;
use rust_mcp_schema::schema_utils::{
    ClientMessage, FromMessage, MessageFromClient, NotificationFromClient, RequestFromClient,
    ResultFromClient, ServerMessage,
};
use rust_mcp_schema::{InitializeRequestParams, InitializeResult, RequestId, RpcError};
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, Command};
use tokio::sync::{mpsc, oneshot, Mutex, Notify, RwLock};
use tracing::debug;

pub(crate) struct StdioClient {
    stdin: Mutex<ChildStdin>,
    pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ServerMessage>>>>,
    next_request_id: AtomicI64,
    server_details: RwLock<Option<rust_mcp_schema::InitializeResult>>,
    server_id: String,
    request_tx: Option<mpsc::UnboundedSender<McpServerRequest>>,
    activity_notify: Arc<Notify>,
    inflight_server_requests: Arc<AtomicI64>,
}

impl StdioClient {
    pub(crate) async fn connect(
        server_id: String,
        config: &McpServerConfig,
        request_tx: Option<mpsc::UnboundedSender<McpServerRequest>>,
    ) -> Result<Arc<Self>, String> {
        let command = require_stdio_command(config)?;
        let args = stdio_args(config);
        debug!(command = %command, args = ?args, "Starting MCP stdio server");
        let mut cmd = Command::new(command);
        cmd.args(args)
            .stdin(std::process::Stdio::piped())
            .stdout(std::process::Stdio::piped())
            .stderr(std::process::Stdio::piped());

        if let Some(env) = stdio_env(config) {
            cmd.envs(env);
        }

        let mut child = cmd.spawn().map_err(|err| err.to_string())?;
        let stdin = child
            .stdin
            .take()
            .ok_or_else(|| "Unable to retrieve stdin.".to_string())?;
        let stdout = child
            .stdout
            .take()
            .ok_or_else(|| "Unable to retrieve stdout.".to_string())?;
        let stderr = child
            .stderr
            .take()
            .ok_or_else(|| "Unable to retrieve stderr.".to_string())?;

        let pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ServerMessage>>>> =
            Arc::new(Mutex::new(HashMap::new()));
        let client = Arc::new(Self {
            stdin: Mutex::new(stdin),
            pending: pending.clone(),
            next_request_id: AtomicI64::new(0),
            server_details: RwLock::new(None),
            server_id,
            request_tx,
            activity_notify: Arc::new(Notify::new()),
            inflight_server_requests: Arc::new(AtomicI64::new(0)),
        });

        Self::spawn_stdout_reader(
            pending.clone(),
            stdout,
            client.server_id.clone(),
            client.request_tx.clone(),
            client.activity_notify.clone(),
            client.inflight_server_requests.clone(),
        );
        Self::spawn_stderr_drain(stderr);

        tokio::spawn(async move {
            let _ = child.wait().await;
            let mut pending = pending.lock().await;
            pending.clear();
        });

        Ok(client)
    }

    pub(crate) async fn initialize(
        &self,
        details: InitializeRequestParams,
    ) -> Result<InitializeResult, String> {
        let response = self
            .send_request(RequestFromClient::InitializeRequest(details))
            .await?;
        let result = protocol::parse_initialize_result(response)?;
        *self.server_details.write().await = Some(result.clone());
        self.send_notification(NotificationFromClient::InitializedNotification(None))
            .await?;
        Ok(result)
    }

    pub(crate) async fn send_request(
        &self,
        request: RequestFromClient,
    ) -> Result<ServerMessage, String> {
        let request_id = self.next_request_id();
        debug!(request_id = ?request_id, "Sending MCP stdio request");
        let message = ClientMessage::from_message(
            MessageFromClient::RequestFromClient(request),
            Some(request_id.clone()),
        )
        .map_err(|err| err.to_string())?;

        let (tx, rx) = oneshot::channel();
        {
            let mut pending = self.pending.lock().await;
            pending.insert(request_id.clone(), tx);
        }

        self.send_client_message(&message).await?;

        let wait_timeout = self.timeout_for_wait();
        match tokio::time::timeout(wait_timeout, rx).await {
            Ok(Ok(message)) => Ok(message),
            Ok(Err(_)) => {
                self.pending.lock().await.remove(&request_id);
                Err("MCP stdio response channel closed.".to_string())
            }
            Err(_) => {
                self.pending.lock().await.remove(&request_id);
                Err("Timed out waiting for MCP stdio response.".to_string())
            }
        }
    }

    pub(crate) async fn send_result(
        &self,
        request_id: RequestId,
        result: ResultFromClient,
    ) -> Result<(), String> {
        debug!(
            server_id = %self.server_id,
            request_id = ?request_id,
            "Preparing MCP stdio result"
        );
        let message = ClientMessage::from_message(
            MessageFromClient::ResultFromClient(result),
            Some(request_id.clone()),
        )
        .map_err(|err| err.to_string())?;
        let result = self.send_client_message(&message).await;
        if result.is_ok() {
            let inflight = self.decrement_inflight();
            debug!(
                request_id = ?request_id,
                inflight_server_requests = inflight,
                "Sent MCP stdio result"
            );
            self.activity_notify.notify_waiters();
        }
        result
    }

    pub(crate) async fn send_error(
        &self,
        request_id: RequestId,
        error: RpcError,
    ) -> Result<(), String> {
        debug!(
            server_id = %self.server_id,
            request_id = ?request_id,
            "Preparing MCP stdio error response"
        );
        let message =
            ClientMessage::from_message(MessageFromClient::Error(error), Some(request_id.clone()))
                .map_err(|err| err.to_string())?;
        let result = self.send_client_message(&message).await;
        if result.is_ok() {
            let inflight = self.decrement_inflight();
            debug!(
                request_id = ?request_id,
                inflight_server_requests = inflight,
                "Sent MCP stdio error response"
            );
            self.activity_notify.notify_waiters();
        }
        result
    }

    async fn send_notification(&self, notification: NotificationFromClient) -> Result<(), String> {
        let message = ClientMessage::from_message(
            MessageFromClient::NotificationFromClient(notification),
            None,
        )
        .map_err(|err| err.to_string())?;
        self.send_client_message(&message).await
    }

    async fn send_client_message(&self, message: &ClientMessage) -> Result<(), String> {
        let lock_timeout = tokio::time::Duration::from_secs(10);
        let write_timeout = tokio::time::Duration::from_secs(10);
        let payload = serde_json::to_string(message).map_err(|err| err.to_string())?;
        let mut stdin = match tokio::time::timeout(lock_timeout, self.stdin.lock()).await {
            Ok(stdin) => stdin,
            Err(_) => return Err("Timed out waiting for MCP stdio stdin lock.".to_string()),
        };

        tokio::time::timeout(write_timeout, stdin.write_all(payload.as_bytes()))
            .await
            .map_err(|_| "Timed out writing MCP stdio client message.".to_string())?
            .map_err(|err| err.to_string())?;
        tokio::time::timeout(write_timeout, stdin.write_all(b"\n"))
            .await
            .map_err(|_| "Timed out writing MCP stdio newline.".to_string())?
            .map_err(|err| err.to_string())?;
        tokio::time::timeout(write_timeout, stdin.flush())
            .await
            .map_err(|_| "Timed out flushing MCP stdio client message.".to_string())?
            .map_err(|err| err.to_string())?;
        Ok(())
    }

    fn timeout_for_wait(&self) -> tokio::time::Duration {
        let inflight = self.inflight_server_requests.load(Ordering::SeqCst);
        let multiplier = if inflight > 0 {
            STDIO_SAMPLING_TIMEOUT_MULTIPLIER
        } else {
            1
        };
        tokio::time::Duration::from_secs(STDIO_REQUEST_TIMEOUT_SECONDS * multiplier)
    }

    fn decrement_inflight(&self) -> i64 {
        let mut current = self.inflight_server_requests.load(Ordering::SeqCst);
        while current > 0 {
            match self.inflight_server_requests.compare_exchange(
                current,
                current - 1,
                Ordering::SeqCst,
                Ordering::SeqCst,
            ) {
                Ok(_) => return current - 1,
                Err(next) => current = next,
            }
        }
        current
    }

    fn next_request_id(&self) -> RequestId {
        let id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
        RequestId::Integer(id)
    }

    fn spawn_stdout_reader(
        pending: Arc<Mutex<HashMap<RequestId, oneshot::Sender<ServerMessage>>>>,
        stdout: tokio::process::ChildStdout,
        server_id: String,
        request_tx: Option<mpsc::UnboundedSender<McpServerRequest>>,
        activity_notify: Arc<Notify>,
        inflight_server_requests: Arc<AtomicI64>,
    ) {
        tokio::spawn(async move {
            let mut reader = BufReader::new(stdout).lines();
            while let Ok(Some(line)) = reader.next_line().await {
                let value = match serde_json::from_str::<serde_json::Value>(&line) {
                    Ok(value) => value,
                    Err(_) => continue,
                };
                if let Some(items) = value.as_array() {
                    for item in items {
                        if let Ok(message) = serde_json::from_value::<ServerMessage>(item.clone()) {
                            Self::dispatch_message(
                                &pending,
                                message,
                                &server_id,
                                request_tx.as_ref(),
                                &activity_notify,
                                &inflight_server_requests,
                            )
                            .await;
                        }
                    }
                } else if let Ok(message) = serde_json::from_value::<ServerMessage>(value) {
                    Self::dispatch_message(
                        &pending,
                        message,
                        &server_id,
                        request_tx.as_ref(),
                        &activity_notify,
                        &inflight_server_requests,
                    )
                    .await;
                }
            }
        });
    }

    fn spawn_stderr_drain(stderr: tokio::process::ChildStderr) {
        tokio::spawn(async move {
            let mut reader = BufReader::new(stderr).lines();
            while let Ok(Some(_)) = reader.next_line().await {}
        });
    }

    async fn dispatch_message(
        pending: &Arc<Mutex<HashMap<RequestId, oneshot::Sender<ServerMessage>>>>,
        message: ServerMessage,
        server_id: &str,
        request_tx: Option<&mpsc::UnboundedSender<McpServerRequest>>,
        activity_notify: &Notify,
        inflight_server_requests: &AtomicI64,
    ) {
        match &message {
            ServerMessage::Response(response) => {
                if let Some(tx) = pending.lock().await.remove(&response.id) {
                    let _ = tx.send(message);
                }
            }
            ServerMessage::Error(error) => {
                if let Some(id) = error.id.as_ref() {
                    if let Some(tx) = pending.lock().await.remove(id) {
                        let _ = tx.send(message);
                    }
                }
            }
            ServerMessage::Request(request) => {
                let _ = inflight_server_requests.fetch_add(1, Ordering::SeqCst);
                activity_notify.notify_waiters();
                if let Some(tx) = request_tx {
                    let _ = tx.send(McpServerRequest {
                        server_id: server_id.to_string(),
                        request: request.clone(),
                    });
                }
            }
            ServerMessage::Notification(_) => {
                activity_notify.notify_waiters();
            }
        }
    }
}

pub(crate) async fn send_request(
    client: Option<Arc<StdioClient>>,
    request: RequestFromClient,
) -> Result<ServerMessage, String> {
    let Some(client) = client else {
        return Err("MCP client not connected.".to_string());
    };
    client.send_request(request).await
}

pub(crate) async fn send_result(
    client: Option<Arc<StdioClient>>,
    request_id: RequestId,
    result: ResultFromClient,
) -> Result<(), String> {
    let Some(client) = client else {
        return Err("MCP client not connected.".to_string());
    };
    client.send_result(request_id, result).await
}

pub(crate) async fn send_error(
    client: Option<Arc<StdioClient>>,
    request_id: RequestId,
    error: RpcError,
) -> Result<(), String> {
    let Some(client) = client else {
        return Err("MCP client not connected.".to_string());
    };
    client.send_error(request_id, error).await
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn stdio_requires_connected_client() {
        let err = send_request(None, RequestFromClient::PingRequest(None))
            .await
            .expect_err("expected missing client error");
        assert_eq!(err, "MCP client not connected.");
    }
}