use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{braced, Ident, Token};
#[derive(Debug)]
pub struct ParsedTopology {
pub react: ReactionCriteria,
pub edges: Vec<ParsedEdge>,
}
#[derive(Debug, Clone)]
pub struct ReactionCriteria {
pub mode: ReactionMode,
pub accumulators: Vec<Ident>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ReactionMode {
WhenAny,
WhenAll,
}
#[derive(Debug)]
pub enum ParsedEdge {
Linear {
from: Ident,
from_inputs: Vec<Ident>,
to: Ident,
},
Routing {
from: Ident,
from_inputs: Vec<Ident>,
variants: Vec<RoutingVariant>,
},
}
#[derive(Debug)]
pub struct RoutingVariant {
pub variant_name: Ident,
pub target: Ident,
}
impl ParsedEdge {
pub fn from_name(&self) -> &Ident {
match self {
ParsedEdge::Linear { from, .. } => from,
ParsedEdge::Routing { from, .. } => from,
}
}
pub fn from_inputs(&self) -> &[Ident] {
match self {
ParsedEdge::Linear { from_inputs, .. } => from_inputs,
ParsedEdge::Routing { from_inputs, .. } => from_inputs,
}
}
}
impl Parse for ParsedTopology {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut react: Option<ReactionCriteria> = None;
let mut edges: Option<Vec<ParsedEdge>> = None;
while !input.is_empty() {
let key: Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"react" => {
if react.is_some() {
return Err(syn::Error::new(key.span(), "duplicate 'react' field"));
}
react = Some(input.parse()?);
}
"graph" => {
if edges.is_some() {
return Err(syn::Error::new(key.span(), "duplicate 'graph' field"));
}
edges = Some(parse_graph_block(input)?);
}
other => {
return Err(syn::Error::new(
key.span(),
format!("unknown field '{}', expected 'react' or 'graph'", other),
));
}
}
let _ = input.parse::<Token![,]>();
}
let react = react.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing 'react' field")
})?;
let edges = edges.ok_or_else(|| {
syn::Error::new(proc_macro2::Span::call_site(), "missing 'graph' field")
})?;
Ok(ParsedTopology { react, edges })
}
}
impl Parse for ReactionCriteria {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mode_ident: Ident = input.parse()?;
let mode = match mode_ident.to_string().as_str() {
"when_any" => ReactionMode::WhenAny,
"when_all" => ReactionMode::WhenAll,
other => {
return Err(syn::Error::new(
mode_ident.span(),
format!(
"unknown reaction mode '{}', expected 'when_any' or 'when_all'",
other
),
));
}
};
let content;
syn::parenthesized!(content in input);
let accumulators: Punctuated<Ident, Token![,]> =
content.parse_terminated(Ident::parse, Token![,])?;
Ok(ReactionCriteria {
mode,
accumulators: accumulators.into_iter().collect(),
})
}
}
fn parse_graph_block(input: ParseStream) -> syn::Result<Vec<ParsedEdge>> {
let content;
braced!(content in input);
let mut edges = Vec::new();
while !content.is_empty() {
edges.push(parse_edge(&content)?);
let _ = content.parse::<Token![,]>();
}
Ok(edges)
}
fn parse_edge(input: ParseStream) -> syn::Result<ParsedEdge> {
let from: Ident = input.parse()?;
let from_inputs = if input.peek(syn::token::Paren) {
let content;
syn::parenthesized!(content in input);
let inputs: Punctuated<Ident, Token![,]> =
content.parse_terminated(Ident::parse, Token![,])?;
inputs.into_iter().collect()
} else {
Vec::new()
};
if input.peek(Token![=>]) {
input.parse::<Token![=>]>()?;
let variants_content;
braced!(variants_content in input);
let mut variants = Vec::new();
while !variants_content.is_empty() {
let variant_name: Ident = variants_content.parse()?;
variants_content.parse::<Token![->]>()?;
let target: Ident = variants_content.parse()?;
variants.push(RoutingVariant {
variant_name,
target,
});
let _ = variants_content.parse::<Token![,]>();
}
if variants.is_empty() {
return Err(syn::Error::new(
from.span(),
"routing edge must have at least one variant",
));
}
Ok(ParsedEdge::Routing {
from,
from_inputs,
variants,
})
} else if input.peek(Token![->]) {
input.parse::<Token![->]>()?;
let to: Ident = input.parse()?;
Ok(ParsedEdge::Linear {
from,
from_inputs,
to,
})
} else {
Err(syn::Error::new(
from.span(),
format!(
"expected '->' or '=>' after node '{}'. Terminal nodes are detected automatically \
from the graph — they don't need explicit declaration.",
from
),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
fn parse_topology(tokens: proc_macro2::TokenStream) -> syn::Result<ParsedTopology> {
syn::parse2::<ParsedTopology>(tokens)
}
#[test]
fn test_parse_when_any() {
let tokens = quote! {
react = when_any(alpha, beta, gamma),
graph = {
entry(alpha, beta) -> output,
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.react.mode, ReactionMode::WhenAny);
assert_eq!(topology.react.accumulators.len(), 3);
assert_eq!(topology.react.accumulators[0].to_string(), "alpha");
assert_eq!(topology.react.accumulators[1].to_string(), "beta");
assert_eq!(topology.react.accumulators[2].to_string(), "gamma");
}
#[test]
fn test_parse_when_all() {
let tokens = quote! {
react = when_all(a, b),
graph = {
entry(a, b) -> output,
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.react.mode, ReactionMode::WhenAll);
assert_eq!(topology.react.accumulators.len(), 2);
}
#[test]
fn test_parse_linear_edge() {
let tokens = quote! {
react = when_any(alpha),
graph = {
entry(alpha) -> middle,
middle -> output,
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.edges.len(), 2);
match &topology.edges[0] {
ParsedEdge::Linear {
from,
from_inputs,
to,
} => {
assert_eq!(from.to_string(), "entry");
assert_eq!(from_inputs.len(), 1);
assert_eq!(from_inputs[0].to_string(), "alpha");
assert_eq!(to.to_string(), "middle");
}
_ => panic!("expected linear edge"),
}
match &topology.edges[1] {
ParsedEdge::Linear {
from,
from_inputs,
to,
} => {
assert_eq!(from.to_string(), "middle");
assert!(from_inputs.is_empty());
assert_eq!(to.to_string(), "output");
}
_ => panic!("expected linear edge"),
}
}
#[test]
fn test_parse_routing_edge() {
let tokens = quote! {
react = when_any(alpha),
graph = {
decision(alpha) => {
Signal -> handler_a,
NoAction -> handler_b,
},
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.edges.len(), 1);
match &topology.edges[0] {
ParsedEdge::Routing {
from,
from_inputs,
variants,
} => {
assert_eq!(from.to_string(), "decision");
assert_eq!(from_inputs.len(), 1);
assert_eq!(variants.len(), 2);
assert_eq!(variants[0].variant_name.to_string(), "Signal");
assert_eq!(variants[0].target.to_string(), "handler_a");
assert_eq!(variants[1].variant_name.to_string(), "NoAction");
assert_eq!(variants[1].target.to_string(), "handler_b");
}
_ => panic!("expected routing edge"),
}
}
#[test]
fn test_parse_mixed_edges() {
let tokens = quote! {
react = when_any(alpha, beta, gamma),
graph = {
decision_engine(alpha, beta, gamma) => {
Signal -> risk_check,
NoAction -> audit_logger,
},
risk_check(gamma) => {
Approved -> output_handler,
Blocked -> alert_handler,
},
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.edges.len(), 2);
match &topology.edges[0] {
ParsedEdge::Routing {
from, from_inputs, ..
} => {
assert_eq!(from.to_string(), "decision_engine");
assert_eq!(from_inputs.len(), 3);
}
_ => panic!("expected routing edge"),
}
match &topology.edges[1] {
ParsedEdge::Routing {
from,
from_inputs,
variants,
} => {
assert_eq!(from.to_string(), "risk_check");
assert_eq!(from_inputs.len(), 1);
assert_eq!(from_inputs[0].to_string(), "gamma");
assert_eq!(variants.len(), 2);
}
_ => panic!("expected routing edge"),
}
}
#[test]
fn test_parse_fan_in() {
let tokens = quote! {
react = when_any(a, b),
graph = {
validate_a(a) -> merge,
validate_b(b) -> merge,
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.edges.len(), 2);
match &topology.edges[0] {
ParsedEdge::Linear { to, .. } => assert_eq!(to.to_string(), "merge"),
_ => panic!("expected linear edge"),
}
match &topology.edges[1] {
ParsedEdge::Linear { to, .. } => assert_eq!(to.to_string(), "merge"),
_ => panic!("expected linear edge"),
}
}
#[test]
fn test_parse_fan_out() {
let tokens = quote! {
react = when_any(a),
graph = {
compute(a) -> output_handler,
compute(a) -> audit_logger,
}
};
let topology = parse_topology(tokens).unwrap();
assert_eq!(topology.edges.len(), 2);
match &topology.edges[0] {
ParsedEdge::Linear { from, to, .. } => {
assert_eq!(from.to_string(), "compute");
assert_eq!(to.to_string(), "output_handler");
}
_ => panic!("expected linear edge"),
}
match &topology.edges[1] {
ParsedEdge::Linear { from, to, .. } => {
assert_eq!(from.to_string(), "compute");
assert_eq!(to.to_string(), "audit_logger");
}
_ => panic!("expected linear edge"),
}
}
#[test]
fn test_error_missing_react() {
let tokens = quote! {
graph = {
a -> b,
}
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("missing 'react' field"), "got: {}", err);
}
#[test]
fn test_error_missing_graph() {
let tokens = quote! {
react = when_any(a)
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("missing 'graph' field"), "got: {}", err);
}
#[test]
fn test_error_unknown_field() {
let tokens = quote! {
react = when_any(a),
graph = { a -> b },
bogus = something,
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("unknown field 'bogus'"), "got: {}", err);
}
#[test]
fn test_error_unknown_reaction_mode() {
let tokens = quote! {
react = when_sometimes(a),
graph = { a -> b },
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("unknown reaction mode"), "got: {}", err);
}
#[test]
fn test_error_empty_routing() {
let tokens = quote! {
react = when_any(a),
graph = {
a(a) => {},
}
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("at least one variant"), "got: {}", err);
}
#[test]
fn test_error_duplicate_react() {
let tokens = quote! {
react = when_any(a),
react = when_all(b),
graph = { a -> b },
};
let result = parse_topology(tokens);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("duplicate 'react'"), "got: {}", err);
}
}