embers-protocol 0.1.0

FlatBuffers protocol, framing, and client transport primitives for Embers.
use std::path::Path;

use embers_core::RequestId;
use tokio::net::UnixStream;
use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
use tokio::sync::mpsc;
use tokio::task::AbortHandle;

use crate::codec::{ProtocolError, decode_server_envelope, encode_client_message};
use crate::framing::{FrameType, RawFrame, read_frame, write_frame};
use crate::types::{ClientMessage, ServerEnvelope, ServerResponse};

type ReaderItem = Result<Option<ServerEnvelope>, ProtocolError>;

#[derive(Debug)]
pub struct ProtocolClient {
    writer: OwnedWriteHalf,
    reader_rx: mpsc::Receiver<ReaderItem>,
    reader_reached_eof: bool,
    reader_abort_handle: AbortHandle,
}

impl ProtocolClient {
    const READER_CHANNEL_CAPACITY: usize = 64;

    pub async fn connect(path: impl AsRef<Path>) -> Result<Self, ProtocolError> {
        let stream = UnixStream::connect(path).await?;
        Ok(Self::from_stream(stream))
    }

    fn from_stream(stream: UnixStream) -> Self {
        let (reader, writer) = stream.into_split();
        let (reader_tx, reader_rx) = mpsc::channel(Self::READER_CHANNEL_CAPACITY);
        let reader_task = tokio::spawn(async move {
            Self::run_reader(reader, reader_tx).await;
        });
        let reader_abort_handle = reader_task.abort_handle();

        Self {
            writer,
            reader_rx,
            reader_reached_eof: false,
            reader_abort_handle,
        }
    }

    async fn run_reader(mut reader: OwnedReadHalf, reader_tx: mpsc::Sender<ReaderItem>) {
        loop {
            let next = match read_frame(&mut reader).await {
                Ok(Some(frame)) => Self::decode_frame(frame).map(Some),
                Ok(None) => Ok(None),
                Err(error) => Err(error),
            };
            let terminal = !matches!(next, Ok(Some(_)));
            if reader_tx.send(next).await.is_err() || terminal {
                break;
            }
        }
    }

    fn decode_frame(frame: RawFrame) -> Result<ServerEnvelope, ProtocolError> {
        let envelope = decode_server_envelope(&frame.payload)?;

        match (frame.frame_type, envelope) {
            (FrameType::Response, ServerEnvelope::Response(response)) => {
                let response_id = response.request_id().unwrap_or(RequestId(0));
                if response_id != frame.request_id {
                    return Err(ProtocolError::MismatchedRequestId {
                        expected: frame.request_id,
                        actual: response_id,
                    });
                }
                Ok(ServerEnvelope::Response(response))
            }
            (FrameType::Event, ServerEnvelope::Event(event)) => {
                if frame.request_id != RequestId(0) {
                    return Err(ProtocolError::MismatchedRequestId {
                        expected: RequestId(0),
                        actual: frame.request_id,
                    });
                }
                Ok(ServerEnvelope::Event(event))
            }
            (FrameType::Response, ServerEnvelope::Event(_)) => {
                Err(ProtocolError::UnexpectedFrameKind {
                    frame_type: FrameType::Response,
                    envelope_kind: "event",
                })
            }
            (FrameType::Event, ServerEnvelope::Response(_)) => {
                Err(ProtocolError::UnexpectedFrameKind {
                    frame_type: FrameType::Event,
                    envelope_kind: "response",
                })
            }
            (FrameType::Request, _) => Err(ProtocolError::UnexpectedFrameType(FrameType::Request)),
        }
    }

    pub async fn send(&mut self, message: &ClientMessage) -> Result<(), ProtocolError> {
        let payload = encode_client_message(message)?;
        let frame = RawFrame::new(FrameType::Request, message.request_id(), payload);
        write_frame(&mut self.writer, &frame).await
    }

    pub async fn recv(&mut self) -> Result<Option<ServerEnvelope>, ProtocolError> {
        match self.reader_rx.recv().await {
            Some(Ok(None)) => {
                self.reader_reached_eof = true;
                Ok(None)
            }
            Some(result) => result,
            None if self.reader_reached_eof => Ok(None),
            None => Err(ProtocolError::ReaderTaskExited),
        }
    }

    pub async fn request(
        &mut self,
        message: &ClientMessage,
    ) -> Result<ServerResponse, ProtocolError> {
        let request_id = message.request_id();
        self.send(message).await?;

        loop {
            match self.recv().await? {
                Some(ServerEnvelope::Response(response)) => match response.request_id() {
                    Some(response_id) if response_id != request_id => {
                        return Err(ProtocolError::MismatchedRequestId {
                            expected: request_id,
                            actual: response_id,
                        });
                    }
                    _ => {
                        return Ok(response);
                    }
                },
                Some(ServerEnvelope::Event(_)) => continue,
                None => {
                    return Err(ProtocolError::InvalidMessage(
                        "connection closed before response",
                    ));
                }
            }
        }
    }
}

