1mod args;
6mod dsl;
7mod graph;
8
9use std::{
10 collections::{BTreeMap, BTreeSet},
11 fmt::Write as _,
12 iter,
13};
14
15use args::*;
16use proc_macro2::{Span, TokenStream};
17use quote::quote;
18use quote::ToTokens;
19use syn::{
20 parse::{Parse, ParseStream},
21 parse_quote,
22 punctuated::Punctuated,
23 spanned::Spanned as _,
24 Arm, Attribute, Expr, Generics, Ident, ImplGenerics, ItemImpl, ItemStruct, Lifetime, Token,
25 Type, TypeGenerics, Variant, Visibility, WhereClause,
26};
27
28use crate::dsl::*;
29use crate::graph::*;
30
31macro_rules! bail_at {
32 ($span:expr, $fmt:literal $(, $arg:expr)* $(,)?) => {
33 return Err(syn::Error::new($span, format!($fmt, $($arg,)*)))
34 };
35}
36
37pub trait Renderer {
39 fn render(&self, diagram: &str) -> Option<String>;
41}
42
43impl Renderer for () {
45 fn render(&self, _: &str) -> Option<String> {
46 None
47 }
48}
49
50impl<T: Renderer> Renderer for Option<T> {
52 fn render(&self, diagram: &str) -> Option<String> {
53 self.as_ref().and_then(|it| it.render(diagram))
54 }
55}
56
57impl<F: Fn(&str) -> Option<String>> Renderer for F {
59 fn render(&self, diagram: &str) -> Option<String> {
60 self(diagram)
61 }
62}
63
64pub struct Mermaid(
66 pub String,
68);
69
70impl Default for Mermaid {
71 fn default() -> Self {
72 Self(String::from(
73 "https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.esm.min.mjs",
74 ))
75 }
76}
77
78impl Renderer for Mermaid {
79 fn render(&self, diagram: &str) -> Option<String> {
80 Some(format!(
81 "\
82<pre class=\"mermaid\">
83{diagram}
84</pre>
85<script type=\"module\">
86 import mermaid from \"{}\";
87 var doc_theme = localStorage.getItem(\"rustdoc-theme\");
88 if (doc_theme === \"dark\" || doc_theme === \"ayu\") mermaid.initialize({{theme: \"dark\"}});
89</script>",
90 self.0
91 ))
92 }
93}
94
95pub struct FsmEntry<MermaidR = ()> {
97 state_attrs: Vec<Attribute>,
98 state_vis: Visibility,
99 state_ident: Ident,
100 state_generics: Generics,
101
102 r#unsafe: bool,
103 path_to_core: ModulePath,
104
105 entry_vis: Visibility,
106 entry_ident: Ident,
107 entry_lifetime: Lifetime,
108
109 graph: Graph,
110
111 render_mermaid: bool,
112 mermaid_renderer: MermaidR,
113}
114
115impl<MermaidR> FsmEntry<MermaidR> {
116 pub fn map_mermaid<F, MermaidR2>(self, f: F) -> FsmEntry<MermaidR2>
118 where
119 F: FnOnce(MermaidR) -> MermaidR2,
120 {
121 let Self {
122 state_attrs,
123 state_vis,
124 state_ident,
125 state_generics,
126 r#unsafe,
127 path_to_core,
128 entry_vis,
129 entry_ident,
130 entry_lifetime,
131 graph,
132 render_mermaid,
133 mermaid_renderer,
134 } = self;
135 FsmEntry {
136 state_attrs,
137 state_vis,
138 state_ident,
139 state_generics,
140 r#unsafe,
141 path_to_core,
142 entry_vis,
143 entry_ident,
144 entry_lifetime,
145 graph,
146 render_mermaid,
147 mermaid_renderer: f(mermaid_renderer),
148 }
149 }
150 fn nodes(&self) -> impl Iterator<Item = &Ident> {
151 self.graph.nodes.keys().map(|NodeId(ident)| ident)
152 }
153 fn edges(&self) -> impl Iterator<Item = (&Ident, &Ident)> {
154 self.graph.edges.keys().map(|(NodeId(f), NodeId(t))| (f, t))
155 }
156 pub fn dot(&self) -> String {
157 let mut s = format!("digraph {}{{\n", self.state_ident);
158 for draw in self.draw() {
159 match draw {
160 Draw::Edge(l, r) => s.write_fmt(format_args!(" {l} -> {r};\n")),
161 Draw::Node(it) => s.write_fmt(format_args!(" {it};\n")),
162 }
163 .unwrap();
164 }
165 s.push_str("}\n");
166 s
167 }
168 pub fn mermaid(&self) -> String {
169 let mut s = String::from("graph LR\n");
170 for draw in self.draw() {
171 match draw {
172 Draw::Edge(l, r) => s.write_fmt(format_args!(" {l} --> {r};\n")),
173 Draw::Node(it) => s.write_fmt(format_args!(" {it};\n")),
174 }
175 .unwrap()
176 }
177 s
178 }
179 fn draw(&self) -> impl Iterator<Item = Draw<'_>> {
180 let mut nodes = self.nodes().collect::<BTreeSet<_>>();
181 let edges = self
182 .edges()
183 .map(|(l, r)| {
184 nodes.remove(l);
185 nodes.remove(r);
186 Draw::Edge(l, r)
187 })
188 .collect::<Vec<_>>();
189 edges.into_iter().chain(nodes.into_iter().map(Draw::Node))
190 }
191}
192enum Draw<'a> {
193 Edge(&'a Ident, &'a Ident),
194 Node(&'a Ident),
195}
196
197impl<MermaidR: Renderer> ToTokens for FsmEntry<MermaidR> {
198 fn to_tokens(&self, tokens: &mut TokenStream) {
199 let Self {
200 state_attrs,
201 state_vis,
202 state_ident,
203 state_generics,
204 r#unsafe,
205 path_to_core,
206 entry_vis,
207 entry_ident,
208 entry_lifetime,
209 graph,
210 mermaid_renderer,
211 render_mermaid,
212 } = self;
213 let mut state_variants: Vec<Variant> = vec![];
214 let mut entry_variants: Vec<Variant> = vec![];
215 let mut entry_structs: Vec<ItemStruct> = vec![];
216 let mut match_ctor: Vec<Arm> = vec![];
217 let mut as_ref_as_mut: Vec<ItemImpl> = vec![];
218 let mut transition: Vec<ItemImpl> = vec![];
219
220 let replace: ModulePath = parse_quote!(#path_to_core::mem::replace);
221 let panik: &Expr = &match r#unsafe {
222 true => parse_quote!(unsafe { #path_to_core::hint::unreachable_unchecked() }),
223 false => {
224 parse_quote!(#path_to_core::panic!("entry struct was instantiated with a mismatched state"))
225 }
226 };
227
228 let entry_generics = {
229 let mut it = state_generics.clone();
230 it.params.insert(0, parse_quote!(#entry_lifetime));
231 it
232 };
233 let (state_impl_generics, state_type_generics, _) = state_generics.split_for_impl();
234 let (entry_impl_generics, entry_type_generics, where_clause) =
235 entry_generics.split_for_impl();
236
237 for (node, NodeData { doc, ty }, ref kind) in graph.nodes() {
238 state_variants.push(match ty {
239 Some(ty) => parse_quote!(#(#doc)* #node(#ty)),
240 None => parse_quote!(#(#doc)* #node),
241 });
242 match_ctor.push(match (ty, kind) {
243 (Some(_), Kind::Isolate | Kind::Sink(_)) => {
244 parse_quote!(#state_ident::#node(it) => #entry_ident::#node(it))
245 }
246 (None, Kind::Isolate | Kind::Sink(_)) => {
247 parse_quote!(#state_ident::#node => #entry_ident::#node)
248 }
249 (Some(_), Kind::NonTerminal { .. } | Kind::Source(_)) => {
250 parse_quote!(#state_ident::#node(_) => #entry_ident::#node(#node(value)))
251 }
252 (None, Kind::NonTerminal { .. } | Kind::Source(_)) => {
253 parse_quote!(#state_ident::#node => #entry_ident::#node(#node(value)))
254 }
255 });
256 let reachability = reachability_docs(&node.0, state_ident, kind);
257 entry_variants.push(match kind {
258 Kind::Isolate | Kind::Sink(_) => match ty {
259 Some(ty) => parse_quote!(#(#reachability)* #node(&#entry_lifetime mut #ty)),
260 None => parse_quote!(#(#reachability)* #node),
261 },
262 Kind::Source(_) | Kind::NonTerminal { .. } => {
263 parse_quote!(#(#reachability)* #node(#node #entry_type_generics))
264 }
265 });
266 if let Kind::Source(outgoing) | Kind::NonTerminal { outgoing, .. } = kind {
267 let outer_doc = format!(" See [`{entry_ident}::{node}`]");
268 let field_doc = format!(" MUST match [`{entry_ident}::{node}`]");
269 entry_structs.push(parse_quote! {
270 #[doc = #outer_doc]
271 #entry_vis struct #node #entry_type_generics(
272 #[doc = #field_doc]
273 & #entry_lifetime mut #state_ident #state_type_generics
274 )
275 #where_clause;
276 });
277 for (dst, NodeData { ty: dst_ty, .. }, EdgeData { method_name, doc }) in outgoing {
278 let body = make_body(
279 state_ident,
280 node,
281 ty.as_ref(),
282 dst,
283 dst_ty.as_ref(),
284 method_name,
285 &replace,
286 panik,
287 );
288 let pointer = DocAttr::new(
289 &format!(" Transition to [`{state_ident}::{}`]", dst.0),
290 Span::call_site(),
291 );
292 let pointer = match doc.is_empty() {
293 true => vec![pointer],
294 false => vec![DocAttr::empty(), pointer],
295 };
296 transition.push(parse_quote! {
297 #[allow(clippy::needless_lifetimes)]
298 impl #entry_impl_generics #node #entry_type_generics
299 #where_clause
300 {
301 #(#doc)*
302 #(#pointer)*
303 #body
304 }
305 });
306 }
307
308 if let Some(ty) = ty {
309 as_ref_as_mut.extend(make_as_ref_mut(
310 &entry_impl_generics,
311 path_to_core,
312 ty,
313 state_ident,
314 &node.0,
315 &entry_type_generics,
316 where_clause,
317 panik,
318 ));
319 }
320 }
321 }
322
323 let mut entry_attrs: Vec<Attribute> = vec![{
324 let doc = format!(" Progress through variants of [`{state_ident}`], created by its [`entry`]({state_ident}::entry) method.");
325 parse_quote!(#[doc = #doc])
326 }];
327
328 if *render_mermaid {
329 if let Some(rendered) = mermaid_renderer.render(&self.mermaid()) {
330 if !entry_attrs.is_empty() {
331 entry_attrs.push(parse_quote!(#[doc = ""]));
332 }
333 entry_attrs.push(parse_quote!(#[doc = #rendered]));
334 }
335 }
336
337 tokens.extend(quote! {
338 #(#state_attrs)*
339 #state_vis enum #state_ident #state_generics #where_clause {
340 #(#state_variants),*
341 }
342 #(#entry_attrs)*
343 #entry_vis enum #entry_ident #entry_generics #where_clause {
344 #(#entry_variants),*
345 }
346 impl #entry_impl_generics
347 #path_to_core::convert::From<& #entry_lifetime mut #state_ident #state_generics>
348 for #entry_ident #entry_type_generics
349 #where_clause {
350 fn from(value: & #entry_lifetime mut #state_ident #state_generics) -> Self {
351 match value {
352 #(#match_ctor),*
353 }
354 }
355 }
356 impl #state_impl_generics #state_ident #state_type_generics
357 #where_clause {
358 #[allow(clippy::needless_lifetimes)]
359 #entry_vis fn entry<#entry_lifetime>(& #entry_lifetime mut self) -> #entry_ident #entry_type_generics {
360 self.into()
361 }
362 }
363 #(#entry_structs)*
364 #(#as_ref_as_mut)*
365 #(#transition)*
366 });
367 }
368}
369
370impl Parse for FsmEntry {
371 fn parse(input: ParseStream) -> syn::Result<Self> {
372 let Root {
373 attrs: mut state_attrs,
374 vis: state_vis,
375 r#enum: _,
376 ident: state_ident,
377 generics: state_generics,
378 brace: _,
379 stmts,
380 } = input.parse()?;
381
382 let mut rename_methods = true;
383 let mut entry = VisIdent {
384 vis: state_vis.clone(),
385 ident: Ident::new(&format!("{}Entry", state_ident), Span::call_site()),
386 };
387 let mut r#unsafe = false;
388 let mut path_to_core: ModulePath = parse_quote!(::core);
389 let mut render_mermaid = false;
390 let mut parser = Parser::new()
391 .once("rename_methods", on_value(bool(&mut rename_methods)))
392 .once("entry", on_value(parse(&mut entry)))
393 .once("unsafe", on_value(bool(&mut r#unsafe)))
394 .once("path_to_core", on_value(parse(&mut path_to_core)))
395 .once("mermaid", on_value(bool(&mut render_mermaid)));
396 parser.extract("fsmentry", &mut state_attrs)?;
397 drop(parser);
398 let graph = stmts2graph(&stmts, rename_methods)?;
399 if graph.edges.is_empty() {
400 bail_at!(state_ident.span(), "must define at least one edge `A -> B`");
401 }
402 let VisIdent {
403 vis: entry_vis,
404 ident: entry_ident,
405 } = entry;
406
407 Ok(Self {
408 state_attrs,
409 state_vis,
410 state_ident,
411 state_generics,
412 r#unsafe,
413 path_to_core,
414 entry_vis,
415 entry_ident,
416 entry_lifetime: parse_quote!('state),
417 graph,
418 mermaid_renderer: (),
419 render_mermaid,
420 })
421 }
422}
423
424fn stmts2graph(
425 stmts: &Punctuated<Statement, Token![,]>,
426 rename_methods: bool,
427) -> syn::Result<Graph> {
428 use std::collections::btree_map::Entry::{Occupied, Vacant};
429
430 let mut nodes = BTreeMap::<NodeId, NodeData>::new();
431 let mut edges = BTreeMap::<(NodeId, NodeId), EdgeData>::new();
432
433 for Node { name, ty, doc } in stmts.iter().flat_map(|it| match it {
436 Statement::Node(it) => Box::new(iter::once(it)) as Box<dyn Iterator<Item = &Node>>,
437 Statement::Transition { first, rest, .. } => Box::new(
438 first
439 .into_iter()
440 .chain(rest.iter().flat_map(|(_, grp)| grp)),
441 ),
442 }) {
443 let ty = ty.as_ref().map(|(_, it)| it);
444 match nodes.entry(NodeId(name.clone())) {
445 Occupied(mut occ) => match (&occ.get().ty, ty) {
446 (None, Some(_)) | (Some(_), None) | (None, None) => {
447 append_docs(&mut occ.get_mut().doc, doc)
448 }
449 (Some(l), Some(r))
451 if l.to_token_stream().to_string() == r.to_token_stream().to_string() =>
452 {
453 append_docs(&mut occ.get_mut().doc, doc)
454 }
455 (Some(_), Some(_)) => bail_at!(name.span(), "incompatible redefinition"),
456 },
457 Vacant(v) => {
458 v.insert(NodeData {
459 ty: ty.cloned(),
460 doc: doc.clone(),
461 });
462 }
463 };
464 }
465
466 for stmt in stmts {
467 let Statement::Transition { first, rest } = stmt else {
468 continue; };
470
471 let mut grp_left = first;
472
473 for (Arrow { doc, kind }, grp_right) in rest {
474 for from in grp_left {
475 for to in grp_right {
476 match edges.entry((NodeId(from.name.clone()), NodeId(to.name.clone()))) {
477 Occupied(already) => {
478 let (a, b) = already.key();
479 bail_at!(kind.span(), "duplicate edge definition between {a} and {b}")
480 }
481 Vacant(v) => {
482 v.insert(EdgeData {
483 doc: doc.clone(),
484 method_name: match kind {
485 ArrowKind::Plain(_) => match rename_methods {
486 true => snake_case(&to.name),
487 false => to.name.clone(),
488 },
489 ArrowKind::Named { ident, .. } => ident.clone(),
490 },
491 });
492 }
493 }
494 }
495 }
496 grp_left = grp_right;
497 }
498 }
499
500 Ok(Graph { nodes, edges })
501}
502
503fn reachability_docs(node_ident: &Ident, state_ident: &Ident, kind: &Kind<'_>) -> Vec<DocAttr> {
504 let span = Span::call_site();
505 let mut dst = vec![DocAttr::new(
506 &format!(" Represents [`{state_ident}::{node_ident}`]"),
507 span,
508 )];
509 if let Kind::Sink(incoming) | Kind::NonTerminal { incoming, .. } = kind {
510 dst.extend([
511 DocAttr::empty(),
512 DocAttr::new(" This state is reachable from the following:", span),
513 ]);
514 dst.extend(incoming.iter().map(|(NodeId(other), _, EdgeData { method_name, .. })| {
515 let s = format!(" - [`{other}`]({state_ident}::{other}) via [`{method_name}`]({other}::{method_name})");
516 DocAttr::new(&s, Span::call_site())
517 }));
518 }
519 if let Kind::Source(outgoing) | Kind::NonTerminal { outgoing, .. } = kind {
520 dst.extend([
521 DocAttr::empty(),
522 DocAttr::new(" This state can transition to the following:", span),
523 ]);
524 dst.extend(outgoing.iter().map(|(NodeId(other), _, EdgeData { method_name, .. })| {
525 let s = format!(" - [`{other}`]({state_ident}::{other}) via [`{method_name}`]({node_ident}::{method_name})");
526 DocAttr::new(&s, Span::call_site())
527 }));
528 }
529 dst
530}
531
532fn append_docs(dst: &mut Vec<DocAttr>, src: &[DocAttr]) {
533 match (dst.is_empty(), src.is_empty()) {
534 (true, true) => {}
535 (true, false) => dst.extend_from_slice(src),
536 (false, true) => {}
537 (false, false) => {
538 dst.push(DocAttr::empty());
539 dst.extend_from_slice(src);
540 }
541 }
542}
543
544fn snake_case(ident: &Ident) -> Ident {
545 let ident = ident.to_string();
546 let mut snake = String::new();
547 for (i, ch) in ident.char_indices() {
548 if i > 0 && ch.is_uppercase() {
549 snake.push('_');
550 }
551 snake.push(ch.to_ascii_lowercase());
552 }
553 match (syn::parse_str(&snake), {
554 snake.insert_str(0, "r#");
555 syn::parse_str(&snake)
556 }) {
557 (Ok(it), _) | (_, Ok(it)) => it,
558 _ => panic!("bad ident {ident}"),
559 }
560}
561
562#[allow(clippy::too_many_arguments)]
563fn make_body(
564 state_ident: &Ident,
565 node: &NodeId,
566 ty: Option<&Type>,
567 dst: &NodeId,
568 dst_ty: Option<&Type>,
569 method_name: &Ident,
570 replace: &ModulePath,
571 panik: &Expr,
572) -> TokenStream {
573 match (ty, dst_ty) {
574 (None, None) => quote! {
575 pub fn #method_name(self) {
576 match #replace(self.0, #state_ident::#dst) {
577 #state_ident::#node => {},
578 _ => #panik,
579 }
580 }
581 },
582 (None, Some(dst_ty)) => quote! {
583 pub fn #method_name(self, next: #dst_ty) {
584 match #replace(self.0, #state_ident::#dst(next)) {
585 #state_ident::#node => {},
586 _ => #panik,
587 }
588 }
589 },
590 (Some(ty), None) => quote! {
591 pub fn #method_name(self) -> #ty {
592 match #replace(self.0, #state_ident::#dst) {
593 #state_ident::#node(it) => it,
594 _ => #panik,
595 }
596 }
597 },
598 (Some(ty), Some(dst_ty)) => quote! {
599 pub fn #method_name(self, next: #dst_ty) -> #ty {
600 match #replace(self.0, #state_ident::#dst(next)) {
601 #state_ident::#node(it) => it,
602 _ => #panik,
603 }
604 }
605 },
606 }
607}
608
609#[allow(clippy::too_many_arguments)]
610fn make_as_ref_mut(
611 entry_impl_generics: &ImplGenerics,
612 path_to_core: &ModulePath,
613 ty: &Type,
614 state_ident: &Ident,
615 node_ident: &Ident,
616 entry_type_generics: &TypeGenerics,
617 where_clause: Option<&WhereClause>,
618 panik: &Expr,
619) -> [ItemImpl; 2] {
620 let as_ref = parse_quote! {
621 #[allow(clippy::needless_lifetimes)]
622 impl #entry_impl_generics #path_to_core::convert::AsRef<#ty> for #node_ident #entry_type_generics
623 #where_clause
624 {
625 fn as_ref(&self) -> &#ty {
626 match &self.0 {
627 #state_ident::#node_ident(it) => it,
628 _ => #panik
629 }
630 }
631 }
632 };
633 let as_mut = parse_quote! {
634 #[allow(clippy::needless_lifetimes)]
635 impl #entry_impl_generics #path_to_core::convert::AsMut<#ty> for #node_ident #entry_type_generics
636 #where_clause
637 {
638 fn as_mut(&mut self) -> &mut #ty {
639 match &mut self.0 {
640 #state_ident::#node_ident(it) => it,
641 _ => #panik
642 }
643 }
644 }
645 };
646 [as_ref, as_mut]
647}