use crate::ast::{Branch, Choreography, LocalType, Protocol, Role};
use crate::topology::{Location, Topology, TopologyConstraint, TopologyMode};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use super::generate_choreography_code;
#[derive(Debug, Clone)]
pub struct InlineTopology {
pub name: String,
pub topology: Topology,
}
#[derive(Debug, Clone)]
struct BranchRequirementSpec {
sender: String,
receiver: String,
label_count: u32,
}
#[must_use]
pub fn generate_topology_integration(
choreography: &Choreography,
inline_topologies: &[InlineTopology],
) -> TokenStream {
let _protocol_name = &choreography.name;
let role_names: Vec<&Ident> = choreography.roles.iter().map(|r| r.name()).collect();
let role_name_strs: Vec<String> = role_names.iter().map(|r| r.to_string()).collect();
let handler_method = generate_handler_method();
let branch_requirements = collect_branch_requirements(&choreography.protocol);
let with_topology_method = generate_with_topology_method(&role_name_strs, &branch_requirements);
let topology_constants = generate_topology_constants(inline_topologies, &role_name_strs);
quote! {
pub mod topology {
use super::*;
use ::telltale_runtime::topology::{
BranchRequirement, Location, Topology, TopologyBuilder, TopologyHandler,
TopologyMode,
};
use ::telltale_runtime::{
ChannelCapacity, Region, RoleFamilyConstraint, RoleName, TopologyEndpoint,
};
#handler_method
#with_topology_method
#topology_constants
}
}
}
fn collect_branch_requirements(protocol: &Protocol) -> Vec<BranchRequirementSpec> {
let mut requirements = Vec::new();
collect_branch_requirements_from_protocol(protocol, &mut requirements);
requirements
}
fn collect_branch_requirements_from_protocol(
protocol: &Protocol,
requirements: &mut Vec<BranchRequirementSpec>,
) {
match protocol {
Protocol::Choice { branches, .. } => {
let label_count = u32::try_from(branches.len()).unwrap_or(u32::MAX);
for branch in branches {
collect_branch_requirement_from_branch(branch, label_count, requirements);
collect_branch_requirements_from_protocol(&branch.protocol, requirements);
}
}
Protocol::Case { branches, .. } => {
for branch in branches {
collect_branch_requirements_from_protocol(&branch.protocol, requirements);
}
}
Protocol::Timeout {
body,
on_timeout,
on_cancel,
..
} => {
collect_branch_requirements_from_protocol(body, requirements);
collect_branch_requirements_from_protocol(on_timeout, requirements);
if let Some(on_cancel) = on_cancel.as_deref() {
collect_branch_requirements_from_protocol(on_cancel, requirements);
}
}
Protocol::Send { continuation, .. } => {
collect_branch_requirements_from_protocol(continuation, requirements);
}
Protocol::Broadcast { continuation, .. } => {
collect_branch_requirements_from_protocol(continuation, requirements);
}
Protocol::Loop { body, .. } => {
collect_branch_requirements_from_protocol(body, requirements);
}
Protocol::Parallel { protocols } => {
for p in protocols {
collect_branch_requirements_from_protocol(p, requirements);
}
}
Protocol::Rec { body, .. } => {
collect_branch_requirements_from_protocol(body, requirements);
}
Protocol::Begin { continuation, .. }
| Protocol::Await { continuation, .. }
| Protocol::Resolve { continuation, .. }
| Protocol::Invalidate { continuation, .. }
| Protocol::Extension { continuation, .. }
| Protocol::Let { continuation, .. }
| Protocol::Publish { continuation, .. }
| Protocol::PublishAuthority { continuation, .. }
| Protocol::Materialize { continuation, .. }
| Protocol::Handoff { continuation, .. }
| Protocol::DependentWork { continuation, .. } => {
collect_branch_requirements_from_protocol(continuation, requirements);
}
Protocol::Var(_) | Protocol::End => {}
}
}
fn collect_branch_requirement_from_branch(
branch: &Branch,
label_count: u32,
requirements: &mut Vec<BranchRequirementSpec>,
) {
match &branch.protocol {
Protocol::Send { from, to, .. } => {
requirements.push(BranchRequirementSpec {
sender: from.name().to_string(),
receiver: to.name().to_string(),
label_count,
});
}
Protocol::Broadcast { from, to_all, .. } => {
for to in to_all {
requirements.push(BranchRequirementSpec {
sender: from.name().to_string(),
receiver: to.name().to_string(),
label_count,
});
}
}
_ => {}
}
}
fn generate_handler_method() -> TokenStream {
quote! {
pub fn handler(role: Role) -> TopologyHandler {
TopologyHandler::local(role.role_name())
}
}
}
fn generate_with_topology_method(
role_names: &[String],
branch_requirements: &[BranchRequirementSpec],
) -> TokenStream {
let role_name_literals: Vec<TokenStream> = role_names
.iter()
.map(|role| quote! { RoleName::from_static(#role) })
.collect();
let branch_requirement_literals: Vec<TokenStream> = branch_requirements
.iter()
.map(|req| {
let sender = &req.sender;
let receiver = &req.receiver;
let label_count = req.label_count;
quote! {
BranchRequirement::new(
RoleName::from_static(#sender),
RoleName::from_static(#receiver),
#label_count
)
}
})
.collect();
quote! {
pub fn with_topology(
topology: Topology,
role: Role,
) -> Result<TopologyHandler, String> {
let roles = [#(#role_name_literals),*];
let branch_requirements: &[BranchRequirement] = &[#(#branch_requirement_literals),*];
let validation = topology.validate_with_branches(&roles, &branch_requirements);
if !validation.is_valid() {
return Err(format!("Topology validation failed: {:?}", validation));
}
Ok(TopologyHandler::new(topology, role.role_name()))
}
}
}
fn generate_topology_constants(
inline_topologies: &[InlineTopology],
role_names: &[String],
) -> TokenStream {
if inline_topologies.is_empty() {
return quote! {};
}
let constants: Vec<TokenStream> = inline_topologies
.iter()
.map(|topo| {
let _const_name = format_ident!("{}", topo.name.to_uppercase());
let fn_name = format_ident!("{}", topo.name.to_lowercase());
let handler_fn_name = format_ident!("{}_handler", topo.name.to_lowercase());
let builder_calls = generate_topology_builder(&topo.topology, role_names);
quote! {
pub fn #fn_name() -> Topology {
#builder_calls
}
pub fn #handler_fn_name(role: Role) -> Result<TopologyHandler, String> {
with_topology(#fn_name(), role)
}
}
})
.collect();
quote! {
pub mod topologies {
use super::*;
#(#constants)*
}
}
}
fn generate_topology_builder(topology: &Topology, _role_names: &[String]) -> TokenStream {
let mut builder_calls = Vec::new();
if let Some(ref mode) = topology.mode {
builder_calls.push(generate_mode_builder_call(mode));
}
for (role, location) in &topology.locations {
builder_calls.push(generate_location_builder_call(role, location));
}
for constraint in &topology.constraints {
builder_calls.push(generate_constraint_builder_call(constraint));
}
for ((sender, receiver), capacity) in &topology.channel_capacities {
builder_calls.push(generate_channel_capacity_builder_call(
sender, receiver, capacity,
));
}
for (family, constraint) in &topology.role_constraints {
builder_calls.push(generate_role_family_constraint_builder_call(
family, constraint,
));
}
if builder_calls.is_empty() {
quote! {
TopologyBuilder::new().build()
}
} else {
quote! {
TopologyBuilder::new()
#(#builder_calls)*
.build()
}
}
}
fn generate_mode_builder_call(mode: &TopologyMode) -> TokenStream {
match mode {
TopologyMode::Local => quote! { .mode(TopologyMode::Local) },
}
}
fn generate_location_builder_call(
role: &crate::identifiers::RoleName,
location: &Location,
) -> TokenStream {
let role_literal = role.as_str();
match location {
Location::Local => quote! { .local_role(RoleName::from_static(#role_literal)) },
Location::Remote(endpoint) => {
let endpoint_literal = endpoint.as_str();
quote! {
.remote_role(
RoleName::from_static(#role_literal),
TopologyEndpoint::new(#endpoint_literal).unwrap()
)
}
}
Location::Colocated(peer) => {
let peer_literal = peer.as_str();
quote! {
.colocated_role(
RoleName::from_static(#role_literal),
RoleName::from_static(#peer_literal)
)
}
}
}
}
fn generate_pinned_location_expr(location: &Location) -> TokenStream {
match location {
Location::Local => quote! { Location::Local },
Location::Remote(endpoint) => {
let endpoint_literal = endpoint.as_str();
quote! { Location::Remote(TopologyEndpoint::new(#endpoint_literal).unwrap()) }
}
Location::Colocated(peer) => {
let peer_literal = peer.as_str();
quote! { Location::Colocated(RoleName::from_static(#peer_literal)) }
}
}
}
fn generate_constraint_builder_call(constraint: &TopologyConstraint) -> TokenStream {
match constraint {
TopologyConstraint::Colocated(r1, r2) => {
let r1_literal = r1.as_str();
let r2_literal = r2.as_str();
quote! {
.colocated(
RoleName::from_static(#r1_literal),
RoleName::from_static(#r2_literal)
)
}
}
TopologyConstraint::Separated(r1, r2) => {
let r1_literal = r1.as_str();
let r2_literal = r2.as_str();
quote! {
.separated(
RoleName::from_static(#r1_literal),
RoleName::from_static(#r2_literal)
)
}
}
TopologyConstraint::Pinned(role, location) => {
let role_literal = role.as_str();
let location_expr = generate_pinned_location_expr(location);
quote! { .pinned(RoleName::from_static(#role_literal), #location_expr) }
}
TopologyConstraint::Region(role, region) => {
let role_literal = role.as_str();
let region_literal = region.as_str();
quote! {
.region(
RoleName::from_static(#role_literal),
Region::new(#region_literal).unwrap()
)
}
}
}
}
fn generate_channel_capacity_builder_call(
sender: &crate::identifiers::RoleName,
receiver: &crate::identifiers::RoleName,
capacity: &crate::ChannelCapacity,
) -> TokenStream {
let sender_literal = sender.as_str();
let receiver_literal = receiver.as_str();
let capacity_value = capacity.get();
quote! {
.channel_capacity(
RoleName::from_static(#sender_literal),
RoleName::from_static(#receiver_literal),
ChannelCapacity::try_new(#capacity_value)
.expect("generated channel capacity must be within declared bounds")
)
}
}
fn generate_role_family_constraint_builder_call(
family: &str,
constraint: &crate::topology::RoleFamilyConstraint,
) -> TokenStream {
let min = constraint.min;
match constraint.max {
Some(max) => quote! {
.role_family_constraint(#family, RoleFamilyConstraint::bounded(#min, #max))
},
None => quote! {
.role_family_constraint(#family, RoleFamilyConstraint::min_only(#min))
},
}
}
#[must_use]
pub fn generate_choreography_code_with_topology(
choreography: &Choreography,
local_types: &[(Role, LocalType)],
inline_topologies: &[InlineTopology],
) -> TokenStream {
let name = choreography.name.to_string();
let base_code = generate_choreography_code(&name, &choreography.roles, local_types);
let topology_code = generate_topology_integration(choreography, inline_topologies);
quote! {
#base_code
#topology_code
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Protocol;
use crate::identifiers::RoleName;
fn create_test_choreography() -> Choreography {
use quote::format_ident;
Choreography {
name: format_ident!("TestProtocol"),
namespace: None,
roles: vec![
Role::new(format_ident!("Alice")).unwrap(),
Role::new(format_ident!("Bob")).unwrap(),
],
protocol: Protocol::End,
attrs: std::collections::HashMap::new(),
}
}
#[test]
fn test_generate_topology_integration_basic() {
let choreography = create_test_choreography();
let inline_topologies = vec![];
let tokens = generate_topology_integration(&choreography, &inline_topologies);
let code = tokens.to_string();
assert!(code.contains("pub mod topology"));
assert!(code.contains("RoleName :: from_static"));
assert!(code.contains("pub fn handler"));
assert!(code.contains("pub fn with_topology"));
}
#[test]
fn test_generate_topology_integration_with_inline_topologies() {
let choreography = create_test_choreography();
let dev_topology = Topology::builder()
.mode(TopologyMode::Local)
.local_role(RoleName::from_static("Alice"))
.local_role(RoleName::from_static("Bob"))
.build();
let prod_topology = Topology::builder()
.remote_role(
RoleName::from_static("Alice"),
crate::identifiers::Endpoint::new("alice.prod:8080").unwrap(),
)
.remote_role(
RoleName::from_static("Bob"),
crate::identifiers::Endpoint::new("bob.prod:8081").unwrap(),
)
.build();
let inline_topologies = vec![
InlineTopology {
name: "Dev".to_string(),
topology: dev_topology,
},
InlineTopology {
name: "Prod".to_string(),
topology: prod_topology,
},
];
let tokens = generate_topology_integration(&choreography, &inline_topologies);
let code = tokens.to_string();
assert!(code.contains("pub mod topologies"));
assert!(code.contains("pub fn dev"));
assert!(code.contains("pub fn prod"));
assert!(code.contains("dev_handler"));
assert!(code.contains("prod_handler"));
}
#[test]
fn test_generate_handler_method() {
let tokens = generate_handler_method();
let code = tokens.to_string();
assert!(code.contains("pub fn handler"));
assert!(code.contains("TopologyHandler :: local"));
assert!(code.contains("role_name"));
}
#[test]
fn test_generate_with_topology_method() {
let tokens = generate_with_topology_method(&["Alice".to_string(), "Bob".to_string()], &[]);
let code = tokens.to_string();
assert!(code.contains("pub fn with_topology"));
assert!(code.contains("TopologyHandler :: new"));
assert!(code.contains("topology . validate_with_branches"));
}
#[test]
fn test_generate_topology_builder_local_mode() {
let topology = Topology::builder().mode(TopologyMode::Local).build();
let tokens =
generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
let code = tokens.to_string();
assert!(code.contains("TopologyMode :: Local"));
}
#[test]
fn test_generated_topology_helpers_do_not_import_deployment_backends() {
let choreography = create_test_choreography();
let tokens = generate_topology_integration(&choreography, &[]);
let code = tokens.to_string();
assert!(!code.contains("Datacenter"));
assert!(!code.contains("Kubernetes"));
assert!(!code.contains("Consul"));
}
#[test]
fn test_generate_topology_builder_with_roles() {
let topology = Topology::builder()
.local_role(RoleName::from_static("Alice"))
.remote_role(
RoleName::from_static("Bob"),
crate::identifiers::Endpoint::new("localhost:8080").unwrap(),
)
.build();
let tokens =
generate_topology_builder(&topology, &["Alice".to_string(), "Bob".to_string()]);
let code = tokens.to_string();
assert!(code.contains("local_role"));
assert!(code.contains("remote_role"));
assert!(code.contains("localhost:8080"));
}
#[test]
fn test_generate_topology_builder_with_constraints() {
let topology = Topology::builder()
.local_role(RoleName::from_static("Alice"))
.local_role(RoleName::from_static("Bob"))
.colocated(RoleName::from_static("Alice"), RoleName::from_static("Bob"))
.separated(
RoleName::from_static("Alice"),
RoleName::from_static("Carol"),
)
.role_family_constraint(
"Witness",
crate::topology::RoleFamilyConstraint::bounded(2, 5),
)
.build();
let tokens = generate_topology_builder(
&topology,
&["Alice".to_string(), "Bob".to_string(), "Carol".to_string()],
);
let code = tokens.to_string();
assert!(code.contains("colocated"));
assert!(code.contains("separated"));
assert!(code.contains("role_family_constraint"));
assert!(code.contains("RoleFamilyConstraint :: bounded"));
}
#[test]
fn test_generate_choreography_code_with_topology() {
let choreography = create_test_choreography();
let local_types = vec![
(
Role::new(format_ident!("Alice")).unwrap(),
crate::ast::LocalType::End,
),
(
Role::new(format_ident!("Bob")).unwrap(),
crate::ast::LocalType::End,
),
];
let inline_topologies = vec![];
let tokens = generate_choreography_code_with_topology(
&choreography,
&local_types,
&inline_topologies,
);
let code = tokens.to_string();
assert!(code.contains("Alice") || code.contains("Roles"));
assert!(code.contains("pub mod topology"));
}
}