#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
#![allow(clippy::unused_unit)]
use msg_queue::{MessageId, MsgQueue};
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
use reqwest::Response;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt};
use tandem::{states::Msg, Circuit, CircuitBlake3Hash};
use tandem_garble_interop::{
check_program, compile_program, deserialize_output, parse_input, Role, TypedCircuit,
};
pub use tandem_garble_interop::{Literal, VariantLiteral};
use url::Url;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
use self::ValidationError::*;
mod msg_queue;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Debug, Clone)]
pub struct MpcProgram {
source_code: String,
function_name: String,
ast: tandem_garble_interop::TypedProgram,
circuit: tandem_garble_interop::TypedCircuit,
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
impl MpcProgram {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen(constructor))]
pub fn new(source_code: String, function_name: String) -> Result<MpcProgram, Error> {
let source_code = source_code.trim().to_string();
let ast = check_program(&source_code).map_err(GarbleCompileTimeError)?;
let circuit = compile_program(&ast, &function_name).map_err(GarbleCompileTimeError)?;
if circuit.fn_def.params.len() != 2 {
return Err(ValidationError::GarbleProgramIsNoTwoPartyFunction.into());
}
Ok(Self {
source_code,
function_name,
ast,
circuit,
})
}
pub fn report_gates(&self) -> String {
self.circuit.info_about_gates.to_string()
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MpcData {
literal: Literal,
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
impl MpcData {
pub fn from_string(program: &MpcProgram, input: String) -> Result<MpcData, Error> {
let literal = parse_input(
Role::Evaluator,
&program.ast,
&program.circuit.fn_def,
&input,
)
.map_err(GarbleCompileTimeError)?;
Ok(MpcData { literal })
}
#[cfg(not(target_arch = "wasm32"))]
pub fn from_literal(program: &MpcProgram, literal: Literal) -> Result<MpcData, Error> {
let expected_type =
tandem_garble_interop::input_type(Role::Evaluator, &program.circuit.fn_def);
if !literal.is_of_type(&program.ast, expected_type) {
return Err(Error::ValidationError(
ValidationError::GarbleCompileTimeError(format!(
"Input literal is not of the type {expected_type}"
)),
));
}
Ok(MpcData { literal })
}
#[cfg(target_arch = "wasm32")]
pub fn from_object(program: &MpcProgram, literal: JsValue) -> Result<MpcData, Error> {
let literal: Literal =
serde_wasm_bindgen::from_value(literal).map_err(|e| Error::JsonError(e.to_string()))?;
let expected_type =
tandem_garble_interop::input_type(Role::Evaluator, &program.circuit.fn_def);
if !literal.is_of_type(&program.ast, &expected_type) {
return Err(Error::ValidationError(
ValidationError::GarbleCompileTimeError(format!(
"Input literal is not of the type {expected_type}"
)),
));
}
Ok(MpcData { literal })
}
pub fn to_literal_string(&self) -> String {
format!("{}", self.literal)
}
#[cfg(target_arch = "wasm32")]
pub fn to_literal(&self) -> Result<JsValue, serde_wasm_bindgen::Error> {
serde_wasm_bindgen::to_value(&self.literal)
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
pub async fn compute(
url: String,
plaintext_metadata: String,
program: MpcProgram,
input: MpcData,
) -> Result<MpcData, Error> {
let url = Url::parse(&url)?;
let my_input = input.literal.as_bits(&program.ast);
let expected_input_len = program
.circuit
.gates
.gates()
.iter()
.filter(|&gate| gate == &tandem::Gate::InEval)
.count();
if expected_input_len != my_input.len() {
return Err(ValidationError::InvalidInput.into());
}
let client = TandemClient::new(&url);
let TypedCircuit { gates, fn_def, .. } = program.circuit;
let session = client
.new_session(
&gates,
program.source_code.clone(),
program.function_name.clone(),
plaintext_metadata,
)
.await?;
let result = session.evaluate(gates, my_input).await?;
let literal =
deserialize_output(&program.ast, &fn_def, &result).map_err(GarbleCompileTimeError)?;
Ok(MpcData { literal })
}
type MessageLog = Vec<(Msg, MessageId)>;
#[derive(Debug)]
struct TandemClient {
url: Url,
}
struct TandemSession {
url: Url,
request_headers: HashMap<String, String>,
}
#[derive(Serialize, Debug)]
struct NewSession {
plaintext_metadata: String,
program: String,
function: String,
circuit_hash: CircuitBlake3Hash,
client_version: String,
}
#[derive(Deserialize, Debug, PartialEq, Eq)]
struct EngineCreationResult {
engine_id: String,
request_headers: HashMap<String, String>,
server_version: String,
}
impl TandemClient {
fn new(url: &Url) -> Self {
Self { url: url.clone() }
}
async fn new_session<'a, 'b>(
&'a self,
circuit: &Circuit,
source_code: String,
function: String,
plaintext_metadata: String,
) -> Result<TandemSession, Error> {
let client_version = env!("CARGO_PKG_VERSION").to_string();
let req = NewSession {
plaintext_metadata,
program: source_code,
function,
circuit_hash: circuit.blake3_hash(),
client_version: client_version.clone(),
};
let EngineCreationResult {
engine_id,
request_headers,
server_version: _server_version,
} = send_new_session(self.url.clone(), &req).await?;
let url = self.url.join(&engine_id)?;
Ok(TandemSession {
url,
request_headers,
})
}
}
impl TandemSession {
async fn evaluate(self, circuit: Circuit, input: Vec<bool>) -> Result<Vec<bool>, Error> {
let mut context = MsgQueue::new();
let mut evaluator =
tandem::states::Evaluator::new(circuit, input, ChaCha20Rng::from_entropy())?;
let mut last_durably_received_offset: Option<MessageId> = None;
let mut steps_remaining = evaluator.steps();
loop {
let messages: Vec<(&Msg, MessageId)> = context.msgs_iter().collect();
let (upstream_msgs, server_commited_offset) =
self.dialog(last_durably_received_offset, &messages).await?;
if messages.last().map(|v| v.1) != server_commited_offset {
return Err(Error::MessageOffsetMismatch);
}
if let Some(last_durably_received_offset) = server_commited_offset {
context.flush_queue(last_durably_received_offset);
}
for (msg, server_offset) in &upstream_msgs {
if *server_offset != last_durably_received_offset.map(|o| o + 1).unwrap_or(0) {
return Err(Error::MessageOffsetMismatch);
}
if steps_remaining > 0 {
let (next_state, msg) = evaluator.run(msg)?;
evaluator = next_state;
steps_remaining -= 1;
context.send(msg);
} else {
return Ok(evaluator.output(msg)?);
}
last_durably_received_offset = Some(*server_offset);
}
}
}
async fn dialog(
&self,
last_durably_received_offset: Option<u32>,
messages: &[(&Msg, MessageId)],
) -> Result<(MessageLog, Option<MessageId>), Error> {
send_msgs(
self.url.clone(),
&self.request_headers,
last_durably_received_offset,
messages,
)
.await
}
}
async fn send_new_session(url: Url, session: &NewSession) -> Result<EngineCreationResult, Error> {
let client = reqwest::Client::new();
let resp = client.post(url).json(session).send().await?;
let resp = resp_or_err(resp).await?;
Ok(resp.json::<EngineCreationResult>().await?)
}
async fn send_msgs(
url: Url,
request_headers: &HashMap<String, String>,
last_durably_received_offset: Option<u32>,
msgs: &[(&Msg, MessageId)],
) -> Result<(MessageLog, Option<MessageId>), Error> {
let client = reqwest::Client::new();
let body = bincode::serialize(&(last_durably_received_offset, msgs))?;
let mut req = client.post(url).body(body);
for (k, v) in request_headers.iter() {
req = req.header(k, v);
}
let resp = req.send().await?;
let resp = resp_or_err(resp).await?;
Ok(bincode::deserialize(&resp.bytes().await?)?)
}
async fn resp_or_err(resp: Response) -> Result<Response, Error> {
if resp.status().is_success() {
Ok(resp)
} else {
let e = resp.text().await?;
let e = match serde_json::from_str::<ErrorJson>(&e) {
Ok(ErrorJson { error, args }) => format!("{error}: {args}"),
Err(_) => e,
};
Err(Error::ServerError(e))
}
}
#[derive(Deserialize)]
struct ErrorJson {
error: String,
args: String,
}
#[derive(Debug)]
pub enum Error {
ServerError(String),
ReqwestError(reqwest::Error),
JsonError(String),
ParseError(url::ParseError),
ValidationError(ValidationError),
TandemError(tandem::Error),
BincodeError,
MessageOffsetMismatch,
}
impl From<bincode::Error> for Error {
fn from(_: bincode::Error) -> Self {
Self::BincodeError
}
}
impl From<reqwest::Error> for Error {
fn from(e: reqwest::Error) -> Self {
Self::ReqwestError(e)
}
}
impl From<url::ParseError> for Error {
fn from(e: url::ParseError) -> Self {
Self::ParseError(e)
}
}
impl From<ValidationError> for Error {
fn from(e: ValidationError) -> Self {
Self::ValidationError(e)
}
}
impl From<tandem::Error> for Error {
fn from(e: tandem::Error) -> Self {
Self::TandemError(e)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ValidationError(e) => write!(f, "The MPC program or the input is invalid: {e}"),
Error::ServerError(e) => write!(f, "An error occurred on the server side: {e}"),
Error::ReqwestError(e) => write!(
f,
"An error occurred while trying to send a request to the server: {e}"
),
Error::JsonError(e) => {
write!(f, "The provided JSON is not a valid Garble literal: {e}")
}
Error::ParseError(e) => write!(f, "The provided URL is invalid: {e}"),
Error::TandemError(e) => write!(
f,
"An error occurred during the client's execution of the MPC protocol: {e}"
),
Error::BincodeError => write!(f, "A message could not be serialized/deserialized."),
Error::MessageOffsetMismatch => write!(
f,
"The client's message id did not match the server's message id."
),
}
}
}
impl std::error::Error for Error {}
#[cfg(target_arch = "wasm32")]
impl From<Error> for JsValue {
fn from(e: Error) -> Self {
JsValue::from_str(&format!("{e}"))
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum ValidationError {
InvalidInput,
GarbleCompileTimeError(String),
GarbleProgramIsNoTwoPartyFunction,
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
InvalidInput => write!(f, "The input does not match the circuit's expected input."),
GarbleCompileTimeError(e) => write!(f, "Garble compile time error: {e}"),
GarbleProgramIsNoTwoPartyFunction => write!(
f,
"The Garble program has more or fewer than two parameters and thus is not a 2-Party program."
),
}
}
}