1use std::{collections::HashMap, sync::Arc};
2
3use serde::{Deserialize, Serialize};
4use tokio::sync::Mutex;
5use tokio::sync::broadcast;
6
7use crate::{GenericMethod, Method, MethodHandler, ws::WebSocket};
8
9#[derive(Debug, Serialize, Deserialize)]
10#[serde(rename_all = "lowercase", tag = "type")]
11pub enum Message<M: Method> {
12 Request {
13 id: u32,
14 method: String,
15 data: M::Request,
16 },
17 Response {
18 id: u32,
19 result: M::Response,
20 },
21 ErrorResponse {
22 id: u32,
23 error: M::Error,
24 },
25 Notification {
26 method: String,
27 data: M::Request,
28 },
29}
30
31pub struct Session {
32 pub ws: WebSocket,
33 id: Arc<Mutex<u32>>,
34 methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
35 tx: broadcast::Sender<(u32, bool, serde_json::Value)>,
36}
37
38impl Session {
39 pub fn clone(&self) -> Self {
40 Self {
41 ws: self.ws.clone(),
42 id: self.id.clone(),
43 methods: self.methods.clone(),
44 tx: self.tx.clone(),
45 }
46 }
47}
48
49impl Session {
50 pub fn from_ws(ws: WebSocket) -> Self {
51 Self {
52 ws,
53 id: Arc::new(Mutex::new(0)),
54 methods: Arc::new(Mutex::new(HashMap::new())),
55 tx: broadcast::channel(8192).0,
56 }
57 }
58
59 pub async fn connect(addr: &str, path: &str) -> crate::Result<Self> {
60 Ok(Self::from_ws(WebSocket::connect(addr, path).await?))
61 }
62}
63
64impl Session {
65 pub fn start_receiver(&self) {
66 let s = self.clone();
67 tokio::spawn(async move {
68 loop {
69 match s.ws.read().await {
70 Ok(crate::ws::Frame::Text(text)) => {
71 let Ok(msg) = serde_json::from_str::<Message<GenericMethod>>(&text) else {
72 continue;
73 };
74
75 match msg {
76 Message::Request { id, method, data } => {
77 if let Some(m) = s.methods.lock().await.get(&method) {
78 if let Some((err, res)) = (m)(id, data).await {
79 if err {
80 s.respond_error(id, res)
81 .await
82 .expect("Failed to respond");
83 } else {
84 s.respond(id, res).await.expect("Failed to respond");
85 }
86 }
87 }
88 }
89 Message::Response { id, result } => {
90 s.tx.send((id, false, result)).unwrap();
91 }
92 Message::ErrorResponse { id, error } => {
93 s.tx.send((id, true, error)).unwrap();
94 }
95 _ => {}
96 }
97 }
98 Ok(_) => {}
99 Err(_) => break,
100 }
101 }
102 });
103 }
104
105 pub async fn on<
106 M: Method,
107 Fut: Future<Output = Result<M::Response, M::Error>> + Send + 'static,
108 >(
109 &self,
110 handler: impl Fn(u32, M::Request) -> Fut + Send + Sync + 'static,
111 ) {
112 let handler = Arc::new(handler);
113
114 self.methods.lock().await.insert(
115 M::NAME.to_string(),
116 Box::new(move |id, value| {
117 let handler = Arc::clone(&handler);
118
119 Box::pin(async move {
120 Some(
121 match handler(id, serde_json::from_value(value).ok()?).await {
122 Ok(v) => (false, serde_json::to_value(v).ok()?),
123 Err(v) => (true, serde_json::to_value(v).ok()?),
124 },
125 )
126 })
127 }),
128 );
129 }
130}
131
132impl Session {
133 pub async fn send<M: Method>(&self, data: &Message<M>) -> crate::Result<()> {
134 self.ws
135 .send_text_payload(&serde_json::to_vec(&data)?)
136 .await?;
137 Ok(())
138 }
139
140 pub async fn use_id(&self) -> u32 {
141 let mut id = self.id.lock().await;
142 *id += 1;
143 *id
144 }
145
146 pub async fn request<M: Method>(
147 &self,
148 req: M::Request,
149 ) -> crate::Result<std::result::Result<M::Response, M::Error>> {
150 let id = self.use_id().await;
151
152 self.send::<M>(&Message::Request {
153 id,
154 method: M::NAME.to_string(),
155 data: req,
156 })
157 .await?;
158
159 let mut rx = self.tx.subscribe();
160
161 loop {
162 let r = rx.recv().await?;
163
164 if r.0 == id {
165 break Ok(if r.1 {
166 Err(serde_json::from_value(r.2)?)
167 } else {
168 Ok(serde_json::from_value(r.2)?)
169 });
170 }
171 }
172 }
173
174 pub async fn respond(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
175 self.send::<GenericMethod>(&Message::Response {
176 id: to,
177 result: val,
178 })
179 .await
180 }
181
182 pub async fn respond_error(&self, to: u32, val: serde_json::Value) -> crate::Result<()> {
183 self.send::<GenericMethod>(&Message::ErrorResponse { id: to, error: val })
184 .await
185 }
186
187 pub async fn notify<M: Method>(&self, data: M::Request) -> crate::Result<()> {
188 self.send::<M>(&Message::Notification {
189 method: M::NAME.to_string(),
190 data,
191 })
192 .await
193 }
194
195 pub async fn close(&self) -> crate::Result<()> {
196 Ok(self.ws.close().await?)
197 }
198}