mpc_client/protocols/
bridge.rsuse std::num::NonZeroU16;
use crate::{
protocols::Driver, EventStream, NetworkTransport, Result,
Transport,
};
use futures::StreamExt;
use mpc_protocol::{Event, SessionId, SessionState};
use mpc_driver::{Error, ProtocolDriver, Round};
use super::public_key_to_str;
pub(crate) struct Bridge<D: ProtocolDriver> {
pub(crate) transport: Transport,
pub(crate) driver: Option<D>,
pub(crate) session: SessionState,
pub(crate) party_number: NonZeroU16,
}
impl<D: ProtocolDriver> Bridge<D> {
pub async fn handle_event(
&mut self,
event: Event,
) -> Result<Option<D::Output>> {
if let Event::JsonMessage {
message,
session_id,
..
} = event
{
if let Some(session_id) = &session_id {
if session_id != &self.session.session_id {
return Err(Error::SessionIdMismatch.into());
}
} else {
return Err(Error::SessionIdRequired.into());
}
let message: D::Message = message.deserialize()?;
let driver = self.driver.as_mut().unwrap();
let round_info =
driver.round_info().map_err(Box::from)?;
if !round_info.can_finalize {
driver.handle_incoming(message).map_err(Box::from)?;
let round_info =
driver.round_info().map_err(Box::from)?;
if round_info.can_finalize {
if let Some(result) = driver
.try_finalize_round()
.map_err(Box::from)?
{
return Ok(Some(result));
}
let messages =
driver.proceed().map_err(Box::from)?;
self.dispatch_round_messages(messages).await?;
}
}
}
Ok(None)
}
pub async fn execute(&mut self) -> Result<()> {
let driver = self.driver.as_mut().unwrap();
let messages = driver.proceed().map_err(Box::from)?;
self.dispatch_round_messages(messages).await?;
Ok(())
}
async fn dispatch_round_messages(
&mut self,
messages: Vec<D::Message>,
) -> Result<()> {
for message in messages {
let party_number = message.receiver();
let owner_key =
self.session.peer_key(self.party_number).unwrap();
let peer_key =
self.session.peer_key(*party_number).unwrap();
tracing::info!(
to = public_key_to_str(peer_key),
from = public_key_to_str(owner_key),
"dispatch_message"
);
self.transport
.send_json(
peer_key,
&message,
Some(self.session.session_id),
)
.await?;
}
Ok(())
}
}
pub async fn wait_for_driver<D>(
stream: &mut EventStream,
mut driver: D,
) -> Result<(Transport, D::Output)>
where
D: Driver + Into<Transport>,
{
driver.execute().await?;
#[allow(unused_assignments)]
let mut output: Option<D::Output> = None;
while let Some(event) = stream.next().await {
let event = event?;
if let Some(result) = driver.handle_event(event).await? {
output = Some(result);
break;
}
}
Ok((driver.into(), output.take().unwrap()))
}
pub async fn wait_for_close(
stream: &mut EventStream,
) -> crate::Result<()> {
while let Some(event) = stream.next().await {
let event = event?;
if let Event::Close = event {
break;
}
}
Ok(())
}
pub async fn wait_for_session_finish(
stream: &mut EventStream,
session_id: SessionId,
) -> crate::Result<()> {
while let Some(event) = stream.next().await {
let event = event?;
if let Event::SessionFinished(id) = event {
if session_id == id {
break;
}
}
}
Ok(())
}