use std::collections::{HashMap, HashSet};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Ident, ItemFn, ItemMod};
use super::graph_ir::{GraphEdge, GraphIR, GraphNode};
use super::parser::ReactionMode;
fn pascal_case_ident(ident: &Ident) -> Ident {
let pascal = ident
.to_string()
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect::<String>();
format_ident!("{}", pascal)
}
pub fn generate(ir: &GraphIR, module: &ItemMod) -> syn::Result<TokenStream> {
let functions = extract_functions(module)?;
let function_names: HashSet<String> = functions.keys().cloned().collect();
let node_names: HashSet<String> = ir.nodes.keys().cloned().collect();
for node_name in &node_names {
if !function_names.contains(node_name) {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
format!(
"node '{}' is referenced in the graph topology but no function with that name exists in the module",
node_name
),
));
}
}
for fn_name in &function_names {
if !node_names.contains(fn_name) {
if let Some(func) = functions.get(fn_name) {
return Err(syn::Error::new(
func.sig.ident.span(),
format!(
"function '{}' exists in the module but is not referenced in the graph topology. \
All functions in a computation_graph module must appear in the graph declaration.",
fn_name
),
));
}
}
}
let blocking_nodes: HashSet<String> = functions
.iter()
.filter(|(_, f)| has_blocking_attr(f))
.map(|(name, _)| name.clone())
.collect();
let is_cloacina_crate_early = std::env::var("CARGO_CRATE_NAME")
.map(|n| n == "cloacina")
.unwrap_or(false);
let compiled_fn =
generate_compiled_function(ir, &functions, &blocking_nodes, is_cloacina_crate_early)?;
let mod_name = &module.ident;
let vis = &module.vis;
let mod_attrs = &module.attrs;
let content = module
.content
.as_ref()
.map(|(_, items)| items.clone())
.unwrap_or_default();
let compiled_fn_name = format_ident!("{}_compiled", mod_name);
let routing_use_stmts = generate_routing_use_stmts(ir, &functions, mod_name);
let mod_name_str = mod_name.to_string();
let auto_register_name = format_ident!("_auto_register_graph_{}", mod_name);
let accumulator_names: Vec<String> = ir
.react
.accumulators
.iter()
.map(|a| a.to_string())
.collect();
let reaction_mode_str = match ir.react.mode {
ReactionMode::WhenAny => "when_any",
ReactionMode::WhenAll => "when_all",
};
let ffi_plugin_name = format_ident!("_GraphPlugin{}", pascal_case_ident(mod_name));
let packaged_ffi = quote! {
#[cfg(feature = "packaged")]
pub mod _ffi {
use cloacina_workflow_plugin::__fidius_CloacinaPlugin;
use cloacina_workflow_plugin::CloacinaPlugin as _;
pub struct #ffi_plugin_name;
#[cloacina_workflow_plugin::plugin_impl(CloacinaPlugin, crate = "cloacina_workflow_plugin")]
impl cloacina_workflow_plugin::CloacinaPlugin for #ffi_plugin_name {
fn get_task_metadata(&self) -> Result<cloacina_workflow_plugin::PackageTasksMetadata, cloacina_workflow_plugin::PluginError> {
Ok(cloacina_workflow_plugin::PackageTasksMetadata {
workflow_name: String::new(),
package_name: env!("CARGO_PKG_NAME").to_string(),
package_description: None,
package_author: None,
workflow_fingerprint: None,
graph_data_json: None,
tasks: vec![],
})
}
fn execute_task(&self, _request: cloacina_workflow_plugin::TaskExecutionRequest) -> Result<cloacina_workflow_plugin::TaskExecutionResult, cloacina_workflow_plugin::PluginError> {
Err(cloacina_workflow_plugin::PluginError {
code: "NOT_SUPPORTED".to_string(),
message: "This is a computation graph package, not a workflow package".to_string(),
details: None,
})
}
fn get_graph_metadata(&self) -> Result<cloacina_workflow_plugin::GraphPackageMetadata, cloacina_workflow_plugin::PluginError> {
Ok(cloacina_workflow_plugin::GraphPackageMetadata {
graph_name: #mod_name_str.to_string(),
package_name: env!("CARGO_PKG_NAME").to_string(),
reaction_mode: #reaction_mode_str.to_string(),
input_strategy: "latest".to_string(),
accumulators: vec![
#(
cloacina_workflow_plugin::AccumulatorDeclarationEntry {
name: #accumulator_names.to_string(),
accumulator_type: "passthrough".to_string(),
config: std::collections::HashMap::new(),
}
),*
],
})
}
fn execute_graph(&self, request: cloacina_workflow_plugin::GraphExecutionRequest) -> Result<cloacina_workflow_plugin::GraphExecutionResult, cloacina_workflow_plugin::PluginError> {
static CDYLIB_RUNTIME: std::sync::OnceLock<tokio::runtime::Runtime> = std::sync::OnceLock::new();
let rt = CDYLIB_RUNTIME.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(2)
.thread_name("cg-cdylib-worker")
.build()
.expect("Failed to create cdylib tokio runtime for computation graph")
});
let mut cache = cloacina_computation_graph::InputCache::new();
for (source_name, json_str) in &request.cache {
let value: serde_json::Value = serde_json::from_str(json_str)
.map_err(|e| cloacina_workflow_plugin::PluginError {
code: "DESERIALIZATION_ERROR".to_string(),
message: format!("Failed to parse cache entry '{}': {}", source_name, e),
details: None,
})?;
let bytes = cloacina_computation_graph::serialize(&value)
.map_err(|e| cloacina_workflow_plugin::PluginError {
code: "SERIALIZATION_ERROR".to_string(),
message: format!("Failed to serialize cache entry '{}': {}", source_name, e),
details: None,
})?;
cache.update(
cloacina_computation_graph::SourceName::new(source_name),
bytes,
);
}
let result = rt.block_on(async {
super::#compiled_fn_name(&cache).await
});
match result {
cloacina_computation_graph::GraphResult::Completed { outputs } => {
let terminal_json: Vec<String> = outputs
.iter()
.filter_map(|o| {
if let Some(val) = o.downcast_ref::<serde_json::Value>() {
Some(serde_json::to_string(val).unwrap_or_default())
} else {
None
}
})
.collect();
Ok(cloacina_workflow_plugin::GraphExecutionResult {
success: true,
terminal_outputs_json: if terminal_json.is_empty() { None } else { Some(terminal_json) },
error: None,
})
}
cloacina_computation_graph::GraphResult::Error(e) => {
Ok(cloacina_workflow_plugin::GraphExecutionResult {
success: false,
terminal_outputs_json: None,
error: Some(format!("{}", e)),
})
}
}
}
}
cloacina_workflow_plugin::fidius_plugin_registry!();
}
};
let (compiled_fn_body, ctor_body) = if is_cloacina_crate_early {
let fn_body = quote! {
#vis async fn #compiled_fn_name(
cache: &crate::computation_graph::InputCache,
) -> crate::computation_graph::GraphResult {
#[allow(unused_imports)]
use #mod_name::*;
#(#routing_use_stmts)*
#compiled_fn
}
};
let ctor = quote! {
#[cfg(not(test))]
#[cfg(not(feature = "packaged"))]
#[ctor::ctor]
fn #auto_register_name() {
crate::register_computation_graph_constructor(
#mod_name_str.to_string(),
|| {
crate::ComputationGraphRegistration {
graph_fn: std::sync::Arc::new(|cache: crate::computation_graph::InputCache| {
Box::pin(async move {
#compiled_fn_name(&cache).await
})
}),
accumulator_names: vec![#(#accumulator_names.to_string()),*],
reaction_mode: #reaction_mode_str.to_string(),
}
},
);
}
};
(fn_body, ctor)
} else {
let fn_body = quote! {
#vis async fn #compiled_fn_name(
cache: &cloacina_computation_graph::InputCache,
) -> cloacina_computation_graph::GraphResult {
#[allow(unused_imports)]
use #mod_name::*;
#(#routing_use_stmts)*
#compiled_fn
}
};
let ctor = quote! {
#[cfg(not(test))]
#[cfg(not(feature = "packaged"))]
#[ctor::ctor]
fn #auto_register_name() {
cloacina_computation_graph::register_computation_graph_constructor(
#mod_name_str.to_string(),
|| {
cloacina_computation_graph::ComputationGraphRegistration {
graph_fn: std::sync::Arc::new(|cache: cloacina_computation_graph::InputCache| {
Box::pin(async move {
#compiled_fn_name(&cache).await
})
}),
accumulator_names: vec![#(#accumulator_names.to_string()),*],
reaction_mode: #reaction_mode_str.to_string(),
}
},
);
}
};
(fn_body, ctor)
};
Ok(quote! {
#(#mod_attrs)*
#vis mod #mod_name {
#(#content)*
}
#compiled_fn_body
#ctor_body
#packaged_ffi
})
}
fn extract_functions(module: &ItemMod) -> syn::Result<HashMap<String, ItemFn>> {
let mut functions = HashMap::new();
if let Some((_, items)) = &module.content {
for item in items {
if let syn::Item::Fn(func) = item {
let name = func.sig.ident.to_string();
functions.insert(name, func.clone());
}
}
} else {
return Err(syn::Error::new(
module.ident.span(),
"computation_graph module must have inline content (use `mod name { ... }`, not `mod name;`)",
));
}
Ok(functions)
}
fn has_blocking_attr(func: &ItemFn) -> bool {
func.attrs.iter().any(|attr| {
if attr.path().is_ident("node") {
if let Ok(meta) = attr.parse_args::<Ident>() {
return meta == "blocking";
}
}
false
})
}
fn generate_compiled_function(
ir: &GraphIR,
functions: &HashMap<String, ItemFn>,
blocking_nodes: &HashSet<String>,
is_cloacina_crate: bool,
) -> syn::Result<TokenStream> {
let entry_nodes = ir.entry_nodes();
if entry_nodes.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"computation graph has no entry nodes (all nodes have incoming edges — possible cycle)",
));
}
let cache_reads = generate_cache_reads(ir);
let mut exec_stmts = Vec::new();
let mut generated_nodes: HashSet<String> = HashSet::new();
for node_name in &ir.sorted_nodes {
if generated_nodes.contains(node_name) {
continue;
}
let node = ir.get_node(node_name).unwrap();
let stmt = generate_node_execution(
ir,
node,
functions,
blocking_nodes,
&mut generated_nodes,
is_cloacina_crate,
)?;
exec_stmts.push(stmt);
}
let graph_result_completed = if is_cloacina_crate {
quote! { crate::computation_graph::GraphResult::completed(__terminal_results) }
} else {
quote! { cloacina_computation_graph::GraphResult::completed(__terminal_results) }
};
Ok(quote! {
let mut __terminal_results: Vec<Box<dyn std::any::Any + Send>> = Vec::new();
#cache_reads
#(#exec_stmts)*
#graph_result_completed
})
}
fn generate_cache_reads(ir: &GraphIR) -> TokenStream {
let mut reads = Vec::new();
let mut seen_inputs: HashSet<String> = HashSet::new();
for node in ir.nodes.values() {
for input in &node.cache_inputs {
if seen_inputs.insert(input.clone()) {
let var_name = format_ident!("__cache_{}", input);
let input_str = input.as_str();
reads.push(quote! {
let #var_name = cache.get(#input_str);
});
}
}
}
quote! { #(#reads)* }
}
fn generate_node_execution(
ir: &GraphIR,
node: &GraphNode,
functions: &HashMap<String, ItemFn>,
blocking_nodes: &HashSet<String>,
generated: &mut HashSet<String>,
is_cloacina_crate: bool,
) -> syn::Result<TokenStream> {
if generated.contains(&node.name) {
return Ok(quote! {});
}
generated.insert(node.name.clone());
let fn_ident = format_ident!("{}", node.name);
let result_var = format_ident!("__result_{}", node.name);
let is_blocking = blocking_nodes.contains(&node.name);
let args = generate_call_args(ir, node);
let call = if is_blocking {
let graph_error_path = if is_cloacina_crate {
quote! { crate::computation_graph::GraphError::NodeExecution }
} else {
quote! { cloacina_computation_graph::GraphError::NodeExecution }
};
quote! {
let #result_var = tokio::task::spawn_blocking(move || {
tokio::runtime::Handle::current().block_on(async {
#fn_ident(#args).await
})
}).await.map_err(|e| #graph_error_path(
format!("blocking node '{}' panicked: {}", stringify!(#fn_ident), e)
))?;
}
} else {
quote! {
let #result_var = #fn_ident(#args).await;
}
};
if node.edges_out.is_empty() {
Ok(quote! {
#call
__terminal_results.push(Box::new(#result_var) as Box<dyn std::any::Any + Send>);
})
} else if node.edges_out.len() == 1 {
match &node.edges_out[0] {
GraphEdge::Linear { .. } => {
Ok(call)
}
GraphEdge::Routing { variants } => {
let match_arms = generate_routing_match(
ir,
&node.name,
variants,
functions,
blocking_nodes,
generated,
is_cloacina_crate,
)?;
Ok(quote! {
#call
#match_arms
})
}
}
} else {
Ok(call)
}
}
fn generate_call_args(ir: &GraphIR, node: &GraphNode) -> TokenStream {
let mut args = Vec::new();
for input in &node.cache_inputs {
let var_name = format_ident!("__cache_{}", input);
args.push(quote! { #var_name.as_ref().map(|r| r.as_ref().ok()).flatten() });
}
for incoming in &node.edges_in {
let from_var = format_ident!("__result_{}", incoming.from);
if incoming.variant.is_some() {
let variant_var = format_ident!(
"__variant_{}_{}_{}",
incoming.from,
incoming.variant.as_ref().unwrap(),
node.name
);
args.push(quote! { &#variant_var });
} else {
args.push(quote! { &#from_var });
}
}
quote! { #(#args),* }
}
fn generate_routing_match(
ir: &GraphIR,
from_name: &str,
variants: &[super::graph_ir::GraphRoutingVariant],
functions: &HashMap<String, ItemFn>,
blocking_nodes: &HashSet<String>,
generated: &mut HashSet<String>,
is_cloacina_crate: bool,
) -> syn::Result<TokenStream> {
let result_var = format_ident!("__result_{}", from_name);
let mut arms = Vec::new();
for variant in variants {
let variant_ident = format_ident!("{}", variant.variant_name);
let variant_var = format_ident!(
"__variant_{}_{}_{}",
from_name,
variant.variant_name,
variant.target
);
let target_node = ir.get_node(&variant.target).ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
format!("routing target '{}' not found in graph", variant.target),
)
})?;
let downstream = generate_node_execution(
ir,
target_node,
functions,
blocking_nodes,
generated,
is_cloacina_crate,
)?;
arms.push(quote! {
#variant_ident(#variant_var) => {
#downstream
}
});
}
Ok(quote! {
match #result_var {
#(#arms)*
}
})
}
fn generate_routing_use_stmts(
ir: &GraphIR,
functions: &HashMap<String, ItemFn>,
mod_name: &Ident,
) -> Vec<TokenStream> {
let mut stmts = Vec::new();
for node in ir.nodes.values() {
let has_routing = node
.edges_out
.iter()
.any(|e| matches!(e, GraphEdge::Routing { .. }));
if !has_routing {
continue;
}
if let Some(func) = functions.get(&node.name) {
if let syn::ReturnType::Type(_, ty) = &func.sig.output {
stmts.push(quote! {
#[allow(unused_imports)]
use #mod_name::#ty::*;
});
}
}
}
stmts
}