mpc_manager/service/
session_service.rs

1//! # Session service
2//!
3//! This module contains the session service that handles incoming requests
4//! for session management.
5
6use crate::state::{
7    group::{Group, GroupId},
8    session::{Session, SessionId, SessionKind, SessionPartyNumber},
9};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use strum::{Display, EnumString};
13
14#[cfg(feature = "server")]
15use super::{notification::Notification, Service, ServiceResponse};
16#[cfg(feature = "server")]
17use crate::state::{ClientId, State};
18#[cfg(feature = "server")]
19use json_rpc2::{Error, Request};
20#[cfg(feature = "server")]
21use std::str::FromStr;
22#[cfg(feature = "server")]
23use std::sync::Arc;
24#[cfg(feature = "server")]
25use tokio::sync::Mutex;
26
27/// Prefix for session routes.
28pub const ROUTE_PREFIX: &str = "session";
29
30/// Available session methods.
31#[derive(Debug, Display, EnumString)]
32pub enum SessionMethod {
33    /// Create a new session.
34    #[strum(serialize = "session_create")]
35    SessionCreate,
36    /// Signup for a session.
37    #[strum(serialize = "session_signup")]
38    SessionSignup,
39    /// Login to a session.
40    #[strum(serialize = "session_login")]
41    SessionLogin,
42    /// Send a message to a session.
43    #[strum(serialize = "session_message")]
44    SessionMessage,
45}
46
47/// Available session events.
48#[derive(Debug, Display, EnumString)]
49pub enum SessionEvent {
50    /// A session was created.
51    #[strum(serialize = "session_created")]
52    SessionCreated,
53    /// A session has enough participants.
54    #[strum(serialize = "session_ready")]
55    SessionReady,
56    /// A session received a message.
57    #[strum(serialize = "session_message")]
58    SessionMessage,
59}
60
61/// Session create request.
62#[derive(Deserialize, Serialize)]
63pub struct SessionCreateRequest {
64    #[serde(rename = "groupId")]
65    pub group_id: GroupId,
66    pub kind: SessionKind,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub value: Option<Value>,
69}
70
71/// Session create response.
72#[derive(Serialize)]
73pub struct SessionCreateResponse {
74    session: Session,
75}
76
77/// Session created notification.
78#[derive(Deserialize, Serialize)]
79pub struct SessionCreatedNotification {
80    group: Group,
81    session: Session,
82}
83
84/// Session signup request.
85#[derive(Deserialize, Serialize)]
86pub struct SessionSignupRequest {
87    #[serde(rename = "groupId")]
88    pub group_id: GroupId,
89    #[serde(rename = "sessionId")]
90    pub session_id: SessionId,
91}
92
93/// Session signup response.
94#[derive(Serialize)]
95pub struct SessionSignupResponse {
96    session: Session,
97    #[serde(rename = "partyNumber")]
98    party_number: SessionPartyNumber,
99}
100
101/// Session login request.
102#[derive(Deserialize, Serialize)]
103pub struct SessionLoginRequest {
104    #[serde(rename = "groupId")]
105    pub group_id: GroupId,
106    #[serde(rename = "sessionId")]
107    pub session_id: SessionId,
108    #[serde(rename = "partyNumber")]
109    pub party_number: SessionPartyNumber,
110}
111
112/// Session login response.
113#[derive(Serialize)]
114pub struct SessionLoginResponse {
115    session: Session,
116}
117
118/// Session ready notification.
119#[derive(Deserialize, Serialize)]
120pub struct SessionReadyNotification {
121    group: Group,
122    session: Session,
123}
124
125/// Session message request.
126#[derive(Debug, Deserialize, Serialize)]
127pub struct SessionMessageRequest<T: Serialize = Value> {
128    #[serde(rename = "groupId")]
129    pub group_id: GroupId,
130    #[serde(rename = "sessionId")]
131    pub session_id: SessionId,
132    pub receiver: Option<SessionPartyNumber>,
133    pub message: T,
134}
135
136/// Session message notification.
137#[derive(Debug, Clone, Deserialize, Serialize)]
138pub struct SessionMessageNotification<T: Serialize = Value> {
139    #[serde(rename = "groupId")]
140    pub group_id: GroupId,
141    #[serde(rename = "sessionId")]
142    pub session_id: SessionId,
143    pub sender: SessionPartyNumber,
144    pub message: T,
145}
146
147/// Session service that handles incoming requests and maps
148/// them to the corresponding methods.
149#[derive(Debug)]
150#[cfg(feature = "server")]
151pub struct SessionService;
152
153#[axum::async_trait]
154#[cfg(feature = "server")]
155impl Service for SessionService {
156    async fn handle(
157        &self,
158        req: &Request,
159        ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
160        client_id: ClientId,
161    ) -> ServiceResponse {
162        let method = SessionMethod::from_str(req.method()).map_err(|_| {
163            json_rpc2::Error::MethodNotFound {
164                name: req.method().to_string(),
165                id: req.id().clone(),
166            }
167        })?;
168        let response = match method {
169            SessionMethod::SessionCreate => self.session_create(req, ctx, client_id).await?,
170            SessionMethod::SessionSignup => self.session_signup(req, ctx, client_id).await?,
171            SessionMethod::SessionLogin => self.session_login(req, ctx, client_id).await?,
172            SessionMethod::SessionMessage => self.session_message(req, ctx, client_id).await?,
173        };
174        Ok(response)
175    }
176}
177
178#[cfg(feature = "server")]
179impl SessionService {
180    async fn session_create(
181        &self,
182        req: &Request,
183        ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
184        client_id: ClientId,
185    ) -> ServiceResponse {
186        let params: SessionCreateRequest = req.deserialize()?;
187        tracing::info!(
188            group_id = params.group_id.to_string(),
189            "Creating a new session"
190        );
191        let (state, notifications) = ctx;
192        let (group, session) = state
193            .add_session(params.group_id, params.kind, params.value)
194            .await
195            .map_err(|e| Error::InvalidParams {
196                id: req.id().clone(),
197                data: e.to_string(),
198            })?;
199
200        let res = serde_json::to_value(SessionCreateResponse {
201            session: session.clone(),
202        })
203        .map_err(|e| Error::from(Box::from(e)))?;
204        let notification = serde_json::to_value(SessionCreatedNotification { group, session })
205            .map_err(|e| Error::from(Box::from(e)))?;
206
207        notifications.lock().await.push(Notification::Group {
208            group_id: params.group_id,
209            filter: vec![client_id],
210            method: SessionEvent::SessionCreated.to_string(),
211            message: notification.clone(),
212        });
213        Ok(Some((req, res).into()))
214    }
215
216    async fn session_signup(
217        &self,
218        req: &Request,
219        ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
220        client_id: ClientId,
221    ) -> ServiceResponse {
222        let params: SessionSignupRequest = req.deserialize()?;
223        tracing::info!(
224            group_id = params.group_id.to_string(),
225            session_id = params.session_id.to_string(),
226            "Signing up client to a session"
227        );
228        let (state, notifications) = ctx;
229
230        let (group, session, party_number, threshold) = state
231            .signup_session(client_id, params.group_id, params.session_id)
232            .await
233            .map_err(|e| Error::InvalidParams {
234                id: req.id().clone(),
235                data: e.to_string(),
236            })?;
237
238        let res = serde_json::to_value(SessionSignupResponse {
239            session: session.clone(),
240            party_number,
241        })
242        .map_err(|e| Error::from(Box::from(e)))?;
243
244        if threshold {
245            let notification = serde_json::to_value(SessionReadyNotification { group, session })
246                .map_err(|e| Error::from(Box::from(e)))?;
247            notifications.lock().await.push(Notification::Group {
248                group_id: params.group_id,
249                filter: vec![],
250                method: SessionEvent::SessionReady.to_string(),
251                message: notification,
252            });
253        }
254        Ok(Some((req, res).into()))
255    }
256    async fn session_login(
257        &self,
258        req: &Request,
259        ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
260        client_id: ClientId,
261    ) -> ServiceResponse {
262        let params: SessionLoginRequest = req.deserialize()?;
263        tracing::info!(
264            group_id = params.group_id.to_string(),
265            session_id = params.session_id.to_string(),
266            "Loggin in client to a session"
267        );
268        let (state, notifications) = ctx;
269        let (group, session, threshold) = state
270            .login_session(
271                client_id,
272                params.group_id,
273                params.session_id,
274                params.party_number,
275            )
276            .await
277            .map_err(|e| Error::InvalidParams {
278                id: req.id().clone(),
279                data: e.to_string(),
280            })?;
281        let res = serde_json::to_value(SessionLoginResponse {
282            session: session.clone(),
283        })
284        .map_err(|e| Error::from(Box::from(e)))?;
285        if threshold {
286            let notification = serde_json::to_value(SessionReadyNotification { group, session })
287                .map_err(|e| Error::from(Box::from(e)))?;
288            notifications.lock().await.push(Notification::Group {
289                group_id: params.group_id,
290                filter: vec![],
291                method: SessionEvent::SessionReady.to_string(),
292                message: notification,
293            });
294        }
295        Ok(Some((req, res).into()))
296    }
297    async fn session_message(
298        &self,
299        req: &Request,
300        ctx: (Arc<State>, Arc<Mutex<Vec<Notification>>>),
301        client_id: ClientId,
302    ) -> ServiceResponse {
303        let params: SessionMessageRequest = req.deserialize()?;
304        tracing::info!(
305            group_id = params.group_id.to_string(),
306            session_id = params.session_id.to_string(),
307            "Sending message to session"
308        );
309        let (state, notifications) = ctx;
310
311        let self_party_number = state
312            .get_party_number_from_client_id(params.group_id, params.session_id, client_id)
313            .await
314            .map_err(|e| Error::InvalidParams {
315                id: req.id().clone(),
316                data: e.to_string(),
317            })?;
318        state
319            .validate_group_and_session(params.group_id, params.session_id)
320            .await
321            .map_err(|e| Error::InvalidParams {
322                id: req.id().clone(),
323                data: e.to_string(),
324            })?;
325
326        let res = serde_json::to_value(SessionMessageNotification {
327            group_id: params.group_id,
328            session_id: params.session_id,
329            message: params.message,
330            sender: self_party_number,
331        })
332        .map_err(|e| Error::from(Box::from(e)))?;
333
334        let mut notifications = notifications.lock().await;
335        match params.receiver {
336            Some(party_number) => {
337                let receiver_client_id = state
338                    .get_client_id_from_party_number(
339                        params.group_id,
340                        params.session_id,
341                        party_number,
342                    )
343                    .await
344                    .map_err(|e| Error::InvalidParams {
345                        id: req.id().clone(),
346                        data: e.to_string(),
347                    })?;
348                notifications.push(Notification::Relay {
349                    method: SessionEvent::SessionMessage.to_string(),
350                    messages: vec![(receiver_client_id, res)],
351                })
352            }
353            None => notifications.push(Notification::Session {
354                method: SessionEvent::SessionMessage.to_string(),
355                group_id: params.group_id,
356                session_id: params.session_id,
357                filter: vec![],
358                message: res,
359            }),
360        };
361
362        Ok(None)
363    }
364}