1use crate::state::{
7 group::{Group, GroupId},
8 session::{Session, SessionId, SessionKind, SessionPartyNumber},
9};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use strum::{Display, EnumString};
13
14#[cfg(feature = "server")]
15use super::{notification::Notification, Service, ServiceResponse};
16#[cfg(feature = "server")]
17use crate::state::{ClientId, State};
18#[cfg(feature = "server")]
19use json_rpc2::{Error, Request};
20#[cfg(feature = "server")]
21use std::str::FromStr;
22#[cfg(feature = "server")]
23use std::sync::Arc;
24#[cfg(feature = "server")]
25use tokio::sync::Mutex;
26
27pub const ROUTE_PREFIX: &str = "session";
29
30#[derive(Debug, Display, EnumString)]
32pub enum SessionMethod {
33 #[strum(serialize = "session_create")]
35 SessionCreate,
36 #[strum(serialize = "session_signup")]
38 SessionSignup,
39 #[strum(serialize = "session_login")]
41 SessionLogin,
42 #[strum(serialize = "session_message")]
44 SessionMessage,
45}
46
47#[derive(Debug, Display, EnumString)]
49pub enum SessionEvent {
50 #[strum(serialize = "session_created")]
52 SessionCreated,
53 #[strum(serialize = "session_ready")]
55 SessionReady,
56 #[strum(serialize = "session_message")]
58 SessionMessage,
59}
60
61#[derive(Deserialize, Serialize)]
63pub struct SessionCreateRequest {
64 #[serde(rename = "groupId")]
65 pub group_id: GroupId,
66 pub kind: SessionKind,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub value: Option<Value>,
69}
70
71#[derive(Serialize)]
73pub struct SessionCreateResponse {
74 session: Session,
75}
76
77#[derive(Deserialize, Serialize)]
79pub struct SessionCreatedNotification {
80 group: Group,
81 session: Session,
82}
83
84#[derive(Deserialize, Serialize)]
86pub struct SessionSignupRequest {
87 #[serde(rename = "groupId")]
88 pub group_id: GroupId,
89 #[serde(rename = "sessionId")]
90 pub session_id: SessionId,
91}
92
93#[derive(Serialize)]
95pub struct SessionSignupResponse {
96 session: Session,
97 #[serde(rename = "partyNumber")]
98 party_number: SessionPartyNumber,
99}
100
101#[derive(Deserialize, Serialize)]
103pub struct SessionLoginRequest {
104 #[serde(rename = "groupId")]
105 pub group_id: GroupId,
106 #[serde(rename = "sessionId")]
107 pub session_id: SessionId,
108 #[serde(rename = "partyNumber")]
109 pub party_number: SessionPartyNumber,
110}
111
112#[derive(Serialize)]
114pub struct SessionLoginResponse {
115 session: Session,
116}
117
118#[derive(Deserialize, Serialize)]
120pub struct SessionReadyNotification {
121 group: Group,
122 session: Session,
123}
124
125#[derive(Debug, Deserialize, Serialize)]
127pub struct SessionMessageRequest<T: Serialize = Value> {
128 #[serde(rename = "groupId")]
129 pub group_id: GroupId,
130 #[serde(rename = "sessionId")]
131 pub session_id: SessionId,
132 pub receiver: Option<SessionPartyNumber>,
133 pub message: T,
134}
135
136#[derive(Debug, Clone, Deserialize, Serialize)]
138pub struct SessionMessageNotification<T: Serialize = Value> {
139 #[serde(rename = "groupId")]
140 pub group_id: GroupId,
141 #[serde(rename = "sessionId")]
142 pub session_id: SessionId,
143 pub sender: SessionPartyNumber,
144 pub message: T,
145}
146
147#[derive(Debug)]
150#[cfg(feature = "server")]
151pub struct SessionService;
152
153#[axum::async_trait]
154#[cfg(feature = "server")]
155impl Service for SessionService {
156 async fn handle(
157 &self,
158 req: &Request,
159 ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
160 client_id: ClientId,
161 ) -> ServiceResponse {
162 let method = SessionMethod::from_str(req.method()).map_err(|_| {
163 json_rpc2::Error::MethodNotFound {
164 name: req.method().to_string(),
165 id: req.id().clone(),
166 }
167 })?;
168 let response = match method {
169 SessionMethod::SessionCreate => self.session_create(req, ctx, client_id).await?,
170 SessionMethod::SessionSignup => self.session_signup(req, ctx, client_id).await?,
171 SessionMethod::SessionLogin => self.session_login(req, ctx, client_id).await?,
172 SessionMethod::SessionMessage => self.session_message(req, ctx, client_id).await?,
173 };
174 Ok(response)
175 }
176}
177
178#[cfg(feature = "server")]
179impl SessionService {
180 async fn session_create(
181 &self,
182 req: &Request,
183 ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
184 client_id: ClientId,
185 ) -> ServiceResponse {
186 let params: SessionCreateRequest = req.deserialize()?;
187 tracing::info!(
188 group_id = params.group_id.to_string(),
189 "Creating a new session"
190 );
191 let (state, notifications) = ctx;
192 let (group, session) = state
193 .add_session(params.group_id, params.kind, params.value)
194 .await
195 .map_err(|e| Error::InvalidParams {
196 id: req.id().clone(),
197 data: e.to_string(),
198 })?;
199
200 let res = serde_json::to_value(SessionCreateResponse {
201 session: session.clone(),
202 })
203 .map_err(|e| Error::from(Box::from(e)))?;
204 let notification = serde_json::to_value(SessionCreatedNotification { group, session })
205 .map_err(|e| Error::from(Box::from(e)))?;
206
207 notifications.lock().await.push(Notification::Group {
208 group_id: params.group_id,
209 filter: vec![client_id],
210 method: SessionEvent::SessionCreated.to_string(),
211 message: notification.clone(),
212 });
213 Ok(Some((req, res).into()))
214 }
215
216 async fn session_signup(
217 &self,
218 req: &Request,
219 ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
220 client_id: ClientId,
221 ) -> ServiceResponse {
222 let params: SessionSignupRequest = req.deserialize()?;
223 tracing::info!(
224 group_id = params.group_id.to_string(),
225 session_id = params.session_id.to_string(),
226 "Signing up client to a session"
227 );
228 let (state, notifications) = ctx;
229
230 let (group, session, party_number, threshold) = state
231 .signup_session(client_id, params.group_id, params.session_id)
232 .await
233 .map_err(|e| Error::InvalidParams {
234 id: req.id().clone(),
235 data: e.to_string(),
236 })?;
237
238 let res = serde_json::to_value(SessionSignupResponse {
239 session: session.clone(),
240 party_number,
241 })
242 .map_err(|e| Error::from(Box::from(e)))?;
243
244 if threshold {
245 let notification = serde_json::to_value(SessionReadyNotification { group, session })
246 .map_err(|e| Error::from(Box::from(e)))?;
247 notifications.lock().await.push(Notification::Group {
248 group_id: params.group_id,
249 filter: vec![],
250 method: SessionEvent::SessionReady.to_string(),
251 message: notification,
252 });
253 }
254 Ok(Some((req, res).into()))
255 }
256 async fn session_login(
257 &self,
258 req: &Request,
259 ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
260 client_id: ClientId,
261 ) -> ServiceResponse {
262 let params: SessionLoginRequest = req.deserialize()?;
263 tracing::info!(
264 group_id = params.group_id.to_string(),
265 session_id = params.session_id.to_string(),
266 "Loggin in client to a session"
267 );
268 let (state, notifications) = ctx;
269 let (group, session, threshold) = state
270 .login_session(
271 client_id,
272 params.group_id,
273 params.session_id,
274 params.party_number,
275 )
276 .await
277 .map_err(|e| Error::InvalidParams {
278 id: req.id().clone(),
279 data: e.to_string(),
280 })?;
281 let res = serde_json::to_value(SessionLoginResponse {
282 session: session.clone(),
283 })
284 .map_err(|e| Error::from(Box::from(e)))?;
285 if threshold {
286 let notification = serde_json::to_value(SessionReadyNotification { group, session })
287 .map_err(|e| Error::from(Box::from(e)))?;
288 notifications.lock().await.push(Notification::Group {
289 group_id: params.group_id,
290 filter: vec![],
291 method: SessionEvent::SessionReady.to_string(),
292 message: notification,
293 });
294 }
295 Ok(Some((req, res).into()))
296 }
297 async fn session_message(
298 &self,
299 req: &Request,
300 ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
301 client_id: ClientId,
302 ) -> ServiceResponse {
303 let params: SessionMessageRequest = req.deserialize()?;
304 tracing::info!(
305 group_id = params.group_id.to_string(),
306 session_id = params.session_id.to_string(),
307 "Sending message to session"
308 );
309 let (state, notifications) = ctx;
310
311 let self_party_number = state
312 .get_party_number_from_client_id(params.group_id, params.session_id, client_id)
313 .await
314 .map_err(|e| Error::InvalidParams {
315 id: req.id().clone(),
316 data: e.to_string(),
317 })?;
318 state
319 .validate_group_and_session(params.group_id, params.session_id)
320 .await
321 .map_err(|e| Error::InvalidParams {
322 id: req.id().clone(),
323 data: e.to_string(),
324 })?;
325
326 let res = serde_json::to_value(SessionMessageNotification {
327 group_id: params.group_id,
328 session_id: params.session_id,
329 message: params.message,
330 sender: self_party_number,
331 })
332 .map_err(|e| Error::from(Box::from(e)))?;
333
334 let mut notifications = notifications.lock().await;
335 match params.receiver {
336 Some(party_number) => {
337 let receiver_client_id = state
338 .get_client_id_from_party_number(
339 params.group_id,
340 params.session_id,
341 party_number,
342 )
343 .await
344 .map_err(|e| Error::InvalidParams {
345 id: req.id().clone(),
346 data: e.to_string(),
347 })?;
348 notifications.push(Notification::Relay {
349 method: SessionEvent::SessionMessage.to_string(),
350 messages: vec![(receiver_client_id, res)],
351 })
352 }
353 None => notifications.push(Notification::Session {
354 method: SessionEvent::SessionMessage.to_string(),
355 group_id: params.group_id,
356 session_id: params.session_id,
357 filter: vec![],
358 message: res,
359 }),
360 };
361
362 Ok(None)
363 }
364}