use crate::ast::Role;
use proc_macro2::TokenStream;
use quote::format_ident;
use std::collections::{HashMap, HashSet};
use super::declarations::enforce_same_line_equals;
use super::error::{ErrorSpan, ParseError};
use super::stmt_parsers::{
parse_authority_let_in_stmt, parse_authority_let_stmt, parse_await_stmt, parse_begin_stmt,
parse_broadcast_stmt, parse_call_stmt, parse_case_stmt, parse_choice_stmt, parse_continue_stmt,
parse_dependent_work_stmt, parse_handoff_stmt, parse_invalidate_stmt, parse_let_in_stmt,
parse_let_stmt, parse_loop_stmt, parse_materialize_stmt, parse_observe_let_in_stmt,
parse_observe_let_stmt, parse_par_stmt, parse_publish_authority_stmt, parse_publish_stmt,
parse_rec_stmt, parse_resolve_stmt, parse_send_stmt, parse_timeout_stmt,
};
use super::types::{MessageSpec, ParsedBody, Statement};
use super::Rule;
pub(crate) fn parse_protocol_body(
pair: pest::iterators::Pair<Rule>,
declared_roles_base: &HashSet<String>,
input: &str,
protocol_defs: &HashMap<String, Vec<Statement>>,
allow_roles_decl: bool,
) -> std::result::Result<ParsedBody, ParseError> {
use super::role::parse_roles_from_pair;
let mut roles: Option<Vec<Role>> = None;
let mut statements = Vec::new();
let mut declared_roles = declared_roles_base.clone();
let mut inner_pairs: Vec<pest::iterators::Pair<Rule>> = Vec::new();
match pair.as_rule() {
Rule::protocol_body => {
if let Some(inner) = pair.into_inner().next() {
inner_pairs = inner.into_inner().collect();
}
}
Rule::block_protocol | Rule::block => {
inner_pairs = pair.into_inner().collect();
}
_ => {
inner_pairs = pair.into_inner().collect();
}
}
for item in inner_pairs {
match item.as_rule() {
Rule::roles_decl => {
if !allow_roles_decl {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(item.as_span(), input),
message: "roles declaration is not allowed here".to_string(),
});
}
if roles.is_some() {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(item.as_span(), input),
message: "duplicate roles declaration".to_string(),
});
}
let parsed_roles = parse_roles_from_pair(item, input)?;
declared_roles = parsed_roles.iter().map(|r| r.name().to_string()).collect();
roles = Some(parsed_roles);
}
_ => {
let statement = parse_statement(item, &declared_roles, input, protocol_defs)?;
statements.push(statement);
}
}
}
Ok(ParsedBody { roles, statements })
}
pub(crate) fn parse_block(
pair: pest::iterators::Pair<Rule>,
declared_roles: &HashSet<String>,
input: &str,
protocol_defs: &HashMap<String, Vec<Statement>>,
) -> std::result::Result<Vec<Statement>, ParseError> {
let span = pair.as_span();
let ParsedBody { roles, statements } =
parse_protocol_body(pair, declared_roles, input, protocol_defs, false)?;
if roles.is_some() {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "roles declaration is not allowed in this block".to_string(),
});
}
Ok(statements)
}
pub(crate) fn parse_statement(
pair: pest::iterators::Pair<Rule>,
declared_roles: &HashSet<String>,
input: &str,
protocol_defs: &HashMap<String, Vec<Statement>>,
) -> std::result::Result<Statement, ParseError> {
parse_statement_inner(pair, declared_roles, input, protocol_defs)
}
fn parse_statement_inner(
pair: pest::iterators::Pair<Rule>,
declared_roles: &HashSet<String>,
input: &str,
protocol_defs: &HashMap<String, Vec<Statement>>,
) -> std::result::Result<Statement, ParseError> {
match pair.as_rule() {
Rule::begin_stmt => parse_begin_stmt(pair, input),
Rule::await_stmt => parse_await_stmt(pair, input),
Rule::resolve_stmt => parse_resolve_stmt(pair, input),
Rule::invalidate_stmt => parse_invalidate_stmt(pair, input),
Rule::authority_let_in_stmt => {
parse_authority_let_in_stmt(pair, declared_roles, input, protocol_defs)
}
Rule::authority_let_stmt => parse_authority_let_stmt(pair, declared_roles, input),
Rule::observe_let_in_stmt => {
parse_observe_let_in_stmt(pair, declared_roles, input, protocol_defs)
}
Rule::observe_let_stmt => parse_observe_let_stmt(pair, declared_roles, input),
Rule::let_in_stmt => parse_let_in_stmt(pair, declared_roles, input, protocol_defs),
Rule::let_stmt => parse_let_stmt(pair, declared_roles, input),
Rule::case_stmt => parse_case_stmt(pair, declared_roles, input, protocol_defs),
Rule::timeout_stmt => parse_timeout_stmt(pair, declared_roles, input, protocol_defs),
Rule::send_stmt => parse_send_stmt(pair, declared_roles, input),
Rule::broadcast_stmt => parse_broadcast_stmt(pair, declared_roles, input),
Rule::choice_stmt => parse_choice_stmt(pair, declared_roles, input, protocol_defs),
Rule::par_stmt => parse_par_stmt(pair, declared_roles, input, protocol_defs),
Rule::loop_stmt => parse_loop_stmt(pair, declared_roles, input, protocol_defs),
Rule::rec_stmt => parse_rec_stmt(pair, declared_roles, input, protocol_defs),
Rule::continue_stmt => parse_continue_stmt(pair, input),
Rule::call_stmt => parse_call_stmt(pair, input),
Rule::publish_authority_stmt => parse_publish_authority_stmt(pair, input),
Rule::publish_stmt => parse_publish_stmt(pair, input),
Rule::materialize_stmt => parse_materialize_stmt(pair, input),
Rule::handoff_stmt => parse_handoff_stmt(pair, declared_roles, input),
Rule::dependent_work_stmt => parse_dependent_work_stmt(pair, input),
_ => {
let span = pair.as_span();
Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: format!("Unexpected statement type: {:?}", pair.as_rule()),
})
}
}
}
pub(crate) fn parse_duration(
pair: pest::iterators::Pair<Rule>,
input: &str,
) -> std::result::Result<u64, ParseError> {
let span = pair.as_span();
let mut value: Option<u64> = None;
let mut unit: Option<&str> = None;
for item in pair.into_inner() {
match item.as_rule() {
Rule::integer => {
value = Some(item.as_str().parse().map_err(|_| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "Invalid duration value".to_string(),
})?);
}
Rule::time_unit => {
unit = Some(item.as_str());
}
_ => {}
}
}
let value = value.ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "Duration missing numeric value".to_string(),
})?;
let unit = unit.ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "Duration missing time unit (ms, s, m, h)".to_string(),
})?;
let ms = match unit {
"ms" => value,
"s" => value.saturating_mul(1000),
"m" => value.saturating_mul(60_000),
"h" => value.saturating_mul(3_600_000),
_ => {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: format!("Unknown time unit: {}", unit),
})
}
};
Ok(ms)
}
pub(crate) fn parse_message(
pair: pest::iterators::Pair<Rule>,
input: &str,
) -> std::result::Result<MessageSpec, ParseError> {
fn normalize_dsl_type_source(src: &str) -> String {
src.replace('.', " :: ")
}
let span = pair.as_span();
let mut inner = pair.into_inner();
let name = format_ident!(
"{}",
inner
.next()
.ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "message is missing a name".to_string(),
})?
.as_str()
);
let type_annotation = None;
let mut payload = None;
for part in inner {
match part.as_rule() {
Rule::message_of => {
let payload_type = part
.into_inner()
.next()
.ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "message `of` clause is missing a type".to_string(),
})?
.as_str()
.to_string();
let payload_type = normalize_dsl_type_source(&payload_type);
payload = syn::parse_str::<TokenStream>(&payload_type).ok();
}
Rule::payload => {
let payload_str = part.as_str();
payload = syn::parse_str::<TokenStream>(payload_str).ok();
}
_ => {}
}
}
Ok(MessageSpec {
name,
type_annotation,
payload,
})
}
pub(crate) fn parse_local_protocol_decl(
pair: pest::iterators::Pair<Rule>,
declared_roles: &HashSet<String>,
input: &str,
protocol_defs: &mut HashMap<String, Vec<Statement>>,
) -> std::result::Result<(), ParseError> {
let span = pair.as_span();
enforce_same_line_equals(pair.as_str(), span, input, "local protocol declaration")?;
let mut inner = pair.into_inner();
let name_pair = inner.next().ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "local protocol declaration is missing a name".to_string(),
})?;
let proto_name = name_pair.as_str().to_string();
let name_span = name_pair.as_span();
if protocol_defs.contains_key(&proto_name) {
return Err(ParseError::DuplicateProtocol {
protocol: proto_name,
span: ErrorSpan::from_pest_span(name_span, input),
});
}
let mut body_pair: Option<pest::iterators::Pair<Rule>> = None;
for item in inner {
match item.as_rule() {
Rule::header_roles => {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(item.as_span(), input),
message: "local protocols cannot declare roles".to_string(),
});
}
Rule::protocol_body => body_pair = Some(item),
_ => {}
}
}
let ParsedBody { roles, statements } = parse_protocol_body(
body_pair.ok_or_else(|| ParseError::Syntax {
span: ErrorSpan::from_pest_span(span, input),
message: "local protocol declaration is missing a body".to_string(),
})?,
declared_roles,
input,
protocol_defs,
false,
)?;
if roles.is_some() {
return Err(ParseError::Syntax {
span: ErrorSpan::from_pest_span(name_span, input),
message: "local protocols cannot declare roles".to_string(),
});
}
protocol_defs.insert(proto_name, statements);
Ok(())
}