Skip to main content

session_rs/
session.rs

1use std::{collections::HashMap, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use tokio::sync::Mutex;
5use tokio::sync::broadcast;
6use tokio::time::timeout;
7
8use crate::BoxFuture;
9use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
10
11#[derive(Debug, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase", tag = "type")]
13pub enum Message<M: Method> {
14    Request {
15        id: u32,
16        method: String,
17        data: M::Request,
18    },
19    Response {
20        id: u32,
21        result: M::Response,
22    },
23    ErrorResponse {
24        id: u32,
25        error: M::Error,
26    },
27    Notification {
28        method: String,
29        data: M::Request,
30    },
31}
32
33pub struct Session {
34    pub ws: WebSocket,
35    id: Arc<Mutex<u32>>,
36    methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
37    on_close_fn:
38        Arc<Mutex<Option<Box<dyn Fn() -> BoxFuture<'static, Result<(), String>> + Send + Sync>>>>,
39    tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
40    pong_tx: broadcast::Sender<()>,
41}
42
43impl Session {
44    pub fn clone(&self) -> Self {
45        Self {
46            ws: self.ws.clone(),
47            id: self.id.clone(),
48            methods: self.methods.clone(),
49            on_close_fn: self.on_close_fn.clone(),
50            tx: self.tx.clone(),
51            pong_tx: self.pong_tx.clone(),
52        }
53    }
54}
55
56impl Session {
57    pub fn from_ws(ws: WebSocket) -> Self {
58        let (tx, _) = broadcast::channel(8192);
59        let (pong_tx, _) = broadcast::channel(16);
60
61        Self {
62            ws,
63            id: Arc::new(Mutex::new(0)),
64            methods: Arc::new(Mutex::new(HashMap::new())),
65            on_close_fn: Arc::new(Mutex::new(None)),
66            tx,
67            pong_tx,
68        }
69    }
70
71    pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
72        Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
73    }
74}
75
76impl Session {
77    pub fn start_receiver(&self) {
78        let s = self.clone();
79        tokio::spawn(async move {
80            loop {
81                match s.ws.read().await {
82                    Ok(crate::ws::Frame::Text(text)) => {
83                        let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
84                            continue;
85                        };
86
87                        match msg {
88                            Message::Request { id, method, data } => {
89                                if let Some(m) = s.methods.lock().await.get(&method) {
90                                    if let Some((err, res)) = (m)(id, data).await {
91                                        if err {
92                                            s.respond_error(id, res)
93                                                .await
94                                                .expect("Failed to respond");
95                                        } else {
96                                            s.respond(id, res).await.expect("Failed to respond");
97                                        }
98                                    }
99                                }
100                            }
101                            Message::Response { id, result } => {
102                                s.tx.send((id, false, result)).unwrap();
103                            }
104                            Message::ErrorResponse { id, error } => {
105                                s.tx.send((id, true, error)).unwrap();
106                            }
107                            _ => {}
108                        }
109                    }
110                    Ok(crate::ws::Frame::Pong) => {
111                        let _ = s.pong_tx.send(());
112                    }
113                    Ok(_) => {}
114                    Err(_) => {
115                        s.trigger_close().await;
116                        break;
117                    }
118                }
119            }
120        });
121    }
122    pub fn start_ping(&self, interval: tokio::time::Duration, timeout_dur: tokio::time::Duration) {
123        let s = self.clone();
124
125        tokio::spawn(async move {
126            let mut pong_rx = s.pong_tx.subscribe();
127
128            loop {
129                tokio::time::sleep(interval).await;
130
131                if s.ws.send_ping().await.is_err() {
132                    s.trigger_close().await;
133                    break;
134                }
135
136                let result = timeout(timeout_dur, pong_rx.recv()).await;
137
138                if result.is_err() {
139                    // timeout expired
140                    let _ = s.close().await;
141                    s.trigger_close().await;
142                    break;
143                }
144            }
145        });
146    }
147
148    pub async fn on_request<
149        M: Method,
150        Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
151    >(
152        &self,
153        handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
154    ) {
155        let handler = Arc::new(handler);
156
157        self.methods.lock().await.insert(
158            M::NAME.to_string(),
159            Box::new(move |id, value| {
160                let handler = Arc::clone(&handler);
161
162                Box::pin(async move {
163                    Some(
164                        match handler(id, serde_json::from_value(value).ok()?).await {
165                            Ok(v) => (false, serde_json::to_value(v).ok()?),
166                            Err(v) => (true, serde_json::to_value(v).ok()?),
167                        },
168                    )
169                })
170            }),
171        );
172    }
173
174    pub async fn on_close<Fut>(&self, handler: impl Fn() -> Fut + Send + Sync + 'static)
175    where
176        Fut: Future<Output = Result<(), String>> + Send + 'static,
177    {
178        let handler = Arc::new(handler);
179
180        *self.on_close_fn.lock().await = Some(Box::new(move || {
181            let handler = handler.clone();
182            Box::pin(async move { handler().await })
183        }));
184    }
185}
186
187impl Session {
188    pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
189        self.ws
190            .send_text_payload(&serde_json::to_vec(&data)?)
191            .await?;
192        Ok(())
193    }
194
195    pub async fn use_id(&self) -> u32 {
196        let mut id = self.id.lock().await;
197        *id += 1;
198        *id
199    }
200
201    pub async fn request<M: Method>(
202        &self,
203        req: M::Request,
204    ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
205        let id = self.use_id().await;
206
207        self.send::<M>(&Message::Request {
208            id,
209            method: M::NAME.to_string(),
210            data: req,
211        })
212        .await?;
213
214        let mut rx = self.tx.subscribe();
215
216        loop {
217            let r = rx.recv().await?;
218
219            if r.0 == id {
220                break Ok(if r.1 {
221                    Err(serde_json::from_value(r.2)?)
222                } else {
223                    Ok(serde_json::from_value(r.2)?)
224                });
225            }
226        }
227    }
228
229    pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
230        self.send::<GenericMethod>(&Message::Response {
231            id: to,
232            result: val,
233        })
234        .await
235    }
236
237    pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
238        self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
239            .await
240    }
241
242    pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
243        self.send::<M>(&Message::Notification {
244            method: M::NAME.to_string(),
245            data,
246        })
247        .await
248    }
249
250    async fn trigger_close(&self) {
251        if let Some(handler) = self.on_close_fn.lock().await.as_ref() {
252            let _ = handler().await;
253        }
254    }
255
256    pub async fn close(&self) -> crate::Result<()> {
257        let res = self.ws.close().await;
258        self.trigger_close().await;
259        Ok(res?)
260    }
261}