multilink 1.0.1

IPC library for communicating with local or remote processes, over stdio or HTTP
Documentation
use std::collections::HashMap;

use futures::StreamExt;
use serde_json::Value;
use tokio::{
    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
    process::{ChildStdin, ChildStdout},
    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, warn};

use crate::{
    jsonrpc::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse},
    stdio::StdioError,
    ServiceResponse,
};

use super::{
    serialize_payload, ClientNotificationLink, ClientRequestTrx, RequestJsonRpcConvert,
    ResponseJsonRpcConvert,
};

pub(super) struct StdioClientCommTask<Request, Response>
where
    Request: RequestJsonRpcConvert<Request> + Send + 'static,
    Response: ResponseJsonRpcConvert<Request, Response> + Send + 'static,
{
    stdin: ChildStdin,
    stdout: BufReader<ChildStdout>,
    pending_reqs: HashMap<u64, ClientRequestTrx<Request, Response>>,
    notification_links: HashMap<u64, ClientNotificationLink<Request, Response>>,
    to_child_rx: UnboundedReceiver<ClientRequestTrx<Request, Response>>,
    to_child_tx: Option<UnboundedSender<ClientRequestTrx<Request, Response>>>,
    last_req_id: u64,
}

impl<Request, Response> StdioClientCommTask<Request, Response>
where
    Request: RequestJsonRpcConvert<Request> + Send + 'static,
    Response: ResponseJsonRpcConvert<Request, Response> + Send + 'static,
{
    pub(super) fn new(stdin: ChildStdin, stdout: BufReader<ChildStdout>) -> Self {
        let (to_child_tx, to_child_rx) =
            mpsc::unbounded_channel::<ClientRequestTrx<Request, Response>>();
        Self {
            stdin,
            stdout,
            pending_reqs: HashMap::new(),
            notification_links: HashMap::new(),
            to_child_rx,
            to_child_tx: Some(to_child_tx),
            last_req_id: 0,
        }
    }

    async fn output_message(&mut self, message: JsonRpcMessage) {
        let serialized_response = serialize_payload(&message);
        self.stdin
            .write_all(serialized_response.as_bytes())
            .await
            .ok();
    }

    async fn handle_outgoing_request(&mut self, req_trx: ClientRequestTrx<Request, Response>) {
        let mut jsonrpc_request = req_trx.request.into_jsonrpc_request();
        let id = self.last_req_id + 1;
        jsonrpc_request.id = serde_json::to_value(id).unwrap();

        self.last_req_id = id;
        self.pending_reqs.insert(id, req_trx);

        self.output_message(jsonrpc_request.into()).await;
    }

    async fn handle_incoming_request(&mut self, request: JsonRpcRequest) {
        self.output_message(
            JsonRpcResponse::new(Err(StdioError::ClientRequestUnsupported.into()), request.id)
                .into(),
        )
        .await
    }

    fn handle_response(&mut self, response: JsonRpcResponse) {
        match self
            .pending_reqs
            .remove(&serde_json::from_value(response.id.clone()).unwrap_or_default())
        {
            None => {
                warn!("received response with unknown id, ignoring {:?}", response)
            }
            Some(trx) => {
                let result = match Response::from_jsonrpc_message(response.into(), &trx.request) {
                    Ok(response) => match response {
                        None => {
                            error!("unknown json rpc notification type received");
                            return;
                        }
                        Some(response) => Ok(ServiceResponse::Single(response)),
                    },
                    Err(e) => Err(e.into()),
                };
                trx.response_tx.send(result).ok();
            }
        }
    }

    fn handle_notification(&mut self, notification: JsonRpcNotification) {
        let id = notification.method.parse::<u64>().unwrap_or_default();
        if let Some(trx) = self.pending_reqs.remove(&id) {
            let (notification_tx, notification_rx) = mpsc::unbounded_channel();
            trx.response_tx
                .send(Ok(ServiceResponse::Multiple(
                    UnboundedReceiverStream::new(notification_rx).boxed(),
                )))
                .ok();
            self.notification_links.insert(
                id,
                ClientNotificationLink {
                    request: trx.request,
                    notification_tx,
                },
            );
        }
        match self.notification_links.get(&id) {
            None => warn!("received notification with unknown id, ignoring"),
            Some(link) => match notification.params.is_some() {
                true => {
                    let result =
                        match Response::from_jsonrpc_message(notification.into(), &link.request) {
                            Ok(notification) => match notification {
                                None => {
                                    error!("unknown json rpc notification type received");
                                    return;
                                }
                                Some(notification) => Ok(notification),
                            },
                            Err(e) => Err(e.into()),
                        };
                    link.notification_tx.send(result).ok();
                }
                false => {
                    self.notification_links.remove(&id);
                    self.pending_reqs.remove(&id);
                }
            },
        }
    }

    async fn run(mut self) {
        loop {
            let mut stdout_message = String::new();
            tokio::select! {
                req_trx = self.to_child_rx.recv() => if let Some(req_trx) = req_trx {
                    self.handle_outgoing_request(req_trx).await;
                },
                result = self.stdout.read_line(&mut stdout_message) => match result {
                    Err(e) => error!("StdioClient i/o error reading line from stdout: {}" ,e),
                    Ok(bytes_read) => {
                        if bytes_read == 0 {
                            return;
                        }
                        match JsonRpcMessage::try_from(serde_json::from_str::<Value>(&stdout_message).unwrap_or_default()) {
                            Err(e) => error!("failed to parse message from server: {}", e),
                            Ok(message) => match message {
                                JsonRpcMessage::Request(request) => self.handle_incoming_request(request).await,
                                JsonRpcMessage::Response(response) => self.handle_response(response),
                                JsonRpcMessage::Notification(notification) => self.handle_notification(notification)
                            }
                        }
                    }
                }
            }
        }
    }

    pub(super) fn start(mut self) -> UnboundedSender<ClientRequestTrx<Request, Response>> {
        let to_child_tx = self.to_child_tx.take().unwrap();
        tokio::spawn(async move {
            self.run().await;
        });
        to_child_tx
    }
}