pilota_build2/codegen/
mod.rs

1use std::{
2    io::Write,
3    ops::Deref,
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7use std::collections::HashMap;
8
9use dashmap::DashMap;
10use faststr::FastStr;
11use itertools::Itertools;
12use quote::quote;
13use rayon::prelude::IntoParallelRefIterator;
14
15use pkg_tree::PkgNode;
16use traits::CodegenBackend;
17
18use crate::{
19    Context,
20    db::RirDatabase,
21    fmt::fmt_file,
22    middle::{
23        self,
24        adjust::Adjust,
25        context::{Mode, tls::CUR_ITEM},
26        rir::{self, Field},
27        ty::TyKind,
28    },
29    symbol::{DefId, EnumRepr}, Symbol,
30    tags::EnumMode,
31};
32
33use self::workspace::Workspace;
34
35pub(crate) mod pkg_tree;
36pub mod toml;
37pub(crate) mod traits;
38
39mod workspace;
40
41pub mod protobuf;
42pub mod thrift;
43
44#[derive(Clone)]
45pub struct Codegen<B> {
46    backend: B,
47}
48
49impl<B> Deref for Codegen<B>
50    where
51        B: CodegenBackend,
52{
53    type Target = Context;
54
55    fn deref(&self) -> &Self::Target {
56        self.backend.cx()
57    }
58}
59
60impl<B> Codegen<B> {
61    pub fn new(backend: B) -> Self {
62        Codegen { backend }
63    }
64}
65
66#[derive(Clone, Copy)]
67pub enum CodegenKind {
68    Direct,
69    RePub,
70}
71
72#[derive(Clone, Copy)]
73pub struct CodegenItem {
74    def_id: DefId,
75    kind: CodegenKind,
76}
77
78impl From<DefId> for CodegenItem {
79    fn from(value: DefId) -> Self {
80        CodegenItem {
81            def_id: value,
82            kind: CodegenKind::Direct,
83        }
84    }
85}
86
87pub fn is_raw_ptr_field(f: &Arc<Field>, adjust: Option<&Adjust>) -> bool {
88    f.is_optional() || adjust.map_or(false, |adjust| adjust.boxed())
89}
90
91
92impl<B> Codegen<B>
93    where
94        B: CodegenBackend + Send,
95{
96    fn check_scalar_ty(&self, def_id: DefId, record: &mut HashMap<DefId, bool>) -> bool {
97        if let Some(x) = record.get(&def_id) {
98            return x.clone();
99        }
100        let item = self.item(def_id).unwrap();
101        match &*item {
102            middle::rir::Item::Enum(e) => { false }
103            middle::rir::Item::NewType(t) => { false }
104            middle::rir::Item::Const(c) => { false }
105            middle::rir::Item::Mod(m) => { false }
106            middle::rir::Item::Message(s) => {
107                for field in &s.fields {
108                    if field.is_optional() || self.with_adjust(field.did, |adjust| {
109                        adjust.map_or(false, |adjust| adjust.boxed())
110                    }) {
111                        field.ty.in_stack.write().unwrap().insert(false);
112                        s.all_in_stack.write().unwrap().insert(false);
113                        continue;
114                    }
115                    if field.ty.setted_in_stack_field() {
116                        let is_in_stack = field.ty.is_in_stack();
117                        if !is_in_stack {
118                            s.all_in_stack.write().unwrap().insert(false);
119                        }
120                        continue;
121                    }
122                    match &field.ty.kind {
123                        TyKind::Path(path) => {
124                            let x = self.check_scalar_ty(path.did, record);
125                            field.ty.in_stack.write().unwrap().get_or_insert(x);
126                        }
127                        _ => {
128                            panic!("should not execute here")
129                        }
130                    }
131                }
132                let mut all_in_stack = s.all_in_stack.write().unwrap();
133                if all_in_stack.is_none() {
134                    all_in_stack.insert(true);
135                }
136                record.insert(def_id, all_in_stack.clone().unwrap());
137                return all_in_stack.unwrap();
138            }
139            middle::rir::Item::Service(s) => {
140                let mut is_scalar = true;
141                for method in &s.methods {
142                    for arg in &method.args {
143                        if arg.ty.setted_in_stack_field() {
144                            is_scalar = arg.ty.is_in_stack();
145                        } else {
146                            match &arg.ty.kind {
147                                TyKind::Path(path) => {
148                                    let x = self.check_scalar_ty(path.did, record);
149                                    arg.ty.in_stack.write().unwrap().get_or_insert(x);
150                                    if !x {
151                                        is_scalar = false;
152                                    }
153                                }
154                                _ => {
155                                    panic!("should not execute here")
156                                }
157                            }
158                        }
159                    }
160                    if method.ret.setted_in_stack_field() {
161                        is_scalar = method.ret.is_in_stack();
162                    } else {
163                        match &method.ret.kind {
164                            TyKind::Path(path) => {
165                                let x = self.check_scalar_ty(path.did, record);
166                                method.ret.in_stack.write().unwrap().get_or_insert(x);
167                                if !x {
168                                    is_scalar = false;
169                                }
170                            }
171                            _ => {
172                                panic!("should not execute here")
173                            }
174                        }
175                    }
176                }
177                record.insert(def_id, is_scalar);
178                is_scalar
179            }
180        }
181    }
182    pub fn write_struct(&self, def_id: DefId, stream: &mut String, s: &rir::Message) {
183        let name = self.rust_name(def_id);
184        let fields = s
185            .fields
186            .iter()
187            .map(|f| {
188                let name = self.rust_name(f.did);
189                self.with_adjust(f.did, |adjust| {
190                    let ty = self.codegen_item_ty(f.ty.kind.clone());
191                    let mut ty = format!("{ty}");
192
193                    if let Some(adjust) = adjust {
194                        if adjust.boxed() {
195                            ty = format!("::std::boxed::Box<{ty}>")
196                        }
197                    }
198
199                    if f.is_optional() {
200                        ty = format!("::std::option::Option<{ty}>")
201                    }
202
203                    let attrs = adjust.iter().flat_map(|a| a.attrs()).join("");
204
205                    format! {
206                        r#"{attrs}
207                        pub {name}: {ty},"#
208                    }
209                })
210            })
211            .join("\n");
212
213        let repr_c_attr = if s.is_all_in_stack() { "#[repr(C)]" } else { "" };
214
215        stream.push_str(&format! {
216            r#"{repr_c_attr}
217            #[derive(Clone, PartialEq)]
218            pub struct {name} {{
219                {fields}
220            }}"#
221        });
222
223        self.backend.codegen_struct_impl(def_id, stream, s);
224    }
225
226    pub fn write_item(&self, stream: &mut String, item: CodegenItem) {
227        CUR_ITEM.set(&item.def_id, || match item.kind {
228            CodegenKind::Direct => {
229                let def_id = item.def_id;
230                let item = self.item(def_id).unwrap();
231                tracing::trace!("write item {}", item.symbol_name());
232                self.with_adjust(def_id, |adjust| {
233                    let attrs = adjust.iter().flat_map(|a| a.attrs()).join("\n");
234
235                    let impls = adjust
236                        .iter()
237                        .flat_map(|a| &a.nested_items)
238                        .sorted()
239                        .join("\n");
240                    stream.push_str(&impls);
241                    stream.push_str(&attrs);
242                });
243
244                match &*item {
245                    middle::rir::Item::Message(s) => self.write_struct(def_id, stream, s),
246                    middle::rir::Item::Enum(e) => self.write_enum(def_id, stream, e),
247                    middle::rir::Item::Service(s) => self.write_service(def_id, stream, s),
248                    middle::rir::Item::NewType(t) => self.write_new_type(def_id, stream, t),
249                    middle::rir::Item::Const(c) => self.write_const(def_id, stream, c),
250                    middle::rir::Item::Mod(m) => {
251                        let mut inner = Default::default();
252                        m.items
253                            .iter()
254                            .for_each(|def_id| self.write_item(&mut inner, (*def_id).into()));
255
256                        stream.push_str(&inner);
257                        // let name = self.rust_name(def_id);
258                        // stream.push_str(&format! {
259                        //     r#"pub mod {name} {{
260                        //     {inner}
261                        // }}"#
262                        // })
263                    }
264                };
265            }
266            CodegenKind::RePub => {
267                let path = self.item_path(item.def_id).join("::");
268                stream.push_str(format!("pub use ::{};", path).as_str());
269            }
270        })
271    }
272
273    pub fn write_enum_as_new_type(
274        &self,
275        def_id: DefId,
276        stream: &mut String,
277        e: &middle::rir::Enum,
278    ) {
279        let name = self.rust_name(def_id);
280
281        let repr = match e.repr {
282            Some(EnumRepr::I32) => quote!(i32),
283            _ => panic!(),
284        };
285
286        let variants = e
287            .variants
288            .iter()
289            .map(|v| {
290                let name = self.rust_name(v.did);
291
292                let discr = v.discr.unwrap();
293                let discr = match e.repr {
294                    Some(EnumRepr::I32) => discr as i32,
295                    None => panic!(),
296                };
297                format!("pub const {name}: Self = Self({discr});")
298            })
299            .join("");
300
301        stream.push_str(&format! {
302            r#"#[derive(Clone, PartialEq, Copy)]
303            #[repr(transparent)]
304            pub struct {name}({repr});
305
306            impl {name} {{
307                {variants}
308
309                pub fn inner(&self) -> {repr} {{
310                    self.0
311                }}
312            }}
313
314            impl ::std::convert::From<{repr}> for {name} {{
315                fn from(value: {repr}) -> Self {{
316                    Self(value)
317                }}
318            }}"#
319        });
320
321        self.backend.codegen_enum_impl(def_id, stream, e);
322    }
323
324    pub fn write_enum(&self, def_id: DefId, stream: &mut String, e: &middle::rir::Enum) {
325        if self
326            .node_tags(def_id)
327            .unwrap()
328            .get::<EnumMode>()
329            .filter(|s| **s == EnumMode::NewType)
330            .is_some()
331        {
332            return self.write_enum_as_new_type(def_id, stream, e);
333        }
334        let name = self.rust_name(def_id);
335
336        let mut repr = if e.variants.is_empty() {
337            quote! {}
338        } else {
339            match e.repr {
340                Some(EnumRepr::I32) => quote! {
341                   #[repr(i32)]
342                },
343                None => quote! {},
344            }
345        };
346
347        if e.repr.is_some() {
348            repr.extend(quote! { #[derive(Copy)] })
349        }
350
351        let variants = e
352            .variants
353            .iter()
354            .map(|v| {
355                let name = self.rust_name(v.did);
356
357                self.with_adjust(v.did, |adjust| {
358                    let attrs = adjust.iter().flat_map(|a| a.attrs()).join("\n");
359
360                    let fields = v
361                        .fields
362                        .iter()
363                        .map(|ty| self.codegen_item_ty(ty.kind.clone()).to_string())
364                        .join(",");
365
366                    let fields_stream = if fields.is_empty() {
367                        Default::default()
368                    } else {
369                        format!("({fields})")
370                    };
371
372                    let discr = v
373                        .discr
374                        .map(|x| {
375                            let x = isize::try_from(x).unwrap();
376                            let x = match e.repr {
377                                Some(EnumRepr::I32) => x as i32,
378                                None => panic!(),
379                            };
380                            format!("={x}")
381                        })
382                        .unwrap_or_default();
383
384                    format!(
385                        r#"{attrs}
386                        {name} {fields_stream} {discr},"#
387                    )
388                })
389            })
390            .join("\n");
391
392        stream.push_str(&format! {
393            r#"
394            #[derive(Clone, PartialEq)]
395            {repr}
396            pub enum {name} {{
397                {variants}
398            }}
399            "#
400        });
401
402        self.backend.codegen_enum_impl(def_id, stream, e);
403    }
404
405    pub fn write_service(&self, def_id: DefId, stream: &mut String, s: &middle::rir::Service) {
406        let name = self.rust_name(def_id);
407        let methods = self.service_methods(def_id);
408
409        let methods = methods
410            .iter()
411            .map(|m| self.backend.codegen_service_method(def_id, m))
412            .filter(|code| !code.is_empty())
413            .join("\n");
414        if !methods.is_empty() {
415            stream.push_str(&format! {r#"
416            pub trait {name} {{
417                {methods}
418            }}
419            "#});
420        }
421        self.backend.codegen_service_impl(def_id, stream, s);
422    }
423
424    pub fn write_new_type(&self, def_id: DefId, stream: &mut String, t: &middle::rir::NewType) {
425        let name = self.rust_name(def_id);
426        let ty = self.codegen_item_ty(t.ty.kind.clone());
427        stream.push_str(&format! {
428            r#"
429            #[derive(Clone, PartialEq)]
430            pub struct {name}(pub {ty});
431
432            impl ::std::ops::Deref for {name} {{
433                type Target = {ty};
434
435                fn deref(&self) -> &Self::Target {{
436                    &self.0
437                }}
438            }}
439
440            impl From<{ty}> for {name} {{
441                fn from(v: {ty}) -> Self {{
442                    Self(v)
443                }}
444            }}
445            "#
446        });
447        self.backend.codegen_newtype_impl(def_id, stream, t);
448    }
449
450    pub fn write_const(&self, did: DefId, stream: &mut String, c: &middle::rir::Const) {
451        let mut ty = self.codegen_ty(did);
452
453        let name = self.rust_name(did);
454
455        stream.push_str(&self.def_lit(&name, &c.lit, &mut ty).unwrap())
456    }
457
458    pub fn write_workspace(self, base_dir: PathBuf) -> anyhow::Result<()> {
459        let ws = Workspace::new(base_dir, self);
460        ws.write_crates()
461    }
462
463    pub fn write_items<'a>(&self, stream: &mut String, items: impl Iterator<Item=CodegenItem>)
464        where
465            B: Send,
466    {
467        use rayon::iter::ParallelIterator;
468
469        let mods = items.into_group_map_by(|CodegenItem { def_id, .. }| {
470            let path = Arc::from_iter(self.mod_path(*def_id).iter().map(|s| s.0.clone()));
471            tracing::debug!("ths path of {:?} is {:?}", def_id, path);
472            match &*self.mode {
473                Mode::Workspace(_) => Arc::from(&path[1..]), /* the first element for
474                                                                * workspace */
475                // path is crate name
476                Mode::SingleFile { .. } => path,
477            }
478        });
479
480        let mut pkgs: DashMap<Arc<[FastStr]>, String> = Default::default();
481
482        let this = self.clone();
483
484        mods.par_iter().for_each_with(this, |this, (p, def_ids)| {
485            let mut stream = pkgs.entry(p.clone()).or_default();
486
487            let span = tracing::span!(tracing::Level::TRACE, "write_mod", path = ?p);
488
489            let _enter = span.enter();
490            def_ids.iter().for_each(|def_id| {
491                match def_id.kind {
492                    CodegenKind::Direct => {
493                        this.check_scalar_ty(def_id.def_id, &mut HashMap::new());
494                    }
495                    _ => {}
496                }
497            });
498            for def_id in def_ids.iter() {
499                this.write_item(&mut stream, *def_id)
500            }
501        });
502
503        fn write_stream(
504            pkgs: &mut DashMap<Arc<[FastStr]>, String>,
505            stream: &mut String,
506            nodes: &[PkgNode],
507        ) {
508            for node in nodes {
509                let mut inner_stream = String::default();
510                if let Some((_, node_stream)) = pkgs.remove(&node.path) {
511                    inner_stream.push_str(&node_stream);
512                }
513
514                write_stream(pkgs, &mut inner_stream, &node.children);
515                let name = node.ident();
516                if name.clone().unwrap_or_default() == "" {
517                    stream.push_str(&inner_stream);
518                    return;
519                }
520
521                stream.push_str(&inner_stream);
522                // let name = Symbol::from(name.unwrap());
523                // stream.push_str(&format! {
524                //     r#"
525                //     pub mod {name} {{
526                //         {inner_stream}
527                //     }}
528                //     "#
529                // });
530            }
531        }
532
533        let keys = pkgs.iter().map(|kv| kv.key().clone()).collect_vec();
534        let pkg_node = PkgNode::from_pkgs(&keys.iter().map(|s| &**s).collect_vec());
535        tracing::debug!(?pkg_node);
536
537        write_stream(&mut pkgs, stream, &pkg_node);
538    }
539
540    pub fn write_file(self, ns_name: Symbol, file_name: impl AsRef<Path>) {
541        let mut stream = String::default();
542        self.write_items(
543            &mut stream,
544            self.codegen_items.iter().map(|def_id| (*def_id).into()),
545        );
546
547        let doc = self.doc_header.as_str();
548        stream = format! {r#"{doc}
549        #![allow(warnings, clippy::all)]
550                {stream}
551        "#};
552
553        // stream = format! {r#"pub mod {ns_name} {{
554        //         #![allow(warnings, clippy::all)]
555        //         {stream}
556        //     }}"#};
557
558        let mut file = std::io::BufWriter::new(std::fs::File::create(&file_name).unwrap());
559        file.write_all(stream.to_string().as_bytes()).unwrap();
560        file.flush().unwrap();
561        fmt_file(file_name)
562    }
563
564    pub fn gen(self) -> anyhow::Result<()> {
565        match &*self.mode.clone() {
566            Mode::Workspace(info) => self.write_workspace(info.dir.clone()),
567            Mode::SingleFile { file_path: p } => {
568                self.write_file(
569                    FastStr::new(
570                        p.file_name()
571                            .and_then(|s| s.to_str())
572                            .and_then(|s| s.split('.').next())
573                            .unwrap(),
574                    )
575                        .into(),
576                    p,
577                );
578                Ok(())
579            }
580        }
581    }
582}