pondsocket-client 0.1.3

Async WebSocket client for PondSocket.
Documentation
use std::marker::PhantomData;
use std::time::Duration;

use pondsocket_common::{
    ChannelEvent, PondEvent, PondSchema, PresenceMessage, ServerMessage, from_pond_map,
    from_pond_value, to_pond_map,
};
use tokio::sync::broadcast;

use crate::{Channel, ClientError, PondClient};

#[derive(Clone)]
pub struct TypedChannel<S> {
    raw: Channel,
    _schema: PhantomData<S>,
}

impl<S> TypedChannel<S>
where
    S: PondSchema,
{
    pub fn new(raw: Channel) -> Self {
        Self {
            raw,
            _schema: PhantomData,
        }
    }

    pub fn raw(&self) -> &Channel {
        &self.raw
    }

    pub fn name(&self) -> &str {
        self.raw.name()
    }

    pub fn state(&self) -> pondsocket_common::ChannelState {
        self.raw.state()
    }

    pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
        self.raw.subscribe_events()
    }

    pub fn subscribe_state(&self) -> tokio::sync::watch::Receiver<pondsocket_common::ChannelState> {
        self.raw.subscribe_state()
    }

    pub async fn presence(&self) -> Result<Vec<S::Presence>, ClientError> {
        self.raw
            .presence()
            .await
            .into_iter()
            .map(from_pond_map)
            .collect::<serde_json::Result<Vec<_>>>()
            .map_err(ClientError::Serialization)
    }

    pub async fn join(&self) {
        self.raw.join().await;
    }

    pub async fn leave(&self) {
        self.raw.leave().await;
    }

    pub async fn send<E>(&self, payload: &E::Payload) -> Result<(), ClientError>
    where
        E: PondEvent,
    {
        self.raw
            .send_message(E::NAME, Some(to_pond_map(payload)?))
            .await;
        Ok(())
    }

    pub async fn request<E>(
        &self,
        payload: &E::Payload,
        timeout: Option<Duration>,
    ) -> Result<E::Response, ClientError>
    where
        E: PondEvent,
    {
        let response = self
            .raw
            .send_for_response(E::NAME, Some(to_pond_map(payload)?), timeout)
            .await?;
        from_pond_map(response).map_err(ClientError::Serialization)
    }

    pub fn decode_message<E>(
        &self,
        message: &ServerMessage,
    ) -> Result<Option<E::Payload>, ClientError>
    where
        E: PondEvent,
    {
        if message.event != E::NAME {
            return Ok(None);
        }
        from_pond_map(message.payload.clone())
            .map(Some)
            .map_err(ClientError::Serialization)
    }

    pub fn decode_presence(
        &self,
        message: &PresenceMessage,
    ) -> Result<(S::Presence, Vec<S::Presence>), ClientError> {
        let changed =
            from_pond_map(message.payload.changed.clone()).map_err(ClientError::Serialization)?;
        let presence = message
            .payload
            .presence
            .iter()
            .cloned()
            .map(from_pond_map)
            .collect::<serde_json::Result<Vec<_>>>()
            .map_err(ClientError::Serialization)?;
        Ok((changed, presence))
    }

    pub fn decode_event<E>(&self, event: ChannelEvent) -> Result<Option<E::Payload>, ClientError>
    where
        E: PondEvent,
    {
        match event {
            ChannelEvent::Message(message) => self.decode_message::<E>(&message),
            ChannelEvent::Presence(_) => Ok(None),
        }
    }
}

impl PondClient {
    pub async fn create_typed_channel<S>(
        &self,
        name: impl Into<String>,
        params: Option<&S::JoinParams>,
    ) -> Result<TypedChannel<S>, ClientError>
    where
        S: PondSchema,
    {
        let params = params.map(to_pond_map).transpose()?;
        let channel = self.create_channel(name, params).await;
        Ok(TypedChannel::new(channel))
    }
}

pub fn decode_payload<E>(message: ServerMessage) -> Result<E::Payload, ClientError>
where
    E: PondEvent,
{
    from_pond_map(message.payload).map_err(ClientError::Serialization)
}

pub fn decode_presence_value<S>(value: serde_json::Value) -> Result<S::Presence, ClientError>
where
    S: PondSchema,
{
    from_pond_value(value).map_err(ClientError::Serialization)
}

#[cfg(test)]
mod tests {
    use super::*;
    use pondsocket_common::{ChannelState, ServerAction};
    use serde::{Deserialize, Serialize};
    use serde_json::json;

    use crate::PondClient;

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    struct ChatPayload {
        text: String,
    }

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    struct AckPayload {
        ok: bool,
    }

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    struct Member {
        user_id: String,
    }

    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
    struct Role {
        role: String,
    }

    struct Chat;
    struct RoomSchema;

    impl PondEvent for Chat {
        type Payload = ChatPayload;
        type Response = AckPayload;

        const NAME: &'static str = "chat";
    }

    impl PondSchema for RoomSchema {
        type Presence = Member;
        type Assigns = Role;
        type JoinParams = Role;
    }

    fn server_message(event: &str, payload: serde_json::Value) -> ServerMessage {
        ServerMessage {
            action: ServerAction::Broadcast,
            event: event.to_owned(),
            channel_name: "room".to_owned(),
            request_id: "r1".to_owned(),
            payload: serde_json::from_value(payload).unwrap(),
        }
    }

    #[test]
    fn decode_payload_reads_typed_event() {
        let message = server_message("chat", json!({ "text": "hi" }));
        let payload = decode_payload::<Chat>(message).unwrap();
        assert_eq!(payload, ChatPayload { text: "hi".to_owned() });
    }

    #[test]
    fn decode_presence_value_reads_single_member() {
        let member = decode_presence_value::<RoomSchema>(json!({ "user_id": "u1" })).unwrap();
        assert_eq!(
            member,
            Member {
                user_id: "u1".to_owned()
            }
        );
    }

    #[tokio::test]
    async fn decode_event_decodes_matching_message_and_ignores_others() {
        let client = PondClient::new("ws://example.com/socket", None).unwrap();
        let channel = client
            .create_typed_channel::<RoomSchema>("room", None::<&Role>)
            .await
            .unwrap();

        let matching = ChannelEvent::Message(server_message("chat", json!({ "text": "hey" })));
        assert_eq!(
            channel.decode_event::<Chat>(matching).unwrap(),
            Some(ChatPayload {
                text: "hey".to_owned()
            })
        );

        let other = ChannelEvent::Message(server_message("other", json!({ "text": "hey" })));
        assert_eq!(channel.decode_event::<Chat>(other).unwrap(), None);
    }

    #[tokio::test]
    async fn subscribe_state_reports_current_channel_state() {
        let client = PondClient::new("ws://example.com/socket", None).unwrap();
        let channel = client
            .create_typed_channel::<RoomSchema>("room", None::<&Role>)
            .await
            .unwrap();

        let state = channel.subscribe_state();
        assert_eq!(*state.borrow(), ChannelState::Idle);
    }
}