1use uuid::Uuid;
7
8#[cfg(feature = "server")]
9use self::{
10 group::{Group, GroupId},
11 parameters::Parameters,
12 session::{Session, SessionId, SessionKind, SessionPartyNumber, SessionValue},
13};
14#[cfg(feature = "server")]
15use anyhow::Result;
16#[cfg(feature = "server")]
17use std::collections::HashMap;
18#[cfg(feature = "server")]
19use thiserror::Error;
20#[cfg(feature = "server")]
21use tokio::sync::{mpsc::UnboundedSender, RwLock};
22
23pub mod group;
24pub mod parameters;
25pub mod session;
26
27pub type ClientId = Uuid;
29
30#[derive(Debug, Error)]
32#[cfg(feature = "server")]
33pub enum StateError {
34 #[error("group `{0}` not found")]
36 GroupNotFound(GroupId),
37 #[error("session `{0}` fro group `{1} not found")]
39 SessionNotFound(SessionId, GroupId),
40 #[error("group `{0}` is full")]
42 GroupIsFull(GroupId),
43 #[error("party `{0}` not found")]
45 PartyNotFound(SessionPartyNumber),
46 #[error("client id `{0}` not found")]
48 ClientNotFound(ClientId),
49}
50
51#[derive(Debug, Default)]
53#[cfg(feature = "server")]
54pub struct State {
55 clients: RwLock<HashMap<ClientId, UnboundedSender<String>>>,
57 groups: RwLock<HashMap<GroupId, Group>>,
59}
60
61#[cfg(feature = "server")]
62impl State {
63 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn new_client_id(&self) -> ClientId {
70 Uuid::new_v4()
71 }
72
73 pub async fn add_client(&self, id: ClientId, tx: UnboundedSender<String>) {
75 self.clients.write().await.insert(id, tx);
76 }
77
78 pub async fn get_client(&self, id: &ClientId) -> Option<UnboundedSender<String>> {
80 self.clients.read().await.get(id).cloned()
81 }
82
83 pub async fn drop_client(&self, id: ClientId) {
86 let mut groups = self.groups.write().await;
88 let mut empty_groups: Vec<Uuid> = Vec::new();
89 groups.iter_mut().for_each(|(group_id, group)| {
90 group.drop_client(id);
91 if group.is_empty() {
92 empty_groups.push(*group_id);
93 }
94 });
95 empty_groups.iter().for_each(|group_id| {
96 tracing::info!(group_id = group_id.to_string(), "Removing empty group");
97 groups.remove(group_id);
98 });
99
100 self.clients.write().await.remove(&id);
104 }
105
106 pub async fn add_group(&self, params: Parameters) -> Group {
109 let uuid = Uuid::new_v4();
110 let group = Group::new(uuid, params);
111 let group_c = group.clone();
112 self.groups.write().await.insert(uuid, group);
113 group_c
114 }
115
116 pub async fn join_group(&self, group_id: GroupId, client_id: ClientId) -> Result<Group> {
119 let groups = self.groups.read().await;
121 let group = groups
122 .get(&group_id)
123 .ok_or(StateError::GroupNotFound(group_id))?;
124 if group.is_full() {
125 return Err(StateError::GroupIsFull(group_id).into());
126 }
127
128 let mut groups = self.groups.write().await;
130 let group = groups.get_mut(&group_id).unwrap(); group.add_client(client_id)?;
132 Ok(group.clone())
133 }
134
135 pub async fn add_session(
138 &self,
139 group_id: GroupId,
140 kind: SessionKind,
141 value: SessionValue,
142 ) -> Result<(Group, Session)> {
143 let groups = self.groups.read().await;
145 groups
146 .get(&group_id)
147 .ok_or(StateError::GroupNotFound(group_id))?;
148
149 let mut groups = self.groups.write().await;
151 let group = groups.get_mut(&group_id).unwrap();
152 let session = group.add_session(kind, value);
153 Ok((group.clone(), session))
154 }
155
156 pub async fn signup_session(
160 &self,
161 client_id: ClientId,
162 group_id: GroupId,
163 session_id: SessionId,
164 ) -> Result<(Group, Session, SessionPartyNumber, bool)> {
165 let groups = self.groups.read().await;
167 let group = groups
168 .get(&group_id)
169 .ok_or(StateError::GroupNotFound(group_id))?;
170 group
171 .get_session(&session_id)
172 .ok_or(StateError::SessionNotFound(session_id, group_id))?;
173
174 let mut groups = self.groups.write().await;
176 let group = groups.get_mut(&group_id).unwrap();
177 let session = group.get_session_mut(&session_id).unwrap();
178 let party_index = session.signup(client_id)?;
179
180 let parties = session.get_number_of_clients();
181 let session_c = session.clone();
182 let threshold = group.params.threshold_reached(session_c.kind, parties);
183 Ok((group.clone(), session_c, party_index, threshold))
184 }
185
186 pub async fn login_session(
189 &self,
190 client_id: ClientId,
191 group_id: GroupId,
192 session_id: SessionId,
193 party_number: SessionPartyNumber,
194 ) -> Result<(Group, Session, bool)> {
195 let groups = self.groups.read().await;
197 let group = groups
198 .get(&group_id)
199 .ok_or(StateError::GroupNotFound(group_id))?;
200 group
201 .get_session(&session_id)
202 .ok_or(StateError::SessionNotFound(session_id, group_id))?;
203
204 let mut groups = self.groups.write().await;
206 let group = groups.get_mut(&group_id).unwrap();
207 let session = group.get_session_mut(&session_id).unwrap();
208 session.login(client_id, party_number)?;
209 let session_c = session.clone();
210 let parties = session.party_signups.len();
211 let threshold = group.params.threshold_reached(session_c.kind, parties);
212 Ok((group.clone(), session_c, threshold))
213 }
214
215 pub async fn get_client_ids_from_group(&self, group_id: &GroupId) -> Result<Vec<ClientId>> {
217 let groups = self.groups.read().await;
218 let group = groups
219 .get(group_id)
220 .ok_or(StateError::GroupNotFound(*group_id))?;
221 let client_ids: Vec<ClientId> = group.clients().iter().copied().collect();
222 Ok(client_ids)
223 }
224
225 pub async fn get_client_ids_from_session(
227 &self,
228 group_id: &GroupId,
229 session_id: &SessionId,
230 ) -> Result<Vec<ClientId>> {
231 let groups = self.groups.read().await;
232 let group = groups
233 .get(group_id)
234 .ok_or(StateError::GroupNotFound(*group_id))?;
235 let session = group
236 .get_session(session_id)
237 .ok_or(StateError::SessionNotFound(*session_id, *group_id))?;
238 let client_ids = session.get_all_client_ids();
239 Ok(client_ids)
240 }
241
242 pub async fn get_client_id_from_party_number(
244 &self,
245 group_id: GroupId,
246 session_id: SessionId,
247 party_number: SessionPartyNumber,
248 ) -> Result<ClientId> {
249 let groups = self.groups.read().await;
251 let group = groups
252 .get(&group_id)
253 .ok_or(StateError::GroupNotFound(group_id))?;
254 let session = group
255 .get_session(&session_id)
256 .ok_or(StateError::SessionNotFound(session_id, group_id))?;
257
258 let client_id = session
260 .get_client_id(party_number)
261 .ok_or(StateError::PartyNotFound(party_number))?;
262 Ok(client_id)
263 }
264
265 pub async fn get_party_number_from_client_id(
267 &self,
268 group_id: GroupId,
269 session_id: SessionId,
270 client_id: ClientId,
271 ) -> Result<SessionPartyNumber> {
272 let groups = self.groups.read().await;
273 let group = groups
274 .get(&group_id)
275 .ok_or(StateError::GroupNotFound(group_id))?;
276 let session = group
277 .get_session(&session_id)
278 .ok_or(StateError::SessionNotFound(session_id, group_id))?;
279 let party_number = session
280 .get_party_number(&client_id)
281 .ok_or(StateError::ClientNotFound(client_id))?;
282 Ok(party_number)
283 }
284
285 pub async fn validate_group_and_session(
287 &self,
288 group_id: GroupId,
289 session_id: SessionId,
290 ) -> Result<()> {
291 let groups = self.groups.read().await;
292 let group = groups
293 .get(&group_id)
294 .ok_or(StateError::GroupNotFound(group_id))?;
295 group
296 .get_session(&session_id)
297 .ok_or(StateError::SessionNotFound(session_id, group_id))?;
298 Ok(())
299 }
300}