use anyhow::{anyhow, ensure};
use entropy_protocol::{
execute_protocol::{execute_dkg, execute_reshare, execute_signing_protocol, Channels},
protocol_transport::{
errors::WsError,
noise::{noise_handshake_initiator, noise_handshake_responder},
ws_to_channels, SubscribeMessage, WsChannels,
},
KeyParams, KeyShareWithAuxInfo, Listener, PartyId, RecoverableSignature, SessionId,
ValidatorInfo,
};
use entropy_shared::X25519PublicKey;
use futures::future;
use sp_core::{sr25519, Pair};
use std::{
collections::BTreeSet,
fmt,
sync::{Arc, Mutex},
time::Duration,
};
use subxt::utils::AccountId32;
use synedrion::{AuxInfo, KeyResharingInputs, KeyShare, NewHolder, OldHolder, ThresholdKeyShare};
use tokio::{
net::{TcpListener, TcpStream},
time::timeout,
};
use tokio_tungstenite::connect_async;
use x25519_dalek::StaticSecret;
#[derive(Clone)]
struct ServerState {
x25519_secret_key: StaticSecret,
listener: Arc<Mutex<Vec<Listener>>>,
}
#[derive(Clone)]
pub enum ProtocolOutput {
Sign(RecoverableSignature),
Reshare(ThresholdKeyShare<KeyParams, PartyId>),
Dkg(KeyShareWithAuxInfo),
}
impl fmt::Debug for ProtocolOutput {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Success")
}
}
pub async fn server(
socket: TcpListener,
validators_info: Vec<ValidatorInfo>,
pair: sr25519::Pair,
x25519_secret_key: StaticSecret,
session_id: SessionId,
keyshare: Option<KeyShare<KeyParams, PartyId>>,
threshold_keyshare: Option<ThresholdKeyShare<KeyParams, PartyId>>,
aux_info: Option<AuxInfo<KeyParams, PartyId>>,
threshold: usize,
) -> anyhow::Result<ProtocolOutput> {
let account_id = AccountId32(pair.public().0);
let (rx_ready, rx_from_others, listener) = Listener::new(validators_info.clone(), &account_id);
let state = ServerState {
listener: Arc::new(Mutex::new(vec![listener])),
x25519_secret_key: x25519_secret_key.clone(),
};
let state_clone = state.clone();
tokio::spawn(async move {
while let Ok((stream, _address)) = socket.accept().await {
let state_clone2 = state_clone.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(state_clone2, stream).await {
tracing::warn!("Error when handling ws connection {}", e);
};
});
}
});
open_protocol_connections(&validators_info, &session_id, &pair, &x25519_secret_key, &state)
.await?;
let channels = {
let ready = timeout(Duration::from_secs(10), rx_ready).await?;
let broadcast_out = ready??;
Channels(broadcast_out, rx_from_others)
};
let tss_accounts: Vec<AccountId32> =
validators_info.iter().map(|validator_info| validator_info.tss_account.clone()).collect();
match session_id.clone() {
SessionId::Sign(session_info) => {
let rsig = execute_signing_protocol(
session_id,
channels,
&keyshare.unwrap(),
&aux_info.unwrap(),
&session_info.message_hash,
&pair,
tss_accounts,
)
.await?;
let (signature, recovery_id) = rsig.to_backend();
Ok(ProtocolOutput::Sign(RecoverableSignature { signature, recovery_id }))
},
SessionId::Reshare { .. } => {
let old_key = threshold_keyshare.unwrap();
let party_ids: BTreeSet<PartyId> =
tss_accounts.iter().cloned().map(PartyId::new).collect();
let inputs = KeyResharingInputs {
old_holder: Some(OldHolder { key_share: old_key.clone() }),
new_holder: Some(NewHolder {
verifying_key: old_key.verifying_key(),
old_threshold: party_ids.len(),
old_holders: party_ids.clone(),
}),
new_holders: party_ids.clone(),
new_threshold: old_key.threshold(),
};
let new_keyshare =
execute_reshare(session_id, channels, &pair, inputs, &party_ids, None).await?;
Ok(ProtocolOutput::Reshare(new_keyshare.0))
},
SessionId::Dkg { .. } => {
let keyshare_and_aux_info =
execute_dkg(session_id, channels, &pair, tss_accounts, threshold).await?;
Ok(ProtocolOutput::Dkg(keyshare_and_aux_info))
},
}
}
async fn handle_connection(state: ServerState, raw_stream: TcpStream) -> anyhow::Result<()> {
let ws_stream = tokio_tungstenite::accept_async(raw_stream).await?;
let (mut encrypted_connection, serialized_signed_message) =
noise_handshake_responder(ws_stream, &state.x25519_secret_key).await?;
let remote_public_key = encrypted_connection.remote_public_key()?;
let (subscribe_response, ws_channels_option) = match handle_initial_incoming_ws_message(
serialized_signed_message,
remote_public_key,
state,
)
.await
{
Ok((ws_channels, party_id)) => (Ok(()), Some((ws_channels, party_id))),
Err(err) => (Err(format!("{err:?}")), None),
};
let subscribe_response_vec = bincode::serialize(&subscribe_response)?;
encrypted_connection.send(subscribe_response_vec).await?;
let (ws_channels, remote_party_id) = ws_channels_option.ok_or(WsError::BadSubscribeMessage)?;
ws_to_channels(encrypted_connection, ws_channels, remote_party_id).await?;
Ok(())
}
async fn handle_initial_incoming_ws_message(
serialized_subscribe_message: Vec<u8>,
_remote_public_key: X25519PublicKey,
state: ServerState,
) -> anyhow::Result<(WsChannels, PartyId)> {
let msg: SubscribeMessage = bincode::deserialize(&serialized_subscribe_message)?;
tracing::info!("Got ws connection, with subscribe message: {msg:?}");
ensure!(msg.verify()?, "Invalid signature");
let ws_channels = get_ws_channels(&msg.session_id, &msg.account_id(), &state)?;
Ok((ws_channels, PartyId::new(msg.account_id())))
}
fn get_ws_channels(
_session_id: &SessionId,
tss_account: &AccountId32,
state: &ServerState,
) -> anyhow::Result<WsChannels> {
let mut listeners = state.listener.lock().unwrap();
let listener = listeners.get_mut(0).ok_or(anyhow::anyhow!("No listener"))?;
let ws_channels = listener.subscribe(tss_account)?;
if ws_channels.is_final {
let listener = listeners.pop().ok_or(anyhow::anyhow!("No listener"))?;
let (tx, broadcaster) = listener.into_broadcaster();
let _ = tx.send(Ok(broadcaster));
};
Ok(ws_channels)
}
async fn open_protocol_connections(
validators_info: &[ValidatorInfo],
session_id: &SessionId,
signer: &sr25519::Pair,
x25519_secret_key: &x25519_dalek::StaticSecret,
state: &ServerState,
) -> anyhow::Result<()> {
let connect_to_validators = validators_info
.iter()
.filter(|validator_info| {
signer.public().0 > validator_info.tss_account.0
})
.map(|validator_info| async move {
let ws_endpoint = format!("ws://{}/v1/ws", validator_info.ip_address);
let (ws_stream, _response) = connect_async(ws_endpoint).await?;
let subscribe_message_vec =
bincode::serialize(&SubscribeMessage::new(session_id.clone(), signer)?)?;
let mut encrypted_connection = noise_handshake_initiator(
ws_stream,
x25519_secret_key,
validator_info.x25519_public_key,
subscribe_message_vec,
)
.await?;
let response_message = encrypted_connection.recv().await?;
let subscribe_response: Result<(), String> = bincode::deserialize(&response_message)?;
if let Err(error_message) = subscribe_response {
return Err(anyhow!(error_message));
}
let ws_channels = get_ws_channels(session_id, &validator_info.tss_account, state)?;
let remote_party_id = PartyId::new(validator_info.tss_account.clone());
tokio::spawn(async move {
if let Err(err) =
ws_to_channels(encrypted_connection, ws_channels, remote_party_id).await
{
tracing::warn!("{:?}", err);
};
});
Ok(())
})
.collect::<Vec<_>>();
future::try_join_all(connect_to_validators).await?;
Ok(())
}