use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::collections::{HashMap, HashSet};
use syn::{
Ident, Result,
parse::{Parse, ParseStream},
punctuated::Punctuated,
token,
};
struct Stage {
name: Ident,
deps: Vec<Ident>,
}
impl Parse for Stage {
fn parse(input: ParseStream) -> Result<Self> {
let name: Ident = input.parse()?;
input.parse::<token::FatArrow>()?;
let content;
syn::bracketed!(content in input);
let deps: Punctuated<Ident, token::Comma> =
content.parse_terminated(Ident::parse, token::Comma)?;
Ok(Self {
name,
deps: deps.into_iter().collect(),
})
}
}
struct PipelineInput {
name: Ident,
stages: Vec<Stage>,
}
impl Parse for PipelineInput {
fn parse(input: ParseStream) -> Result<Self> {
let name_kw: Ident = input.parse()?;
if name_kw != "name" {
return Err(syn::Error::new(
name_kw.span(),
"expected `name: <Ident>` as the first field",
));
}
input.parse::<token::Colon>()?;
let name: Ident = input.parse()?;
input.parse::<token::Comma>()?;
let stages_kw: Ident = input.parse()?;
if stages_kw != "stages" {
return Err(syn::Error::new(
stages_kw.span(),
"expected `stages: { ... }` as the second field",
));
}
input.parse::<token::Colon>()?;
let braced;
syn::braced!(braced in input);
let parsed: Punctuated<Stage, token::Comma> =
braced.parse_terminated(Stage::parse, token::Comma)?;
Ok(Self {
name,
stages: parsed.into_iter().collect(),
})
}
}
fn topo_sort(stages: &[Stage]) -> std::result::Result<Vec<String>, String> {
let names: Vec<String> = stages.iter().map(|s| s.name.to_string()).collect();
let name_set: HashSet<&str> = names.iter().map(|s| s.as_str()).collect();
for s in stages {
for d in &s.deps {
if !name_set.contains(d.to_string().as_str()) {
return Err(format!(
"stage `{}` depends on `{}` which is not declared",
s.name, d
));
}
}
}
let mut indeg: HashMap<String, usize> = names.iter().map(|n| (n.clone(), 0)).collect();
let mut adj: HashMap<String, Vec<String>> = HashMap::new();
for s in stages {
let me = s.name.to_string();
for d in &s.deps {
let dep = d.to_string();
adj.entry(dep).or_default().push(me.clone());
*indeg.get_mut(&me).unwrap() += 1;
}
}
let mut order: Vec<String> = Vec::with_capacity(stages.len());
let mut ready: Vec<String> = names
.iter()
.filter(|n| indeg[n.as_str()] == 0)
.cloned()
.collect();
while let Some(n) = ready.pop() {
order.push(n.clone());
if let Some(succs) = adj.get(&n) {
for s in succs {
let d = indeg.get_mut(s).unwrap();
*d -= 1;
if *d == 0 {
ready.push(s.clone());
}
}
}
}
if order.len() != stages.len() {
let unresolved: Vec<&str> = indeg
.iter()
.filter(|(_, d)| **d > 0)
.map(|(n, _)| n.as_str())
.collect();
return Err(format!(
"pipeline has a dependency cycle through: {}",
unresolved.join(", ")
));
}
Ok(order)
}
pub fn pipeline_schedule_impl(input: TokenStream2) -> TokenStream2 {
let parsed: PipelineInput = match syn::parse2(input) {
Ok(p) => p,
Err(e) => return e.to_compile_error(),
};
let order = match topo_sort(&parsed.stages) {
Ok(o) => o,
Err(msg) => {
return quote! { compile_error!(#msg); };
}
};
let name = parsed.name;
let order_lits: Vec<_> = order.iter().map(|s| quote! { #s }).collect();
let dep_pairs: Vec<TokenStream2> = parsed
.stages
.iter()
.map(|s| {
let stage_str = s.name.to_string();
let deps: Vec<_> = s
.deps
.iter()
.map(|d| {
let s = d.to_string();
quote! { #s }
})
.collect();
if deps.is_empty() {
quote! { (#stage_str, &[]) }
} else {
quote! { (#stage_str, &[#(#deps),*]) }
}
})
.collect();
quote! {
pub struct #name;
impl #name {
pub const ORDER: &'static [&'static str] = &[
#(#order_lits),*
];
pub const DEPS: &'static [(&'static str, &'static [&'static str])] = &[
#(#dep_pairs),*
];
}
}
}