use std::{borrow::Borrow, io::Write, net::SocketAddr, process::exit, thread, time::Duration};
use clap::Parser;
use colored::Colorize;
use dns_lookup::lookup_host;
use env_logger::Builder;
use futures::{SinkExt, StreamExt};
use helpers::PartyIDBeaverSource;
use mpc_stark::{
network::{NetworkOutbound, NetworkPayload, QuicTwoPartyNet},
MpcFabric, PARTY0,
};
use tokio::runtime::{Builder as RuntimeBuilder, Handle};
use tracing::log::{self, LevelFilter};
mod authenticated_scalar;
mod authenticated_stark_point;
mod circuits;
mod fabric;
mod helpers;
mod mpc_scalar;
mod mpc_stark_point;
const SHUTDOWN_TIMEOUT_MS: u64 = 3_000;
#[derive(Clone, Debug)]
struct IntegrationTestArgs {
party_id: u64,
fabric: MpcFabric,
}
#[derive(Clone)]
struct IntegrationTest {
pub name: &'static str,
pub test_fn: fn(&IntegrationTestArgs) -> Result<(), String>,
}
inventory::collect!(IntegrationTest);
#[derive(Clone, Parser, Debug)]
struct Args {
#[clap(long, value_parser)]
party: u64,
#[clap(long = "port1", value_parser)]
port1: u64,
#[clap(long = "port2", value_parser)]
port2: u64,
#[clap(short, long, value_parser)]
test: Option<String>,
#[clap(long, takes_value = false, value_parser)]
docker: bool,
}
#[allow(unused_doc_comments, clippy::await_holding_refcell_ref)]
fn main() {
init_logger();
let args = Args::parse();
let args_clone = args.clone();
let runtime = RuntimeBuilder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let result = runtime.spawn_blocking(move || {
let local_addr: SocketAddr = format!("0.0.0.0:{}", args.port1).parse().unwrap();
let peer_addr: SocketAddr = {
if args.docker {
let other_host_alias = format!("party{}", if args.party == 1 { 0 } else { 1 });
let hosts = lookup_host(other_host_alias.as_str()).unwrap();
println!(
"Lookup successful for {}... found hosts: {:?}",
other_host_alias, hosts
);
format!("{}:{}", hosts[0], args.port2).parse().unwrap()
} else {
format!("{}:{}", "127.0.0.1", args.port2).parse().unwrap()
}
};
println!("Lookup successful, found peer at {:?}", peer_addr);
let mut net = QuicTwoPartyNet::new(args.party, local_addr, peer_addr);
Handle::current().block_on(net.connect()).unwrap();
if args.party == 0 {
Handle::current()
.block_on(net.send(NetworkOutbound {
result_id: 1,
payload: NetworkPayload::Bytes(vec![1u8]),
}))
.unwrap();
} else {
let _recv_bytes = Handle::current().block_on(net.next()).unwrap();
}
let beaver_source = PartyIDBeaverSource::new(args.party);
let fabric =
MpcFabric::new_with_size_hint(10_000_000 , net, beaver_source);
if args.party == 0 {
println!("\n\n{}\n", "Running integration tests...".blue());
}
let test_args = IntegrationTestArgs {
party_id: args.party,
fabric: fabric.clone(),
};
let mut all_success = true;
for test in inventory::iter::<IntegrationTest> {
if args.borrow().test.is_some() && args.borrow().test.as_deref().unwrap() != test.name {
continue;
}
if args.party == 0 {
print!("Running {}... ", test.name);
}
let test_clone = test.clone();
let res = (test_clone.test_fn)(&test_args);
all_success &= validate_success(res, args.party);
}
if test_args.party_id == PARTY0 {
log::info!("Tearing down fabric...");
}
thread::sleep(Duration::from_millis(SHUTDOWN_TIMEOUT_MS));
fabric.shutdown();
all_success
});
let all_success = runtime.block_on(result).unwrap();
if all_success {
if args_clone.party == 0 {
log::info!("{}", "Integration tests successful!".green(),);
}
exit(0);
}
exit(-1);
}
fn init_logger() {
Builder::new()
.format(|buf, record| writeln!(buf, "[{}] - {}", record.level(), record.args()))
.filter(None, LevelFilter::Info)
.init();
}
#[inline]
fn validate_success(res: Result<(), String>, party_id: u64) -> bool {
if res.is_ok() {
if party_id == 0 {
println!("{}", "Success!".green());
}
true
} else {
println!("{}\n\t{}", "Failure...".red(), res.err().unwrap());
false
}
}