tokio_ipc 0.1.0

Multi-protocol RPC framework built on top of tokio
Documentation
// tokio_socket/src/rpc2/rpc_packet.rs - Extended RPC packet with protocol ID support

use std::sync;
use anyhow::{Result, anyhow, bail};
use lockshed::{PromiseId, PromiseStore};
use serde::{Deserialize, Serialize};

pub type HandlerId = u64;
pub type RequestId = PromiseId;
pub type Buffer = Vec<u8>;

#[derive(Copy, Clone, Serialize, Deserialize, PartialEq)]
#[repr(u8)]
pub enum RpcPacketType {
    Request = 1,
    Response = 2,
}

impl RpcPacketType {
    pub fn to_u8(&self) -> u8 {
        *self as u8
    }

    pub fn from_u8(val: u8) -> Result<Self> {
        match val {
            _ if val == RpcPacketType::Request as u8 => Ok(RpcPacketType::Request),
            _ if val == RpcPacketType::Response as u8 => Ok(RpcPacketType::Response),
            _ => bail!("Unknown message type {val}"),
        }
    }

    pub fn to_string(&self) -> &'static str {
        match self {
            RpcPacketType::Request => "Request",
            RpcPacketType::Response => "Response",
        }
    }
}

impl std::fmt::Debug for RpcPacketType {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.to_string())
    }
}

pub type ResponsePromiseStore = PromiseStore<Vec<u8>>;

/// Extended RPC packet with protocol ID for multiplexing
#[derive(Clone, Debug)]
pub struct RpcPacket {
    pub packet_type: RpcPacketType,
    pub packet_id: RequestId,
    pub protocol_id: u64,  // NEW: Protocol identifier for multiplexing
    pub buf: Vec<u8>,
}

impl RpcPacket {
    fn new(packet_type: RpcPacketType, packet_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
        Self {
            packet_type,
            packet_id,
            protocol_id,
            buf,
        }
    }

    pub fn new_request(request_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
        Self::new(RpcPacketType::Request, request_id, protocol_id, buf)
    }

    pub fn new_response(request_id: RequestId, protocol_id: u64, buf: Vec<u8>) -> Self {
        Self::new(RpcPacketType::Response, request_id, protocol_id, buf)
    }
}

impl tokio_socket::PacketProtocol for RpcPacket {
    fn to_bytes(&self) -> Result<Vec<u8>> {
        let buf_len = self.buf.len();

        let mut bytes = Vec::with_capacity(
            1   // packet_type
          + 8   // packet_id
          + 8   // protocol_id
          + 8   // buf_len
          + buf_len,
        );

        bytes.push(self.packet_type.to_u8());
        bytes.extend(&self.packet_id.to_be_bytes());
        bytes.extend(&self.protocol_id.to_be_bytes());
        bytes.extend(&(buf_len as u64).to_be_bytes());
        bytes.extend(&self.buf);

        Ok(bytes)
    }

    fn from_bytes(bytes: &[u8]) -> Result<Self> {
        if bytes.len() < 25 {
            bail!("Packet too small: {} bytes", bytes.len());
        }

        let packet_kind = RpcPacketType::from_u8(bytes[0])
            .map_err(|e| anyhow!("Invalid message type: {e}"))?;
        let packet_id = u64::from_be_bytes(
            bytes[1..9]
                .try_into()
                .map_err(|e| anyhow!("Invalid message_id: {e}"))?
        );
        let protocol_id = u64::from_be_bytes(
            bytes[9..17]
                .try_into()
                .map_err(|e| anyhow!("Invalid protocol_id: {e}"))?
        );
        let buf_len = u64::from_be_bytes(
            bytes[17..25]
                .try_into()
                .map_err(|e| anyhow!("Invalid buf_len: {e}"))?
        ) as usize;

        if bytes.len() < 25 + buf_len {
            bail!("Packet truncated: expected {} bytes, got {}", 25 + buf_len, bytes.len());
        }

        let buf = bytes[25..25 + buf_len].to_vec();
        Ok(RpcPacket::new(packet_kind, packet_id, protocol_id, buf))
    }
}

/// Base trait for protocol senders - provides access to RpcPeer
pub trait RpcProtocolSender: Clone + Send + Sync + 'static {
    fn peer(&self) -> &crate::RpcPeer;
}

/// Trait for receiving RPC messages - now protocol-aware
pub trait ReceiveRpcProtocol: Clone + Send + Sync + 'static {
    fn handle_packet(
        &self,
        protocol_id: u64,
        peer: &tokio_socket::SocketPeer,
        buf: Vec<u8>,
    ) -> impl std::future::Future<Output = Result<Option<Vec<u8>>>> + Send;
}

/// Trait for sending RPC messages
pub trait SendRpcProtocol: Send + Sync + 'static {
    fn new(peer: crate::RpcPeer) -> Self
    where
        Self: Sized;
}

#[derive(Clone)]
pub struct RpcMessageState {
    pub response_promise_store: ResponsePromiseStore,
    request_id_counter: sync::Arc<sync::atomic::AtomicU64>,
}

impl RpcMessageState {
    pub fn new() -> Self {
        Self {
            request_id_counter: sync::Arc::new(sync::atomic::AtomicU64::new(0)),
            response_promise_store: ResponsePromiseStore::new(),
        }
    }
}

impl Default for RpcMessageState {
    fn default() -> Self {
        Self::new()
    }
}

impl RpcMessageState {
    pub fn next_request_id(&self) -> u64 {
        self.request_id_counter
            .fetch_add(1, sync::atomic::Ordering::SeqCst)
    }
}

#[derive(Clone)]
pub struct RpcPacketHandler<H: ReceiveRpcProtocol> {
    packet_handler: H,
    state: RpcMessageState,
}

impl<H: ReceiveRpcProtocol> std::fmt::Debug for RpcPacketHandler<H> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("RpcPacketHandler").finish()
    }
}

impl<H: ReceiveRpcProtocol> RpcPacketHandler<H> {
    pub fn new(packet_handler: H, state: RpcMessageState) -> Self {
        Self {
            packet_handler,
            state,
        }
    }
}

impl<H: ReceiveRpcProtocol + 'static> tokio_socket::HandlePacket for RpcPacketHandler<H> {
    type Packet = RpcPacket;

    async fn on_packet(&self, peer: tokio_socket::SocketPeer, packet: Self::Packet) -> Result<()> {
        match packet.packet_type {
            RpcPacketType::Request => {
                tracing::trace!("recv request from {peer} protocol={} {packet:?}", packet.protocol_id);
                let response = self.packet_handler
                    .handle_packet(packet.protocol_id, &peer, packet.buf)
                    .await?;
                tracing::trace!("send response to {peer} {response:?}");
                if let Some(buf) = response {
                    let response = RpcPacket::new_response(packet.packet_id, packet.protocol_id, buf);
                    tokio_socket::PacketWriter::write_packet(&peer, &response).await
                } else {
                    Ok(())
                }
            }
            RpcPacketType::Response => {
                tracing::trace!("recv response from {peer} {packet:?}");
                let resolver = self
                    .state
                    .response_promise_store
                    .get_resolver(packet.packet_id)
                    .await?;
                tracing::trace!("found response resolver for request from {peer} {packet:?}");
                resolver.resolve(packet.buf).await
            }
        }
    }
}