impl Drop for ProtocolClient {
    fn drop(&mut self) {
        self.reader_abort_handle.abort();
    }
}

#[cfg(test)]
impl ProtocolClient {
    fn abort_reader_task(&self) {
        self.reader_abort_handle.abort();
    }

    fn drain_recv_buffer(&mut self) {
        while let Ok(item) = self.reader_rx.try_recv() {
            if matches!(item, Ok(None)) {
                self.reader_reached_eof = true;
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::ProtocolClient;
    use embers_core::{ErrorCode, RequestId, SessionId, WireError};
    use tokio::io::AsyncWriteExt;
    use tokio::net::UnixStream;
    use tokio::time::{Duration, timeout};

    use crate::codec::{ProtocolError, encode_server_envelope};
    use crate::framing::{FrameType, RawFrame, read_frame, write_frame};
    use crate::types::{
        ClientMessage, ErrorResponse, PingRequest, ServerEnvelope, ServerEvent, ServerResponse,
        SessionClosedEvent,
    };

    #[tokio::test]
    async fn request_accepts_unscoped_error_response() {
        let (mut server, client_stream) = UnixStream::pair().expect("create unix stream pair");
        let mut client = ProtocolClient::from_stream(client_stream);

        let request = ClientMessage::Ping(PingRequest {
            request_id: RequestId(7),
            payload: "phase2".to_owned(),
        });

        let server_task = tokio::spawn(async move {
            let frame = read_frame(&mut server)
                .await
                .expect("read request frame")
                .expect("request frame");
            assert_eq!(frame.frame_type, FrameType::Request);
            assert_eq!(frame.request_id, RequestId(7));

            let payload = encode_server_envelope(&ServerEnvelope::Response(ServerResponse::Error(
                ErrorResponse {
                    request_id: None,
                    error: WireError::new(ErrorCode::ProtocolViolation, "bad request"),
                },
            )))
            .expect("encode error response");
            let frame = RawFrame::new(FrameType::Response, RequestId(0), payload);
            write_frame(&mut server, &frame)
                .await
                .expect("write response frame");
        });

        let response = client.request(&request).await.expect("receive response");
        match response {
            ServerResponse::Error(response) => {
                assert_eq!(response.request_id, None);
                assert_eq!(response.error.code, ErrorCode::ProtocolViolation);
                assert_eq!(response.error.message, "bad request");
            }
            other => panic!("expected error response, got {other:?}"),
        }

        server_task.await.expect("server task joins");
    }

    #[tokio::test]
    async fn recv_timeout_does_not_cancel_in_progress_frame_read() {
        let (mut server, client_stream) = UnixStream::pair().expect("create unix stream pair");
        let mut client = ProtocolClient::from_stream(client_stream);

        let payload = encode_server_envelope(&ServerEnvelope::Event(ServerEvent::SessionClosed(
            SessionClosedEvent {
                session_id: SessionId(9),
            },
        )))
        .expect("encode event");
        let mut frame_bytes = Vec::with_capacity(13 + payload.len());
        frame_bytes.extend_from_slice(&(payload.len() as u32).to_le_bytes());
        frame_bytes.push(FrameType::Event as u8);
        frame_bytes.extend_from_slice(&0_u64.to_le_bytes());
        frame_bytes.extend_from_slice(&payload);

        server
            .write_all(&frame_bytes[..5])
            .await
            .expect("write partial frame");

        let timed_out = timeout(Duration::from_millis(20), client.recv()).await;
        assert!(timed_out.is_err(), "partial frame should keep recv pending");

        server
            .write_all(&frame_bytes[5..])
            .await
            .expect("write remainder");

        let envelope = timeout(Duration::from_secs(1), client.recv())
            .await
            .expect("recv finishes after remainder arrives")
            .expect("recv succeeds")
            .expect("connection remains open");
        assert!(matches!(
            envelope,
            ServerEnvelope::Event(ServerEvent::SessionClosed(SessionClosedEvent {
                session_id
            })) if session_id == SessionId(9)
        ));
    }

    #[tokio::test]
    async fn recv_reports_reader_task_exit_when_channel_closes_without_eof() {
        let (_server, client_stream) = UnixStream::pair().expect("create unix stream pair");
        let mut client = ProtocolClient::from_stream(client_stream);

        client.abort_reader_task();
        client.drain_recv_buffer();

        let error = timeout(Duration::from_secs(1), client.recv())
            .await
            .expect("recv returns after reader abort")
            .expect_err("closed reader channel should error");
        assert!(matches!(error, ProtocolError::ReaderTaskExited));
    }
}