tarpc-cat 0.1.0

RPC framework built on comp-cat-rs: typed effects, no async, categorical foundations
Documentation
//! RPC server that accepts connections and dispatches to a [`Serve`] handler.
//!
//! The server binds a TCP listener and spawns a thread per connection.
//! Each connection processes requests sequentially: read envelope,
//! dispatch to the service handler, write response envelope, repeat.
//!
//! The entire accept loop runs inside a single [`Io::suspend`] to
//! avoid stack growth from recursive [`flat_map`] chains.
//!
//! [`Io::suspend`]: comp_cat_rs::effect::io::Io::suspend
//! [`flat_map`]: comp_cat_rs::effect::io::Io::flat_map

use std::io::{BufReader, BufWriter};
use std::net::{SocketAddr, TcpListener, TcpStream};

use comp_cat_rs::effect::io::Io;

use crate::codec;
use crate::error::Error;
use crate::protocol::{Envelope, RequestId};
use crate::serve::Serve;

/// Address to listen on for incoming connections.
#[derive(Debug, Clone, Copy)]
pub struct ListenAddr(SocketAddr);

impl ListenAddr {
    /// Create from a [`SocketAddr`].
    #[must_use]
    pub fn new(addr: SocketAddr) -> Self {
        Self(addr)
    }

    /// The underlying socket address.
    #[must_use]
    pub fn addr(self) -> SocketAddr {
        self.0
    }
}

/// Start an RPC server.
///
/// Binds to `addr` and serves incoming connections using `service`.
/// Each connection gets its own thread.  This [`Io`], when run,
/// blocks indefinitely accepting connections.
///
/// # Errors
///
/// Returns [`Error::Io`] if binding or accepting fails.
///
/// # Examples
///
/// ```rust,ignore
/// use tarpc_cat::server::{serve, ListenAddr};
///
/// let addr = ListenAddr::new("127.0.0.1:9000".parse().unwrap());
/// serve(addr, my_service).run()?;
/// ```
#[must_use]
pub fn serve<S: Serve>(addr: ListenAddr, service: S) -> Io<Error, core::convert::Infallible> {
    Io::suspend(move || {
        let listener = TcpListener::bind(addr.addr())?;
        listener
            .incoming()
            .try_for_each(|stream_result| {
                let stream = stream_result?;
                let svc = service.clone();
                std::thread::Builder::new()
                    .spawn(move || {
                        let _: Result<(), Error> = handle_connection(stream, &svc);
                    })
                    .map(|_handle| ())
            })
            .map_err(Error::from)
            .and(Err(Error::ConnectionClosed))
    })
}

/// Process requests on a single connection until it closes.
fn handle_connection<S: Serve>(stream: TcpStream, service: &S) -> Result<(), Error> {
    let read_stream = stream.try_clone()?;
    let mut reader = BufReader::new(read_stream);
    let mut writer = BufWriter::new(stream);

    std::iter::from_fn(|| Some(process_one_request(&mut reader, &mut writer, service)))
        .try_for_each(|result| match result {
            Ok(()) => Ok(()),
            Err(Error::ConnectionClosed) => Err(Error::ConnectionClosed),
            Err(e) => Err(e),
        })
        .or_else(|e| match e {
            Error::ConnectionClosed => Ok(()),
            other => Err(other),
        })
}

/// Read one request, dispatch to the service, and write the response.
fn process_one_request<S: Serve>(
    reader: &mut impl std::io::Read,
    writer: &mut impl std::io::Write,
    service: &S,
) -> Result<(), Error> {
    let envelope: Envelope = codec::decode(reader)?;
    match envelope {
        Envelope::Request { id, payload } => {
            let response_envelope = deserialize_and_handle(id, &payload, service)?;
            codec::encode(writer, &response_envelope)
        }
        Envelope::Response { .. } | Envelope::Error { .. } => Err(Error::Server {
            message: "unexpected non-request envelope from client".to_owned(),
        }),
    }
}

/// Deserialize the request payload, run the handler, and build the response envelope.
fn deserialize_and_handle<S: Serve>(
    id: RequestId,
    payload: &str,
    service: &S,
) -> Result<Envelope, Error> {
    serde_json::from_str::<S::Request>(payload)
        .map_err(Error::from_deserialize)
        .and_then(|request| service.handle(request).run())
        .and_then(|response| {
            serde_json::to_string(&response)
                .map_err(Error::from_serialize)
                .map(|resp_payload| Envelope::Response {
                    id,
                    payload: resp_payload,
                })
        })
        .or_else(|e| {
            Ok(Envelope::Error {
                id,
                message: e.to_string(),
            })
        })
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::client;
    use crate::serve::Serve;
    use serde::{Deserialize, Serialize};

    #[derive(Clone)]
    struct EchoService;

    #[derive(Serialize, Deserialize, Debug, PartialEq)]
    struct EchoRequest {
        message: String,
    }

    #[derive(Serialize, Deserialize, Debug, PartialEq)]
    struct EchoResponse {
        echo: String,
    }

    impl Serve for EchoService {
        type Request = EchoRequest;
        type Response = EchoResponse;

        fn handle(&self, request: EchoRequest) -> Io<Error, EchoResponse> {
            Io::pure(EchoResponse {
                echo: request.message,
            })
        }
    }

    #[test]
    fn end_to_end_echo() -> Result<(), Error> {
        let listener = TcpListener::bind("127.0.0.1:0")?;
        let addr = listener.local_addr()?;
        drop(listener);

        let listen_addr = ListenAddr::new(addr);
        let server_addr = client::ServerAddr::new(addr);

        let server_handle = std::thread::spawn(move || {
            let _: Result<core::convert::Infallible, Error> = serve(listen_addr, EchoService).run();
        });

        // Give the server a moment to bind.
        std::thread::sleep(std::time::Duration::from_millis(50));

        let request = EchoRequest {
            message: "hello".to_owned(),
        };
        let response: EchoResponse = client::call(server_addr, request).run()?;
        assert_eq!(response.echo, "hello");

        // Second call to verify the server handles multiple connections.
        let request2 = EchoRequest {
            message: "world".to_owned(),
        };
        let response2: EchoResponse = client::call(server_addr, request2).run()?;
        assert_eq!(response2.echo, "world");

        drop(server_handle);
        Ok(())
    }
}