use std::collections::HashSet;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
visit::Visit,
Attribute,
Expr,
ExprMacro,
FnArg,
GenericParam,
Generics,
Ident,
ItemFn,
LitStr,
Local,
Pat,
PatIdent,
PatType,
ReturnType,
Stmt,
Token,
Type,
WhereClause,
};
struct TaskAttr {
task_name: Ident,
}
impl Parse for TaskAttr {
fn parse(input: ParseStream) -> syn::Result<Self> {
let task_name: Ident = input.parse()?;
Ok(TaskAttr { task_name })
}
}
#[derive(Debug, Clone)]
struct Requirement {
condition: Expr,
message: LitStr,
}
impl Parse for Requirement {
fn parse(input: ParseStream) -> syn::Result<Self> {
let condition: Expr = input.parse()?;
input.parse::<Token![,]>()?;
let message: LitStr = input.parse()?;
Ok(Requirement { condition, message })
}
}
fn parse_requirements(attrs: &[Attribute]) -> syn::Result<Vec<Requirement>> {
let mut requirements = Vec::new();
for attr in attrs {
if attr.path().is_ident("require") {
let requirement: Requirement = attr.parse_args()?;
requirements.push(requirement);
}
}
Ok(requirements)
}
pub fn task_transform(attr: TokenStream, item: TokenStream) -> TokenStream {
let task_attr = parse_macro_input!(attr as TaskAttr);
let input_fn = parse_macro_input!(item as ItemFn);
match transform_task_fn(task_attr.task_name, input_fn) {
Ok(output) => output.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn transform_task_fn(task_name: Ident, input_fn: ItemFn) -> syn::Result<TokenStream2> {
let _vis = &input_fn.vis;
let generics = &input_fn.sig.generics;
let where_clause = &input_fn.sig.generics.where_clause;
let requirements = parse_requirements(&input_fn.attrs)?;
let params: Vec<_> = input_fn
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg {
Some(pat_type.clone())
} else {
None
}
})
.collect();
let output_type = match &input_fn.sig.output {
ReturnType::Type(_, ty) => ty.as_ref().clone(),
ReturnType::Default => {
return Err(syn::Error::new_spanned(
&input_fn.sig,
"Task function must have a return type",
))
}
};
let field_type_param = find_field_extension_param(generics);
let segments = split_at_round_boundaries(&input_fn.block.stmts)?;
if segments.is_empty() {
return Err(syn::Error::new_spanned(
&input_fn.block,
"Task function body cannot be empty",
));
}
let round_analyses = analyze_rounds(&segments, ¶ms)?;
let round_structs = generate_round_structs(&task_name, &round_analyses, generics);
let task_enum = generate_task_enum(&task_name, &round_analyses, generics);
let task_impl = generate_task_impl(
&task_name,
&round_analyses,
generics,
where_clause,
¶ms,
&output_type,
field_type_param.as_ref(),
&requirements,
);
let task_trait_impl = generate_task_trait_impl(
&task_name,
&round_analyses,
generics,
where_clause,
¶ms,
&output_type,
field_type_param.as_ref(),
);
Ok(quote! {
#(#round_structs)*
#task_enum
#task_impl
#task_trait_impl
})
}
fn find_field_extension_param(generics: &Generics) -> Option<Ident> {
for param in &generics.params {
if let GenericParam::Type(type_param) = param {
for bound in &type_param.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if let Some(segment) = trait_bound.path.segments.last() {
if segment.ident == "FieldExtension" {
return Some(type_param.ident.clone());
}
}
}
}
}
}
None
}
#[derive(Debug)]
struct Segment {
stmts: Vec<Stmt>,
boundary: Option<RoundBoundary>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum OpeningKind {
Field,
ToOne,
Binary,
}
#[derive(Debug)]
struct OpenCall {
bindings: Vec<BindingInfo>,
exprs: Vec<Expr>,
kind: OpeningKind,
target: Option<Expr>,
}
#[derive(Debug)]
struct BindingInfo {
name: Ident,
ty: Type,
}
#[derive(Debug)]
struct SendOperation {
target: Expr,
message: Expr,
}
impl Clone for SendOperation {
fn clone(&self) -> Self {
SendOperation {
target: self.target.clone(),
message: self.message.clone(),
}
}
}
#[derive(Debug)]
struct SendAllOperation {
message: Expr,
}
impl Clone for SendAllOperation {
fn clone(&self) -> Self {
SendAllOperation {
message: self.message.clone(),
}
}
}
#[derive(Debug)]
struct ReceiveOperation {
binding: BindingInfo,
source: Expr,
}
impl Clone for ReceiveOperation {
fn clone(&self) -> Self {
ReceiveOperation {
binding: self.binding.clone(),
source: self.source.clone(),
}
}
}
#[derive(Debug)]
struct PrivateOpenOperation {
binding: BindingInfo,
share_expr: Expr,
}
impl Clone for PrivateOpenOperation {
fn clone(&self) -> Self {
PrivateOpenOperation {
binding: self.binding.clone(),
share_expr: self.share_expr.clone(),
}
}
}
#[derive(Debug, Default)]
struct RoundBoundary {
opens: Vec<OpenCall>,
sends: Vec<SendOperation>,
send_alls: Vec<SendAllOperation>,
receives: Vec<ReceiveOperation>,
private_opens: Vec<PrivateOpenOperation>,
}
impl Clone for RoundBoundary {
fn clone(&self) -> Self {
RoundBoundary {
opens: self.opens.clone(),
sends: self.sends.clone(),
send_alls: self.send_alls.clone(),
receives: self.receives.clone(),
private_opens: self.private_opens.clone(),
}
}
}
enum BoundaryOperation {
Open(OpenCall),
Send(SendOperation),
SendAll(SendAllOperation),
Receive(ReceiveOperation),
PrivateOpen(PrivateOpenOperation),
SameRound(Vec<BoundaryOperation>),
}
fn add_op_to_boundary(boundary: &mut RoundBoundary, op: BoundaryOperation) {
match op {
BoundaryOperation::Open(open_call) => boundary.opens.push(open_call),
BoundaryOperation::Send(send_op) => boundary.sends.push(send_op),
BoundaryOperation::SendAll(send_all_op) => boundary.send_alls.push(send_all_op),
BoundaryOperation::Receive(recv_op) => boundary.receives.push(recv_op),
BoundaryOperation::PrivateOpen(private_open_op) => {
boundary.private_opens.push(private_open_op)
}
BoundaryOperation::SameRound(grouped_ops) => {
for grouped_op in grouped_ops {
add_op_to_boundary(boundary, grouped_op);
}
}
}
}
fn split_at_round_boundaries(stmts: &[Stmt]) -> syn::Result<Vec<Segment>> {
let mut segments = Vec::new();
let mut current_stmts = Vec::new();
for stmt in stmts {
if let Some(op) = extract_boundary_operation(stmt)? {
let mut boundary = RoundBoundary::default();
add_op_to_boundary(&mut boundary, op);
segments.push(Segment {
stmts: current_stmts,
boundary: Some(boundary),
});
current_stmts = Vec::new();
} else {
current_stmts.push(stmt.clone());
}
}
segments.push(Segment {
stmts: current_stmts,
boundary: None,
});
Ok(segments)
}
fn extract_boundary_operation(stmt: &Stmt) -> syn::Result<Option<BoundaryOperation>> {
match stmt {
Stmt::Local(local) => {
if let Some(init) = &local.init {
if let Expr::Macro(expr_macro) = init.expr.as_ref() {
match get_round_boundary_macro_kind(&expr_macro.mac) {
Some(RoundBoundaryMacroKind::OpenField)
| Some(RoundBoundaryMacroKind::OpenBinary)
| Some(RoundBoundaryMacroKind::OpenTo) => {
let open_call = parse_open_call(local, expr_macro)?;
return Ok(Some(BoundaryOperation::Open(open_call)));
}
Some(RoundBoundaryMacroKind::ReceiveFrom) => {
let recv_op = parse_receive_from(local, expr_macro)?;
return Ok(Some(BoundaryOperation::Receive(recv_op)));
}
Some(RoundBoundaryMacroKind::PrivateOpen) => {
let private_open_op = parse_private_open(local, expr_macro)?;
return Ok(Some(BoundaryOperation::PrivateOpen(private_open_op)));
}
Some(RoundBoundaryMacroKind::SendTo) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"send_to! should be used as an expression statement, not in a let binding",
));
}
Some(RoundBoundaryMacroKind::SendAll) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"send_all! should be used as an expression statement, not in a let binding",
));
}
Some(RoundBoundaryMacroKind::SameRound) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"same_round! should be used as a statement, not in a let binding",
));
}
None => {}
}
}
}
}
Stmt::Expr(Expr::Macro(expr_macro), _semi) => {
match get_round_boundary_macro_kind(&expr_macro.mac) {
Some(RoundBoundaryMacroKind::SendTo) => {
let send_op = parse_send_to(expr_macro)?;
return Ok(Some(BoundaryOperation::Send(send_op)));
}
Some(RoundBoundaryMacroKind::SendAll) => {
let send_all_op = parse_send_all(expr_macro)?;
return Ok(Some(BoundaryOperation::SendAll(send_all_op)));
}
Some(RoundBoundaryMacroKind::SameRound) => {
let grouped_ops = parse_same_round(&expr_macro.mac)?;
return Ok(Some(BoundaryOperation::SameRound(grouped_ops)));
}
Some(RoundBoundaryMacroKind::OpenField)
| Some(RoundBoundaryMacroKind::OpenBinary)
| Some(RoundBoundaryMacroKind::OpenTo) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"open_field!, open_binary!, and open_to! must be used in a let binding: let x: Type = open_field!(...)",
));
}
Some(RoundBoundaryMacroKind::ReceiveFrom) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"receive_from! must be used in a let binding: let x: Type = receive_from!(peer)",
));
}
Some(RoundBoundaryMacroKind::PrivateOpen) => {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"private_open! must be used in a let binding: let x: Type = private_open!(share)",
));
}
None => {}
}
}
Stmt::Macro(stmt_macro) => {
match get_round_boundary_macro_kind(&stmt_macro.mac) {
Some(RoundBoundaryMacroKind::SendTo) => {
let expr_macro = ExprMacro {
attrs: stmt_macro.attrs.clone(),
mac: stmt_macro.mac.clone(),
};
let send_op = parse_send_to(&expr_macro)?;
return Ok(Some(BoundaryOperation::Send(send_op)));
}
Some(RoundBoundaryMacroKind::SendAll) => {
let expr_macro = ExprMacro {
attrs: stmt_macro.attrs.clone(),
mac: stmt_macro.mac.clone(),
};
let send_all_op = parse_send_all(&expr_macro)?;
return Ok(Some(BoundaryOperation::SendAll(send_all_op)));
}
Some(RoundBoundaryMacroKind::SameRound) => {
let grouped_ops = parse_same_round(&stmt_macro.mac)?;
return Ok(Some(BoundaryOperation::SameRound(grouped_ops)));
}
Some(RoundBoundaryMacroKind::OpenField)
| Some(RoundBoundaryMacroKind::OpenBinary)
| Some(RoundBoundaryMacroKind::OpenTo) => {
return Err(syn::Error::new_spanned(
&stmt_macro.mac,
"open_field!, open_binary!, and open_to! must be used in a let binding: let x: Type = open_field!(...)",
));
}
Some(RoundBoundaryMacroKind::ReceiveFrom) => {
return Err(syn::Error::new_spanned(
&stmt_macro.mac,
"receive_from! must be used in a let binding: let x: Type = receive_from!(peer)",
));
}
Some(RoundBoundaryMacroKind::PrivateOpen) => {
return Err(syn::Error::new_spanned(
&stmt_macro.mac,
"private_open! must be used in a let binding: let x: Type = private_open!(share)",
));
}
None => {}
}
}
_ => {}
}
Ok(None)
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum RoundBoundaryMacroKind {
OpenField,
OpenBinary,
OpenTo,
SendTo,
SendAll,
ReceiveFrom,
PrivateOpen,
SameRound,
}
fn get_round_boundary_macro_kind(mac: &syn::Macro) -> Option<RoundBoundaryMacroKind> {
mac.path.get_ident().and_then(|ident| {
if ident == "open_field" {
Some(RoundBoundaryMacroKind::OpenField)
} else if ident == "open_binary" {
Some(RoundBoundaryMacroKind::OpenBinary)
} else if ident == "open_to" {
Some(RoundBoundaryMacroKind::OpenTo)
} else if ident == "send_to" {
Some(RoundBoundaryMacroKind::SendTo)
} else if ident == "send_all" {
Some(RoundBoundaryMacroKind::SendAll)
} else if ident == "receive_from" {
Some(RoundBoundaryMacroKind::ReceiveFrom)
} else if ident == "private_open" {
Some(RoundBoundaryMacroKind::PrivateOpen)
} else if ident == "same_round" {
Some(RoundBoundaryMacroKind::SameRound)
} else {
None
}
})
}
fn get_open_macro_kind(mac: &syn::Macro) -> Option<OpeningKind> {
mac.path.get_ident().and_then(|ident| {
if ident == "open_field" {
Some(OpeningKind::Field)
} else if ident == "open_binary" {
Some(OpeningKind::Binary)
} else if ident == "open_to" {
Some(OpeningKind::ToOne)
} else {
None
}
})
}
fn parse_open_call(local: &Local, expr_macro: &ExprMacro) -> syn::Result<OpenCall> {
let bindings = extract_bindings(&local.pat)?;
let kind = get_open_macro_kind(&expr_macro.mac).ok_or_else(|| {
syn::Error::new_spanned(
&expr_macro.mac,
"Expected open_field!, open_binary!, or open_to!",
)
})?;
let all_exprs = parse_open_exprs(&expr_macro.mac.tokens)?;
let (target, exprs) = match kind {
OpeningKind::Field | OpeningKind::Binary => (None, all_exprs),
OpeningKind::ToOne => {
if all_exprs.is_empty() {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"open_to! requires at least a target and one share: open_to!(target, share)",
));
}
let mut exprs = all_exprs;
let target = exprs.remove(0);
if exprs.is_empty() {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"open_to! requires at least one share after the target: open_to!(target, share)",
));
}
(Some(target), exprs)
}
};
Ok(OpenCall {
bindings,
exprs,
kind,
target,
})
}
fn extract_bindings(pat: &Pat) -> syn::Result<Vec<BindingInfo>> {
match pat {
Pat::Ident(PatIdent { .. }) => {
Err(syn::Error::new_spanned(
pat,
"open_field!/open_binary! bindings require explicit type annotation: let x: Type = open_field!(...)",
))
}
Pat::Type(PatType { pat, ty, .. }) => {
extract_bindings_with_type(pat, ty)
}
Pat::Tuple(_) => {
Err(syn::Error::new_spanned(
pat,
"open_field!/open_binary! tuple bindings require explicit type annotation: let (a, b): (TypeA, TypeB) = open_field!(...)",
))
}
_ => Err(syn::Error::new_spanned(
pat,
"Unsupported pattern in open! binding",
)),
}
}
fn extract_bindings_with_type(pat: &Pat, ty: &Type) -> syn::Result<Vec<BindingInfo>> {
match (pat, ty) {
(Pat::Ident(PatIdent { ident, .. }), ty) => Ok(vec![BindingInfo {
name: ident.clone(),
ty: ty.clone(),
}]),
(Pat::Tuple(pat_tuple), Type::Tuple(ty_tuple)) => {
if pat_tuple.elems.len() != ty_tuple.elems.len() {
return Err(syn::Error::new_spanned(
pat,
"Tuple pattern and type must have the same number of elements",
));
}
let mut bindings = Vec::new();
for (p, t) in pat_tuple.elems.iter().zip(ty_tuple.elems.iter()) {
bindings.extend(extract_bindings_with_type(p, t)?);
}
Ok(bindings)
}
_ => Err(syn::Error::new_spanned(
pat,
"Unsupported pattern/type combination in open! binding",
)),
}
}
fn parse_open_exprs(tokens: &TokenStream2) -> syn::Result<Vec<Expr>> {
struct OpenArgs {
exprs: Punctuated<Expr, Token![,]>,
}
impl Parse for OpenArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
Ok(OpenArgs {
exprs: Punctuated::parse_terminated(input)?,
})
}
}
let args: OpenArgs = syn::parse2(tokens.clone())?;
Ok(args.exprs.into_iter().collect())
}
fn parse_send_to(expr_macro: &ExprMacro) -> syn::Result<SendOperation> {
let exprs = parse_open_exprs(&expr_macro.mac.tokens)?;
if exprs.len() != 2 {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"send_to! requires exactly 2 arguments: send_to!(peer, message)",
));
}
let mut iter = exprs.into_iter();
let target = iter.next().unwrap();
let message = iter.next().unwrap();
Ok(SendOperation { target, message })
}
fn parse_send_all(expr_macro: &ExprMacro) -> syn::Result<SendAllOperation> {
let exprs = parse_open_exprs(&expr_macro.mac.tokens)?;
if exprs.len() != 1 {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"send_all! requires exactly 1 argument: send_all!(message)",
));
}
let message = exprs.into_iter().next().unwrap();
Ok(SendAllOperation { message })
}
fn parse_receive_from(local: &Local, expr_macro: &ExprMacro) -> syn::Result<ReceiveOperation> {
let bindings = extract_bindings(&local.pat)?;
if bindings.len() != 1 {
return Err(syn::Error::new_spanned(
&local.pat,
"receive_from! must bind to a single variable: let x: Type = receive_from!(peer)",
));
}
let binding = bindings.into_iter().next().unwrap();
let exprs = parse_open_exprs(&expr_macro.mac.tokens)?;
if exprs.len() != 1 {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"receive_from! requires exactly 1 argument: receive_from!(peer)",
));
}
let source = exprs.into_iter().next().unwrap();
Ok(ReceiveOperation { binding, source })
}
fn parse_private_open(local: &Local, expr_macro: &ExprMacro) -> syn::Result<PrivateOpenOperation> {
let bindings = extract_bindings(&local.pat)?;
if bindings.len() != 1 {
return Err(syn::Error::new_spanned(
&local.pat,
"private_open! must bind to a single variable: let x: Type = private_open!(share)",
));
}
let binding = bindings.into_iter().next().unwrap();
let exprs = parse_open_exprs(&expr_macro.mac.tokens)?;
if exprs.len() != 1 {
return Err(syn::Error::new_spanned(
&expr_macro.mac,
"private_open! requires exactly 1 argument: private_open!(share)",
));
}
let share_expr = exprs.into_iter().next().unwrap();
Ok(PrivateOpenOperation {
binding,
share_expr,
})
}
struct StmtList {
stmts: Vec<Stmt>,
}
impl Parse for StmtList {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut stmts = Vec::new();
while !input.is_empty() {
stmts.push(input.parse()?);
}
Ok(StmtList { stmts })
}
}
fn parse_same_round(mac: &syn::Macro) -> syn::Result<Vec<BoundaryOperation>> {
let stmt_list: StmtList = syn::parse2(mac.tokens.clone())?;
let mut ops = Vec::new();
for stmt in stmt_list.stmts {
if let Some(op) = extract_boundary_operation(&stmt)? {
if matches!(op, BoundaryOperation::SameRound(_)) {
return Err(syn::Error::new_spanned(
&stmt,
"Nested same_round! is not allowed",
));
}
ops.push(op);
} else {
return Err(syn::Error::new_spanned(
&stmt,
"same_round! can only contain communication macros (send_to!, send_all!, receive_from!, open!, open_to!, private_open!)",
));
}
}
if ops.is_empty() {
return Err(syn::Error::new_spanned(
mac,
"same_round! must contain at least one communication macro",
));
}
Ok(ops)
}
#[derive(Debug)]
struct RoundAnalysis {
round_num: usize,
stmts: Vec<Stmt>,
state_vars: Vec<StateVar>,
boundary: Option<RoundBoundary>,
opened_vars: Option<Vec<BindingInfo>>,
received_vars: Option<Vec<ReceiveOperation>>,
private_open_vars: Option<Vec<PrivateOpenOperation>>,
}
#[derive(Debug, Clone)]
struct StateVar {
name: Ident,
ty: Type,
}
fn analyze_rounds(segments: &[Segment], params: &[PatType]) -> syn::Result<Vec<RoundAnalysis>> {
let mut analyses = Vec::new();
let mut opened_vars: Option<Vec<BindingInfo>> = None;
let mut received_vars: Option<Vec<ReceiveOperation>> = None;
let mut private_open_vars: Option<Vec<PrivateOpenOperation>> = None;
for (i, segment) in segments.iter().enumerate() {
let round_num = i + 1;
let _is_final = segment.boundary.is_none();
let mut vars_used_in_current_and_later = collect_vars_used_in_segments(&segments[i..]);
if let Some(ref recv_vars) = received_vars {
let mut visitor = VarUsageVisitor {
used_vars: HashSet::new(),
};
for recv in recv_vars {
syn::visit::visit_expr(&mut visitor, &recv.source);
}
vars_used_in_current_and_later.extend(visitor.used_vars);
}
if let Some(ref po_vars) = private_open_vars {
let mut visitor = VarUsageVisitor {
used_vars: HashSet::new(),
};
for po in po_vars {
syn::visit::visit_expr(&mut visitor, &po.share_expr);
}
vars_used_in_current_and_later.extend(visitor.used_vars);
}
let vars_defined_before = if i == 0 {
collect_params_as_state_vars(params)
} else {
collect_vars_defined_in_segments(&segments[..i], params)
};
let mut shadowed_names: HashSet<String> = opened_vars
.as_ref()
.map(|vars| vars.iter().map(|v| v.name.to_string()).collect())
.unwrap_or_default();
if let Some(ref recv_vars) = received_vars {
for recv in recv_vars {
shadowed_names.insert(recv.binding.name.to_string());
}
}
if let Some(ref po_vars) = private_open_vars {
for po in po_vars {
shadowed_names.insert(po.binding.name.to_string());
}
}
let state_vars: Vec<StateVar> = vars_defined_before
.into_iter()
.filter(|v| {
let name = v.name.to_string();
vars_used_in_current_and_later.contains(&name) && !shadowed_names.contains(&name)
})
.collect();
analyses.push(RoundAnalysis {
round_num,
stmts: segment.stmts.clone(),
state_vars,
boundary: segment.boundary.clone(),
opened_vars: opened_vars.clone(),
received_vars: received_vars.clone(),
private_open_vars: private_open_vars.clone(),
});
if let Some(ref boundary) = segment.boundary {
let all_open_bindings: Vec<BindingInfo> = boundary
.opens
.iter()
.flat_map(|oc| oc.bindings.clone())
.collect();
opened_vars = if all_open_bindings.is_empty() {
None
} else {
Some(all_open_bindings)
};
received_vars = if boundary.receives.is_empty() {
None
} else {
Some(boundary.receives.clone())
};
private_open_vars = if boundary.private_opens.is_empty() {
None
} else {
Some(boundary.private_opens.clone())
};
} else {
opened_vars = None;
received_vars = None;
private_open_vars = None;
}
}
Ok(analyses)
}
fn collect_params_as_state_vars(params: &[PatType]) -> Vec<StateVar> {
params
.iter()
.filter_map(|param| {
if let Pat::Ident(PatIdent { ident, .. }) = param.pat.as_ref() {
Some(StateVar {
name: ident.clone(),
ty: param.ty.as_ref().clone(),
})
} else {
None
}
})
.collect()
}
struct VarUsageVisitor {
used_vars: HashSet<String>,
}
impl<'ast> Visit<'ast> for VarUsageVisitor {
fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) {
if expr_path.qself.is_none() && expr_path.path.segments.len() == 1 {
if let Some(segment) = expr_path.path.segments.first() {
self.used_vars.insert(segment.ident.to_string());
}
}
syn::visit::visit_expr_path(self, expr_path);
}
fn visit_expr_macro(&mut self, expr_macro: &'ast ExprMacro) {
extract_idents_from_tokens(&expr_macro.mac.tokens, &mut self.used_vars);
syn::visit::visit_expr_macro(self, expr_macro);
}
fn visit_ident(&mut self, ident: &'ast Ident) {
self.used_vars.insert(ident.to_string());
}
}
fn extract_idents_from_tokens(tokens: &TokenStream2, used_vars: &mut HashSet<String>) {
use proc_macro2::TokenTree;
for token in tokens.clone() {
match token {
TokenTree::Ident(ident) => {
let name = ident.to_string();
if !is_keyword(&name) {
used_vars.insert(name);
}
}
TokenTree::Group(group) => {
extract_idents_from_tokens(&group.stream(), used_vars);
}
_ => {}
}
}
}
fn is_keyword(name: &str) -> bool {
matches!(
name,
"as" | "break"
| "const"
| "continue"
| "crate"
| "else"
| "enum"
| "extern"
| "false"
| "fn"
| "for"
| "if"
| "impl"
| "in"
| "let"
| "loop"
| "match"
| "mod"
| "move"
| "mut"
| "pub"
| "ref"
| "return"
| "self"
| "Self"
| "static"
| "struct"
| "super"
| "trait"
| "true"
| "type"
| "unsafe"
| "use"
| "where"
| "while"
| "async"
| "await"
| "dyn"
)
}
fn collect_vars_used_in_segments(segments: &[Segment]) -> HashSet<String> {
let mut visitor = VarUsageVisitor {
used_vars: HashSet::new(),
};
for segment in segments {
for stmt in &segment.stmts {
visitor.visit_stmt(stmt);
}
if let Some(ref boundary) = segment.boundary {
for open_call in &boundary.opens {
for expr in &open_call.exprs {
visitor.visit_expr(expr);
}
if let Some(target) = &open_call.target {
visitor.visit_expr(target);
}
}
for send_op in &boundary.sends {
visitor.visit_expr(&send_op.target);
visitor.visit_expr(&send_op.message);
}
for send_all_op in &boundary.send_alls {
visitor.visit_expr(&send_all_op.message);
}
for recv_op in &boundary.receives {
visitor.visit_expr(&recv_op.source);
}
for po_op in &boundary.private_opens {
visitor.visit_expr(&po_op.share_expr);
}
}
}
visitor.used_vars
}
fn collect_vars_defined_in_segments(segments: &[Segment], params: &[PatType]) -> Vec<StateVar> {
let mut vars = Vec::new();
for param in params {
if let Pat::Ident(PatIdent { ident, .. }) = param.pat.as_ref() {
vars.push(StateVar {
name: ident.clone(),
ty: param.ty.as_ref().clone(),
});
}
}
for segment in segments {
for stmt in &segment.stmts {
if let Stmt::Local(local) = stmt {
collect_vars_from_pattern(&local.pat, &mut vars);
}
}
}
vars
}
fn collect_vars_from_pattern(pat: &Pat, vars: &mut Vec<StateVar>) {
match pat {
Pat::Ident(PatIdent { .. }) => {
}
Pat::Type(PatType { pat, ty, .. }) => {
collect_vars_from_pattern_with_type(pat, ty, vars);
}
Pat::Tuple(tuple) => {
for elem in &tuple.elems {
collect_vars_from_pattern(elem, vars);
}
}
Pat::Struct(pat_struct) => {
for field in &pat_struct.fields {
collect_vars_from_pattern(&field.pat, vars);
}
}
_ => {}
}
}
fn collect_vars_from_pattern_with_type(pat: &Pat, ty: &Type, vars: &mut Vec<StateVar>) {
match (pat, ty) {
(Pat::Ident(PatIdent { ident, .. }), ty) => {
vars.push(StateVar {
name: ident.clone(),
ty: ty.clone(),
});
}
(Pat::Tuple(pat_tuple), Type::Tuple(ty_tuple)) => {
for (p, t) in pat_tuple.elems.iter().zip(ty_tuple.elems.iter()) {
collect_vars_from_pattern_with_type(p, t, vars);
}
}
_ => {}
}
}
fn generate_round_structs(
task_name: &Ident,
analyses: &[RoundAnalysis],
generics: &Generics,
) -> Vec<TokenStream2> {
let (impl_generics, _ty_generics, where_clause) = generics.split_for_impl();
let type_params: Vec<_> = generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Type(tp) = p {
Some(&tp.ident)
} else {
None
}
})
.collect();
analyses
.iter()
.map(|analysis| {
let round_name = format_ident!("{}Round{}", task_name, analysis.round_num);
let fields: Vec<_> = analysis
.state_vars
.iter()
.map(|v| {
let name = &v.name;
let ty = &v.ty;
quote! { #name: #ty }
})
.collect();
let fields_types_str: String = analysis
.state_vars
.iter()
.map(|v| {
let ty = &v.ty;
quote! { #ty }.to_string()
})
.collect();
let unused_type_params: Vec<_> = type_params
.iter()
.filter(|tp| !fields_types_str.contains(&tp.to_string()))
.collect();
let phantom = if !unused_type_params.is_empty() {
let phantom_types = unused_type_params.iter().map(|tp| quote! { #tp });
Some(quote! { _phantom: ::std::marker::PhantomData<(#(#phantom_types),*)> })
} else {
None
};
let all_fields: Vec<_> = fields.into_iter().chain(phantom).collect();
quote! {
#[derive(Debug)]
pub(crate) struct #round_name #impl_generics #where_clause {
#(#all_fields),*
}
}
})
.collect()
}
fn generate_task_enum(
task_name: &Ident,
analyses: &[RoundAnalysis],
generics: &Generics,
) -> TokenStream2 {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let type_params: Vec<_> = generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Type(tp) = p {
Some(&tp.ident)
} else {
None
}
})
.collect();
let variants: Vec<_> = analyses
.iter()
.map(|analysis| {
let round_name = format_ident!("{}Round{}", task_name, analysis.round_num);
quote! { #round_name(#round_name #ty_generics) }
})
.collect();
let resolving_variant = if type_params.is_empty() {
quote! { Resolving }
} else {
let phantom_types = type_params.iter().map(|tp| quote! { #tp });
quote! { Resolving(::std::marker::PhantomData<(#(#phantom_types),*)>) }
};
quote! {
#[derive(Debug)]
pub enum #task_name #impl_generics #where_clause {
#(#variants,)*
#resolving_variant,
}
}
}
#[allow(clippy::too_many_arguments)] fn generate_task_impl(
task_name: &Ident,
analyses: &[RoundAnalysis],
generics: &Generics,
where_clause: &Option<WhereClause>,
params: &[PatType],
output_type: &Type,
field_type_param: Option<&Ident>,
requirements: &[Requirement],
) -> TokenStream2 {
let (impl_generics, ty_generics, _) = generics.split_for_impl();
let constructor = generate_constructor(task_name, analyses, params, generics, requirements);
let state_transitions = generate_state_transitions(task_name, analyses, generics);
let round_methods =
generate_round_methods(task_name, analyses, output_type, generics, field_type_param);
let has_open_calls = analyses
.iter()
.any(|a| a.boundary.as_ref().is_some_and(|b| !b.opens.is_empty()));
let extra_where_bounds = if let Some(f) = field_type_param.filter(|_| has_open_calls) {
let param_types: std::collections::HashMap<String, &Type> = params
.iter()
.filter_map(|p| {
if let Pat::Ident(PatIdent { ident, .. }) = p.pat.as_ref() {
Some((ident.to_string(), p.ty.as_ref()))
} else {
None
}
})
.collect();
let mut all_bounds: Vec<TokenStream2> = Vec::new();
for analysis in analyses {
if let Some(boundary) = &analysis.boundary {
for open_call in &boundary.opens {
if open_call.kind == OpeningKind::Binary {
continue;
}
for expr in &open_call.exprs {
if let Expr::Path(expr_path) = expr {
if expr_path.qself.is_none() && expr_path.path.segments.len() == 1 {
let ident = &expr_path.path.segments[0].ident;
if let Some(ty) = param_types.get(&ident.to_string()) {
all_bounds.push(quote! { #ty: Into<crate::Share<#f>> });
}
}
}
}
for binding in &open_call.bindings {
let ty = if open_call.kind == OpeningKind::ToOne {
unwrap_option_type(&binding.ty)
} else {
binding.ty.clone()
};
all_bounds.push(quote! { crate::Secret<#f>: Into<#ty> });
}
}
}
}
if !all_bounds.is_empty() {
quote! { #(#all_bounds),* }
} else {
quote! {}
}
} else {
quote! {}
};
let has_extra_bounds = !extra_where_bounds.is_empty();
let full_where_clause = match (where_clause, has_extra_bounds) {
(Some(wc), true) => {
let predicates: Vec<_> = wc.predicates.iter().collect();
quote! { where #(#predicates),*, #extra_where_bounds }
}
(Some(wc), false) => {
let predicates: Vec<_> = wc.predicates.iter().collect();
quote! { where #(#predicates),* }
}
(None, true) => {
quote! { where #extra_where_bounds }
}
(None, false) => {
quote! {}
}
};
quote! {
impl #impl_generics #task_name #ty_generics
#full_where_clause
{
#constructor
#(#state_transitions)*
#(#round_methods)*
}
}
}
fn is_phantom_data_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "PhantomData";
}
}
false
}
fn unwrap_option_type(ty: &Type) -> Type {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
return inner_ty.clone();
}
}
}
}
}
ty.clone()
}
fn generate_constructor(
task_name: &Ident,
analyses: &[RoundAnalysis],
params: &[PatType],
generics: &Generics,
requirements: &[Requirement],
) -> TokenStream2 {
let (param_names, param_types): (Vec<_>, Vec<_>) = params
.iter()
.filter(|p| !is_phantom_data_type(&p.ty))
.filter_map(|p| {
if let Pat::Ident(PatIdent { ident, .. }) = p.pat.as_ref() {
Some((ident, &p.ty))
} else {
None
}
})
.unzip();
let round1_name = format_ident!("{}Round1", task_name);
let state_var_names: Vec<_> = analyses[0].state_vars.iter().map(|v| &v.name).collect();
let type_params: Vec<_> = generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Type(tp) = p {
Some(&tp.ident)
} else {
None
}
})
.collect();
let state_vars_types_str: String = analyses[0]
.state_vars
.iter()
.map(|v| {
let ty = &v.ty;
quote! { #ty }.to_string()
})
.collect();
let has_unused_type_params = type_params
.iter()
.any(|tp| !state_vars_types_str.contains(&tp.to_string()));
let phantom_init = if has_unused_type_params && !type_params.is_empty() {
quote! { _phantom: ::std::marker::PhantomData, }
} else {
quote! {}
};
let validation_checks: Vec<TokenStream2> = requirements
.iter()
.map(|req| {
let condition = &req.condition;
let message = &req.message;
quote! {
if !(#condition) {
return Err(crate::abort::Abort::StaticError(#message));
}
}
})
.collect();
let construct_self = quote! {
Self::#round1_name(#round1_name {
#(#state_var_names,)*
#phantom_init
})
};
if !requirements.is_empty() {
quote! {
pub fn new(#(#param_names: #param_types),*) -> Result<Self, crate::abort::Abort> {
#(#validation_checks)*
Ok(#construct_self)
}
}
} else {
quote! {
pub fn new(#(#param_names: #param_types),*) -> Self {
#construct_self
}
}
}
}
fn generate_state_transitions(
task_name: &Ident,
analyses: &[RoundAnalysis],
generics: &Generics,
) -> Vec<TokenStream2> {
let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
let has_type_params = generics
.params
.iter()
.any(|p| matches!(p, GenericParam::Type(_)));
let resolving_value = if has_type_params {
quote! { #task_name::Resolving(::std::marker::PhantomData) }
} else {
quote! { #task_name::Resolving }
};
analyses
.iter()
.map(|analysis| {
let method_name = format_ident!("try_into_round{}", analysis.round_num);
let round_name = format_ident!("{}Round{}", task_name, analysis.round_num);
let task_name_str = task_name.to_string();
let round_num_str = analysis.round_num.to_string();
let resolving = resolving_value.clone();
quote! {
fn #method_name(&mut self) -> Result<#round_name #ty_generics, crate::abort::Abort> {
match std::mem::replace(self, #resolving) {
#task_name::#round_name(state) => Ok(state),
other => {
*self = other;
Err(crate::abort::Abort::StaticError(concat!(
"Incorrect state transition in ",
#task_name_str,
" round ",
#round_num_str
)))
}
}
}
}
})
.collect()
}
fn generate_round_methods(
task_name: &Ident,
analyses: &[RoundAnalysis],
output_type: &Type,
generics: &Generics,
field_type_param: Option<&Ident>,
) -> Vec<TokenStream2> {
let type_params: Vec<_> = generics
.params
.iter()
.filter_map(|p| {
if let GenericParam::Type(tp) = p {
Some(&tp.ident)
} else {
None
}
})
.collect();
let has_open_calls = analyses
.iter()
.any(|a| a.boundary.as_ref().is_some_and(|b| !b.opens.is_empty()));
let (method_generics, round_output_type) =
if let Some(f) = field_type_param.filter(|_| has_open_calls) {
(quote! {}, quote! { crate::RoundOutput<#output_type, #f> })
} else {
(
quote! { <__ProtocolF: primitives::algebra::field::FieldExtension> },
quote! { crate::RoundOutput<#output_type, __ProtocolF> },
)
};
analyses
.iter()
.map(|analysis| {
let method_name = format_ident!("round{}", analysis.round_num);
let is_final = analysis.boundary.is_none();
let mut params: Vec<TokenStream2> = Vec::new();
if let Some(ref opened_vars) = analysis.opened_vars {
for v in opened_vars {
let name = &v.name;
let ty = &v.ty;
params.push(quote! { #name: #ty });
}
}
let has_receives = analysis.received_vars.as_ref().is_some_and(|r| !r.is_empty());
let has_private_opens = analysis.private_open_vars.as_ref().is_some_and(|p| !p.is_empty());
if has_receives || has_private_opens {
params.push(quote! { __messages: &Vec<Vec<crate::net::Bytes>> });
}
let round_num = analysis.round_num;
let try_into = format_ident!("try_into_round{}", round_num);
let state_extraction = quote! { let state = self.#try_into()?; };
let state_vars_to_extract = analysis.state_vars.clone();
let state_destructure: Vec<_> = state_vars_to_extract
.iter()
.map(|v| {
let name = &v.name;
quote! { let #name = state.#name; }
})
.collect();
let receive_extractions: Vec<TokenStream2> = if let Some(ref received_vars) = analysis.received_vars {
received_vars
.iter()
.enumerate()
.map(|(recv_idx, recv_op)| {
let name = &recv_op.binding.name;
let ty = &recv_op.binding.ty;
let source = &recv_op.source;
quote! {
let #name: #ty = {
let __peer_idx = #source as usize;
let __peer_msgs = __messages.get(__peer_idx)
.ok_or(crate::abort::Abort::StaticError("Missing messages from peer"))?;
let __msg_bytes = __peer_msgs.get(#recv_idx)
.ok_or(crate::abort::Abort::StaticError("Missing message at index"))?;
wincode::deserialize(__msg_bytes)
.map_err(|_| crate::abort::Abort::StaticError("Failed to deserialize message"))?
};
}
})
.collect()
} else {
Vec::new()
};
let private_open_extractions: Vec<TokenStream2> = if let Some(ref po_vars) = analysis.private_open_vars {
po_vars
.iter()
.map(|po_op| {
let name = &po_op.binding.name;
let ty = &po_op.binding.ty;
let share_expr = &po_op.share_expr;
quote! {
let #name: #ty = {
use primitives::sharing::Reconstructible;
let __share = #share_expr;
let mut __openings = Vec::new();
for __peer_msgs in __messages.iter() {
if let Some(__msg) = __peer_msgs.first() {
__openings.push(
wincode::deserialize(__msg)
.map_err(|_| crate::abort::Abort::StaticError("Failed to deserialize opening"))?
);
}
}
__share.reconstruct(__openings)
.map_err(|e| crate::abort::Abort::DynamicError(format!("Reconstruction failed: {:?}", e)))?
};
}
})
.collect()
} else {
Vec::new()
};
let stmts = &analysis.stmts;
if is_final {
let final_expr = if analysis.stmts.is_empty() {
quote! { () }
} else if let Some(Stmt::Expr(expr, None)) = analysis.stmts.last() {
quote! { #expr }
} else {
quote! { () }
};
let stmts_without_last: Vec<_> = if !analysis.stmts.is_empty() {
analysis.stmts[..analysis.stmts.len() - 1].to_vec()
} else {
Vec::new()
};
let method_gen = method_generics.clone();
quote! {
fn #method_name #method_gen(&mut self, #(#params),*) -> Result<#round_output_type, crate::abort::Abort> {
#state_extraction
#(#state_destructure)*
#(#receive_extractions)*
#(#private_open_extractions)*
#(#stmts_without_last)*
Ok(crate::RoundOutput::Finished {
value: #final_expr,
send: crate::net::SendMessages::new(),
})
}
}
} else {
let boundary = analysis.boundary.as_ref().unwrap();
let shares_to_open: Vec<_> = boundary
.opens
.iter()
.flat_map(|open_call| {
match &open_call.kind {
OpeningKind::Field => {
open_call
.exprs
.iter()
.map(|e| {
quote! {
crate::PendingOpening {
share: (#e).into(),
target: crate::OpeningTarget::All,
}
}
})
.collect::<Vec<_>>()
}
OpeningKind::Binary => {
open_call
.exprs
.iter()
.map(|e| {
quote! {
crate::PendingOpening {
share: crate::Share::Binary((#e).to_vec()),
target: crate::OpeningTarget::All,
}
}
})
.collect::<Vec<_>>()
}
OpeningKind::ToOne => {
let target_expr = open_call.target.as_ref().unwrap();
open_call
.exprs
.iter()
.map(|e| {
quote! {
crate::PendingOpening {
share: (#e).into(),
target: crate::OpeningTarget::ToOne(#target_expr),
}
}
})
.collect::<Vec<_>>()
}
}
})
.collect();
let send_messages_code: Vec<_> = boundary
.sends
.iter()
.map(|send_op| {
let target = &send_op.target;
let message = &send_op.message;
quote! {
__send = __send.to(
#target,
wincode::serialize(&#message).map_err(|_| crate::abort::Abort::StaticError("Failed to serialize message"))?
);
}
})
.collect();
let send_all_messages_code: Vec<_> = boundary
.send_alls
.iter()
.map(|send_all_op| {
let message = &send_all_op.message;
quote! {
__send = __send.broadcast(
wincode::serialize(&#message).map_err(|_| crate::abort::Abort::StaticError("Failed to serialize broadcast message"))?
);
}
})
.collect();
let expect_peers_code: Vec<_> = boundary
.receives
.iter()
.map(|recv_op| {
let source = &recv_op.source;
quote! { #source }
})
.collect();
let next_round_num = analysis.round_num + 1;
let next_round_name = format_ident!("{}Round{}", task_name, next_round_num);
let next_state_vars: Vec<_> = if next_round_num <= analyses.len() {
analyses[next_round_num - 1]
.state_vars
.iter()
.map(|v| &v.name)
.collect()
} else {
Vec::new()
};
let next_state_vars_types_str: String = if next_round_num <= analyses.len() {
analyses[next_round_num - 1]
.state_vars
.iter()
.map(|v| {
let ty = &v.ty;
quote! { #ty }.to_string()
})
.collect()
} else {
String::new()
};
let has_unused_type_params = type_params.iter().any(|tp| !next_state_vars_types_str.contains(&tp.to_string()));
let phantom_init = if has_unused_type_params && !type_params.is_empty() {
quote! { _phantom: ::std::marker::PhantomData }
} else {
quote! {}
};
let shares_to_open_expr = if shares_to_open.is_empty() {
quote! { None }
} else {
quote! { Some(vec![#(#shares_to_open),*]) }
};
let has_private_opens_in_boundary = !boundary.private_opens.is_empty();
let expect_expr = if has_private_opens_in_boundary {
quote! { crate::net::ExpectMessages::FromAll }
} else if expect_peers_code.is_empty() {
quote! { crate::net::ExpectMessages::None }
} else {
quote! { crate::net::ExpectMessages::From(vec![#(#expect_peers_code),*]) }
};
let method_gen = method_generics.clone();
quote! {
fn #method_name #method_gen(&mut self, #(#params),*) -> Result<#round_output_type, crate::abort::Abort> {
#state_extraction
#(#state_destructure)*
#(#receive_extractions)*
#(#private_open_extractions)*
#(#stmts)*
let mut __send = crate::net::SendMessages::new();
#(#send_messages_code)*
#(#send_all_messages_code)*
*self = #task_name::#next_round_name(#next_round_name {
#(#next_state_vars,)*
#phantom_init
});
Ok(crate::RoundOutput::Continue {
send: __send,
expect: #expect_expr,
shares_to_open: #shares_to_open_expr,
})
}
}
}
})
.collect()
}
fn generate_task_trait_impl(
task_name: &Ident,
analyses: &[RoundAnalysis],
generics: &Generics,
where_clause: &Option<WhereClause>,
fn_params: &[PatType],
output_type: &Type,
field_type_param: Option<&Ident>,
) -> TokenStream2 {
let (_impl_generics, ty_generics, _) = generics.split_for_impl();
let has_open_calls = analyses
.iter()
.any(|a| a.boundary.as_ref().is_some_and(|b| !b.opens.is_empty()));
let param_types: std::collections::HashMap<String, &Type> = fn_params
.iter()
.filter_map(|p| {
if let Pat::Ident(PatIdent { ident, .. }) = p.pat.as_ref() {
Some((ident.to_string(), p.ty.as_ref()))
} else {
None
}
})
.collect();
let (
protocol_field_type,
impl_generics_with_field,
extra_where_bounds,
has_extra_bounds,
needs_turbofish,
) = if let Some(f) = field_type_param.filter(|_| has_open_calls) {
let gen_params = &generics.params;
let input_bound_types: Vec<_> = analyses
.iter()
.filter_map(|a| a.boundary.as_ref())
.flat_map(|boundary| boundary.opens.iter())
.filter(|open_call| open_call.kind != OpeningKind::Binary)
.flat_map(|open_call| open_call.exprs.iter())
.filter_map(|expr| {
if let Expr::Path(expr_path) = expr {
if expr_path.qself.is_none() && expr_path.path.segments.len() == 1 {
let ident = &expr_path.path.segments[0].ident;
return param_types.get(&ident.to_string()).copied();
}
}
None
})
.collect();
let output_bound_types: Vec<_> = analyses
.iter()
.filter_map(|a| a.boundary.as_ref())
.flat_map(|boundary| boundary.opens.iter())
.filter(|open_call| open_call.kind != OpeningKind::Binary)
.flat_map(|open_call| {
open_call.bindings.iter().map(move |b| {
if open_call.kind == OpeningKind::ToOne {
unwrap_option_type(&b.ty)
} else {
b.ty.clone()
}
})
})
.collect();
let mut all_bounds: Vec<TokenStream2> = Vec::new();
for ty in &input_bound_types {
all_bounds.push(quote! { #ty: Into<crate::Share<#f>> });
}
for ty in output_bound_types {
all_bounds.push(quote! { crate::Secret<#f>: Into<#ty> });
}
let has_bounds = !all_bounds.is_empty();
let extra_bounds = if has_bounds {
quote! { #(#all_bounds),* }
} else {
quote! {}
};
(
quote! { #f },
quote! { <#gen_params> },
extra_bounds,
has_bounds,
false, )
} else {
let params = &generics.params;
let impl_gen = if params.is_empty() {
quote! { <__ProtocolF: primitives::algebra::field::FieldExtension> }
} else {
quote! { <#params, __ProtocolF: primitives::algebra::field::FieldExtension> }
};
let binding_types: Vec<_> = analyses
.iter()
.filter_map(|a| a.boundary.as_ref())
.flat_map(|boundary| boundary.opens.iter())
.filter(|open_call| open_call.kind != OpeningKind::Binary)
.flat_map(|open_call| {
open_call.bindings.iter().map(move |b| {
if open_call.kind == OpeningKind::ToOne {
unwrap_option_type(&b.ty)
} else {
b.ty.clone()
}
})
})
.collect();
let has_binding = !binding_types.is_empty();
let extra_bounds = if has_binding {
quote! {
#(crate::Secret<__ProtocolF>: Into<#binding_types>),*
}
} else {
quote! {}
};
(
quote! { __ProtocolF },
impl_gen,
extra_bounds,
has_binding,
true,
)
};
let method_turbofish = if needs_turbofish {
quote! { ::<#protocol_field_type> }
} else {
quote! {}
};
let match_arms: Vec<_> = analyses
.iter()
.map(|analysis| {
let variant_name = format_ident!("{}Round{}", task_name, analysis.round_num);
let method_name = format_ident!("round{}", analysis.round_num);
let turbofish = method_turbofish.clone();
if analysis.round_num == 1 {
quote! {
#task_name::#variant_name(_) => self.#method_name #turbofish()
}
} else {
let prev_boundary = &analyses[analysis.round_num - 2].boundary;
let mut extractions: Vec<TokenStream2> = Vec::new();
let mut arg_names: Vec<Ident> = Vec::new();
let mut total_open_count = 0usize;
if let Some(boundary) = prev_boundary {
for open_call in &boundary.opens {
total_open_count += open_call.bindings.len();
match &open_call.kind {
OpeningKind::Field => {
for binding in &open_call.bindings {
let name = &binding.name;
let ty = &binding.ty;
extractions.push(quote! {
let #name: #ty = match __opened_iter.next().ok_or(crate::abort::Abort::StaticError("Insufficient openings received"))? {
crate::OpenedValue::Value(s) => s.into(),
crate::OpenedValue::NotTarget => return Err(crate::abort::Abort::StaticError("Expected Value but got NotTarget for open_field! binding")),
};
});
arg_names.push(name.clone());
}
}
OpeningKind::Binary => {
for binding in &open_call.bindings {
let name = &binding.name;
let ty = &binding.ty;
extractions.push(quote! {
let #name: #ty = match __opened_iter.next().ok_or(crate::abort::Abort::StaticError("Insufficient openings received"))? {
crate::OpenedValue::Value(crate::Secret::Binary(v)) => v.into(),
crate::OpenedValue::Value(_) => return Err(crate::abort::Abort::StaticError("Expected Binary secret for open_binary! binding")),
crate::OpenedValue::NotTarget => return Err(crate::abort::Abort::StaticError("Expected Value but got NotTarget for open_binary! binding")),
};
});
arg_names.push(name.clone());
}
}
OpeningKind::ToOne => {
for binding in &open_call.bindings {
let name = &binding.name;
let ty = &binding.ty;
extractions.push(quote! {
let #name: #ty = match __opened_iter.next().ok_or(crate::abort::Abort::StaticError("Insufficient openings received"))? {
crate::OpenedValue::Value(s) => Some(s.into()),
crate::OpenedValue::NotTarget => None,
};
});
arg_names.push(name.clone());
}
}
}
}
}
let has_receives = prev_boundary.as_ref().is_some_and(|b| !b.receives.is_empty());
let has_private_opens = prev_boundary.as_ref().is_some_and(|b| !b.private_opens.is_empty());
let needs_messages = has_receives || has_private_opens;
if !extractions.is_empty() || total_open_count > 0 || needs_messages {
let length_error_msg = format!(
"Not enough opened values in {} round {}: expected {}",
task_name,
analysis.round_num,
total_open_count
);
let open_check = if total_open_count > 0 {
quote! {
if input.opened_values.len() < #total_open_count {
return Err(crate::abort::Abort::StaticError(#length_error_msg));
}
let mut __opened_iter = input.opened_values.drain(..);
}
} else {
quote! {}
};
let (call_with_messages, extra_input_arg) = if needs_messages {
(
quote! { self.#method_name #turbofish(#(#arg_names,)* &input.messages) },
true,
)
} else {
(
quote! { self.#method_name #turbofish(#(#arg_names),*) },
false,
)
};
let _ = extra_input_arg;
quote! {
#task_name::#variant_name(_) => {
#open_check
#(#extractions)*
#call_with_messages
}
}
} else {
quote! {
#task_name::#variant_name(_) => self.#method_name #turbofish()
}
}
}
})
.collect();
let has_type_params = !generics.params.is_empty();
let resolving_pattern = if has_type_params {
quote! { #task_name::Resolving(_) }
} else {
quote! { #task_name::Resolving }
};
let full_where_clause = match (where_clause, has_extra_bounds) {
(Some(wc), true) => {
let predicates: Vec<_> = wc.predicates.iter().collect();
quote! { where #(#predicates),*, #extra_where_bounds }
}
(Some(wc), false) => {
let predicates: Vec<_> = wc.predicates.iter().collect();
quote! { where #(#predicates),* }
}
(None, true) => {
quote! { where #extra_where_bounds }
}
(None, false) => {
quote! {}
}
};
quote! {
impl #impl_generics_with_field crate::Task<#protocol_field_type> for #task_name #ty_generics
#full_where_clause
{
type Output = #output_type;
fn next_round(&mut self, mut input: crate::TaskInputs<#protocol_field_type>) -> Result<crate::RoundOutput<Self::Output, #protocol_field_type>, crate::abort::Abort> {
match self {
#(#match_arms,)*
#resolving_pattern => Err(crate::abort::Abort::StaticError(concat!("Incorrect state: ", stringify!(#task_name), " is resolving"))),
}
}
}
}
}
impl Clone for OpenCall {
fn clone(&self) -> Self {
OpenCall {
bindings: self.bindings.clone(),
exprs: self.exprs.clone(),
kind: self.kind.clone(),
target: self.target.clone(),
}
}
}
impl Clone for BindingInfo {
fn clone(&self) -> Self {
BindingInfo {
name: self.name.clone(),
ty: self.ty.clone(),
}
}
}