use crate::{encoding::types, PartyNumber, Result, TAGLEN};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use snow::{HandshakeState, TransportState};
use std::{
collections::{HashMap, HashSet},
time::{Duration, SystemTime},
};
pub type SessionId = uuid::Uuid;
#[derive(
Debug, Clone, Copy, Hash, Eq, PartialEq, Serialize, Deserialize,
)]
pub struct UserId([u8; 32]);
impl AsRef<[u8; 32]> for UserId {
fn as_ref(&self) -> &[u8; 32] {
&self.0
}
}
impl From<[u8; 32]> for UserId {
fn from(value: [u8; 32]) -> Self {
Self(value)
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub struct Parameters {
pub parties: u16,
pub threshold: u16,
}
impl Default for Parameters {
fn default() -> Self {
Self {
parties: 3,
threshold: 1,
}
}
}
pub enum ProtocolState {
Handshake(Box<HandshakeState>),
Transport(TransportState),
}
#[derive(Default, Debug)]
pub enum HandshakeMessage {
#[default]
#[doc(hidden)]
Noop,
Initiator(usize, Vec<u8>),
Responder(usize, Vec<u8>),
}
impl From<&HandshakeMessage> for u8 {
fn from(value: &HandshakeMessage) -> Self {
match value {
HandshakeMessage::Noop => types::NOOP,
HandshakeMessage::Initiator(_, _) => {
types::HANDSHAKE_INITIATOR
}
HandshakeMessage::Responder(_, _) => {
types::HANDSHAKE_RESPONDER
}
}
}
}
#[derive(Default, Debug)]
pub enum TransparentMessage {
#[default]
#[doc(hidden)]
Noop,
Error(StatusCode, String),
ServerHandshake(HandshakeMessage),
PeerHandshake {
public_key: Vec<u8>,
message: HandshakeMessage,
},
}
impl From<&TransparentMessage> for u8 {
fn from(value: &TransparentMessage) -> Self {
match value {
TransparentMessage::Noop => types::NOOP,
TransparentMessage::Error(_, _) => types::ERROR,
TransparentMessage::ServerHandshake(_) => {
types::HANDSHAKE_SERVER
}
TransparentMessage::PeerHandshake { .. } => {
types::HANDSHAKE_PEER
}
}
}
}
#[derive(Default, Debug)]
pub enum ServerMessage {
#[default]
#[doc(hidden)]
Noop,
Error(StatusCode, String),
NewSession(SessionRequest),
SessionConnection {
session_id: SessionId,
peer_key: Vec<u8>,
},
SessionCreated(SessionState),
SessionReady(SessionState),
SessionActive(SessionState),
SessionTimeout(SessionId),
CloseSession(SessionId),
SessionFinished(SessionId),
}
impl From<&ServerMessage> for u8 {
fn from(value: &ServerMessage) -> Self {
match value {
ServerMessage::Noop => types::NOOP,
ServerMessage::Error(_, _) => types::ERROR,
ServerMessage::NewSession(_) => types::SESSION_NEW,
ServerMessage::SessionConnection { .. } => {
types::SESSION_CONNECTION
}
ServerMessage::SessionCreated(_) => {
types::SESSION_CREATED
}
ServerMessage::SessionReady(_) => types::SESSION_READY,
ServerMessage::SessionActive(_) => types::SESSION_ACTIVE,
ServerMessage::SessionTimeout(_) => {
types::SESSION_TIMEOUT
}
ServerMessage::CloseSession(_) => types::SESSION_CLOSE,
ServerMessage::SessionFinished(_) => {
types::SESSION_FINISHED
}
}
}
}
#[derive(Default, Debug)]
pub enum OpaqueMessage {
#[default]
#[doc(hidden)]
Noop,
ServerMessage(SealedEnvelope),
PeerMessage {
public_key: Vec<u8>,
session_id: Option<SessionId>,
envelope: SealedEnvelope,
},
}
impl From<&OpaqueMessage> for u8 {
fn from(value: &OpaqueMessage) -> Self {
match value {
OpaqueMessage::Noop => types::NOOP,
OpaqueMessage::ServerMessage(_) => types::OPAQUE_SERVER,
OpaqueMessage::PeerMessage { .. } => types::OPAQUE_PEER,
}
}
}
#[derive(Default, Debug)]
pub enum RequestMessage {
#[default]
#[doc(hidden)]
Noop,
Transparent(TransparentMessage),
Opaque(OpaqueMessage),
}
impl From<&RequestMessage> for u8 {
fn from(value: &RequestMessage) -> Self {
match value {
RequestMessage::Noop => types::NOOP,
RequestMessage::Transparent(_) => types::TRANSPARENT,
RequestMessage::Opaque(_) => types::OPAQUE,
}
}
}
#[derive(Default, Debug)]
pub enum ResponseMessage {
#[default]
#[doc(hidden)]
Noop,
Transparent(TransparentMessage),
Opaque(OpaqueMessage),
}
impl From<&ResponseMessage> for u8 {
fn from(value: &ResponseMessage) -> Self {
match value {
ResponseMessage::Noop => types::NOOP,
ResponseMessage::Transparent(_) => types::TRANSPARENT,
ResponseMessage::Opaque(_) => types::OPAQUE,
}
}
}
#[derive(Default, Clone, Copy, Debug)]
pub enum Encoding {
#[default]
#[doc(hidden)]
Noop,
Blob,
Json,
}
impl From<Encoding> for u8 {
fn from(value: Encoding) -> Self {
match value {
Encoding::Noop => types::NOOP,
Encoding::Blob => types::ENCODING_BLOB,
Encoding::Json => types::ENCODING_JSON,
}
}
}
#[derive(Default, Debug)]
pub struct Chunk {
pub length: usize,
pub contents: Vec<u8>,
}
impl Chunk {
const CHUNK_SIZE: usize = 65535 - TAGLEN;
pub fn split(
payload: &[u8],
transport: &mut TransportState,
) -> Result<Vec<Chunk>> {
let mut chunks = Vec::new();
for chunk in payload.chunks(Self::CHUNK_SIZE) {
let mut contents = vec![0; chunk.len() + TAGLEN];
let length =
transport.write_message(chunk, &mut contents)?;
chunks.push(Chunk { length, contents });
}
Ok(chunks)
}
pub fn join(
chunks: Vec<Chunk>,
transport: &mut TransportState,
) -> Result<Vec<u8>> {
let mut payload = Vec::new();
for chunk in chunks {
let mut contents = vec![0; chunk.length];
transport.read_message(
&chunk.contents[..chunk.length],
&mut contents,
)?;
let new_length = contents.len() - TAGLEN;
contents.truncate(new_length);
payload.extend_from_slice(contents.as_slice());
}
Ok(payload)
}
}
#[derive(Default, Debug)]
pub struct SealedEnvelope {
pub encoding: Encoding,
pub chunks: Vec<Chunk>,
pub broadcast: bool,
}
pub struct Session {
owner_key: Vec<u8>,
participant_keys: HashSet<Vec<u8>>,
connections: HashSet<(Vec<u8>, Vec<u8>)>,
last_access: SystemTime,
}
impl Session {
pub fn owner_key(&self) -> &[u8] {
self.owner_key.as_slice()
}
pub fn public_keys(&self) -> Vec<&[u8]> {
let mut keys = vec![self.owner_key.as_slice()];
let mut participants: Vec<_> = self
.participant_keys
.iter()
.map(|k| k.as_slice())
.collect();
keys.append(&mut participants);
keys
}
pub fn register_connection(
&mut self,
peer: Vec<u8>,
other: Vec<u8>,
) {
self.connections.insert((peer, other));
}
pub fn is_active(&self) -> bool {
let all_participants = self.public_keys();
fn check_connection(
connections: &HashSet<(Vec<u8>, Vec<u8>)>,
peer: &[u8],
all: &[&[u8]],
) -> bool {
for key in all {
if key == &peer {
continue;
}
let left =
connections.get(&(peer.to_vec(), key.to_vec()));
let right =
connections.get(&(key.to_vec(), peer.to_vec()));
let is_connected = left.is_some() || right.is_some();
if !is_connected {
return false;
}
}
true
}
for key in &all_participants {
let is_connected_others = check_connection(
&self.connections,
key,
all_participants.as_slice(),
);
if !is_connected_others {
return false;
}
}
true
}
}
#[derive(Default)]
pub struct SessionManager {
sessions: HashMap<SessionId, Session>,
}
impl SessionManager {
pub fn new_session(
&mut self,
owner_key: Vec<u8>,
participant_keys: Vec<Vec<u8>>,
) -> SessionId {
let session_id = SessionId::new_v4();
let session = Session {
owner_key,
participant_keys: participant_keys.into_iter().collect(),
connections: Default::default(),
last_access: SystemTime::now(),
};
self.sessions.insert(session_id, session);
session_id
}
pub fn get_session(&self, id: &SessionId) -> Option<&Session> {
self.sessions.get(id)
}
pub fn get_session_mut(
&mut self,
id: &SessionId,
) -> Option<&mut Session> {
self.sessions.get_mut(id)
}
pub fn remove_session(
&mut self,
id: &SessionId,
) -> Option<Session> {
self.sessions.remove(id)
}
pub fn touch_session(
&mut self,
id: &SessionId,
) -> Option<&Session> {
if let Some(session) = self.sessions.get_mut(id) {
session.last_access = SystemTime::now();
Some(&*session)
} else {
None
}
}
pub fn expired_keys(&self, timeout: u64) -> Vec<SessionId> {
self.sessions
.iter()
.filter(|(_, v)| {
let now = SystemTime::now();
let ttl = Duration::from_millis(timeout * 1000);
if let Some(current) = v.last_access.checked_add(ttl)
{
current < now
} else {
false
}
})
.map(|(k, _)| *k)
.collect::<Vec<_>>()
}
}
#[derive(Default, Debug)]
pub struct SessionRequest {
pub participant_keys: Vec<Vec<u8>>,
}
#[derive(Default, Debug, Clone)]
pub struct SessionState {
pub session_id: SessionId,
pub all_participants: Vec<Vec<u8>>,
}
impl SessionState {
pub fn len(&self) -> usize {
self.all_participants.len()
}
pub fn party_number(
&self,
public_key: impl AsRef<[u8]>,
) -> Option<PartyNumber> {
self.all_participants
.iter()
.position(|k| k == public_key.as_ref())
.map(|pos| PartyNumber::new((pos + 1) as u16).unwrap())
}
pub fn peer_key(
&self,
party_number: PartyNumber,
) -> Option<&[u8]> {
for (index, key) in self.all_participants.iter().enumerate() {
if index + 1 == party_number.get() as usize {
return Some(key.as_slice());
}
}
None
}
pub fn connections(&self, own_key: &[u8]) -> &[Vec<u8>] {
if self.all_participants.is_empty() {
return &[];
}
if let Some(position) =
self.all_participants.iter().position(|k| k == own_key)
{
if position < self.all_participants.len() - 1 {
&self.all_participants[position + 1..]
} else {
&[]
}
} else {
&[]
}
}
pub fn recipients(&self, own_key: &[u8]) -> Vec<Vec<u8>> {
self.all_participants
.iter()
.filter(|&k| k != own_key)
.map(|k| k.to_vec())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::Chunk;
use crate::PATTERN;
use anyhow::Result;
#[test]
fn chunks_split_join() -> Result<()> {
let builder_1 = snow::Builder::new(PATTERN.parse()?);
let builder_2 = snow::Builder::new(PATTERN.parse()?);
let keypair1 = builder_1.generate_keypair()?;
let keypair2 = builder_2.generate_keypair()?;
let mut initiator = builder_1
.local_private_key(&keypair1.private)
.remote_public_key(&keypair2.public)
.build_initiator()?;
let mut responder = builder_2
.local_private_key(&keypair2.private)
.remote_public_key(&keypair1.public)
.build_responder()?;
let (mut read_buf, mut first_msg, mut second_msg) =
([0u8; 1024], [0u8; 1024], [0u8; 1024]);
let len = initiator.write_message(&[], &mut first_msg)?;
responder.read_message(&first_msg[..len], &mut read_buf)?;
let len = responder.write_message(&[], &mut second_msg)?;
initiator.read_message(&second_msg[..len], &mut read_buf)?;
let mut initiator = initiator.into_transport_mode()?;
let mut responder = responder.into_transport_mode()?;
let mock_payload = vec![0; 76893];
let chunks = Chunk::split(&mock_payload, &mut initiator)?;
assert_eq!(2, chunks.len());
let decrypted_payload = Chunk::join(chunks, &mut responder)?;
assert_eq!(mock_payload, decrypted_payload);
Ok(())
}
}