static_graph/codegen/
mod.rs

1pub mod ty;
2
3use std::sync::Arc;
4use std::{collections::VecDeque, ops::Deref};
5
6use faststr::FastStr;
7use fxhash::{FxHashMap, FxHashSet};
8use proc_macro2::TokenStream;
9use quote::format_ident;
10
11use crate::tags::Editable;
12use crate::{
13    context::Context,
14    resolver::rir::{Graph, Node},
15    symbol::{DefId, IdentName},
16    tags::Construct,
17};
18
19pub struct Codegen {
20    cx: Context,
21    nesteds: FxHashMap<DefId, FastStr>,
22    in_degrees: FxHashMap<DefId, u32>,
23    froms: FxHashMap<DefId, Vec<DefId>>,
24    tos: FxHashMap<DefId, Vec<DefId>>,
25    visited: FxHashSet<DefId>,
26}
27
28impl Deref for Codegen {
29    type Target = Context;
30
31    fn deref(&self) -> &Self::Target {
32        &self.cx
33    }
34}
35
36impl Codegen {
37    pub fn new(cx: Context) -> Self {
38        Self {
39            cx,
40            nesteds: FxHashMap::default(),
41            in_degrees: FxHashMap::default(),
42            froms: FxHashMap::default(),
43            tos: FxHashMap::default(),
44            visited: FxHashSet::default(),
45        }
46    }
47
48    pub fn write_document(&mut self, def_ids: Vec<DefId>) -> TokenStream {
49        let mut stream = TokenStream::new();
50        self.write_trait(&mut stream);
51        for def_id in def_ids {
52            self.write_graph(def_id, &mut stream);
53        }
54        stream
55    }
56
57    pub fn write_graph(&mut self, def_id: DefId, stream: &mut TokenStream) {
58        let graph = self.graph(def_id).unwrap();
59        let graph_name = self.upper_camel_name(&graph.name).as_syn_ident();
60
61        let entry_node = self.node(graph.entry_node).unwrap();
62        let entry_node_name = self.snake_name(&entry_node.name).as_syn_ident();
63        let entry_node_ty = self.upper_camel_name(&entry_node.name).as_syn_ident();
64
65        stream.extend(quote::quote! {
66            pub struct #graph_name {
67                pub #entry_node_name: ::std::sync::Arc<#entry_node_ty>,
68            }
69            impl #graph_name {
70                pub fn new() -> Self {
71                    Self {
72                        #entry_node_name: ::std::sync::Arc::new(#entry_node_ty::new()),
73                    }
74                }
75            }
76        });
77
78        self.write_node(graph.entry_node, &entry_node, stream, &"self".into());
79
80        self.write_run(graph, stream);
81    }
82
83    pub fn write_node(
84        &mut self,
85        def_id: DefId,
86        node: &Arc<Node>,
87        stream: &mut TokenStream,
88        nested: &FastStr,
89    ) {
90        if self.visited.contains(&def_id) {
91            return;
92        }
93        self.visited.insert(def_id);
94
95        let name = self.upper_camel_name(&node.name).as_syn_ident();
96        let mut nodes = TokenStream::new();
97        let mut nodes_impl = TokenStream::new();
98        for did in &node.to_nodes {
99            let node = self.node(*did).unwrap();
100            let name = self.snake_name(&node.name).as_syn_ident();
101            let ty = self.upper_camel_name(&node.name).as_syn_ident();
102            self.in_degrees
103                .entry(*did)
104                .and_modify(|e| *e += 1)
105                .or_insert(1);
106            self.tos.entry(def_id).or_default().push(*did);
107            self.froms.entry(*did).or_default().push(def_id);
108            nodes.extend(quote::quote! {
109                pub #name: ::std::sync::Arc<#ty>,
110            });
111            nodes_impl.extend(quote::quote! {
112                #name: ::std::sync::Arc::new(#ty::new()),
113            });
114        }
115
116        let mut fields = TokenStream::new();
117        let mut fields_impl = TokenStream::new();
118        for f in &node.fields {
119            let name = self.snake_name(&f.name).as_syn_ident();
120            let ty = f.ty.to_codegen_ty();
121            fields.extend(quote::quote! {
122                pub #name: #ty,
123            });
124
125            let tags = self.tag(f.tag_id).unwrap();
126            if let Some(c) = tags.get::<Construct>() {
127                let ident: Vec<_> = c.0.split("::").map(|s| format_ident!("{}", s)).collect();
128                if let Some(Editable(true)) = tags.get::<Editable>() {
129                    fields_impl.extend(quote::quote! {
130                        #name: ::static_graph::ArcSwap::from_pointee(#(#ident)::*()),
131                    });
132                } else {
133                    fields_impl.extend(quote::quote! {
134                        #name: #(#ident)::*(),
135                    });
136                }
137            } else {
138                fields_impl.extend(quote::quote! {
139                    #name: ::std::default::Default::default(),
140                });
141            };
142        }
143
144        stream.extend(quote::quote! {
145            pub struct #name {
146                #nodes
147                #fields
148            }
149            impl #name {
150                pub fn new() -> Self {
151                    Self {
152                        #nodes_impl
153                        #fields_impl
154                    }
155                }
156            }
157        });
158
159        let nested: FastStr = format!("{}.{}", nested, self.snake_name(&node.name)).into();
160        self.nesteds.insert(def_id, nested.clone());
161        for did in &node.to_nodes {
162            self.write_node(*did, &self.node(*did).unwrap(), stream, &nested);
163        }
164    }
165
166    #[inline]
167    fn write_trait(&mut self, stream: &mut TokenStream) {
168        stream.extend(quote::quote! {
169            #[static_graph::async_trait]
170            pub trait Runnable<Req, PrevResp> {
171                type Resp;
172                type Error;
173                async fn run(&self, req: Req, prev_resp: PrevResp) -> ::std::result::Result<Self::Resp, Self::Error>;
174            }
175        });
176    }
177
178    fn write_run(&mut self, graph: Arc<Graph>, stream: &mut TokenStream) {
179        let name = self.upper_camel_name(&graph.name).as_syn_ident();
180        let mut queue = VecDeque::new();
181
182        assert!(self.in_degrees.get(&graph.entry_node).is_none());
183
184        queue.push_back(graph.entry_node);
185        let mut bounds = TokenStream::new();
186        let mut bodys = TokenStream::new();
187        let mut generics = Vec::new();
188        let mut out_resp = None;
189        while !queue.is_empty() {
190            let sz = queue.len();
191            for _ in 0..sz {
192                let mut channels = TokenStream::new();
193
194                let did = queue.pop_front().unwrap();
195                let node = self.node(did).unwrap();
196                let name = self.snake_name(&node.name).as_syn_ident();
197                let upper_name = self.upper_camel_name(&node.name).as_syn_ident();
198
199                let mut upper_prev_resps = Vec::new();
200                let mut resps = Vec::new();
201                if let Some(from_dids) = self.froms.get(&did) {
202                    let mut rxs = Vec::with_capacity(from_dids.len());
203                    let mut matches = Vec::with_capacity(from_dids.len());
204
205                    for from_did in from_dids {
206                        let node = self.node(*from_did).unwrap();
207
208                        let f_name = self.snake_name(&node.name).as_syn_ident();
209                        let upper_f_name = self.upper_camel_name(&node.name).as_syn_ident();
210                        let upper_prev_resp = format_ident!("{}Resp", upper_f_name);
211
212                        let resp = format_ident!("{}_resp", f_name);
213
214                        resps.push(resp.clone());
215                        rxs.push(format_ident!("{}_rx_{}", name, f_name));
216                        matches.push(quote::quote! {
217                            Ok(Ok(#resp))
218                        });
219
220                        upper_prev_resps.push(upper_prev_resp);
221                    }
222
223                    if !resps.is_empty() {
224                        channels.extend(quote::quote! {
225                            let (#(#resps),*) = match static_graph::join!(#(#rxs.recv()),*) {
226                                (#(#matches,)*) => (#(#resps),*),
227                                _ => panic!("Error"),
228                            };
229                        });
230                    }
231                };
232
233                let upper_resp = format_ident!("{}Resp", upper_name);
234                generics.push(upper_resp.clone());
235                bounds.extend(quote::quote! {
236                    #upper_name: Runnable<Req, (#(#upper_prev_resps),*), Resp = #upper_resp, Error = Error>,
237                    #upper_resp: Clone + Send + Sync + 'static,
238                });
239
240                let req = format_ident!("{}_req", name);
241                let tx = format_ident!("{}_tx", name);
242                let node: Vec<_> = self
243                    .nesteds
244                    .get(&did)
245                    .unwrap()
246                    .split('.')
247                    .map(|s| format_ident!("{}", s))
248                    .collect();
249
250                if let Some(to_dids) = self.tos.get(&did) {
251                    let mut rxs = Vec::with_capacity(to_dids.len());
252                    let len = to_dids.len() + 1;
253                    for to_did in to_dids {
254                        if let Some(in_degree) = self.in_degrees.get_mut(to_did) {
255                            *in_degree -= 1;
256                            if *in_degree == 0 {
257                                self.in_degrees.remove(to_did);
258                                queue.push_back(*to_did);
259                            }
260                        }
261                        let node = self.node(*to_did).unwrap();
262                        let to_name = self.snake_name(&node.name).as_syn_ident();
263                        rxs.push(format_ident!("{}_rx_{}", to_name, name));
264                    }
265                    bodys.extend(quote::quote! {
266                        let #req = req.clone();
267                        let #name = #(#node.)*clone();
268                        let (#tx, _) = static_graph::sync::broadcast::channel(#len);
269                        #(let mut #rxs = #tx.subscribe();)*
270                        static_graph::spawn(async move {
271                            #channels
272                            let resp = #name.run(#req, (#(#resps),*)).await;
273                            #tx.send(resp).ok();
274                        });
275                    });
276                } else {
277                    assert!(out_resp.is_none());
278
279                    out_resp.replace(upper_resp);
280                    bodys.extend(quote::quote! {
281                        #channels
282                        #(#node).*.run(req, (#(#resps),*)).await
283                    });
284                }
285            }
286        }
287
288        assert!(self.in_degrees.is_empty());
289
290        let out_resp = out_resp.unwrap();
291        stream.extend(quote::quote! {
292            impl #name {
293                pub async fn run<Req, #(#generics),*, Error>(&self, req: Req) -> ::std::result::Result<#out_resp, Error>
294                where
295                    Req: Clone + Send + Sync + 'static,
296                    Error: Clone + Send + Sync + 'static,
297                    #bounds
298                {
299                    #bodys
300                }
301            }
302        });
303    }
304}