1use crate::{encoding::types, PartyNumber, Result, TAGLEN};
2use http::StatusCode;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use snow::{HandshakeState, TransportState};
6use std::{
7 collections::{HashMap, HashSet},
8 time::{Duration, SystemTime},
9};
10
11pub type MeetingId = uuid::Uuid;
13
14pub type SessionId = uuid::Uuid;
16
17#[derive(Debug, Clone, Hash, Eq, PartialEq)]
20pub struct UserId([u8; 32]);
21
22impl AsRef<[u8; 32]> for UserId {
23 fn as_ref(&self) -> &[u8; 32] {
24 &self.0
25 }
26}
27
28impl From<[u8; 32]> for UserId {
29 fn from(value: [u8; 32]) -> Self {
30 Self(value)
31 }
32}
33
34#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
36pub struct Parameters {
37 pub parties: u16,
39 pub threshold: u16,
44}
45
46impl Default for Parameters {
47 fn default() -> Self {
48 Self {
49 parties: 3,
50 threshold: 1,
51 }
52 }
53}
54
55pub enum ProtocolState {
57 Handshake(Box<HandshakeState>),
59 Transport(TransportState),
61}
62
63#[derive(Default, Debug)]
65pub enum HandshakeMessage {
66 #[default]
67 #[doc(hidden)]
68 Noop,
69 Initiator(usize, Vec<u8>),
71 Responder(usize, Vec<u8>),
73}
74
75impl From<&HandshakeMessage> for u8 {
76 fn from(value: &HandshakeMessage) -> Self {
77 match value {
78 HandshakeMessage::Noop => types::NOOP,
79 HandshakeMessage::Initiator(_, _) => {
80 types::HANDSHAKE_INITIATOR
81 }
82 HandshakeMessage::Responder(_, _) => {
83 types::HANDSHAKE_RESPONDER
84 }
85 }
86 }
87}
88
89#[derive(Default, Debug)]
91pub enum TransparentMessage {
92 #[default]
93 #[doc(hidden)]
94 Noop,
95 Error(StatusCode, String),
97 ServerHandshake(HandshakeMessage),
99 PeerHandshake {
101 public_key: Vec<u8>,
103 message: HandshakeMessage,
105 },
106}
107
108impl From<&TransparentMessage> for u8 {
109 fn from(value: &TransparentMessage) -> Self {
110 match value {
111 TransparentMessage::Noop => types::NOOP,
112 TransparentMessage::Error(_, _) => types::ERROR,
113 TransparentMessage::ServerHandshake(_) => {
114 types::HANDSHAKE_SERVER
115 }
116 TransparentMessage::PeerHandshake { .. } => {
117 types::HANDSHAKE_PEER
118 }
119 }
120 }
121}
122
123#[derive(Default, Debug)]
125pub enum ServerMessage {
126 #[default]
127 #[doc(hidden)]
128 Noop,
129 Error(StatusCode, String),
131 NewMeeting {
133 owner_id: UserId,
137 slots: HashSet<UserId>,
139 data: Value,
141 },
142 MeetingCreated(MeetingState),
144 JoinMeeting(MeetingId, UserId),
146 MeetingReady(MeetingState),
150 NewSession(SessionRequest),
152 SessionConnection {
154 session_id: SessionId,
156 peer_key: Vec<u8>,
158 },
159 SessionCreated(SessionState),
161 SessionReady(SessionState),
165 SessionActive(SessionState),
169 SessionTimeout(SessionId),
173 CloseSession(SessionId),
175 SessionFinished(SessionId),
177}
178
179impl From<&ServerMessage> for u8 {
180 fn from(value: &ServerMessage) -> Self {
181 match value {
182 ServerMessage::Noop => types::NOOP,
183 ServerMessage::Error(_, _) => types::ERROR,
184 ServerMessage::NewMeeting { .. } => types::MEETING_NEW,
185 ServerMessage::MeetingCreated(_) => {
186 types::MEETING_CREATED
187 }
188 ServerMessage::JoinMeeting(_, _) => types::MEETING_JOIN,
189 ServerMessage::MeetingReady(_) => types::MEETING_READY,
190 ServerMessage::NewSession(_) => types::SESSION_NEW,
191 ServerMessage::SessionConnection { .. } => {
192 types::SESSION_CONNECTION
193 }
194 ServerMessage::SessionCreated(_) => {
195 types::SESSION_CREATED
196 }
197 ServerMessage::SessionReady(_) => types::SESSION_READY,
198 ServerMessage::SessionActive(_) => types::SESSION_ACTIVE,
199 ServerMessage::SessionTimeout(_) => {
200 types::SESSION_TIMEOUT
201 }
202 ServerMessage::CloseSession(_) => types::SESSION_CLOSE,
203 ServerMessage::SessionFinished(_) => {
204 types::SESSION_FINISHED
205 }
206 }
207 }
208}
209
210#[derive(Default, Debug)]
212pub enum OpaqueMessage {
213 #[default]
214 #[doc(hidden)]
215 Noop,
216
217 ServerMessage(SealedEnvelope),
221
222 PeerMessage {
224 public_key: Vec<u8>,
226 session_id: Option<SessionId>,
228 envelope: SealedEnvelope,
230 },
231}
232
233impl From<&OpaqueMessage> for u8 {
234 fn from(value: &OpaqueMessage) -> Self {
235 match value {
236 OpaqueMessage::Noop => types::NOOP,
237 OpaqueMessage::ServerMessage(_) => types::OPAQUE_SERVER,
238 OpaqueMessage::PeerMessage { .. } => types::OPAQUE_PEER,
239 }
240 }
241}
242
243#[derive(Default, Debug)]
245pub enum RequestMessage {
246 #[default]
247 #[doc(hidden)]
248 Noop,
249
250 Transparent(TransparentMessage),
252
253 Opaque(OpaqueMessage),
255}
256
257impl From<&RequestMessage> for u8 {
258 fn from(value: &RequestMessage) -> Self {
259 match value {
260 RequestMessage::Noop => types::NOOP,
261 RequestMessage::Transparent(_) => types::TRANSPARENT,
262 RequestMessage::Opaque(_) => types::OPAQUE,
263 }
264 }
265}
266
267#[derive(Default, Debug)]
269pub enum ResponseMessage {
270 #[default]
271 #[doc(hidden)]
272 Noop,
273
274 Transparent(TransparentMessage),
276
277 Opaque(OpaqueMessage),
279}
280
281impl From<&ResponseMessage> for u8 {
282 fn from(value: &ResponseMessage) -> Self {
283 match value {
284 ResponseMessage::Noop => types::NOOP,
285 ResponseMessage::Transparent(_) => types::TRANSPARENT,
286 ResponseMessage::Opaque(_) => types::OPAQUE,
287 }
288 }
289}
290
291#[derive(Default, Clone, Copy, Debug)]
293pub enum Encoding {
294 #[default]
295 #[doc(hidden)]
296 Noop,
297 Blob,
299 Json,
301}
302
303impl From<Encoding> for u8 {
304 fn from(value: Encoding) -> Self {
305 match value {
306 Encoding::Noop => types::NOOP,
307 Encoding::Blob => types::ENCODING_BLOB,
308 Encoding::Json => types::ENCODING_JSON,
309 }
310 }
311}
312
313#[derive(Default, Debug)]
320pub struct Chunk {
321 pub length: usize,
323 pub contents: Vec<u8>,
325}
326
327impl Chunk {
328 const CHUNK_SIZE: usize = 65535 - TAGLEN;
329
330 pub fn split(
332 payload: &[u8],
333 transport: &mut TransportState,
334 ) -> Result<Vec<Chunk>> {
335 let mut chunks = Vec::new();
336 for chunk in payload.chunks(Self::CHUNK_SIZE) {
337 let mut contents = vec![0; chunk.len() + TAGLEN];
338 let length =
339 transport.write_message(chunk, &mut contents)?;
340 chunks.push(Chunk { length, contents });
341 }
342 Ok(chunks)
343 }
344
345 pub fn join(
347 chunks: Vec<Chunk>,
348 transport: &mut TransportState,
349 ) -> Result<Vec<u8>> {
350 let mut payload = Vec::new();
351 for chunk in chunks {
352 let mut contents = vec![0; chunk.length];
353 transport.read_message(
354 &chunk.contents[..chunk.length],
355 &mut contents,
356 )?;
357 let new_length = contents.len() - TAGLEN;
358 contents.truncate(new_length);
359 payload.extend_from_slice(contents.as_slice());
360 }
361 Ok(payload)
362 }
363}
364
365#[derive(Default, Debug)]
370pub struct SealedEnvelope {
371 pub encoding: Encoding,
373 pub chunks: Vec<Chunk>,
375 pub broadcast: bool,
377}
378
379pub struct Session {
385 owner_key: Vec<u8>,
390
391 participant_keys: HashSet<Vec<u8>>,
393
394 connections: HashSet<(Vec<u8>, Vec<u8>)>,
397
398 last_access: SystemTime,
401}
402
403impl Session {
404 pub fn owner_key(&self) -> &[u8] {
406 self.owner_key.as_slice()
407 }
408
409 pub fn public_keys(&self) -> Vec<&[u8]> {
411 let mut keys = vec![self.owner_key.as_slice()];
412 let mut participants: Vec<_> = self
413 .participant_keys
414 .iter()
415 .map(|k| k.as_slice())
416 .collect();
417 keys.append(&mut participants);
418 keys
419 }
420
421 pub fn register_connection(
423 &mut self,
424 peer: Vec<u8>,
425 other: Vec<u8>,
426 ) {
427 self.connections.insert((peer, other));
428 }
429
430 pub fn is_active(&self) -> bool {
435 let all_participants = self.public_keys();
436
437 fn check_connection(
438 connections: &HashSet<(Vec<u8>, Vec<u8>)>,
439 peer: &[u8],
440 all: &[&[u8]],
441 ) -> bool {
442 for key in all {
443 if key == &peer {
444 continue;
445 }
446 let left =
449 connections.get(&(peer.to_vec(), key.to_vec()));
450 let right =
451 connections.get(&(key.to_vec(), peer.to_vec()));
452 let is_connected = left.is_some() || right.is_some();
453 if !is_connected {
454 return false;
455 }
456 }
457 true
458 }
459
460 for key in &all_participants {
461 let is_connected_others = check_connection(
462 &self.connections,
463 key,
464 all_participants.as_slice(),
465 );
466 if !is_connected_others {
467 return false;
468 }
469 }
470
471 true
472 }
473}
474
475#[derive(Debug)]
477pub struct Meeting {
478 slots: HashMap<UserId, Option<Vec<u8>>>,
480
481 last_access: SystemTime,
484
485 data: Value,
487}
488
489impl Meeting {
490 pub fn join(&mut self, user_id: UserId, public_key: Vec<u8>) {
492 self.slots.insert(user_id, Some(public_key));
493 self.last_access = SystemTime::now();
494 }
495
496 pub fn is_full(&self) -> bool {
498 self.slots.values().all(|s| s.is_some())
499 }
500
501 pub fn participants(&self) -> Vec<Vec<u8>> {
503 self.slots
504 .values()
505 .filter(|s| s.is_some())
506 .map(|s| s.as_ref().unwrap().to_owned())
507 .collect()
508 }
509
510 pub fn data(&self) -> &Value {
512 &self.data
513 }
514}
515
516#[derive(Default)]
518pub struct MeetingManager {
519 meetings: HashMap<MeetingId, Meeting>,
520}
521
522impl MeetingManager {
523 pub fn new_meeting(
525 &mut self,
526 owner_key: Vec<u8>,
527 owner_id: UserId,
528 slots: HashSet<UserId>,
529 data: Value,
530 ) -> MeetingId {
531 let meeting_id = MeetingId::new_v4();
532 let slots: HashMap<UserId, Option<Vec<u8>>> =
533 slots.into_iter().map(|id| (id, None)).collect();
534
535 let mut meeting = Meeting {
536 slots,
537 last_access: SystemTime::now(),
538 data,
539 };
540 meeting.join(owner_id, owner_key);
541
542 self.meetings.insert(meeting_id, meeting);
543 meeting_id
544 }
545
546 pub fn remove_meeting(
548 &mut self,
549 id: &MeetingId,
550 ) -> Option<Meeting> {
551 self.meetings.remove(id)
552 }
553
554 pub fn get_meeting(&self, id: &MeetingId) -> Option<&Meeting> {
556 self.meetings.get(id)
557 }
558
559 pub fn get_meeting_mut(
561 &mut self,
562 id: &MeetingId,
563 ) -> Option<&mut Meeting> {
564 self.meetings.get_mut(id)
565 }
566
567 pub fn expired_keys(&self, timeout: u64) -> Vec<MeetingId> {
569 self.meetings
570 .iter()
571 .filter(|(_, v)| {
572 let now = SystemTime::now();
573 let ttl = Duration::from_millis(timeout * 1000);
574 if let Some(current) = v.last_access.checked_add(ttl)
575 {
576 current < now
577 } else {
578 false
579 }
580 })
581 .map(|(k, _)| *k)
582 .collect::<Vec<_>>()
583 }
584}
585
586#[derive(Default)]
588pub struct SessionManager {
589 sessions: HashMap<SessionId, Session>,
590}
591
592impl SessionManager {
593 pub fn new_session(
595 &mut self,
596 owner_key: Vec<u8>,
597 participant_keys: Vec<Vec<u8>>,
598 ) -> SessionId {
599 let session_id = SessionId::new_v4();
600 let session = Session {
601 owner_key,
602 participant_keys: participant_keys.into_iter().collect(),
603 connections: Default::default(),
604 last_access: SystemTime::now(),
605 };
606 self.sessions.insert(session_id, session);
607 session_id
608 }
609
610 pub fn get_session(&self, id: &SessionId) -> Option<&Session> {
612 self.sessions.get(id)
613 }
614
615 pub fn get_session_mut(
617 &mut self,
618 id: &SessionId,
619 ) -> Option<&mut Session> {
620 self.sessions.get_mut(id)
621 }
622
623 pub fn remove_session(
625 &mut self,
626 id: &SessionId,
627 ) -> Option<Session> {
628 self.sessions.remove(id)
629 }
630
631 pub fn touch_session(
633 &mut self,
634 id: &SessionId,
635 ) -> Option<&Session> {
636 if let Some(session) = self.sessions.get_mut(id) {
637 session.last_access = SystemTime::now();
638 Some(&*session)
639 } else {
640 None
641 }
642 }
643
644 pub fn expired_keys(&self, timeout: u64) -> Vec<SessionId> {
646 self.sessions
647 .iter()
648 .filter(|(_, v)| {
649 let now = SystemTime::now();
650 let ttl = Duration::from_millis(timeout * 1000);
651 if let Some(current) = v.last_access.checked_add(ttl)
652 {
653 current < now
654 } else {
655 false
656 }
657 })
658 .map(|(k, _)| *k)
659 .collect::<Vec<_>>()
660 }
661}
662
663#[derive(Default, Debug, Clone)]
665pub struct MeetingState {
666 pub meeting_id: MeetingId,
668 pub registered_participants: Vec<Vec<u8>>,
670 pub data: Value,
672}
673
674#[derive(Default, Debug)]
679pub struct SessionRequest {
680 pub participant_keys: Vec<Vec<u8>>,
682}
683
684#[derive(Default, Debug, Clone)]
686pub struct SessionState {
687 pub session_id: SessionId,
689 pub all_participants: Vec<Vec<u8>>,
691}
692
693impl SessionState {
694 pub fn len(&self) -> usize {
696 self.all_participants.len()
697 }
698
699 pub fn party_number(
701 &self,
702 public_key: impl AsRef<[u8]>,
703 ) -> Option<PartyNumber> {
704 self.all_participants
705 .iter()
706 .position(|k| k == public_key.as_ref())
707 .map(|pos| PartyNumber::new((pos + 1) as u16).unwrap())
708 }
709
710 pub fn peer_key(
712 &self,
713 party_number: PartyNumber,
714 ) -> Option<&[u8]> {
715 for (index, key) in self.all_participants.iter().enumerate() {
716 if index + 1 == party_number.get() as usize {
717 return Some(key.as_slice());
718 }
719 }
720 None
721 }
722
723 pub fn connections(&self, own_key: &[u8]) -> &[Vec<u8>] {
725 if self.all_participants.is_empty() {
726 return &[];
727 }
728
729 if let Some(position) =
730 self.all_participants.iter().position(|k| k == own_key)
731 {
732 if position < self.all_participants.len() - 1 {
733 &self.all_participants[position + 1..]
734 } else {
735 &[]
736 }
737 } else {
738 &[]
739 }
740 }
741
742 pub fn recipients(&self, own_key: &[u8]) -> Vec<Vec<u8>> {
744 self.all_participants
745 .iter()
746 .filter(|&k| k != own_key)
747 .map(|k| k.to_vec())
748 .collect()
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::Chunk;
755 use crate::PATTERN;
756 use anyhow::Result;
757
758 #[test]
759 fn chunks_split_join() -> Result<()> {
760 let builder_1 = snow::Builder::new(PATTERN.parse()?);
761 let builder_2 = snow::Builder::new(PATTERN.parse()?);
762
763 let keypair1 = builder_1.generate_keypair()?;
764 let keypair2 = builder_2.generate_keypair()?;
765
766 let mut initiator = builder_1
767 .local_private_key(&keypair1.private)
768 .remote_public_key(&keypair2.public)
769 .build_initiator()?;
770
771 let mut responder = builder_2
772 .local_private_key(&keypair2.private)
773 .remote_public_key(&keypair1.public)
774 .build_responder()?;
775
776 let (mut read_buf, mut first_msg, mut second_msg) =
777 ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
778
779 let len = initiator.write_message(&[], &mut first_msg)?;
781
782 responder.read_message(&first_msg[..len], &mut read_buf)?;
784
785 let len = responder.write_message(&[], &mut second_msg)?;
787
788 initiator.read_message(&second_msg[..len], &mut read_buf)?;
790
791 let mut initiator = initiator.into_transport_mode()?;
793 let mut responder = responder.into_transport_mode()?;
794
795 let mock_payload = vec![0; 76893];
796
797 let chunks = Chunk::split(&mock_payload, &mut initiator)?;
799 assert_eq!(2, chunks.len());
800
801 let decrypted_payload = Chunk::join(chunks, &mut responder)?;
803 assert_eq!(mock_payload, decrypted_payload);
804
805 Ok(())
806 }
807}