embers-test-support 0.1.0

Shared integration-test harnesses and helpers for Embers crates.
use std::path::Path;
use std::time::Duration;

use embers_core::{BufferId, MuxError, Result, SessionId, new_request_id};
use embers_protocol::{
    BufferRequest, ClientMessage, PingRequest, ProtocolClient, ScrollbackSliceResponse,
    ServerEnvelope, ServerEvent, ServerResponse, SessionRequest, SessionSnapshot, SnapshotResponse,
    SubscribeRequest, UnsubscribeRequest, VisibleSnapshotResponse,
};

#[derive(Debug)]
pub struct TestConnection {
    client: ProtocolClient,
}

impl TestConnection {
    pub async fn connect(path: impl AsRef<Path>) -> Result<Self> {
        let client = ProtocolClient::connect(path)
            .await
            .map_err(|error| MuxError::transport(error.to_string()))?;
        Ok(Self { client })
    }

    pub async fn send(&mut self, message: &ClientMessage) -> Result<()> {
        self.client
            .send(message)
            .await
            .map_err(|error| MuxError::transport(error.to_string()))
    }

    pub async fn recv(&mut self) -> Result<Option<ServerEnvelope>> {
        self.client
            .recv()
            .await
            .map_err(|error| MuxError::transport(error.to_string()))
    }

    pub async fn request(&mut self, message: &ClientMessage) -> Result<ServerResponse> {
        self.client
            .request(message)
            .await
            .map_err(|error| MuxError::transport(error.to_string()))
    }

    pub async fn recv_event(&mut self) -> Result<ServerEvent> {
        match self.recv().await? {
            Some(ServerEnvelope::Event(event)) => Ok(event),
            Some(ServerEnvelope::Response(response)) => Err(MuxError::protocol(format!(
                "expected event, got response: {response:?}"
            ))),
            None => Err(MuxError::transport(
                "connection closed while waiting for an event",
            )),
        }
    }

    pub async fn wait_for_event<F>(
        &mut self,
        timeout: Duration,
        mut predicate: F,
    ) -> Result<ServerEvent>
    where
        F: FnMut(&ServerEvent) -> bool,
    {
        tokio::time::timeout(timeout, async {
            loop {
                let event = self.recv_event().await?;
                if predicate(&event) {
                    return Ok(event);
                }
            }
        })
        .await
        .map_err(|_| MuxError::timeout(format!("timed out waiting for event after {timeout:?}")))?
    }

    pub async fn subscribe(&mut self, session_id: Option<SessionId>) -> Result<u64> {
        let response = self
            .request(&ClientMessage::Subscribe(SubscribeRequest {
                request_id: new_request_id(),
                session_id,
            }))
            .await?;

        match response {
            ServerResponse::SubscriptionAck(ack) => Ok(ack.subscription_id),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to subscribe request: {other:?}"
            ))),
        }
    }

    pub async fn unsubscribe(&mut self, subscription_id: u64) -> Result<()> {
        let response = self
            .request(&ClientMessage::Unsubscribe(UnsubscribeRequest {
                request_id: new_request_id(),
                subscription_id,
            }))
            .await?;

        match response {
            ServerResponse::Ok(_) => Ok(()),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to unsubscribe request: {other:?}"
            ))),
        }
    }

    pub async fn ping(&mut self, payload: impl Into<String>) -> Result<String> {
        let response = self
            .request(&ClientMessage::Ping(PingRequest {
                request_id: new_request_id(),
                payload: payload.into(),
            }))
            .await?;

        match response {
            ServerResponse::Pong(pong) => Ok(pong.payload),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to ping request: {other:?}"
            ))),
        }
    }

    pub async fn session_snapshot(&mut self, session_id: SessionId) -> Result<SessionSnapshot> {
        let response = self
            .request(&ClientMessage::Session(SessionRequest::Get {
                request_id: new_request_id(),
                session_id,
            }))
            .await?;

        match response {
            ServerResponse::SessionSnapshot(response) => Ok(response.snapshot),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to session snapshot request: {other:?}"
            ))),
        }
    }

    pub async fn capture_buffer(&mut self, buffer_id: BufferId) -> Result<SnapshotResponse> {
        let response = self
            .request(&ClientMessage::Buffer(BufferRequest::Capture {
                request_id: new_request_id(),
                buffer_id,
            }))
            .await?;

        match response {
            ServerResponse::Snapshot(snapshot) => Ok(snapshot),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to capture request: {other:?}"
            ))),
        }
    }

    pub async fn capture_visible_buffer(
        &mut self,
        buffer_id: BufferId,
    ) -> Result<VisibleSnapshotResponse> {
        let response = self
            .request(&ClientMessage::Buffer(BufferRequest::CaptureVisible {
                request_id: new_request_id(),
                buffer_id,
            }))
            .await?;

        match response {
            ServerResponse::VisibleSnapshot(snapshot) => Ok(snapshot),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to visible capture request: {other:?}"
            ))),
        }
    }

    pub async fn capture_scrollback_slice(
        &mut self,
        buffer_id: BufferId,
        start_line: u64,
        line_count: u32,
    ) -> Result<ScrollbackSliceResponse> {
        let response = self
            .request(&ClientMessage::Buffer(BufferRequest::ScrollbackSlice {
                request_id: new_request_id(),
                buffer_id,
                start_line,
                line_count,
            }))
            .await?;

        match response {
            ServerResponse::ScrollbackSlice(snapshot) => Ok(snapshot),
            ServerResponse::Error(error) => Err(error.error.into()),
            other => Err(MuxError::protocol(format!(
                "unexpected response to scrollback slice request: {other:?}"
            ))),
        }
    }

    pub async fn wait_for_capture_contains(
        &mut self,
        buffer_id: BufferId,
        needle: &str,
        timeout: Duration,
    ) -> Result<SnapshotResponse> {
        let deadline = tokio::time::Instant::now() + timeout;

        loop {
            let snapshot = self.capture_buffer(buffer_id).await?;
            let capture = snapshot.lines.join("\n");
            if capture.contains(needle) {
                return Ok(snapshot);
            }

            if tokio::time::Instant::now() >= deadline {
                return Err(MuxError::timeout(format!(
                    "timed out waiting for buffer {buffer_id} to contain {needle:?}; last capture: {:?}",
                    capture
                )));
            }

            tokio::time::sleep(Duration::from_millis(50)).await;
        }
    }
}