use std::collections::HashMap;
use quote::ToTokens;
use syn::{visit::Visit, Expr, ExprField, Ident, Member};
use twine_core::graph::{ComponentGraph, Target};
pub(crate) fn build_graph(input_exprs: &HashMap<Ident, Expr>) -> ComponentGraph {
let mut graph_builder = GraphBuilder::new();
for (component, expr) in input_exprs {
match expr {
Expr::Struct(inputs) => {
inputs.fields.iter().for_each(|field| {
let input_field = match &field.member {
Member::Named(ident) => ident.to_string(),
Member::Unnamed(index) => index.index.to_string(),
};
graph_builder
.with_target(component, &input_field)
.visit_expr(&field.expr);
});
}
expr => {
graph_builder.with_target(component, "").visit_expr(expr);
}
}
}
graph_builder.graph
}
struct GraphBuilder {
graph: ComponentGraph,
current_target: Target,
}
impl GraphBuilder {
fn new() -> Self {
Self {
graph: ComponentGraph::new(),
current_target: Target {
component: String::new(),
input: String::new(),
},
}
}
fn with_target<T, U>(&mut self, component: &T, input: &U) -> &mut Self
where
T: ToString + ?Sized,
U: ToString + ?Sized,
{
self.current_target = Target::new(component.to_string(), input.to_string());
self
}
}
impl Visit<'_> for GraphBuilder {
fn visit_expr_field(&mut self, node: &ExprField) {
let full_path = node.to_token_stream().to_string().replace(" . ", ".");
if let Some(stripped) = full_path.strip_prefix("input") {
self.graph.connect(
match stripped.split_once('.') {
Some((_, field)) => ("__input", field),
None => ("__input", ""),
},
self.current_target.clone(),
);
}
if let Some(stripped) = full_path.strip_prefix("output.") {
self.graph.connect(
match stripped.split_once('.') {
Some((component, output_field)) => (component, output_field),
None => (stripped, ""),
},
self.current_target.clone(),
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use itertools::Itertools;
use syn::parse_quote;
use twine_core::graph::Source;
fn sorted_incoming(graph: &ComponentGraph, component: &str) -> Vec<Source> {
graph
.incoming_connections(component)
.sorted_by_key(|source| format!("{source:?}"))
.collect_vec()
}
fn sorted_outgoing(graph: &ComponentGraph, component: &str) -> Vec<Target> {
graph
.outgoing_connections(component)
.sorted_by_key(|target| format!("{target:?}"))
.collect_vec()
}
#[test]
fn build_component_graph_works() {
let input_exprs: HashMap<Ident, Expr> = vec![
(
parse_quote!(first),
parse_quote!(FirstInput {
x: input.value,
y: output.second
}),
),
(parse_quote!(second), parse_quote!(output.first.value)),
(
parse_quote!(third),
parse_quote!(ThirdInput {
input_1: output.first.nested.output,
input_2: output.second,
}),
),
]
.into_iter()
.collect();
let graph = build_graph(&input_exprs);
assert_eq!(
sorted_incoming(&graph, "first"),
vec![Source::new("__input", "value"), Source::new("second", "")],
);
assert_eq!(
sorted_incoming(&graph, "second"),
vec![Source::new("first", "value")],
);
assert_eq!(
sorted_incoming(&graph, "third"),
vec![
Source::new("first", "nested.output"),
Source::new("second", "")
],
);
assert_eq!(
sorted_outgoing(&graph, "first"),
vec![Target::new("second", ""), Target::new("third", "input_1")],
);
assert_eq!(
sorted_outgoing(&graph, "second"),
vec![Target::new("first", "y"), Target::new("third", "input_2")],
);
assert_eq!(sorted_outgoing(&graph, "third"), vec![]);
}
}