mpc_manager/
state.rs

1//! State module.
2//!
3//! This module contains the state of the server and the different types used
4//! to represent it.
5
6use uuid::Uuid;
7
8#[cfg(feature = "server")]
9use self::{
10    group::{Group, GroupId},
11    parameters::Parameters,
12    session::{Session, SessionId, SessionKind, SessionPartyNumber, SessionValue},
13};
14#[cfg(feature = "server")]
15use anyhow::Result;
16#[cfg(feature = "server")]
17use std::collections::HashMap;
18#[cfg(feature = "server")]
19use thiserror::Error;
20#[cfg(feature = "server")]
21use tokio::sync::{mpsc::UnboundedSender, RwLock};
22
23pub mod group;
24pub mod parameters;
25pub mod session;
26
27/// Unique ID of a client.
28pub type ClientId = Uuid;
29
30/// Error type for state operations.
31#[derive(Debug, Error)]
32#[cfg(feature = "server")]
33pub enum StateError {
34    /// Error generated when a group was not found.
35    #[error("group `{0}` not found")]
36    GroupNotFound(GroupId),
37    /// Error generated when a session was not found.
38    #[error("session `{0}` fro group `{1} not found")]
39    SessionNotFound(SessionId, GroupId),
40    /// Error generated when a group is already full.
41    #[error("group `{0}` is full")]
42    GroupIsFull(GroupId),
43    /// Error generated when a party number was not found.
44    #[error("party `{0}` not found")]
45    PartyNotFound(SessionPartyNumber),
46    /// Error generated when a client was not found.
47    #[error("client id `{0}` not found")]
48    ClientNotFound(ClientId),
49}
50
51/// Shared state of clients and db managed by the server.
52#[derive(Debug, Default)]
53#[cfg(feature = "server")]
54pub struct State {
55    /// Connected clients.
56    clients: RwLock<HashMap<ClientId, UnboundedSender<String>>>,
57    /// Collection of groups mapped by UUID.
58    groups: RwLock<HashMap<GroupId, Group>>,
59}
60
61#[cfg(feature = "server")]
62impl State {
63    /// Return new state, should only be called once.
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Returns a new client id.
69    pub fn new_client_id(&self) -> ClientId {
70        Uuid::new_v4()
71    }
72
73    /// Adds a new client.
74    pub async fn add_client(&self, id: ClientId, tx: UnboundedSender<String>) {
75        self.clients.write().await.insert(id, tx);
76    }
77
78    /// Returns client data.
79    pub async fn get_client(&self, id: &ClientId) -> Option<UnboundedSender<String>> {
80        self.clients.read().await.get(id).cloned()
81    }
82
83    /// Drops a client, performing all necessary cleanup to preserve
84    /// security.
85    pub async fn drop_client(&self, id: ClientId) {
86        // Remove client from groups and remove group if empty
87        let mut groups = self.groups.write().await;
88        let mut empty_groups: Vec<Uuid> = Vec::new();
89        groups.iter_mut().for_each(|(group_id, group)| {
90            group.drop_client(id);
91            if group.is_empty() {
92                empty_groups.push(*group_id);
93            }
94        });
95        empty_groups.iter().for_each(|group_id| {
96            tracing::info!(group_id = group_id.to_string(), "Removing empty group");
97            groups.remove(group_id);
98        });
99
100        // TODO: remove from sessions?
101
102        // Remove client
103        self.clients.write().await.remove(&id);
104    }
105
106    /// Adds a new group to the state, returning a clone without
107    /// sensitive information for logging purposes.
108    pub async fn add_group(&self, params: Parameters) -> Group {
109        let uuid = Uuid::new_v4();
110        let group = Group::new(uuid, params);
111        let group_c = group.clone();
112        self.groups.write().await.insert(uuid, group);
113        group_c
114    }
115
116    /// Joins a client to a group, returning a clone without
117    /// sensitive information for logging purposes.
118    pub async fn join_group(&self, group_id: GroupId, client_id: ClientId) -> Result<Group> {
119        // Validate group exists and is not full
120        let groups = self.groups.read().await;
121        let group = groups
122            .get(&group_id)
123            .ok_or(StateError::GroupNotFound(group_id))?;
124        if group.is_full() {
125            return Err(StateError::GroupIsFull(group_id).into());
126        }
127
128        // Join group
129        let mut groups = self.groups.write().await;
130        let group = groups.get_mut(&group_id).unwrap(); // validation was done previously
131        group.add_client(client_id)?;
132        Ok(group.clone())
133    }
134
135    /// Adds a new session, returning a clone without sensitive information
136    /// for logging purposes.
137    pub async fn add_session(
138        &self,
139        group_id: GroupId,
140        kind: SessionKind,
141        value: SessionValue,
142    ) -> Result<(Group, Session)> {
143        // Validate group exists
144        let groups = self.groups.read().await;
145        groups
146            .get(&group_id)
147            .ok_or(StateError::GroupNotFound(group_id))?;
148
149        // Add session
150        let mut groups = self.groups.write().await;
151        let group = groups.get_mut(&group_id).unwrap();
152        let session = group.add_session(kind, value);
153        Ok((group.clone(), session))
154    }
155
156    /// Registers a client to a given session and returns
157    /// a session clone, session party number and a boolean
158    /// indicating if the threshold has been reached.
159    pub async fn signup_session(
160        &self,
161        client_id: ClientId,
162        group_id: GroupId,
163        session_id: SessionId,
164    ) -> Result<(Group, Session, SessionPartyNumber, bool)> {
165        // Validate group and session exist
166        let groups = self.groups.read().await;
167        let group = groups
168            .get(&group_id)
169            .ok_or(StateError::GroupNotFound(group_id))?;
170        group
171            .get_session(&session_id)
172            .ok_or(StateError::SessionNotFound(session_id, group_id))?;
173
174        // Signup session
175        let mut groups = self.groups.write().await;
176        let group = groups.get_mut(&group_id).unwrap();
177        let session = group.get_session_mut(&session_id).unwrap();
178        let party_index = session.signup(client_id)?;
179
180        let parties = session.get_number_of_clients();
181        let session_c = session.clone();
182        let threshold = group.params.threshold_reached(session_c.kind, parties);
183        Ok((group.clone(), session_c, party_index, threshold))
184    }
185
186    /// Logins a client witha given party number to a session and returns
187    /// the session and a boolean indicating if the threshold has been reached.
188    pub async fn login_session(
189        &self,
190        client_id: ClientId,
191        group_id: GroupId,
192        session_id: SessionId,
193        party_number: SessionPartyNumber,
194    ) -> Result<(Group, Session, bool)> {
195        // Validate group and session exist
196        let groups = self.groups.read().await;
197        let group = groups
198            .get(&group_id)
199            .ok_or(StateError::GroupNotFound(group_id))?;
200        group
201            .get_session(&session_id)
202            .ok_or(StateError::SessionNotFound(session_id, group_id))?;
203
204        // Login session
205        let mut groups = self.groups.write().await;
206        let group = groups.get_mut(&group_id).unwrap();
207        let session = group.get_session_mut(&session_id).unwrap();
208        session.login(client_id, party_number)?;
209        let session_c = session.clone();
210        let parties = session.party_signups.len();
211        let threshold = group.params.threshold_reached(session_c.kind, parties);
212        Ok((group.clone(), session_c, threshold))
213    }
214
215    /// Returns client ids associated with a given group, if it exists.
216    pub async fn get_client_ids_from_group(&self, group_id: &GroupId) -> Result<Vec<ClientId>> {
217        let groups = self.groups.read().await;
218        let group = groups
219            .get(group_id)
220            .ok_or(StateError::GroupNotFound(*group_id))?;
221        let client_ids: Vec<ClientId> = group.clients().iter().copied().collect();
222        Ok(client_ids)
223    }
224
225    /// Returns client ids associated with a given session, if it exists.
226    pub async fn get_client_ids_from_session(
227        &self,
228        group_id: &GroupId,
229        session_id: &SessionId,
230    ) -> Result<Vec<ClientId>> {
231        let groups = self.groups.read().await;
232        let group = groups
233            .get(group_id)
234            .ok_or(StateError::GroupNotFound(*group_id))?;
235        let session = group
236            .get_session(session_id)
237            .ok_or(StateError::SessionNotFound(*session_id, *group_id))?;
238        let client_ids = session.get_all_client_ids();
239        Ok(client_ids)
240    }
241
242    /// Returns client id associated with a given session and party number.
243    pub async fn get_client_id_from_party_number(
244        &self,
245        group_id: GroupId,
246        session_id: SessionId,
247        party_number: SessionPartyNumber,
248    ) -> Result<ClientId> {
249        // Validate group, session and party number exist.
250        let groups = self.groups.read().await;
251        let group = groups
252            .get(&group_id)
253            .ok_or(StateError::GroupNotFound(group_id))?;
254        let session = group
255            .get_session(&session_id)
256            .ok_or(StateError::SessionNotFound(session_id, group_id))?;
257
258        // Get client id
259        let client_id = session
260            .get_client_id(party_number)
261            .ok_or(StateError::PartyNotFound(party_number))?;
262        Ok(client_id)
263    }
264
265    /// Returns the party number of a given client id.
266    pub async fn get_party_number_from_client_id(
267        &self,
268        group_id: GroupId,
269        session_id: SessionId,
270        client_id: ClientId,
271    ) -> Result<SessionPartyNumber> {
272        let groups = self.groups.read().await;
273        let group = groups
274            .get(&group_id)
275            .ok_or(StateError::GroupNotFound(group_id))?;
276        let session = group
277            .get_session(&session_id)
278            .ok_or(StateError::SessionNotFound(session_id, group_id))?;
279        let party_number = session
280            .get_party_number(&client_id)
281            .ok_or(StateError::ClientNotFound(client_id))?;
282        Ok(party_number)
283    }
284
285    /// Helper function that validates if group and session are valid.
286    pub async fn validate_group_and_session(
287        &self,
288        group_id: GroupId,
289        session_id: SessionId,
290    ) -> Result<()> {
291        let groups = self.groups.read().await;
292        let group = groups
293            .get(&group_id)
294            .ok_or(StateError::GroupNotFound(group_id))?;
295        group
296            .get_session(&session_id)
297            .ok_or(StateError::SessionNotFound(session_id, group_id))?;
298        Ok(())
299    }
300}