embers-client 0.1.0

Client rendering, input handling, configuration, and scripting support for Embers.
use std::cell::Cell;
use std::collections::BTreeSet;
use std::path::Path;

use embers_core::{BufferId, IdAllocator, MuxError, RequestId, Result, SessionId};
use embers_protocol::{
    BufferRequest, ClientMessage, ClientRecord, ClientRequest, ScrollbackSliceResponse,
    ServerEvent, ServerResponse, SessionRequest, SnapshotResponse, SubscribeRequest,
};

use crate::socket_transport::SocketTransport;
use crate::state::ClientState;
use crate::transport::Transport;

#[derive(Debug)]
pub struct MuxClient<T> {
    transport: T,
    request_ids: IdAllocator<RequestId>,
    client_id: Cell<Option<u64>>,
    state: ClientState,
}

impl<T> MuxClient<T>
where
    T: Transport,
{
    pub fn new(transport: T) -> Self {
        Self {
            transport,
            request_ids: IdAllocator::new(1),
            client_id: Cell::new(None),
            state: ClientState::default(),
        }
    }

    pub fn next_request_id(&self) -> RequestId {
        self.request_ids.next()
    }

    pub fn state(&self) -> &ClientState {
        &self.state
    }

    pub fn state_mut(&mut self) -> &mut ClientState {
        &mut self.state
    }

    pub fn transport(&self) -> &T {
        &self.transport
    }

    pub async fn request_message(&self, message: ClientMessage) -> Result<ServerResponse> {
        let response = self.transport.request(message).await?;
        expect_response(response)
    }

    pub async fn subscribe(&self, session_id: Option<SessionId>) -> Result<u64> {
        let response = self
            .request_message(ClientMessage::Subscribe(SubscribeRequest {
                request_id: self.next_request_id(),
                session_id,
            }))
            .await?;
        match response {
            ServerResponse::SubscriptionAck(response) => Ok(response.subscription_id),
            other => Err(MuxError::protocol(format!(
                "expected subscription ack response, got {other:?}"
            ))),
        }
    }

    pub async fn current_client(&self) -> Result<ClientRecord> {
        match self
            .request_message(ClientMessage::Client(ClientRequest::Get {
                request_id: self.next_request_id(),
                client_id: None,
            }))
            .await?
        {
            ServerResponse::Client(response) => {
                self.client_id.set(Some(response.client.id));
                Ok(response.client)
            }
            other => Err(MuxError::protocol(format!(
                "expected client response, got {other:?}"
            ))),
        }
    }

    pub async fn switch_current_session(&self, session_id: SessionId) -> Result<ClientRecord> {
        match self
            .request_message(ClientMessage::Client(ClientRequest::Switch {
                request_id: self.next_request_id(),
                client_id: None,
                session_id,
            }))
            .await?
        {
            ServerResponse::Client(response) => {
                self.client_id.set(Some(response.client.id));
                Ok(response.client)
            }
            other => Err(MuxError::protocol(format!(
                "expected client response, got {other:?}"
            ))),
        }
    }

    pub async fn process_next_event(&mut self) -> Result<ServerEvent> {
        let event = self.next_event().await?;
        self.handle_event(&event).await?;
        Ok(event)
    }

    async fn own_client_id(&self) -> Result<u64> {
        if let Some(client_id) = self.client_id.get() {
            Ok(client_id)
        } else {
            Ok(self.current_client().await?.id)
        }
    }

    pub async fn process_next_event_timeout(
        &mut self,
        timeout: std::time::Duration,
    ) -> Result<Option<ServerEvent>> {
        let event = match tokio::time::timeout(timeout, self.next_event()).await {
            Ok(result) => result?,
            Err(_) => return Ok(None),
        };
        self.handle_event(&event).await?;
        Ok(Some(event))
    }

    pub async fn next_event(&mut self) -> Result<ServerEvent> {
        self.transport.next_event().await
    }

    pub async fn handle_event(&mut self, event: &ServerEvent) -> Result<()> {
        self.state.apply_event(event);
        self.resync_for_event(event).await
    }

    pub async fn resync_session(&mut self, session_id: SessionId) -> Result<()> {
        let response = self
            .transport
            .request(ClientMessage::Session(SessionRequest::Get {
                request_id: self.next_request_id(),
                session_id,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::SessionSnapshot(response) => {
                self.state.apply_session_snapshot(response.snapshot);
                Ok(())
            }
            other => Err(MuxError::protocol(format!(
                "expected session snapshot response, got {other:?}"
            ))),
        }
    }

    pub async fn refresh_buffer_snapshot(&mut self, buffer_id: BufferId) -> Result<()> {
        let response = self
            .transport
            .request(ClientMessage::Buffer(BufferRequest::CaptureVisible {
                request_id: self.next_request_id(),
                buffer_id,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::VisibleSnapshot(snapshot) => {
                self.state.apply_buffer_snapshot(snapshot);
                Ok(())
            }
            other => Err(MuxError::protocol(format!(
                "expected visible snapshot response, got {other:?}"
            ))),
        }
    }

    pub async fn refresh_buffer_record(&mut self, buffer_id: BufferId) -> Result<()> {
        let response = self
            .transport
            .request(ClientMessage::Buffer(BufferRequest::Get {
                request_id: self.next_request_id(),
                buffer_id,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::Buffer(response) => {
                self.state.apply_buffer_record(response.buffer);
                Ok(())
            }
            other => Err(MuxError::protocol(format!(
                "expected buffer response, got {other:?}"
            ))),
        }
    }

    pub async fn capture_buffer(&self, buffer_id: BufferId) -> Result<SnapshotResponse> {
        let response = self
            .transport
            .request(ClientMessage::Buffer(BufferRequest::Capture {
                request_id: self.next_request_id(),
                buffer_id,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::Snapshot(snapshot) => Ok(snapshot),
            other => Err(MuxError::protocol(format!(
                "expected snapshot response, got {other:?}"
            ))),
        }
    }

    pub async fn capture_scrollback_slice(
        &self,
        buffer_id: BufferId,
        start_line: u64,
        line_count: u32,
    ) -> Result<ScrollbackSliceResponse> {
        let response = self
            .transport
            .request(ClientMessage::Buffer(BufferRequest::ScrollbackSlice {
                request_id: self.next_request_id(),
                buffer_id,
                start_line,
                line_count,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::ScrollbackSlice(snapshot) => Ok(snapshot),
            other => Err(MuxError::protocol(format!(
                "expected scrollback slice response, got {other:?}"
            ))),
        }
    }

    pub async fn resync_dirty_sessions(&mut self) -> Result<()> {
        let session_ids = self
            .state
            .dirty_sessions
            .iter()
            .copied()
            .collect::<Vec<_>>();
        for session_id in session_ids {
            self.resync_session(session_id).await?;
        }
        Ok(())
    }

    pub async fn resync_all_sessions(&mut self) -> Result<()> {
        let response = self
            .transport
            .request(ClientMessage::Session(SessionRequest::List {
                request_id: self.next_request_id(),
            }))
            .await?;

        let sessions = match expect_response(response)? {
            ServerResponse::Sessions(response) => response.sessions,
            other => {
                return Err(MuxError::protocol(format!(
                    "expected sessions response, got {other:?}"
                )));
            }
        };

        let live_sessions = sessions
            .iter()
            .map(|session| session.id)
            .collect::<BTreeSet<_>>();
        let known_sessions = self.state.sessions.keys().copied().collect::<Vec<_>>();

        for session_id in known_sessions {
            if !live_sessions.contains(&session_id) {
                self.state.remove_session(session_id);
            }
        }

        for session in sessions {
            self.state.dirty_sessions.insert(session.id);
            self.resync_session(session.id).await?;
        }

        self.resync_detached_buffers().await
    }

    async fn resync_for_event(&mut self, event: &ServerEvent) -> Result<()> {
        match event {
            ServerEvent::SessionCreated(event) => self.resync_session(event.session.id).await,
            ServerEvent::NodeChanged(event) => {
                self.resync_session(event.session_id).await?;
                self.resync_detached_buffers().await
            }
            ServerEvent::FloatingChanged(event) => {
                self.resync_session(event.session_id).await?;
                self.resync_detached_buffers().await
            }
            ServerEvent::SessionClosed(_) => self.resync_detached_buffers().await,
            ServerEvent::SessionRenamed(event) => {
                self.state
                    .apply_event(&ServerEvent::SessionRenamed(event.clone()));
                self.resync_session(event.session_id).await
            }
            ServerEvent::ClientChanged(event) => {
                if event.client.id != self.own_client_id().await? {
                    return Ok(());
                }
                if let Some(session_id) = event.client.current_session_id {
                    self.resync_session(session_id).await?;
                }
                Ok(())
            }
            ServerEvent::RenderInvalidated(event) => {
                self.refresh_buffer_record(event.buffer_id).await
            }
            ServerEvent::BufferPipeChanged(_) => Ok(()),
            ServerEvent::BufferCreated(_)
            | ServerEvent::BufferDetached(_)
            | ServerEvent::FocusChanged(_) => Ok(()),
        }
    }

    async fn resync_detached_buffers(&mut self) -> Result<()> {
        let response = self
            .transport
            .request(ClientMessage::Buffer(BufferRequest::List {
                request_id: self.next_request_id(),
                session_id: None,
                attached_only: false,
                detached_only: true,
            }))
            .await?;

        match expect_response(response)? {
            ServerResponse::Buffers(response) => {
                self.state.apply_detached_buffers(response.buffers);
                Ok(())
            }
            other => Err(MuxError::protocol(format!(
                "expected buffers response, got {other:?}"
            ))),
        }
    }
}

impl MuxClient<SocketTransport> {
    pub async fn connect(path: impl AsRef<Path>) -> Result<Self> {
        let transport = SocketTransport::connect(path).await?;
        Ok(Self::new(transport))
    }
}

fn expect_response(response: ServerResponse) -> Result<ServerResponse> {
    match response {
        ServerResponse::Error(error) => Err(error.error.into()),
        other => Ok(other),
    }
}