use async_trait::async_trait;
use json_rpc2::{futures::*, Error, Request, Response, Result, RpcError};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use uuid::Uuid;
use super::server::{
Group, Notification, Parameters, Session, SessionKind, State,
};
#[derive(Debug, Error)]
pub enum ServiceError {
#[error("parties must be greater than one")]
PartiesTooSmall,
#[error("threshold must be greater than zero")]
ThresholdTooSmall,
#[error("threshold must be less than parties")]
ThresholdRange,
#[error("group {0} is full, cannot accept new connections")]
GroupFull(Uuid),
#[error("group {0} does not exist")]
GroupDoesNotExist(Uuid),
#[error("group {0} does not exist")]
SessionDoesNotExist(Uuid),
#[error("party {0} does not exist")]
PartyDoesNotExist(u16),
#[error("party {0} is not valid in this context")]
BadParty(u16),
#[error("receiver {0} for peer to peer message does not exist")]
BadPeerReceiver(u16),
#[error("client {0} does not belong to the group {1}")]
BadConnection(usize, Uuid),
}
pub const CLOSE_CONNECTION: &str = "close-connection";
pub const GROUP_CREATE: &str = "Group.create";
pub const GROUP_JOIN: &str = "Group.join";
pub const SESSION_CREATE: &str = "Session.create";
pub const SESSION_JOIN: &str = "Session.join";
pub const SESSION_SIGNUP: &str = "Session.signup";
pub const SESSION_LOAD: &str = "Session.load";
pub const SESSION_MESSAGE: &str = "Session.message";
pub const SESSION_FINISH: &str = "Session.finish";
pub const NOTIFY_PROPOSAL: &str = "Notify.proposal";
pub const NOTIFY_SIGNED: &str = "Notify.signed";
pub const SESSION_CREATE_EVENT: &str = "sessionCreate";
pub const SESSION_SIGNUP_EVENT: &str = "sessionSignup";
pub const SESSION_LOAD_EVENT: &str = "sessionLoad";
pub const SESSION_MESSAGE_EVENT: &str = "sessionMessage";
pub const SESSION_CLOSED_EVENT: &str = "sessionClosed";
pub const NOTIFY_PROPOSAL_EVENT: &str = "notifyProposal";
pub const NOTIFY_SIGNED_EVENT: &str = "notifySigned";
type GroupCreateParams = (String, Parameters);
type SessionCreateParams = (Uuid, SessionKind, Option<Value>);
type SessionJoinParams = (Uuid, Uuid, SessionKind);
type SessionSignupParams = (Uuid, Uuid, SessionKind);
type SessionLoadParams = (Uuid, Uuid, SessionKind, u16);
type SessionMessageParams = (Uuid, Uuid, SessionKind, Message);
type SessionFinishParams = (Uuid, Uuid, u16);
type NotifyProposalParams = (Uuid, Uuid, String, String);
type NotifySignedParams = (Uuid, Uuid, Value);
#[derive(Serialize, Deserialize)]
struct Message {
round: u16,
sender: u16,
receiver: Option<u16>,
uuid: String,
body: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct Proposal {
#[serde(rename = "sessionId")]
session_id: Uuid,
#[serde(rename = "proposalId")]
proposal_id: String,
message: String,
}
pub struct ServiceHandler;
#[async_trait]
impl Service for ServiceHandler {
type Data = (usize, Arc<RwLock<State>>, Arc<Mutex<Option<Notification>>>);
async fn handle(
&self,
req: &Request,
ctx: &Self::Data,
) -> Result<Option<Response>> {
let response = match req.method() {
GROUP_CREATE => {
let (conn_id, state, _) = ctx;
let params: GroupCreateParams = req.deserialize()?;
let (label, parameters) = params;
if parameters.parties <= 1 {
return Err(Error::from(Box::from(
ServiceError::PartiesTooSmall,
)));
} else if parameters.threshold == 0 {
return Err(Error::from(Box::from(
ServiceError::ThresholdTooSmall,
)));
} else if parameters.threshold >= parameters.parties {
return Err(Error::from(Box::from(
ServiceError::ThresholdRange,
)));
}
let group =
Group::new(*conn_id, parameters.clone(), label.clone());
let res = serde_json::to_value(&group.uuid).unwrap();
let mut writer = state.write().await;
writer.groups.insert(group.uuid.clone(), group);
Some((req, res).into())
}
GROUP_JOIN => {
let (conn_id, state, _) = ctx;
let group_id: Uuid = req.deserialize()?;
let mut writer = state.write().await;
if let Some(group) = writer.groups.get_mut(&group_id) {
if group.clients.len() == group.params.parties as usize {
let error = ServiceError::GroupFull(group_id);
let err = RpcError::new(
error.to_string(),
Some(CLOSE_CONNECTION.to_string()),
);
Some((req, err).into())
} else {
if let None =
group.clients.iter().find(|c| *c == conn_id)
{
group.clients.push(*conn_id);
}
let res = serde_json::to_value(group).unwrap();
Some((req, res).into())
}
} else {
return Err(Error::from(Box::from(
ServiceError::GroupDoesNotExist(group_id),
)));
}
}
SESSION_CREATE => {
let (conn_id, state, notification) = ctx;
let params: SessionCreateParams = req.deserialize()?;
let (group_id, kind, value) = params;
let mut writer = state.write().await;
let group =
get_group_mut(&conn_id, &group_id, &mut writer.groups)?;
let session = Session::from((kind.clone(), value));
let key = session.uuid.clone();
group.sessions.insert(key, session.clone());
if let SessionKind::Keygen = kind {
let value =
serde_json::to_value((SESSION_CREATE_EVENT, &session))
.unwrap();
let response: Response = value.into();
let ctx = Notification::Group {
group_id,
filter: Some(vec![*conn_id]),
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
}
let res = serde_json::to_value(&session).unwrap();
Some((req, res).into())
}
SESSION_JOIN => {
let (conn_id, state, _) = ctx;
let params: SessionJoinParams = req.deserialize()?;
let (group_id, session_id, _kind) = params;
let mut writer = state.write().await;
let group =
get_group_mut(&conn_id, &group_id, &mut writer.groups)?;
if let Some(session) = group.sessions.get_mut(&session_id) {
let res = serde_json::to_value(&session).unwrap();
Some((req, res).into())
} else {
return Err(Error::from(Box::from(
ServiceError::SessionDoesNotExist(session_id),
)));
}
}
SESSION_SIGNUP => {
let (conn_id, state, notification) = ctx;
let params: SessionSignupParams = req.deserialize()?;
let (group_id, session_id, kind) = params;
let mut writer = state.write().await;
let group =
get_group_mut(&conn_id, &group_id, &mut writer.groups)?;
if let Some(session) = group.sessions.get_mut(&session_id) {
let party_number = session.signup(*conn_id);
tracing::info!(party_number, "session signup {}", conn_id);
if threshold(
&kind,
&group.params,
session.party_signups.len(),
) {
let value = serde_json::to_value((
SESSION_SIGNUP_EVENT,
&session_id,
))
.unwrap();
let response: Response = value.into();
let ctx = Notification::Session {
group_id,
session_id,
filter: None,
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
}
let res = serde_json::to_value(&party_number).unwrap();
Some((req, res).into())
} else {
return Err(Error::from(Box::from(
ServiceError::SessionDoesNotExist(session_id),
)));
}
}
SESSION_LOAD => {
let (conn_id, state, notification) = ctx;
let params: SessionLoadParams = req.deserialize()?;
let (group_id, session_id, kind, party_number) = params;
let mut writer = state.write().await;
let group =
get_group_mut(&conn_id, &group_id, &mut writer.groups)?;
if let Some(session) = group.sessions.get_mut(&session_id) {
let res = serde_json::to_value(&party_number).unwrap();
match session.load(&group.params, *conn_id, party_number) {
Ok(_) => {
if threshold(
&kind,
&group.params,
session.party_signups.len(),
) {
let value = serde_json::to_value((
SESSION_LOAD_EVENT,
&session_id,
))
.unwrap();
let response: Response = value.into();
let ctx = Notification::Session {
group_id,
session_id,
filter: None,
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
}
Some((req, res).into())
}
Err(err) => return Err(Error::from(Box::from(err))),
}
} else {
return Err(Error::from(Box::from(
ServiceError::SessionDoesNotExist(session_id),
)));
}
}
SESSION_FINISH => {
let (conn_id, state, notification) = ctx;
let params: SessionFinishParams = req.deserialize()?;
let (group_id, session_id, party_number) = params;
let mut writer = state.write().await;
let group =
get_group_mut(&conn_id, &group_id, &mut writer.groups)?;
if let Some(session) = group.sessions.get_mut(&session_id) {
let existing_signup = session
.party_signups
.iter()
.find(|(s, _)| s == &party_number);
if let Some((_, conn)) = existing_signup {
if conn != conn_id {
return Err(Error::from(Box::from(
ServiceError::BadParty(party_number),
)));
}
session.finished.insert(party_number);
let mut signups = session
.party_signups
.iter()
.map(|(n, _)| n.clone())
.collect::<Vec<u16>>();
let mut completed = session
.finished
.iter()
.cloned()
.collect::<Vec<u16>>();
signups.sort();
completed.sort();
if signups == completed {
let value = serde_json::to_value((
SESSION_CLOSED_EVENT,
completed,
))
.unwrap();
let response: Response = value.into();
let ctx = Notification::Session {
group_id,
session_id,
filter: None,
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
}
Some(req.into())
} else {
return Err(Error::from(Box::from(
ServiceError::PartyDoesNotExist(party_number),
)));
}
} else {
return Err(Error::from(Box::from(
ServiceError::SessionDoesNotExist(session_id),
)));
}
}
SESSION_MESSAGE => {
let (conn_id, state, notification) = ctx;
let params: SessionMessageParams = req.deserialize()?;
let (group_id, session_id, _kind, msg) = params;
let reader = state.read().await;
let (_group, session) = get_group_session(
&conn_id,
&group_id,
&session_id,
&reader.groups,
)?;
if let Some(receiver) = &msg.receiver {
if let Some(s) =
session.party_signups.iter().find(|s| s.0 == *receiver)
{
let value =
serde_json::to_value((SESSION_MESSAGE_EVENT, msg))
.unwrap();
let response: Response = value.into();
let message = (s.1, response);
let ctx = Notification::Relay {
messages: vec![message],
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
} else {
return Err(Error::from(Box::from(
ServiceError::BadPeerReceiver(*receiver),
)));
}
} else {
let value =
serde_json::to_value((SESSION_MESSAGE_EVENT, msg))
.unwrap();
let response: Response = value.clone().into();
let ctx = Notification::Session {
group_id,
session_id,
filter: Some(vec![*conn_id]),
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
}
Some(req.into())
}
NOTIFY_PROPOSAL => {
let (conn_id, _state, notification) = ctx;
let params: NotifyProposalParams = req.deserialize()?;
let (group_id, session_id, proposal_id, message) = params;
let proposal = Proposal {
session_id,
proposal_id,
message,
};
let value =
serde_json::to_value((NOTIFY_PROPOSAL_EVENT, &proposal))
.unwrap();
let response: Response = value.into();
let ctx = Notification::Group {
group_id,
filter: Some(vec![*conn_id]),
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
Some(req.into())
}
NOTIFY_SIGNED => {
let (conn_id, state, notification) = ctx;
let params: NotifySignedParams = req.deserialize()?;
let (group_id, session_id, value) = params;
let reader = state.read().await;
let (_group, session) = get_group_session(
&conn_id,
&group_id,
&session_id,
&reader.groups,
)?;
let participants = session
.party_signups
.iter()
.map(|(_, c)| c.clone())
.collect::<Vec<usize>>();
let value =
serde_json::to_value((NOTIFY_SIGNED_EVENT, value)).unwrap();
let response: Response = value.into();
let ctx = Notification::Group {
group_id,
filter: Some(participants),
response,
};
let mut writer = notification.lock().await;
*writer = Some(ctx);
Some(req.into())
}
_ => None,
};
Ok(response)
}
}
fn get_group_mut<'a>(
conn_id: &usize,
group_id: &Uuid,
groups: &'a mut HashMap<Uuid, Group>,
) -> Result<&'a mut Group> {
if let Some(group) = groups.get_mut(group_id) {
if let Some(_) = group.clients.iter().find(|c| *c == conn_id) {
Ok(group)
} else {
return Err(Error::from(Box::from(ServiceError::BadConnection(
*conn_id,
group_id.clone(),
))));
}
} else {
return Err(Error::from(Box::from(ServiceError::GroupDoesNotExist(
group_id.clone(),
))));
}
}
fn get_group<'a>(
conn_id: &usize,
group_id: &Uuid,
groups: &'a HashMap<Uuid, Group>,
) -> Result<&'a Group> {
if let Some(group) = groups.get(group_id) {
if let Some(_) = group.clients.iter().find(|c| *c == conn_id) {
Ok(group)
} else {
return Err(Error::from(Box::from(ServiceError::BadConnection(
*conn_id,
group_id.clone(),
))));
}
} else {
return Err(Error::from(Box::from(ServiceError::GroupDoesNotExist(
group_id.clone(),
))));
}
}
fn get_group_session<'a>(
conn_id: &usize,
group_id: &Uuid,
session_id: &Uuid,
groups: &'a HashMap<Uuid, Group>,
) -> Result<(&'a Group, &'a Session)> {
let group = get_group(conn_id, group_id, groups)?;
if let Some(session) = group.sessions.get(session_id) {
Ok((group, session))
} else {
return Err(Error::from(Box::from(ServiceError::SessionDoesNotExist(
session_id.clone(),
))));
}
}
fn threshold(
kind: &SessionKind,
params: &Parameters,
num_entries: usize,
) -> bool {
let parties = params.parties as usize;
let threshold = params.threshold as usize;
let required_num_entries = match kind {
SessionKind::Keygen => parties,
SessionKind::Sign => threshold + 1,
};
num_entries == required_num_entries
}