1pub mod ty;
2
3use std::sync::Arc;
4use std::{collections::VecDeque, ops::Deref};
5
6use faststr::FastStr;
7use fxhash::{FxHashMap, FxHashSet};
8use proc_macro2::TokenStream;
9use quote::format_ident;
10
11use crate::tags::Editable;
12use crate::{
13 context::Context,
14 resolver::rir::{Graph, Node},
15 symbol::{DefId, IdentName},
16 tags::Construct,
17};
18
19pub struct Codegen {
20 cx: Context,
21 nesteds: FxHashMap<DefId, FastStr>,
22 in_degrees: FxHashMap<DefId, u32>,
23 froms: FxHashMap<DefId, Vec<DefId>>,
24 tos: FxHashMap<DefId, Vec<DefId>>,
25 visited: FxHashSet<DefId>,
26}
27
28impl Deref for Codegen {
29 type Target = Context;
30
31 fn deref(&self) -> &Self::Target {
32 &self.cx
33 }
34}
35
36impl Codegen {
37 pub fn new(cx: Context) -> Self {
38 Self {
39 cx,
40 nesteds: FxHashMap::default(),
41 in_degrees: FxHashMap::default(),
42 froms: FxHashMap::default(),
43 tos: FxHashMap::default(),
44 visited: FxHashSet::default(),
45 }
46 }
47
48 pub fn write_document(&mut self, def_ids: Vec<DefId>) -> TokenStream {
49 let mut stream = TokenStream::new();
50 self.write_trait(&mut stream);
51 for def_id in def_ids {
52 self.write_graph(def_id, &mut stream);
53 }
54 stream
55 }
56
57 pub fn write_graph(&mut self, def_id: DefId, stream: &mut TokenStream) {
58 let graph = self.graph(def_id).unwrap();
59 let graph_name = self.upper_camel_name(&graph.name).as_syn_ident();
60
61 let entry_node = self.node(graph.entry_node).unwrap();
62 let entry_node_name = self.snake_name(&entry_node.name).as_syn_ident();
63 let entry_node_ty = self.upper_camel_name(&entry_node.name).as_syn_ident();
64
65 stream.extend(quote::quote! {
66 pub struct #graph_name {
67 pub #entry_node_name: ::std::sync::Arc<#entry_node_ty>,
68 }
69 impl #graph_name {
70 pub fn new() -> Self {
71 Self {
72 #entry_node_name: ::std::sync::Arc::new(#entry_node_ty::new()),
73 }
74 }
75 }
76 });
77
78 self.write_node(graph.entry_node, &entry_node, stream, &"self".into());
79
80 self.write_run(graph, stream);
81 }
82
83 pub fn write_node(
84 &mut self,
85 def_id: DefId,
86 node: &Arc<Node>,
87 stream: &mut TokenStream,
88 nested: &FastStr,
89 ) {
90 if self.visited.contains(&def_id) {
91 return;
92 }
93 self.visited.insert(def_id);
94
95 let name = self.upper_camel_name(&node.name).as_syn_ident();
96 let mut nodes = TokenStream::new();
97 let mut nodes_impl = TokenStream::new();
98 for did in &node.to_nodes {
99 let node = self.node(*did).unwrap();
100 let name = self.snake_name(&node.name).as_syn_ident();
101 let ty = self.upper_camel_name(&node.name).as_syn_ident();
102 self.in_degrees
103 .entry(*did)
104 .and_modify(|e| *e += 1)
105 .or_insert(1);
106 self.tos.entry(def_id).or_default().push(*did);
107 self.froms.entry(*did).or_default().push(def_id);
108 nodes.extend(quote::quote! {
109 pub #name: ::std::sync::Arc<#ty>,
110 });
111 nodes_impl.extend(quote::quote! {
112 #name: ::std::sync::Arc::new(#ty::new()),
113 });
114 }
115
116 let mut fields = TokenStream::new();
117 let mut fields_impl = TokenStream::new();
118 for f in &node.fields {
119 let name = self.snake_name(&f.name).as_syn_ident();
120 let ty = f.ty.to_codegen_ty();
121 fields.extend(quote::quote! {
122 pub #name: #ty,
123 });
124
125 let tags = self.tag(f.tag_id).unwrap();
126 if let Some(c) = tags.get::<Construct>() {
127 let ident: Vec<_> = c.0.split("::").map(|s| format_ident!("{}", s)).collect();
128 if let Some(Editable(true)) = tags.get::<Editable>() {
129 fields_impl.extend(quote::quote! {
130 #name: ::static_graph::ArcSwap::from_pointee(#(#ident)::*()),
131 });
132 } else {
133 fields_impl.extend(quote::quote! {
134 #name: #(#ident)::*(),
135 });
136 }
137 } else {
138 fields_impl.extend(quote::quote! {
139 #name: ::std::default::Default::default(),
140 });
141 };
142 }
143
144 stream.extend(quote::quote! {
145 pub struct #name {
146 #nodes
147 #fields
148 }
149 impl #name {
150 pub fn new() -> Self {
151 Self {
152 #nodes_impl
153 #fields_impl
154 }
155 }
156 }
157 });
158
159 let nested: FastStr = format!("{}.{}", nested, self.snake_name(&node.name)).into();
160 self.nesteds.insert(def_id, nested.clone());
161 for did in &node.to_nodes {
162 self.write_node(*did, &self.node(*did).unwrap(), stream, &nested);
163 }
164 }
165
166 #[inline]
167 fn write_trait(&mut self, stream: &mut TokenStream) {
168 stream.extend(quote::quote! {
169 #[static_graph::async_trait]
170 pub trait Runnable<Req, PrevResp> {
171 type Resp;
172 type Error;
173 async fn run(&self, req: Req, prev_resp: PrevResp) -> ::std::result::Result<Self::Resp, Self::Error>;
174 }
175 });
176 }
177
178 fn write_run(&mut self, graph: Arc<Graph>, stream: &mut TokenStream) {
179 let name = self.upper_camel_name(&graph.name).as_syn_ident();
180 let mut queue = VecDeque::new();
181
182 assert!(self.in_degrees.get(&graph.entry_node).is_none());
183
184 queue.push_back(graph.entry_node);
185 let mut bounds = TokenStream::new();
186 let mut bodys = TokenStream::new();
187 let mut generics = Vec::new();
188 let mut out_resp = None;
189 while !queue.is_empty() {
190 let sz = queue.len();
191 for _ in 0..sz {
192 let mut channels = TokenStream::new();
193
194 let did = queue.pop_front().unwrap();
195 let node = self.node(did).unwrap();
196 let name = self.snake_name(&node.name).as_syn_ident();
197 let upper_name = self.upper_camel_name(&node.name).as_syn_ident();
198
199 let mut upper_prev_resps = Vec::new();
200 let mut resps = Vec::new();
201 if let Some(from_dids) = self.froms.get(&did) {
202 let mut rxs = Vec::with_capacity(from_dids.len());
203 let mut matches = Vec::with_capacity(from_dids.len());
204
205 for from_did in from_dids {
206 let node = self.node(*from_did).unwrap();
207
208 let f_name = self.snake_name(&node.name).as_syn_ident();
209 let upper_f_name = self.upper_camel_name(&node.name).as_syn_ident();
210 let upper_prev_resp = format_ident!("{}Resp", upper_f_name);
211
212 let resp = format_ident!("{}_resp", f_name);
213
214 resps.push(resp.clone());
215 rxs.push(format_ident!("{}_rx_{}", name, f_name));
216 matches.push(quote::quote! {
217 Ok(Ok(#resp))
218 });
219
220 upper_prev_resps.push(upper_prev_resp);
221 }
222
223 if !resps.is_empty() {
224 channels.extend(quote::quote! {
225 let (#(#resps),*) = match static_graph::join!(#(#rxs.recv()),*) {
226 (#(#matches,)*) => (#(#resps),*),
227 _ => panic!("Error"),
228 };
229 });
230 }
231 };
232
233 let upper_resp = format_ident!("{}Resp", upper_name);
234 generics.push(upper_resp.clone());
235 bounds.extend(quote::quote! {
236 #upper_name: Runnable<Req, (#(#upper_prev_resps),*), Resp = #upper_resp, Error = Error>,
237 #upper_resp: Clone + Send + Sync + 'static,
238 });
239
240 let req = format_ident!("{}_req", name);
241 let tx = format_ident!("{}_tx", name);
242 let node: Vec<_> = self
243 .nesteds
244 .get(&did)
245 .unwrap()
246 .split('.')
247 .map(|s| format_ident!("{}", s))
248 .collect();
249
250 if let Some(to_dids) = self.tos.get(&did) {
251 let mut rxs = Vec::with_capacity(to_dids.len());
252 let len = to_dids.len() + 1;
253 for to_did in to_dids {
254 if let Some(in_degree) = self.in_degrees.get_mut(to_did) {
255 *in_degree -= 1;
256 if *in_degree == 0 {
257 self.in_degrees.remove(to_did);
258 queue.push_back(*to_did);
259 }
260 }
261 let node = self.node(*to_did).unwrap();
262 let to_name = self.snake_name(&node.name).as_syn_ident();
263 rxs.push(format_ident!("{}_rx_{}", to_name, name));
264 }
265 bodys.extend(quote::quote! {
266 let #req = req.clone();
267 let #name = #(#node.)*clone();
268 let (#tx, _) = static_graph::sync::broadcast::channel(#len);
269 #(let mut #rxs = #tx.subscribe();)*
270 static_graph::spawn(async move {
271 #channels
272 let resp = #name.run(#req, (#(#resps),*)).await;
273 #tx.send(resp).ok();
274 });
275 });
276 } else {
277 assert!(out_resp.is_none());
278
279 out_resp.replace(upper_resp);
280 bodys.extend(quote::quote! {
281 #channels
282 #(#node).*.run(req, (#(#resps),*)).await
283 });
284 }
285 }
286 }
287
288 assert!(self.in_degrees.is_empty());
289
290 let out_resp = out_resp.unwrap();
291 stream.extend(quote::quote! {
292 impl #name {
293 pub async fn run<Req, #(#generics),*, Error>(&self, req: Req) -> ::std::result::Result<#out_resp, Error>
294 where
295 Req: Clone + Send + Sync + 'static,
296 Error: Clone + Send + Sync + 'static,
297 #bounds
298 {
299 #bodys
300 }
301 }
302 });
303 }
304}