Skip to main content

pondsocket_client/
typed.rs

1use std::marker::PhantomData;
2use std::time::Duration;
3
4use pondsocket_common::{
5    ChannelEvent, PondEvent, PondSchema, PresenceMessage, ServerMessage, from_pond_map,
6    from_pond_value, to_pond_map,
7};
8use tokio::sync::broadcast;
9
10use crate::{Channel, ClientError, PondClient};
11
12#[derive(Clone)]
13pub struct TypedChannel<S> {
14    raw: Channel,
15    _schema: PhantomData<S>,
16}
17
18impl<S> TypedChannel<S>
19where
20    S: PondSchema,
21{
22    pub fn new(raw: Channel) -> Self {
23        Self {
24            raw,
25            _schema: PhantomData,
26        }
27    }
28
29    pub fn raw(&self) -> &Channel {
30        &self.raw
31    }
32
33    pub fn name(&self) -> &str {
34        self.raw.name()
35    }
36
37    pub fn state(&self) -> pondsocket_common::ChannelState {
38        self.raw.state()
39    }
40
41    pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
42        self.raw.subscribe_events()
43    }
44
45    pub fn subscribe_state(&self) -> tokio::sync::watch::Receiver<pondsocket_common::ChannelState> {
46        self.raw.subscribe_state()
47    }
48
49    pub async fn presence(&self) -> Result<Vec<S::Presence>, ClientError> {
50        self.raw
51            .presence()
52            .await
53            .into_iter()
54            .map(from_pond_map)
55            .collect::<serde_json::Result<Vec<_>>>()
56            .map_err(ClientError::Serialization)
57    }
58
59    pub async fn join(&self) {
60        self.raw.join().await;
61    }
62
63    pub async fn leave(&self) {
64        self.raw.leave().await;
65    }
66
67    pub async fn send<E>(&self, payload: &E::Payload) -> Result<(), ClientError>
68    where
69        E: PondEvent,
70    {
71        self.raw
72            .send_message(E::NAME, Some(to_pond_map(payload)?))
73            .await;
74        Ok(())
75    }
76
77    pub async fn request<E>(
78        &self,
79        payload: &E::Payload,
80        timeout: Option<Duration>,
81    ) -> Result<E::Response, ClientError>
82    where
83        E: PondEvent,
84    {
85        let response = self
86            .raw
87            .send_for_response(E::NAME, Some(to_pond_map(payload)?), timeout)
88            .await?;
89        from_pond_map(response).map_err(ClientError::Serialization)
90    }
91
92    pub fn decode_message<E>(
93        &self,
94        message: &ServerMessage,
95    ) -> Result<Option<E::Payload>, ClientError>
96    where
97        E: PondEvent,
98    {
99        if message.event != E::NAME {
100            return Ok(None);
101        }
102        from_pond_map(message.payload.clone())
103            .map(Some)
104            .map_err(ClientError::Serialization)
105    }
106
107    pub fn decode_presence(
108        &self,
109        message: &PresenceMessage,
110    ) -> Result<(S::Presence, Vec<S::Presence>), ClientError> {
111        let changed =
112            from_pond_map(message.payload.changed.clone()).map_err(ClientError::Serialization)?;
113        let presence = message
114            .payload
115            .presence
116            .iter()
117            .cloned()
118            .map(from_pond_map)
119            .collect::<serde_json::Result<Vec<_>>>()
120            .map_err(ClientError::Serialization)?;
121        Ok((changed, presence))
122    }
123
124    pub fn decode_event<E>(&self, event: ChannelEvent) -> Result<Option<E::Payload>, ClientError>
125    where
126        E: PondEvent,
127    {
128        match event {
129            ChannelEvent::Message(message) => self.decode_message::<E>(&message),
130            ChannelEvent::Presence(_) => Ok(None),
131        }
132    }
133}
134
135impl PondClient {
136    pub async fn create_typed_channel<S>(
137        &self,
138        name: impl Into<String>,
139        params: Option<&S::JoinParams>,
140    ) -> Result<TypedChannel<S>, ClientError>
141    where
142        S: PondSchema,
143    {
144        let params = params.map(to_pond_map).transpose()?;
145        let channel = self.create_channel(name, params).await;
146        Ok(TypedChannel::new(channel))
147    }
148}
149
150pub fn decode_payload<E>(message: ServerMessage) -> Result<E::Payload, ClientError>
151where
152    E: PondEvent,
153{
154    from_pond_map(message.payload).map_err(ClientError::Serialization)
155}
156
157pub fn decode_presence_value<S>(value: serde_json::Value) -> Result<S::Presence, ClientError>
158where
159    S: PondSchema,
160{
161    from_pond_value(value).map_err(ClientError::Serialization)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use pondsocket_common::{ChannelState, ServerAction};
168    use serde::{Deserialize, Serialize};
169    use serde_json::json;
170
171    use crate::PondClient;
172
173    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
174    struct ChatPayload {
175        text: String,
176    }
177
178    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179    struct AckPayload {
180        ok: bool,
181    }
182
183    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
184    struct Member {
185        user_id: String,
186    }
187
188    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189    struct Role {
190        role: String,
191    }
192
193    struct Chat;
194    struct RoomSchema;
195
196    impl PondEvent for Chat {
197        type Payload = ChatPayload;
198        type Response = AckPayload;
199
200        const NAME: &'static str = "chat";
201    }
202
203    impl PondSchema for RoomSchema {
204        type Presence = Member;
205        type Assigns = Role;
206        type JoinParams = Role;
207    }
208
209    fn server_message(event: &str, payload: serde_json::Value) -> ServerMessage {
210        ServerMessage {
211            action: ServerAction::Broadcast,
212            event: event.to_owned(),
213            channel_name: "room".to_owned(),
214            request_id: "r1".to_owned(),
215            payload: serde_json::from_value(payload).unwrap(),
216        }
217    }
218
219    #[test]
220    fn decode_payload_reads_typed_event() {
221        let message = server_message("chat", json!({ "text": "hi" }));
222        let payload = decode_payload::<Chat>(message).unwrap();
223        assert_eq!(payload, ChatPayload { text: "hi".to_owned() });
224    }
225
226    #[test]
227    fn decode_presence_value_reads_single_member() {
228        let member = decode_presence_value::<RoomSchema>(json!({ "user_id": "u1" })).unwrap();
229        assert_eq!(
230            member,
231            Member {
232                user_id: "u1".to_owned()
233            }
234        );
235    }
236
237    #[tokio::test]
238    async fn decode_event_decodes_matching_message_and_ignores_others() {
239        let client = PondClient::new("ws://example.com/socket", None).unwrap();
240        let channel = client
241            .create_typed_channel::<RoomSchema>("room", None::<&Role>)
242            .await
243            .unwrap();
244
245        let matching = ChannelEvent::Message(server_message("chat", json!({ "text": "hey" })));
246        assert_eq!(
247            channel.decode_event::<Chat>(matching).unwrap(),
248            Some(ChatPayload {
249                text: "hey".to_owned()
250            })
251        );
252
253        let other = ChannelEvent::Message(server_message("other", json!({ "text": "hey" })));
254        assert_eq!(channel.decode_event::<Chat>(other).unwrap(), None);
255    }
256
257    #[tokio::test]
258    async fn subscribe_state_reports_current_channel_state() {
259        let client = PondClient::new("ws://example.com/socket", None).unwrap();
260        let channel = client
261            .create_typed_channel::<RoomSchema>("room", None::<&Role>)
262            .await
263            .unwrap();
264
265        let state = channel.subscribe_state();
266        assert_eq!(*state.borrow(), ChannelState::Idle);
267    }
268}