1use crate::{encoding::types, PartyNumber, Result, TAGLEN};
2use http::StatusCode;
3use serde::{Deserialize, Serialize};
4use snow::{HandshakeState, TransportState};
5use std::{
6 collections::{HashMap, HashSet},
7 time::{Duration, SystemTime},
8};
9
10pub type SessionId = uuid::Uuid;
12
13#[derive(
16 Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize,
17)]
18pub struct UserId([u8; 32]);
19
20impl AsRef<[u8; 32]> for UserId {
21 fn as_ref(&self) -> &[u8; 32] {
22 &self.0
23 }
24}
25
26impl From<[u8; 32]> for UserId {
27 fn from(value: [u8; 32]) -> Self {
28 Self(value)
29 }
30}
31
32#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
34pub struct Parameters {
35 pub parties: u16,
37 pub threshold: u16,
42}
43
44impl Default for Parameters {
45 fn default() -> Self {
46 Self {
47 parties: 3,
48 threshold: 1,
49 }
50 }
51}
52
53pub enum ProtocolState {
55 Handshake(Box<HandshakeState>),
57 Transport(TransportState),
59}
60
61#[derive(Default, Debug)]
63pub enum HandshakeMessage {
64 #[default]
65 #[doc(hidden)]
66 Noop,
67 Initiator(usize, Vec<u8>),
69 Responder(usize, Vec<u8>),
71}
72
73impl From<&HandshakeMessage> for u8 {
74 fn from(value: &HandshakeMessage) -> Self {
75 match value {
76 HandshakeMessage::Noop => types::NOOP,
77 HandshakeMessage::Initiator(_, _) => {
78 types::HANDSHAKE_INITIATOR
79 }
80 HandshakeMessage::Responder(_, _) => {
81 types::HANDSHAKE_RESPONDER
82 }
83 }
84 }
85}
86
87#[derive(Default, Debug)]
89pub enum TransparentMessage {
90 #[default]
91 #[doc(hidden)]
92 Noop,
93 Error(StatusCode, String),
95 ServerHandshake(HandshakeMessage),
97 PeerHandshake {
99 public_key: Vec<u8>,
101 message: HandshakeMessage,
103 },
104}
105
106impl From<&TransparentMessage> for u8 {
107 fn from(value: &TransparentMessage) -> Self {
108 match value {
109 TransparentMessage::Noop => types::NOOP,
110 TransparentMessage::Error(_, _) => types::ERROR,
111 TransparentMessage::ServerHandshake(_) => {
112 types::HANDSHAKE_SERVER
113 }
114 TransparentMessage::PeerHandshake { .. } => {
115 types::HANDSHAKE_PEER
116 }
117 }
118 }
119}
120
121#[derive(Default, Debug)]
123pub enum ServerMessage {
124 #[default]
125 #[doc(hidden)]
126 Noop,
127 Error(StatusCode, String),
129 NewSession(SessionRequest),
131 SessionConnection {
133 session_id: SessionId,
135 peer_key: Vec<u8>,
137 },
138 SessionCreated(SessionState),
140 SessionReady(SessionState),
144 SessionActive(SessionState),
148 SessionTimeout(SessionId),
152 CloseSession(SessionId),
154 SessionFinished(SessionId),
156}
157
158impl From<&ServerMessage> for u8 {
159 fn from(value: &ServerMessage) -> Self {
160 match value {
161 ServerMessage::Noop => types::NOOP,
162 ServerMessage::Error(_, _) => types::ERROR,
163 ServerMessage::NewSession(_) => types::SESSION_NEW,
164 ServerMessage::SessionConnection { .. } => {
165 types::SESSION_CONNECTION
166 }
167 ServerMessage::SessionCreated(_) => {
168 types::SESSION_CREATED
169 }
170 ServerMessage::SessionReady(_) => types::SESSION_READY,
171 ServerMessage::SessionActive(_) => types::SESSION_ACTIVE,
172 ServerMessage::SessionTimeout(_) => {
173 types::SESSION_TIMEOUT
174 }
175 ServerMessage::CloseSession(_) => types::SESSION_CLOSE,
176 ServerMessage::SessionFinished(_) => {
177 types::SESSION_FINISHED
178 }
179 }
180 }
181}
182
183#[derive(Default, Debug)]
185pub enum OpaqueMessage {
186 #[default]
187 #[doc(hidden)]
188 Noop,
189
190 ServerMessage(SealedEnvelope),
194
195 PeerMessage {
197 public_key: Vec<u8>,
199 session_id: Option<SessionId>,
201 envelope: SealedEnvelope,
203 },
204}
205
206impl From<&OpaqueMessage> for u8 {
207 fn from(value: &OpaqueMessage) -> Self {
208 match value {
209 OpaqueMessage::Noop => types::NOOP,
210 OpaqueMessage::ServerMessage(_) => types::OPAQUE_SERVER,
211 OpaqueMessage::PeerMessage { .. } => types::OPAQUE_PEER,
212 }
213 }
214}
215
216#[derive(Default, Debug)]
218pub enum RequestMessage {
219 #[default]
220 #[doc(hidden)]
221 Noop,
222
223 Transparent(TransparentMessage),
225
226 Opaque(OpaqueMessage),
228}
229
230impl From<&RequestMessage> for u8 {
231 fn from(value: &RequestMessage) -> Self {
232 match value {
233 RequestMessage::Noop => types::NOOP,
234 RequestMessage::Transparent(_) => types::TRANSPARENT,
235 RequestMessage::Opaque(_) => types::OPAQUE,
236 }
237 }
238}
239
240#[derive(Default, Debug)]
242pub enum ResponseMessage {
243 #[default]
244 #[doc(hidden)]
245 Noop,
246
247 Transparent(TransparentMessage),
249
250 Opaque(OpaqueMessage),
252}
253
254impl From<&ResponseMessage> for u8 {
255 fn from(value: &ResponseMessage) -> Self {
256 match value {
257 ResponseMessage::Noop => types::NOOP,
258 ResponseMessage::Transparent(_) => types::TRANSPARENT,
259 ResponseMessage::Opaque(_) => types::OPAQUE,
260 }
261 }
262}
263
264#[derive(Default, Clone, Copy, Debug)]
266pub enum Encoding {
267 #[default]
268 #[doc(hidden)]
269 Noop,
270 Blob,
272 Json,
274}
275
276impl From<Encoding> for u8 {
277 fn from(value: Encoding) -> Self {
278 match value {
279 Encoding::Noop => types::NOOP,
280 Encoding::Blob => types::ENCODING_BLOB,
281 Encoding::Json => types::ENCODING_JSON,
282 }
283 }
284}
285
286#[derive(Default, Debug)]
293pub struct Chunk {
294 pub length: usize,
296 pub contents: Vec<u8>,
298}
299
300impl Chunk {
301 const CHUNK_SIZE: usize = 65535 - TAGLEN;
302
303 pub fn split(
305 payload: &[u8],
306 transport: &mut TransportState,
307 ) -> Result<Vec<Chunk>> {
308 let mut chunks = Vec::new();
309 for chunk in payload.chunks(Self::CHUNK_SIZE) {
310 let mut contents = vec![0; chunk.len() + TAGLEN];
311 let length =
312 transport.write_message(chunk, &mut contents)?;
313 chunks.push(Chunk { length, contents });
314 }
315 Ok(chunks)
316 }
317
318 pub fn join(
320 chunks: Vec<Chunk>,
321 transport: &mut TransportState,
322 ) -> Result<Vec<u8>> {
323 let mut payload = Vec::new();
324 for chunk in chunks {
325 let mut contents = vec![0; chunk.length];
326 transport.read_message(
327 &chunk.contents[..chunk.length],
328 &mut contents,
329 )?;
330 let new_length = contents.len() - TAGLEN;
331 contents.truncate(new_length);
332 payload.extend_from_slice(contents.as_slice());
333 }
334 Ok(payload)
335 }
336}
337
338#[derive(Default, Debug)]
343pub struct SealedEnvelope {
344 pub encoding: Encoding,
346 pub chunks: Vec<Chunk>,
348 pub broadcast: bool,
350}
351
352pub struct Session {
358 owner_key: Vec<u8>,
363
364 participant_keys: HashSet<Vec<u8>>,
366
367 connections: HashSet<(Vec<u8>, Vec<u8>)>,
370
371 last_access: SystemTime,
374}
375
376impl Session {
377 pub fn owner_key(&self) -> &[u8] {
379 self.owner_key.as_slice()
380 }
381
382 pub fn public_keys(&self) -> Vec<&[u8]> {
384 let mut keys = vec![self.owner_key.as_slice()];
385 let mut participants: Vec<_> = self
386 .participant_keys
387 .iter()
388 .map(|k| k.as_slice())
389 .collect();
390 keys.append(&mut participants);
391 keys
392 }
393
394 pub fn register_connection(
396 &mut self,
397 peer: Vec<u8>,
398 other: Vec<u8>,
399 ) {
400 self.connections.insert((peer, other));
401 }
402
403 pub fn is_active(&self) -> bool {
408 let all_participants = self.public_keys();
409
410 fn check_connection(
411 connections: &HashSet<(Vec<u8>, Vec<u8>)>,
412 peer: &[u8],
413 all: &[&[u8]],
414 ) -> bool {
415 for key in all {
416 if key == &peer {
417 continue;
418 }
419 let left =
422 connections.get(&(peer.to_vec(), key.to_vec()));
423 let right =
424 connections.get(&(key.to_vec(), peer.to_vec()));
425 let is_connected = left.is_some() || right.is_some();
426 if !is_connected {
427 return false;
428 }
429 }
430 true
431 }
432
433 for key in &all_participants {
434 let is_connected_others = check_connection(
435 &self.connections,
436 key,
437 all_participants.as_slice(),
438 );
439 if !is_connected_others {
440 return false;
441 }
442 }
443
444 true
445 }
446}
447
448#[derive(Default)]
450pub struct SessionManager {
451 sessions: HashMap<SessionId, Session>,
452}
453
454impl SessionManager {
455 pub fn new_session(
457 &mut self,
458 owner_key: Vec<u8>,
459 participant_keys: Vec<Vec<u8>>,
460 ) -> SessionId {
461 let session_id = SessionId::new_v4();
462 let session = Session {
463 owner_key,
464 participant_keys: participant_keys.into_iter().collect(),
465 connections: Default::default(),
466 last_access: SystemTime::now(),
467 };
468 self.sessions.insert(session_id, session);
469 session_id
470 }
471
472 pub fn get_session(&self, id: &SessionId) -> Option<&Session> {
474 self.sessions.get(id)
475 }
476
477 pub fn get_session_mut(
479 &mut self,
480 id: &SessionId,
481 ) -> Option<&mut Session> {
482 self.sessions.get_mut(id)
483 }
484
485 pub fn remove_session(
487 &mut self,
488 id: &SessionId,
489 ) -> Option<Session> {
490 self.sessions.remove(id)
491 }
492
493 pub fn touch_session(
495 &mut self,
496 id: &SessionId,
497 ) -> Option<&Session> {
498 if let Some(session) = self.sessions.get_mut(id) {
499 session.last_access = SystemTime::now();
500 Some(&*session)
501 } else {
502 None
503 }
504 }
505
506 pub fn expired_keys(&self, timeout: u64) -> Vec<SessionId> {
508 self.sessions
509 .iter()
510 .filter(|(_, v)| {
511 let now = SystemTime::now();
512 let ttl = Duration::from_millis(timeout * 1000);
513 if let Some(current) = v.last_access.checked_add(ttl)
514 {
515 current < now
516 } else {
517 false
518 }
519 })
520 .map(|(k, _)| *k)
521 .collect::<Vec<_>>()
522 }
523}
524
525#[derive(Default, Debug)]
530pub struct SessionRequest {
531 pub participant_keys: Vec<Vec<u8>>,
533}
534
535#[derive(Default, Debug, Clone)]
537pub struct SessionState {
538 pub session_id: SessionId,
540 pub all_participants: Vec<Vec<u8>>,
542}
543
544impl SessionState {
545 pub fn len(&self) -> usize {
547 self.all_participants.len()
548 }
549
550 pub fn party_number(
552 &self,
553 public_key: impl AsRef<[u8]>,
554 ) -> Option<PartyNumber> {
555 self.all_participants
556 .iter()
557 .position(|k| k == public_key.as_ref())
558 .map(|pos| PartyNumber::new((pos + 1) as u16).unwrap())
559 }
560
561 pub fn peer_key(
563 &self,
564 party_number: PartyNumber,
565 ) -> Option<&[u8]> {
566 for (index, key) in self.all_participants.iter().enumerate() {
567 if index + 1 == party_number.get() as usize {
568 return Some(key.as_slice());
569 }
570 }
571 None
572 }
573
574 pub fn connections(&self, own_key: &[u8]) -> &[Vec<u8>] {
576 if self.all_participants.is_empty() {
577 return &[];
578 }
579
580 if let Some(position) =
581 self.all_participants.iter().position(|k| k == own_key)
582 {
583 if position < self.all_participants.len() - 1 {
584 &self.all_participants[position + 1..]
585 } else {
586 &[]
587 }
588 } else {
589 &[]
590 }
591 }
592
593 pub fn recipients(&self, own_key: &[u8]) -> Vec<Vec<u8>> {
595 self.all_participants
596 .iter()
597 .filter(|&k| k != own_key)
598 .map(|k| k.to_vec())
599 .collect()
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::Chunk;
606 use crate::PATTERN;
607 use anyhow::Result;
608
609 #[test]
610 fn chunks_split_join() -> Result<()> {
611 let builder_1 = snow::Builder::new(PATTERN.parse()?);
612 let builder_2 = snow::Builder::new(PATTERN.parse()?);
613
614 let keypair1 = builder_1.generate_keypair()?;
615 let keypair2 = builder_2.generate_keypair()?;
616
617 let mut initiator = builder_1
618 .local_private_key(&keypair1.private)
619 .remote_public_key(&keypair2.public)
620 .build_initiator()?;
621
622 let mut responder = builder_2
623 .local_private_key(&keypair2.private)
624 .remote_public_key(&keypair1.public)
625 .build_responder()?;
626
627 let (mut read_buf, mut first_msg, mut second_msg) =
628 ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
629
630 let len = initiator.write_message(&[], &mut first_msg)?;
632
633 responder.read_message(&first_msg[..len], &mut read_buf)?;
635
636 let len = responder.write_message(&[], &mut second_msg)?;
638
639 initiator.read_message(&second_msg[..len], &mut read_buf)?;
641
642 let mut initiator = initiator.into_transport_mode()?;
644 let mut responder = responder.into_transport_mode()?;
645
646 let mock_payload = vec![0; 76893];
647
648 let chunks = Chunk::split(&mock_payload, &mut initiator)?;
650 assert_eq!(2, chunks.len());
651
652 let decrypted_payload = Chunk::join(chunks, &mut responder)?;
654 assert_eq!(mock_payload, decrypted_payload);
655
656 Ok(())
657 }
658}