Skip to main content

fsmentry_core/
lib.rs

1//! A code generator for state machines with an entry API.
2//!
3//! See the [`fsmentry` crate](https://docs.rs/fsmentry).
4
5mod args;
6mod dsl;
7mod graph;
8
9use std::{
10    collections::{BTreeMap, BTreeSet},
11    fmt::Write as _,
12    iter,
13};
14
15use args::*;
16use proc_macro2::{Span, TokenStream};
17use quote::quote;
18use quote::ToTokens;
19use syn::{
20    parse::{Parse, ParseStream},
21    parse_quote,
22    punctuated::Punctuated,
23    spanned::Spanned as _,
24    Arm, Attribute, Expr, Generics, Ident, ImplGenerics, ItemImpl, ItemStruct, Lifetime, Token,
25    Type, TypeGenerics, Variant, Visibility, WhereClause,
26};
27
28use crate::dsl::*;
29use crate::graph::*;
30
31macro_rules! bail_at {
32    ($span:expr, $fmt:literal $(, $arg:expr)* $(,)?) => {
33        return Err(syn::Error::new($span, format!($fmt, $($arg,)*)))
34    };
35}
36
37/// Renderer for mermaid diagrams.
38pub trait Renderer {
39    /// Return [`None`] to skip rendering.
40    fn render(&self, diagram: &str) -> Option<String>;
41}
42
43/// Skip rendering entirely.
44impl Renderer for () {
45    fn render(&self, _: &str) -> Option<String> {
46        None
47    }
48}
49
50/// Forward to the inner [`Renderer`], if present.
51impl<T: Renderer> Renderer for Option<T> {
52    fn render(&self, diagram: &str) -> Option<String> {
53        self.as_ref().and_then(|it| it.render(diagram))
54    }
55}
56
57/// Call the provided function.
58impl<F: Fn(&str) -> Option<String>> Renderer for F {
59    fn render(&self, diagram: &str) -> Option<String> {
60        self(diagram)
61    }
62}
63
64/// A [`Renderer`] which embeds a script to load `mermaidjs` into the docs.
65pub struct Mermaid(
66    /// The URL to import mermaid from.
67    pub String,
68);
69
70impl Default for Mermaid {
71    fn default() -> Self {
72        Self(String::from(
73            "https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.esm.min.mjs",
74        ))
75    }
76}
77
78impl Renderer for Mermaid {
79    fn render(&self, diagram: &str) -> Option<String> {
80        Some(format!(
81            "\
82<pre class=\"mermaid\">
83{diagram}
84</pre>
85<script type=\"module\">
86  import mermaid from \"{}\";
87  var doc_theme = localStorage.getItem(\"rustdoc-theme\");
88  if (doc_theme === \"dark\" || doc_theme === \"ayu\") mermaid.initialize({{theme: \"dark\"}});
89</script>",
90            self.0
91        ))
92    }
93}
94
95/// A [`Parse`]-able and [printable](ToTokens) representation of a state machine.
96pub struct FsmEntry<MermaidR = ()> {
97    state_attrs: Vec<Attribute>,
98    state_vis: Visibility,
99    state_ident: Ident,
100    state_generics: Generics,
101
102    r#unsafe: bool,
103    path_to_core: ModulePath,
104
105    entry_vis: Visibility,
106    entry_ident: Ident,
107    entry_lifetime: Lifetime,
108
109    graph: Graph,
110
111    render_mermaid: bool,
112    mermaid_renderer: MermaidR,
113}
114
115impl<MermaidR> FsmEntry<MermaidR> {
116    /// Change the mermaid renderer.
117    pub fn map_mermaid<F, MermaidR2>(self, f: F) -> FsmEntry<MermaidR2>
118    where
119        F: FnOnce(MermaidR) -> MermaidR2,
120    {
121        let Self {
122            state_attrs,
123            state_vis,
124            state_ident,
125            state_generics,
126            r#unsafe,
127            path_to_core,
128            entry_vis,
129            entry_ident,
130            entry_lifetime,
131            graph,
132            render_mermaid,
133            mermaid_renderer,
134        } = self;
135        FsmEntry {
136            state_attrs,
137            state_vis,
138            state_ident,
139            state_generics,
140            r#unsafe,
141            path_to_core,
142            entry_vis,
143            entry_ident,
144            entry_lifetime,
145            graph,
146            render_mermaid,
147            mermaid_renderer: f(mermaid_renderer),
148        }
149    }
150    fn nodes(&self) -> impl Iterator<Item = &Ident> {
151        self.graph.nodes.keys().map(|NodeId(ident)| ident)
152    }
153    fn edges(&self) -> impl Iterator<Item = (&Ident, &Ident)> {
154        self.graph.edges.keys().map(|(NodeId(f), NodeId(t))| (f, t))
155    }
156    pub fn dot(&self) -> String {
157        let mut s = format!("digraph {}{{\n", self.state_ident);
158        for draw in self.draw() {
159            match draw {
160                Draw::Edge(l, r) => s.write_fmt(format_args!("  {l} -> {r};\n")),
161                Draw::Node(it) => s.write_fmt(format_args!("  {it};\n")),
162            }
163            .unwrap();
164        }
165        s.push_str("}\n");
166        s
167    }
168    pub fn mermaid(&self) -> String {
169        let mut s = String::from("graph LR\n");
170        for draw in self.draw() {
171            match draw {
172                Draw::Edge(l, r) => s.write_fmt(format_args!("  {l} --> {r};\n")),
173                Draw::Node(it) => s.write_fmt(format_args!("  {it};\n")),
174            }
175            .unwrap()
176        }
177        s
178    }
179    fn draw(&self) -> impl Iterator<Item = Draw<'_>> {
180        let mut nodes = self.nodes().collect::<BTreeSet<_>>();
181        let edges = self
182            .edges()
183            .map(|(l, r)| {
184                nodes.remove(l);
185                nodes.remove(r);
186                Draw::Edge(l, r)
187            })
188            .collect::<Vec<_>>();
189        edges.into_iter().chain(nodes.into_iter().map(Draw::Node))
190    }
191}
192enum Draw<'a> {
193    Edge(&'a Ident, &'a Ident),
194    Node(&'a Ident),
195}
196
197impl<MermaidR: Renderer> ToTokens for FsmEntry<MermaidR> {
198    fn to_tokens(&self, tokens: &mut TokenStream) {
199        let Self {
200            state_attrs,
201            state_vis,
202            state_ident,
203            state_generics,
204            r#unsafe,
205            path_to_core,
206            entry_vis,
207            entry_ident,
208            entry_lifetime,
209            graph,
210            mermaid_renderer,
211            render_mermaid,
212        } = self;
213        let mut state_variants: Vec<Variant> = vec![];
214        let mut entry_variants: Vec<Variant> = vec![];
215        let mut entry_structs: Vec<ItemStruct> = vec![];
216        let mut match_ctor: Vec<Arm> = vec![];
217        let mut as_ref_as_mut: Vec<ItemImpl> = vec![];
218        let mut transition: Vec<ItemImpl> = vec![];
219
220        let replace: ModulePath = parse_quote!(#path_to_core::mem::replace);
221        let panik: &Expr = &match r#unsafe {
222            true => parse_quote!(unsafe { #path_to_core::hint::unreachable_unchecked() }),
223            false => {
224                parse_quote!(#path_to_core::panic!("entry struct was instantiated with a mismatched state"))
225            }
226        };
227
228        let entry_generics = {
229            let mut it = state_generics.clone();
230            it.params.insert(0, parse_quote!(#entry_lifetime));
231            it
232        };
233        let (state_impl_generics, state_type_generics, _) = state_generics.split_for_impl();
234        let (entry_impl_generics, entry_type_generics, where_clause) =
235            entry_generics.split_for_impl();
236
237        for (node, NodeData { doc, ty }, ref kind) in graph.nodes() {
238            state_variants.push(match ty {
239                Some(ty) => parse_quote!(#(#doc)* #node(#ty)),
240                None => parse_quote!(#(#doc)* #node),
241            });
242            match_ctor.push(match (ty, kind) {
243                (Some(_), Kind::Isolate | Kind::Sink(_)) => {
244                    parse_quote!(#state_ident::#node(it) => #entry_ident::#node(it))
245                }
246                (None, Kind::Isolate | Kind::Sink(_)) => {
247                    parse_quote!(#state_ident::#node     => #entry_ident::#node)
248                }
249                (Some(_), Kind::NonTerminal { .. } | Kind::Source(_)) => {
250                    parse_quote!(#state_ident::#node(_)  => #entry_ident::#node(#node(value)))
251                }
252                (None, Kind::NonTerminal { .. } | Kind::Source(_)) => {
253                    parse_quote!(#state_ident::#node     => #entry_ident::#node(#node(value)))
254                }
255            });
256            let reachability = reachability_docs(&node.0, state_ident, kind);
257            entry_variants.push(match kind {
258                Kind::Isolate | Kind::Sink(_) => match ty {
259                    Some(ty) => parse_quote!(#(#reachability)* #node(&#entry_lifetime mut #ty)),
260                    None => parse_quote!(#(#reachability)* #node),
261                },
262                Kind::Source(_) | Kind::NonTerminal { .. } => {
263                    parse_quote!(#(#reachability)* #node(#node #entry_type_generics))
264                }
265            });
266            if let Kind::Source(outgoing) | Kind::NonTerminal { outgoing, .. } = kind {
267                let outer_doc = format!(" See [`{entry_ident}::{node}`]");
268                let field_doc = format!(" MUST match [`{entry_ident}::{node}`]");
269                entry_structs.push(parse_quote! {
270                    #[doc = #outer_doc]
271                    #entry_vis struct #node #entry_type_generics(
272                        #[doc = #field_doc]
273                        & #entry_lifetime mut #state_ident #state_type_generics
274                    )
275                    #where_clause;
276                });
277                for (dst, NodeData { ty: dst_ty, .. }, EdgeData { method_name, doc }) in outgoing {
278                    let body = make_body(
279                        state_ident,
280                        node,
281                        ty.as_ref(),
282                        dst,
283                        dst_ty.as_ref(),
284                        method_name,
285                        &replace,
286                        panik,
287                    );
288                    let pointer = DocAttr::new(
289                        &format!(" Transition to [`{state_ident}::{}`]", dst.0),
290                        Span::call_site(),
291                    );
292                    let pointer = match doc.is_empty() {
293                        true => vec![pointer],
294                        false => vec![DocAttr::empty(), pointer],
295                    };
296                    transition.push(parse_quote! {
297                        #[allow(clippy::needless_lifetimes)]
298                        impl #entry_impl_generics #node #entry_type_generics
299                        #where_clause
300                        {
301                            #(#doc)*
302                            #(#pointer)*
303                            #body
304                        }
305                    });
306                }
307
308                if let Some(ty) = ty {
309                    as_ref_as_mut.extend(make_as_ref_mut(
310                        &entry_impl_generics,
311                        path_to_core,
312                        ty,
313                        state_ident,
314                        &node.0,
315                        &entry_type_generics,
316                        where_clause,
317                        panik,
318                    ));
319                }
320            }
321        }
322
323        let mut entry_attrs: Vec<Attribute> = vec![{
324            let doc = format!(" Progress through variants of [`{state_ident}`], created by its [`entry`]({state_ident}::entry) method.");
325            parse_quote!(#[doc = #doc])
326        }];
327
328        if *render_mermaid {
329            if let Some(rendered) = mermaid_renderer.render(&self.mermaid()) {
330                if !entry_attrs.is_empty() {
331                    entry_attrs.push(parse_quote!(#[doc = ""]));
332                }
333                entry_attrs.push(parse_quote!(#[doc = #rendered]));
334            }
335        }
336
337        tokens.extend(quote! {
338            #(#state_attrs)*
339            #state_vis enum #state_ident #state_generics #where_clause {
340                #(#state_variants),*
341            }
342            #(#entry_attrs)*
343            #entry_vis enum #entry_ident #entry_generics #where_clause {
344                #(#entry_variants),*
345            }
346            impl #entry_impl_generics
347                #path_to_core::convert::From<& #entry_lifetime mut #state_ident #state_generics>
348            for #entry_ident #entry_type_generics
349            #where_clause {
350                fn from(value: & #entry_lifetime mut #state_ident #state_generics) -> Self {
351                    match value {
352                        #(#match_ctor),*
353                    }
354                }
355            }
356            impl #state_impl_generics #state_ident #state_type_generics
357            #where_clause {
358                #[allow(clippy::needless_lifetimes)]
359                #entry_vis fn entry<#entry_lifetime>(& #entry_lifetime mut self) -> #entry_ident #entry_type_generics {
360                    self.into()
361                }
362            }
363            #(#entry_structs)*
364            #(#as_ref_as_mut)*
365            #(#transition)*
366        });
367    }
368}
369
370impl Parse for FsmEntry {
371    fn parse(input: ParseStream) -> syn::Result<Self> {
372        let Root {
373            attrs: mut state_attrs,
374            vis: state_vis,
375            r#enum: _,
376            ident: state_ident,
377            generics: state_generics,
378            brace: _,
379            stmts,
380        } = input.parse()?;
381
382        let mut rename_methods = true;
383        let mut entry = VisIdent {
384            vis: state_vis.clone(),
385            ident: Ident::new(&format!("{}Entry", state_ident), Span::call_site()),
386        };
387        let mut r#unsafe = false;
388        let mut path_to_core: ModulePath = parse_quote!(::core);
389        let mut render_mermaid = false;
390        let mut parser = Parser::new()
391            .once("rename_methods", on_value(bool(&mut rename_methods)))
392            .once("entry", on_value(parse(&mut entry)))
393            .once("unsafe", on_value(bool(&mut r#unsafe)))
394            .once("path_to_core", on_value(parse(&mut path_to_core)))
395            .once("mermaid", on_value(bool(&mut render_mermaid)));
396        parser.extract("fsmentry", &mut state_attrs)?;
397        drop(parser);
398        let graph = stmts2graph(&stmts, rename_methods)?;
399        if graph.edges.is_empty() {
400            bail_at!(state_ident.span(), "must define at least one edge `A -> B`");
401        }
402        let VisIdent {
403            vis: entry_vis,
404            ident: entry_ident,
405        } = entry;
406
407        Ok(Self {
408            state_attrs,
409            state_vis,
410            state_ident,
411            state_generics,
412            r#unsafe,
413            path_to_core,
414            entry_vis,
415            entry_ident,
416            entry_lifetime: parse_quote!('state),
417            graph,
418            mermaid_renderer: (),
419            render_mermaid,
420        })
421    }
422}
423
424fn stmts2graph(
425    stmts: &Punctuated<Statement, Token![,]>,
426    rename_methods: bool,
427) -> syn::Result<Graph> {
428    use std::collections::btree_map::Entry::{Occupied, Vacant};
429
430    let mut nodes = BTreeMap::<NodeId, NodeData>::new();
431    let mut edges = BTreeMap::<(NodeId, NodeId), EdgeData>::new();
432
433    // Define all the nodes upfront.
434    // Note that transition definitions may include types, at any location.
435    for Node { name, ty, doc } in stmts.iter().flat_map(|it| match it {
436        Statement::Node(it) => Box::new(iter::once(it)) as Box<dyn Iterator<Item = &Node>>,
437        Statement::Transition { first, rest, .. } => Box::new(
438            first
439                .into_iter()
440                .chain(rest.iter().flat_map(|(_, grp)| grp)),
441        ),
442    }) {
443        let ty = ty.as_ref().map(|(_, it)| it);
444        match nodes.entry(NodeId(name.clone())) {
445            Occupied(mut occ) => match (&occ.get().ty, ty) {
446                (None, Some(_)) | (Some(_), None) | (None, None) => {
447                    append_docs(&mut occ.get_mut().doc, doc)
448                }
449                // don't compile `syn` with `extra-traits`
450                (Some(l), Some(r))
451                    if l.to_token_stream().to_string() == r.to_token_stream().to_string() =>
452                {
453                    append_docs(&mut occ.get_mut().doc, doc)
454                }
455                (Some(_), Some(_)) => bail_at!(name.span(), "incompatible redefinition"),
456            },
457            Vacant(v) => {
458                v.insert(NodeData {
459                    ty: ty.cloned(),
460                    doc: doc.clone(),
461                });
462            }
463        };
464    }
465
466    for stmt in stmts {
467        let Statement::Transition { first, rest } = stmt else {
468            continue; // handled above
469        };
470
471        let mut grp_left = first;
472
473        for (Arrow { doc, kind }, grp_right) in rest {
474            for from in grp_left {
475                for to in grp_right {
476                    match edges.entry((NodeId(from.name.clone()), NodeId(to.name.clone()))) {
477                        Occupied(already) => {
478                            let (a, b) = already.key();
479                            bail_at!(kind.span(), "duplicate edge definition between {a} and {b}")
480                        }
481                        Vacant(v) => {
482                            v.insert(EdgeData {
483                                doc: doc.clone(),
484                                method_name: match kind {
485                                    ArrowKind::Plain(_) => match rename_methods {
486                                        true => snake_case(&to.name),
487                                        false => to.name.clone(),
488                                    },
489                                    ArrowKind::Named { ident, .. } => ident.clone(),
490                                },
491                            });
492                        }
493                    }
494                }
495            }
496            grp_left = grp_right;
497        }
498    }
499
500    Ok(Graph { nodes, edges })
501}
502
503fn reachability_docs(node_ident: &Ident, state_ident: &Ident, kind: &Kind<'_>) -> Vec<DocAttr> {
504    let span = Span::call_site();
505    let mut dst = vec![DocAttr::new(
506        &format!(" Represents [`{state_ident}::{node_ident}`]"),
507        span,
508    )];
509    if let Kind::Sink(incoming) | Kind::NonTerminal { incoming, .. } = kind {
510        dst.extend([
511            DocAttr::empty(),
512            DocAttr::new(" This state is reachable from the following:", span),
513        ]);
514        dst.extend(incoming.iter().map(|(NodeId(other), _, EdgeData { method_name, .. })| {
515            let s = format!(" - [`{other}`]({state_ident}::{other}) via [`{method_name}`]({other}::{method_name})");
516            DocAttr::new(&s, Span::call_site())
517        }));
518    }
519    if let Kind::Source(outgoing) | Kind::NonTerminal { outgoing, .. } = kind {
520        dst.extend([
521            DocAttr::empty(),
522            DocAttr::new(" This state can transition to the following:", span),
523        ]);
524        dst.extend(outgoing.iter().map(|(NodeId(other), _, EdgeData { method_name, .. })| {
525            let s = format!(" - [`{other}`]({state_ident}::{other}) via [`{method_name}`]({node_ident}::{method_name})");
526            DocAttr::new(&s, Span::call_site())
527        }));
528    }
529    dst
530}
531
532fn append_docs(dst: &mut Vec<DocAttr>, src: &[DocAttr]) {
533    match (dst.is_empty(), src.is_empty()) {
534        (true, true) => {}
535        (true, false) => dst.extend_from_slice(src),
536        (false, true) => {}
537        (false, false) => {
538            dst.push(DocAttr::empty());
539            dst.extend_from_slice(src);
540        }
541    }
542}
543
544fn snake_case(ident: &Ident) -> Ident {
545    let ident = ident.to_string();
546    let mut snake = String::new();
547    for (i, ch) in ident.char_indices() {
548        if i > 0 && ch.is_uppercase() {
549            snake.push('_');
550        }
551        snake.push(ch.to_ascii_lowercase());
552    }
553    match (syn::parse_str(&snake), {
554        snake.insert_str(0, "r#");
555        syn::parse_str(&snake)
556    }) {
557        (Ok(it), _) | (_, Ok(it)) => it,
558        _ => panic!("bad ident {ident}"),
559    }
560}
561
562#[allow(clippy::too_many_arguments)]
563fn make_body(
564    state_ident: &Ident,
565    node: &NodeId,
566    ty: Option<&Type>,
567    dst: &NodeId,
568    dst_ty: Option<&Type>,
569    method_name: &Ident,
570    replace: &ModulePath,
571    panik: &Expr,
572) -> TokenStream {
573    match (ty, dst_ty) {
574        (None, None) => quote! {
575            pub fn #method_name(self) {
576                match #replace(self.0, #state_ident::#dst) {
577                    #state_ident::#node => {},
578                    _ => #panik,
579                }
580            }
581        },
582        (None, Some(dst_ty)) => quote! {
583            pub fn #method_name(self, next: #dst_ty) {
584                match #replace(self.0, #state_ident::#dst(next)) {
585                    #state_ident::#node => {},
586                    _ => #panik,
587                }
588            }
589        },
590        (Some(ty), None) => quote! {
591            pub fn #method_name(self) -> #ty {
592                match #replace(self.0, #state_ident::#dst) {
593                    #state_ident::#node(it) => it,
594                    _ => #panik,
595                }
596            }
597        },
598        (Some(ty), Some(dst_ty)) => quote! {
599            pub fn #method_name(self, next: #dst_ty) -> #ty {
600                match #replace(self.0, #state_ident::#dst(next)) {
601                    #state_ident::#node(it) => it,
602                    _ => #panik,
603                }
604            }
605        },
606    }
607}
608
609#[allow(clippy::too_many_arguments)]
610fn make_as_ref_mut(
611    entry_impl_generics: &ImplGenerics,
612    path_to_core: &ModulePath,
613    ty: &Type,
614    state_ident: &Ident,
615    node_ident: &Ident,
616    entry_type_generics: &TypeGenerics,
617    where_clause: Option<&WhereClause>,
618    panik: &Expr,
619) -> [ItemImpl; 2] {
620    let as_ref = parse_quote! {
621        #[allow(clippy::needless_lifetimes)]
622        impl #entry_impl_generics #path_to_core::convert::AsRef<#ty> for #node_ident #entry_type_generics
623        #where_clause
624        {
625            fn as_ref(&self) -> &#ty {
626                match &self.0 {
627                    #state_ident::#node_ident(it) => it,
628                    _ => #panik
629                }
630            }
631        }
632    };
633    let as_mut = parse_quote! {
634        #[allow(clippy::needless_lifetimes)]
635        impl #entry_impl_generics #path_to_core::convert::AsMut<#ty> for #node_ident #entry_type_generics
636        #where_clause
637        {
638            fn as_mut(&mut self) -> &mut #ty {
639                match &mut self.0 {
640                    #state_ident::#node_ident(it) => it,
641                    _ => #panik
642                }
643            }
644        }
645    };
646    [as_ref, as_mut]
647}