1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
//! Session state
//!
//! This module contains all the logic related to session management.

use super::ClientId;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use strum::EnumString;
use thiserror::Error;
use uuid::Uuid;

/// Value associated to a session.
pub type SessionValue = Option<Value>;
/// Unique ID of a session.
pub type SessionId = Uuid;
/// Party number of a session.
pub type SessionPartyNumber = u16;

/// Error type for session operations.
#[derive(Debug, Error)]
pub enum SessionError {
    #[error("party number `{0}` is already occupied by another party")]
    PartyNumberAlreadyOccupied(SessionPartyNumber),
    #[error("client `{0}` is already signed up")]
    ClientAlreadySignedUp(ClientId),
}

/// Session kinds available in this implementation.
#[derive(Debug, Clone, Copy, Deserialize, Serialize, EnumString)]
pub enum SessionKind {
    /// Key generation session.
    #[serde(rename = "keygen")]
    #[strum(serialize = "keygen")]
    Keygen,
    /// Signing session.
    #[serde(rename = "sign")]
    #[strum(serialize = "sign")]
    Sign,
}

/// Session is subgroup of clients intended to be used for a specific purpose.
#[derive(Debug, Deserialize, Serialize)]
pub struct Session {
    /// Unique ID of the session.
    pub id: SessionId,
    /// Session kind
    pub kind: SessionKind,
    /// Public value associated to this session.
    ///
    /// This value can be set at the moment of creation.
    /// It can be a message or transaction intended
    /// to be signed by the session.
    pub value: SessionValue,
    /// Map party number to client id, starting at 1.
    #[serde(skip)]
    pub party_signups: HashMap<SessionPartyNumber, ClientId>,
    /// Occupied party numbers, starting at 1.
    #[serde(skip)]
    pub occupied_party_numbers: Vec<SessionPartyNumber>,
    ///
    /// Party numbers of finished clients
    #[serde(skip)]
    pub finished: HashSet<u16>,
}

impl Session {
    /// Creates a new session with the given parameters.
    pub fn new(id: Uuid, kind: SessionKind, value: SessionValue) -> Self {
        Self {
            id,
            kind,
            value,
            party_signups: HashMap::new(),
            occupied_party_numbers: Vec::new(),
            finished: HashSet::new(),
        }
    }

    /// Registers a client in the session and returns its party number.
    #[cfg(feature = "server")]
    pub fn signup(&mut self, client_id: ClientId) -> anyhow::Result<SessionPartyNumber> {
        if self.is_client_in_session(&client_id) {
            return Err(SessionError::ClientAlreadySignedUp(client_id).into());
        }
        let party_number = self.get_next_party_number();
        self.add_party(client_id, party_number);
        Ok(party_number)
    }

    /// Signs in a client in the session with a given party number.
    #[cfg(feature = "server")]
    pub fn login(
        &mut self,
        client_id: ClientId,
        party_number: SessionPartyNumber,
    ) -> anyhow::Result<()> {
        if self.is_client_in_session(&client_id) {
            return Ok(()); //TODO: think of a better way to handle this (should we return an error?)
        }
        if self.occupied_party_numbers.contains(&party_number) {
            return Err(SessionError::PartyNumberAlreadyOccupied(party_number).into());
        }
        self.add_party(client_id, party_number);
        Ok(())
    }

    /// Adds new party assuming `party_number` doesn't exist already.
    #[cfg(feature = "server")]
    fn add_party(&mut self, client_id: ClientId, party_number: SessionPartyNumber) {
        self.occupied_party_numbers.push(party_number);
        self.occupied_party_numbers.sort();
        self.party_signups.insert(party_number, client_id);
    }

    /// Gets the party number of a client.
    #[cfg(feature = "server")]
    pub fn get_party_number(&self, client_id: &ClientId) -> Option<SessionPartyNumber> {
        self.party_signups
            .iter()
            .find(|(_, id)| id == &client_id)
            .map(|(party, _)| *party)
    }

    /// Returns boolean indicating if the client is already in this session.
    #[cfg(feature = "server")]
    pub fn is_client_in_session(&self, client_id: &ClientId) -> bool {
        self.party_signups.values().any(|id| id == client_id)
    }

    /// Returns the client id of a given party number.
    #[cfg(feature = "server")]
    pub fn get_client_id(&self, party_number: SessionPartyNumber) -> Option<ClientId> {
        self.party_signups
            .iter()
            .find(|(&pn, _)| pn == party_number)
            .map(|(_, id)| *id)
    }

    /// Returns all the client ids associated with the session.
    #[cfg(feature = "server")]
    pub fn get_all_client_ids(&self) -> Vec<ClientId> {
        self.party_signups.values().copied().collect()
    }

    /// Returns the number of clients associated with this session.
    #[cfg(feature = "server")]
    pub fn get_number_of_clients(&self) -> usize {
        self.party_signups.len()
    }

    /// Gets the next missing party number, assuming `occupied_party_numbers`
    /// is a sorted array.
    ///
    /// # Examples
    ///
    /// - if `[1,2,3,4]` it will return 5
    /// - if `[1,4,5,6]` it will return 2
    #[cfg(feature = "server")]
    fn get_next_party_number(&self) -> SessionPartyNumber {
        for (i, party) in self.occupied_party_numbers.iter().enumerate() {
            if (i + 1) != *party as usize {
                return (i + 1) as SessionPartyNumber;
            }
        }

        match self.occupied_party_numbers.last() {
            Some(party) => party + 1,
            None => 1,
        }
    }
}

impl Clone for Session {
    /// Clones session parameters, disregarding sensitive information.
    ///
    /// Should be used only for logging purposes.
    fn clone(&self) -> Self {
        Self {
            id: self.id,
            kind: self.kind,
            value: self.value.clone(),
            party_signups: HashMap::new(),
            occupied_party_numbers: Vec::new(),
            finished: HashSet::new(),
        }
    }
}