Skip to main content

session_rs/
session.rs

1use std::hash::Hash;
2use std::{collections::HashMap, sync::Arc};
3
4use serde::{Deserialize, Serialize};
5use tokio::sync::Mutex;
6use tokio::sync::broadcast;
7use tokio::time::timeout;
8
9use crate::BoxFuture;
10use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
11
12#[derive(Debug, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase", tag = "type")]
14pub enum Message<M: Method> {
15    Request {
16        id: u32,
17        method: String,
18        data: M::Request,
19    },
20    Response {
21        id: u32,
22        result: M::Response,
23    },
24    ErrorResponse {
25        id: u32,
26        error: M::Error,
27    },
28    Notification {
29        method: String,
30        data: M::Request,
31    },
32}
33
34pub struct Session {
35    pub ws: WebSocket,
36    id: Arc<Mutex<u32>>,
37    methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
38    on_close_fn:
39        Arc<Mutex<Option<Box<dyn Fn() -> BoxFuture<'static, Result<(), String>> + Send + Sync>>>>,
40    tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
41    pong_tx: broadcast::Sender<()>,
42}
43
44impl Session {
45    pub fn clone(&self) -> Self {
46        Self {
47            ws: self.ws.clone(),
48            id: self.id.clone(),
49            methods: self.methods.clone(),
50            on_close_fn: self.on_close_fn.clone(),
51            tx: self.tx.clone(),
52            pong_tx: self.pong_tx.clone(),
53        }
54    }
55}
56
57impl Session {
58    pub fn from_ws(ws: WebSocket) -> Self {
59        let (tx, _) = broadcast::channel(8192);
60        let (pong_tx, _) = broadcast::channel(16);
61
62        Self {
63            ws,
64            id: Arc::new(Mutex::new(0)),
65            methods: Arc::new(Mutex::new(HashMap::new())),
66            on_close_fn: Arc::new(Mutex::new(None)),
67            tx,
68            pong_tx,
69        }
70    }
71
72    pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
73        Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
74    }
75}
76
77impl Session {
78    pub fn start_receiver(&self) {
79        let s = self.clone();
80        tokio::spawn(async move {
81            loop {
82                match s.ws.read().await {
83                    Ok(crate::ws::Frame::Text(text)) => {
84                        let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
85                            continue;
86                        };
87
88                        match msg {
89                            Message::Request { id, method, data } => {
90                                if let Some(m) = s.methods.lock().await.get(&method) {
91                                    if let Some((err, res)) = (m)(id, data).await {
92                                        if err {
93                                            s.respond_error(id, res)
94                                                .await
95                                                .expect("Failed to respond");
96                                        } else {
97                                            s.respond(id, res).await.expect("Failed to respond");
98                                        }
99                                    }
100                                }
101                            }
102                            Message::Response { id, result } => {
103                                s.tx.send((id, false, result)).unwrap();
104                            }
105                            Message::ErrorResponse { id, error } => {
106                                s.tx.send((id, true, error)).unwrap();
107                            }
108                            _ => {}
109                        }
110                    }
111                    Ok(crate::ws::Frame::Pong) => {
112                        let _ = s.pong_tx.send(());
113                    }
114                    Ok(_) => {}
115                    Err(_) => {
116                        s.trigger_close().await;
117                        break;
118                    }
119                }
120            }
121        });
122    }
123    pub fn start_ping(&self, interval: tokio::time::Duration, timeout_dur: tokio::time::Duration) {
124        let s = self.clone();
125
126        tokio::spawn(async move {
127            let mut pong_rx = s.pong_tx.subscribe();
128
129            loop {
130                tokio::time::sleep(interval).await;
131
132                if s.ws.send_ping().await.is_err() {
133                    s.trigger_close().await;
134                    break;
135                }
136
137                let result = timeout(timeout_dur, pong_rx.recv()).await;
138
139                if result.is_err() {
140                    // timeout expired
141                    let _ = s.close().await;
142                    s.trigger_close().await;
143                    break;
144                }
145            }
146        });
147    }
148
149    pub async fn on_request<
150        M: Method,
151        Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
152    >(
153        &self,
154        handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
155    ) {
156        let handler = Arc::new(handler);
157
158        self.methods.lock().await.insert(
159            M::NAME.to_string(),
160            Box::new(move |id, value| {
161                let handler = Arc::clone(&handler);
162
163                Box::pin(async move {
164                    Some(
165                        match handler(id, serde_json::from_value(value).ok()?).await {
166                            Ok(v) => (false, serde_json::to_value(v).ok()?),
167                            Err(v) => (true, serde_json::to_value(v).ok()?),
168                        },
169                    )
170                })
171            }),
172        );
173    }
174
175    pub async fn on_close<Fut>(&self, handler: impl Fn() -> Fut + Send + Sync + 'static)
176    where
177        Fut: Future<Output = Result<(), String>> + Send + 'static,
178    {
179        let handler = Arc::new(handler);
180
181        *self.on_close_fn.lock().await = Some(Box::new(move || {
182            let handler = handler.clone();
183            Box::pin(async move { handler().await })
184        }));
185    }
186}
187
188impl Session {
189    pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
190        self.ws
191            .send_text_payload(&serde_json::to_vec(&data)?)
192            .await?;
193        Ok(())
194    }
195
196    pub async fn use_id(&self) -> u32 {
197        let mut id = self.id.lock().await;
198        *id += 1;
199        *id
200    }
201
202    pub async fn request<M: Method>(
203        &self,
204        req: M::Request,
205    ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
206        let id = self.use_id().await;
207
208        self.send::<M>(&Message::Request {
209            id,
210            method: M::NAME.to_string(),
211            data: req,
212        })
213        .await?;
214
215        let mut rx = self.tx.subscribe();
216
217        loop {
218            let r = rx.recv().await?;
219
220            if r.0 == id {
221                break Ok(if r.1 {
222                    Err(serde_json::from_value(r.2)?)
223                } else {
224                    Ok(serde_json::from_value(r.2)?)
225                });
226            }
227        }
228    }
229
230    pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
231        self.send::<GenericMethod>(&Message::Response {
232            id: to,
233            result: val,
234        })
235        .await
236    }
237
238    pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
239        self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
240            .await
241    }
242
243    pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
244        self.send::<M>(&Message::Notification {
245            method: M::NAME.to_string(),
246            data,
247        })
248        .await
249    }
250
251    async fn trigger_close(&self) {
252        if let Some(handler) = self.on_close_fn.lock().await.as_ref() {
253            let _ = handler().await;
254        }
255    }
256
257    pub async fn close(&self) -> crate::Result<()> {
258        let res = self.ws.close().await;
259        self.trigger_close().await;
260        Ok(res?)
261    }
262}
263
264impl Hash for Session {
265    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
266        self.ws.id.hash(state);
267    }
268}
269
270impl PartialEq for Session {
271    fn eq(&self, other: &Self) -> bool {
272        self.ws.id == other.ws.id
273    }
274}
275
276impl Eq for Session {}