1use std::{collections::HashMap, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use tokio::sync::Mutex;
5use tokio::sync::broadcast;
6use tokio::time::timeout;
7
8use crate::BoxFuture;
9use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
10
11#[derive(Debug, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase", tag = "type")]
13pub enum Message<M: Method> {
14 Request {
15 id: u32,
16 method: String,
17 data: M::Request,
18 },
19 Response {
20 id: u32,
21 result: M::Response,
22 },
23 ErrorResponse {
24 id: u32,
25 error: M::Error,
26 },
27 Notification {
28 method: String,
29 data: M::Request,
30 },
31}
32
33pub struct Session {
34 pub ws: WebSocket,
35 id: Arc<Mutex<u32>>,
36 methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
37 on_close_fn:
38 Arc<Mutex<Option<Box<dyn Fn() -> BoxFuture<'static, Result<(), String>> + Send + Sync>>>>,
39 tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
40 pong_tx: broadcast::Sender<()>,
41}
42
43impl Session {
44 pub fn clone(&self) -> Self {
45 Self {
46 ws: self.ws.clone(),
47 id: self.id.clone(),
48 methods: self.methods.clone(),
49 on_close_fn: self.on_close_fn.clone(),
50 tx: self.tx.clone(),
51 pong_tx: self.pong_tx.clone(),
52 }
53 }
54}
55
56impl Session {
57 pub fn from_ws(ws: WebSocket) -> Self {
58 let (tx, _) = broadcast::channel(8192);
59 let (pong_tx, _) = broadcast::channel(16);
60
61 Self {
62 ws,
63 id: Arc::new(Mutex::new(0)),
64 methods: Arc::new(Mutex::new(HashMap::new())),
65 on_close_fn: Arc::new(Mutex::new(None)),
66 tx,
67 pong_tx,
68 }
69 }
70
71 pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
72 Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
73 }
74}
75
76impl Session {
77 pub fn start_receiver(&self) {
78 let s = self.clone();
79 tokio::spawn(async move {
80 loop {
81 match s.ws.read().await {
82 Ok(crate::ws::Frame::Text(text)) => {
83 let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
84 continue;
85 };
86
87 match msg {
88 Message::Request { id, method, data } => {
89 if let Some(m) = s.methods.lock().await.get(&method) {
90 if let Some((err, res)) = (m)(id, data).await {
91 if err {
92 s.respond_error(id, res)
93 .await
94 .expect("Failed to respond");
95 } else {
96 s.respond(id, res).await.expect("Failed to respond");
97 }
98 }
99 }
100 }
101 Message::Response { id, result } => {
102 s.tx.send((id, false, result)).unwrap();
103 }
104 Message::ErrorResponse { id, error } => {
105 s.tx.send((id, true, error)).unwrap();
106 }
107 _ => {}
108 }
109 }
110 Ok(crate::ws::Frame::Pong) => {
111 let _ = s.pong_tx.send(());
112 }
113 Ok(_) => {}
114 Err(_) => {
115 s.trigger_close().await;
116 break;
117 }
118 }
119 }
120 });
121 }
122 pub fn start_ping(&self, interval: tokio::time::Duration, timeout_dur: tokio::time::Duration) {
123 let s = self.clone();
124
125 tokio::spawn(async move {
126 let mut pong_rx = s.pong_tx.subscribe();
127
128 loop {
129 tokio::time::sleep(interval).await;
130
131 if s.ws.send_ping().await.is_err() {
132 s.trigger_close().await;
133 break;
134 }
135
136 let result = timeout(timeout_dur, pong_rx.recv()).await;
137
138 if result.is_err() {
139 let _ = s.close().await;
141 s.trigger_close().await;
142 break;
143 }
144 }
145 });
146 }
147
148 pub async fn on_request<
149 M: Method,
150 Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
151 >(
152 &self,
153 handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
154 ) {
155 let handler = Arc::new(handler);
156
157 self.methods.lock().await.insert(
158 M::NAME.to_string(),
159 Box::new(move |id, value| {
160 let handler = Arc::clone(&handler);
161
162 Box::pin(async move {
163 Some(
164 match handler(id, serde_json::from_value(value).ok()?).await {
165 Ok(v) => (false, serde_json::to_value(v).ok()?),
166 Err(v) => (true, serde_json::to_value(v).ok()?),
167 },
168 )
169 })
170 }),
171 );
172 }
173
174 pub async fn on_close<Fut>(&self, handler: impl Fn() -> Fut + Send + Sync + 'static)
175 where
176 Fut: Future<Output = Result<(), String>> + Send + 'static,
177 {
178 let handler = Arc::new(handler);
179
180 *self.on_close_fn.lock().await = Some(Box::new(move || {
181 let handler = handler.clone();
182 Box::pin(async move { handler().await })
183 }));
184 }
185}
186
187impl Session {
188 pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
189 self.ws
190 .send_text_payload(&serde_json::to_vec(&data)?)
191 .await?;
192 Ok(())
193 }
194
195 pub async fn use_id(&self) -> u32 {
196 let mut id = self.id.lock().await;
197 *id += 1;
198 *id
199 }
200
201 pub async fn request<M: Method>(
202 &self,
203 req: M::Request,
204 ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
205 let id = self.use_id().await;
206
207 self.send::<M>(&Message::Request {
208 id,
209 method: M::NAME.to_string(),
210 data: req,
211 })
212 .await?;
213
214 let mut rx = self.tx.subscribe();
215
216 loop {
217 let r = rx.recv().await?;
218
219 if r.0 == id {
220 break Ok(if r.1 {
221 Err(serde_json::from_value(r.2)?)
222 } else {
223 Ok(serde_json::from_value(r.2)?)
224 });
225 }
226 }
227 }
228
229 pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
230 self.send::<GenericMethod>(&Message::Response {
231 id: to,
232 result: val,
233 })
234 .await
235 }
236
237 pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
238 self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
239 .await
240 }
241
242 pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
243 self.send::<M>(&Message::Notification {
244 method: M::NAME.to_string(),
245 data,
246 })
247 .await
248 }
249
250 async fn trigger_close(&self) {
251 if let Some(handler) = self.on_close_fn.lock().await.as_ref() {
252 let _ = handler().await;
253 }
254 }
255
256 pub async fn close(&self) -> crate::Result<()> {
257 let res = self.ws.close().await;
258 self.trigger_close().await;
259 Ok(res?)
260 }
261}