1use std::marker::PhantomData;
2use std::time::Duration;
3
4use pondsocket_common::{
5 ChannelEvent, PondEvent, PondSchema, PresenceMessage, ServerMessage, from_pond_map,
6 from_pond_value, to_pond_map,
7};
8use tokio::sync::broadcast;
9
10use crate::{Channel, ClientError, PondClient};
11
12#[derive(Clone)]
13pub struct TypedChannel<S> {
14 raw: Channel,
15 _schema: PhantomData<S>,
16}
17
18impl<S> TypedChannel<S>
19where
20 S: PondSchema,
21{
22 pub fn new(raw: Channel) -> Self {
23 Self {
24 raw,
25 _schema: PhantomData,
26 }
27 }
28
29 pub fn raw(&self) -> &Channel {
30 &self.raw
31 }
32
33 pub fn name(&self) -> &str {
34 self.raw.name()
35 }
36
37 pub fn state(&self) -> pondsocket_common::ChannelState {
38 self.raw.state()
39 }
40
41 pub fn subscribe_events(&self) -> broadcast::Receiver<ChannelEvent> {
42 self.raw.subscribe_events()
43 }
44
45 pub fn subscribe_state(&self) -> tokio::sync::watch::Receiver<pondsocket_common::ChannelState> {
46 self.raw.subscribe_state()
47 }
48
49 pub async fn presence(&self) -> Result<Vec<S::Presence>, ClientError> {
50 self.raw
51 .presence()
52 .await
53 .into_iter()
54 .map(from_pond_map)
55 .collect::<serde_json::Result<Vec<_>>>()
56 .map_err(ClientError::Serialization)
57 }
58
59 pub async fn join(&self) {
60 self.raw.join().await;
61 }
62
63 pub async fn leave(&self) {
64 self.raw.leave().await;
65 }
66
67 pub async fn send<E>(&self, payload: &E::Payload) -> Result<(), ClientError>
68 where
69 E: PondEvent,
70 {
71 self.raw
72 .send_message(E::NAME, Some(to_pond_map(payload)?))
73 .await;
74 Ok(())
75 }
76
77 pub async fn request<E>(
78 &self,
79 payload: &E::Payload,
80 timeout: Option<Duration>,
81 ) -> Result<E::Response, ClientError>
82 where
83 E: PondEvent,
84 {
85 let response = self
86 .raw
87 .send_for_response(E::NAME, Some(to_pond_map(payload)?), timeout)
88 .await?;
89 from_pond_map(response).map_err(ClientError::Serialization)
90 }
91
92 pub fn decode_message<E>(
93 &self,
94 message: &ServerMessage,
95 ) -> Result<Option<E::Payload>, ClientError>
96 where
97 E: PondEvent,
98 {
99 if message.event != E::NAME {
100 return Ok(None);
101 }
102 from_pond_map(message.payload.clone())
103 .map(Some)
104 .map_err(ClientError::Serialization)
105 }
106
107 pub fn decode_presence(
108 &self,
109 message: &PresenceMessage,
110 ) -> Result<(S::Presence, Vec<S::Presence>), ClientError> {
111 let changed =
112 from_pond_map(message.payload.changed.clone()).map_err(ClientError::Serialization)?;
113 let presence = message
114 .payload
115 .presence
116 .iter()
117 .cloned()
118 .map(from_pond_map)
119 .collect::<serde_json::Result<Vec<_>>>()
120 .map_err(ClientError::Serialization)?;
121 Ok((changed, presence))
122 }
123
124 pub fn decode_event<E>(&self, event: ChannelEvent) -> Result<Option<E::Payload>, ClientError>
125 where
126 E: PondEvent,
127 {
128 match event {
129 ChannelEvent::Message(message) => self.decode_message::<E>(&message),
130 ChannelEvent::Presence(_) => Ok(None),
131 }
132 }
133}
134
135impl PondClient {
136 pub async fn create_typed_channel<S>(
137 &self,
138 name: impl Into<String>,
139 params: Option<&S::JoinParams>,
140 ) -> Result<TypedChannel<S>, ClientError>
141 where
142 S: PondSchema,
143 {
144 let params = params.map(to_pond_map).transpose()?;
145 let channel = self.create_channel(name, params).await;
146 Ok(TypedChannel::new(channel))
147 }
148}
149
150pub fn decode_payload<E>(message: ServerMessage) -> Result<E::Payload, ClientError>
151where
152 E: PondEvent,
153{
154 from_pond_map(message.payload).map_err(ClientError::Serialization)
155}
156
157pub fn decode_presence_value<S>(value: serde_json::Value) -> Result<S::Presence, ClientError>
158where
159 S: PondSchema,
160{
161 from_pond_value(value).map_err(ClientError::Serialization)
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use pondsocket_common::{ChannelState, ServerAction};
168 use serde::{Deserialize, Serialize};
169 use serde_json::json;
170
171 use crate::PondClient;
172
173 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
174 struct ChatPayload {
175 text: String,
176 }
177
178 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
179 struct AckPayload {
180 ok: bool,
181 }
182
183 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
184 struct Member {
185 user_id: String,
186 }
187
188 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
189 struct Role {
190 role: String,
191 }
192
193 struct Chat;
194 struct RoomSchema;
195
196 impl PondEvent for Chat {
197 type Payload = ChatPayload;
198 type Response = AckPayload;
199
200 const NAME: &'static str = "chat";
201 }
202
203 impl PondSchema for RoomSchema {
204 type Presence = Member;
205 type Assigns = Role;
206 type JoinParams = Role;
207 }
208
209 fn server_message(event: &str, payload: serde_json::Value) -> ServerMessage {
210 ServerMessage {
211 action: ServerAction::Broadcast,
212 event: event.to_owned(),
213 channel_name: "room".to_owned(),
214 request_id: "r1".to_owned(),
215 payload: serde_json::from_value(payload).unwrap(),
216 }
217 }
218
219 #[test]
220 fn decode_payload_reads_typed_event() {
221 let message = server_message("chat", json!({ "text": "hi" }));
222 let payload = decode_payload::<Chat>(message).unwrap();
223 assert_eq!(payload, ChatPayload { text: "hi".to_owned() });
224 }
225
226 #[test]
227 fn decode_presence_value_reads_single_member() {
228 let member = decode_presence_value::<RoomSchema>(json!({ "user_id": "u1" })).unwrap();
229 assert_eq!(
230 member,
231 Member {
232 user_id: "u1".to_owned()
233 }
234 );
235 }
236
237 #[tokio::test]
238 async fn decode_event_decodes_matching_message_and_ignores_others() {
239 let client = PondClient::new("ws://example.com/socket", None).unwrap();
240 let channel = client
241 .create_typed_channel::<RoomSchema>("room", None::<&Role>)
242 .await
243 .unwrap();
244
245 let matching = ChannelEvent::Message(server_message("chat", json!({ "text": "hey" })));
246 assert_eq!(
247 channel.decode_event::<Chat>(matching).unwrap(),
248 Some(ChatPayload {
249 text: "hey".to_owned()
250 })
251 );
252
253 let other = ChannelEvent::Message(server_message("other", json!({ "text": "hey" })));
254 assert_eq!(channel.decode_event::<Chat>(other).unwrap(), None);
255 }
256
257 #[tokio::test]
258 async fn subscribe_state_reports_current_channel_state() {
259 let client = PondClient::new("ws://example.com/socket", None).unwrap();
260 let channel = client
261 .create_typed_channel::<RoomSchema>("room", None::<&Role>)
262 .await
263 .unwrap();
264
265 let state = channel.subscribe_state();
266 assert_eq!(*state.borrow(), ChannelState::Idle);
267 }
268}