pilota_build/
lib.rs

1#![doc(
2    html_logo_url = "https://github.com/cloudwego/pilota/raw/main/.github/assets/logo.png?sanitize=true"
3)]
4#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
5#![allow(clippy::mutable_key_type)]
6
7mod util;
8
9pub mod codegen;
10pub mod db;
11pub(crate) mod errors;
12pub mod fmt;
13mod index;
14mod ir;
15pub mod middle;
16pub mod parser;
17mod resolve;
18mod symbol;
19
20use faststr::FastStr;
21pub use symbol::Symbol;
22pub mod tags;
23use std::{path::PathBuf, sync::Arc};
24
25mod dedup;
26pub mod plugin;
27
28pub use codegen::{
29    Codegen, protobuf::ProtobufBackend, thrift::ThriftBackend, traits::CodegenBackend,
30};
31use db::{RirDatabase, RootDatabase};
32use middle::{
33    context::{CollectMode, ContextBuilder, Mode, WorkspaceInfo, tls::CONTEXT},
34    rir::NodeKind,
35    type_graph::TypeGraph,
36    workspace_graph::WorkspaceGraph,
37};
38pub use middle::{
39    context::{Context, SourceType},
40    rir, ty,
41};
42use parser::{ParseResult, Parser, protobuf::ProtobufParser, thrift::ThriftParser};
43use plugin::{AutoDerivePlugin, BoxedPlugin, ImplDefaultPlugin, PredicateResult, WithAttrsPlugin};
44pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
45use resolve::{ResolveResult, Resolver};
46pub use symbol::{DefId, IdentName};
47pub use tags::TagId;
48
49pub trait MakeBackend: Sized {
50    type Target: CodegenBackend;
51    fn make_backend(self, context: Context) -> Self::Target;
52}
53
54pub struct MkThriftBackend;
55
56impl MakeBackend for MkThriftBackend {
57    type Target = ThriftBackend;
58
59    fn make_backend(self, context: Context) -> Self::Target {
60        ThriftBackend::new(context)
61    }
62}
63
64pub struct MkProtobufBackend;
65
66impl MakeBackend for MkProtobufBackend {
67    type Target = ProtobufBackend;
68
69    fn make_backend(self, context: Context) -> Self::Target {
70        ProtobufBackend::new(context)
71    }
72}
73
74pub struct MkPbBackend;
75
76impl MakeBackend for MkPbBackend {
77    type Target = codegen::pb::ProtobufBackend;
78
79    fn make_backend(self, context: Context) -> Self::Target {
80        codegen::pb::ProtobufBackend::new(context)
81    }
82}
83
84pub struct Builder<MkB, P> {
85    source_type: SourceType,
86    mk_backend: MkB,
87    parser: P,
88    plugins: Vec<Box<dyn Plugin>>,
89    ignore_unused: bool,
90    split: bool,
91    touches: Vec<(std::path::PathBuf, Vec<String>)>,
92    change_case: bool,
93    keep_unknown_fields: Vec<std::path::PathBuf>,
94    dedups: Vec<FastStr>,
95    special_namings: Vec<FastStr>,
96    common_crate_name: FastStr,
97    with_descriptor: bool,
98    with_field_mask: bool,
99}
100
101impl Builder<MkThriftBackend, ThriftParser> {
102    pub fn thrift() -> Self {
103        Builder {
104            source_type: SourceType::Thrift,
105            mk_backend: MkThriftBackend,
106            parser: ThriftParser::default(),
107            plugins: vec![
108                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
109                Box::new(ImplDefaultPlugin),
110            ],
111            touches: Vec::default(),
112            ignore_unused: true,
113            change_case: true,
114            keep_unknown_fields: Vec::default(),
115            dedups: Vec::default(),
116            special_namings: Vec::default(),
117            common_crate_name: "common".into(),
118            split: false,
119            with_descriptor: false,
120            with_field_mask: false,
121        }
122    }
123}
124
125impl Builder<MkProtobufBackend, ProtobufParser> {
126    pub fn protobuf() -> Self {
127        Builder {
128            source_type: SourceType::Protobuf,
129            mk_backend: MkProtobufBackend,
130            parser: ProtobufParser::default(),
131            plugins: vec![
132                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
133                Box::new(ImplDefaultPlugin),
134            ],
135            touches: Vec::default(),
136            ignore_unused: true,
137            change_case: true,
138            keep_unknown_fields: Vec::default(),
139            dedups: Vec::default(),
140            special_namings: Vec::default(),
141            common_crate_name: "common".into(),
142            split: false,
143            with_descriptor: false,
144            with_field_mask: false,
145        }
146    }
147}
148
149impl Builder<MkPbBackend, ProtobufParser> {
150    pub fn pb() -> Self {
151        Builder {
152            source_type: SourceType::Protobuf,
153            mk_backend: MkPbBackend,
154            parser: ProtobufParser::default(),
155            plugins: vec![
156                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
157                Box::new(ImplDefaultPlugin),
158            ],
159            touches: Vec::default(),
160            ignore_unused: true,
161            change_case: true,
162            keep_unknown_fields: Vec::default(),
163            dedups: Vec::default(),
164            special_namings: Vec::default(),
165            common_crate_name: "common".into(),
166            split: false,
167            with_descriptor: false,
168            with_field_mask: false,
169        }
170    }
171}
172
173impl<MkB, P> Builder<MkB, P>
174where
175    P: Parser,
176{
177    pub fn include_dirs(mut self, include_dirs: Vec<PathBuf>) -> Self {
178        self.parser.include_dirs(include_dirs);
179        self
180    }
181}
182
183impl<MkB, P> Builder<MkB, P> {
184    pub fn with_backend<B: MakeBackend>(self, mk_backend: B) -> Builder<B, P> {
185        Builder {
186            source_type: self.source_type,
187            mk_backend,
188            parser: self.parser,
189            plugins: self.plugins,
190            ignore_unused: self.ignore_unused,
191            touches: self.touches,
192            change_case: self.change_case,
193            keep_unknown_fields: self.keep_unknown_fields,
194            dedups: self.dedups,
195            special_namings: self.special_namings,
196            common_crate_name: self.common_crate_name,
197            split: self.split,
198            with_descriptor: self.with_descriptor,
199            with_field_mask: self.with_field_mask,
200        }
201    }
202
203    pub fn plugin<Plu: Plugin + 'static>(mut self, p: Plu) -> Self {
204        self.plugins.push(Box::new(p));
205
206        self
207    }
208
209    pub fn split_generated_files(mut self, split: bool) -> Self {
210        self.split = split;
211        self
212    }
213
214    pub fn change_case(mut self, change_case: bool) -> Self {
215        self.change_case = change_case;
216        self
217    }
218
219    /**
220     * Don't generate items which are unused by the main service
221     */
222    pub fn ignore_unused(mut self, flag: bool) -> Self {
223        self.ignore_unused = flag;
224        self
225    }
226
227    /**
228     * Generate items even them are not used.
229     *
230     * This is ignored if `ignore_unused` is false
231     */
232    pub fn touch(
233        mut self,
234        item: impl IntoIterator<Item = (PathBuf, Vec<impl Into<String>>)>,
235    ) -> Self {
236        self.touches.extend(
237            item.into_iter()
238                .map(|s| (s.0, s.1.into_iter().map(|s| s.into()).collect())),
239        );
240        self
241    }
242
243    pub fn keep_unknown_fields(mut self, item: impl IntoIterator<Item = PathBuf>) -> Self {
244        self.keep_unknown_fields.extend(item);
245        self
246    }
247
248    pub fn dedup(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
249        self.dedups.extend(item);
250        self
251    }
252
253    pub fn special_namings(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
254        self.special_namings.extend(item);
255        self
256    }
257
258    pub fn common_crate_name(mut self, name: FastStr) -> Self {
259        self.common_crate_name = name;
260        self
261    }
262
263    pub fn with_descriptor(mut self, on: bool) -> Self {
264        self.with_descriptor = on;
265        self
266    }
267
268    pub fn with_field_mask(mut self, on: bool) -> Self {
269        self.with_field_mask = on;
270        self
271    }
272}
273
274pub enum Output {
275    Workspace(PathBuf),
276    File(PathBuf),
277}
278
279#[derive(serde::Deserialize, serde::Serialize)]
280pub struct IdlService {
281    pub path: PathBuf,
282    pub config: serde_yaml::Value,
283}
284
285impl IdlService {
286    pub fn from_path(p: PathBuf) -> Self {
287        IdlService {
288            path: p,
289            config: Default::default(),
290        }
291    }
292}
293
294impl<MkB, P> Builder<MkB, P>
295where
296    MkB: MakeBackend + Send,
297    MkB::Target: Send,
298    P: Parser,
299{
300    pub fn compile(
301        self,
302        services: impl IntoIterator<Item = impl AsRef<std::path::Path>>,
303        out: Output,
304    ) {
305        let services = services
306            .into_iter()
307            .map(|path| IdlService {
308                config: serde_yaml::Value::default(),
309                path: path.as_ref().to_owned(),
310            })
311            .collect();
312
313        self.compile_with_config(services, out)
314    }
315
316    #[allow(clippy::too_many_arguments)]
317    pub fn build_cx(
318        services: Vec<IdlService>,
319        out: Option<Output>,
320        mut parser: P,
321        touches: Vec<(PathBuf, Vec<String>)>,
322        ignore_unused: bool,
323        source_type: SourceType,
324        change_case: bool,
325        keep_unknown_fields: Vec<PathBuf>,
326        dedups: Vec<FastStr>,
327        special_namings: Vec<FastStr>,
328        common_crate_name: FastStr,
329        split: bool,
330        with_descriptor: bool,
331        with_field_mask: bool,
332    ) -> Context {
333        parser.inputs(services.iter().map(|s| &s.path));
334        let ParseResult {
335            files,
336            input_files,
337            file_ids_map,
338        } = parser.parse();
339
340        let ResolveResult {
341            files,
342            nodes,
343            tags,
344            args,
345        } = Resolver::default().resolve_files(&files);
346
347        let items = nodes.iter().filter_map(|(k, v)| match &v.kind {
348            NodeKind::Item(item) => Some((*k, item.clone())),
349            _ => None,
350        });
351
352        let type_graph = TypeGraph::from_items(items.clone());
353        let workspace_graph = WorkspaceGraph::from_items(items);
354
355        // Build the database using the builder pattern
356        let db = RootDatabase::default()
357            .with_file_ids_map(file_ids_map)
358            .with_files(files.into_iter())
359            .with_nodes(nodes)
360            .with_tags(tags, type_graph)
361            .with_args(args)
362            .with_workspace_graph(workspace_graph)
363            .with_input_files(input_files.clone());
364
365        let mut input = Vec::with_capacity(input_files.len());
366        for file_id in &input_files {
367            let file = db.file(*file_id).unwrap();
368            file.items.iter().for_each(|def_id| {
369                // Check if the node is an Item before calling item()
370                if let Some(node) = db.node(*def_id) {
371                    if let NodeKind::Item(item) = &node.kind {
372                        if matches!(&**item, rir::Item::Service(_)) {
373                            input.push(*def_id)
374                        }
375                    }
376                }
377            });
378        }
379
380        let mut cx = ContextBuilder::new(
381            db,
382            match out {
383                Some(Output::Workspace(dir)) => Mode::Workspace(WorkspaceInfo {
384                    dir,
385                    location_map: Default::default(),
386                }),
387                Some(Output::File(p)) => Mode::SingleFile { file_path: p },
388                None => Mode::SingleFile {
389                    file_path: Default::default(),
390                },
391            },
392            input,
393        );
394
395        cx.collect(if ignore_unused {
396            CollectMode::OnlyUsed { touches }
397        } else {
398            CollectMode::All
399        });
400
401        cx.keep(keep_unknown_fields);
402
403        cx.build(
404            Arc::from(services),
405            source_type,
406            change_case,
407            dedups,
408            special_namings,
409            common_crate_name,
410            split,
411            with_descriptor,
412            with_field_mask,
413        )
414    }
415
416    pub fn compile_with_config(self, services: Vec<IdlService>, out: Output) {
417        let _ = tracing_subscriber::fmt::try_init();
418
419        let cx = Self::build_cx(
420            services,
421            Some(out),
422            self.parser,
423            self.touches,
424            self.ignore_unused,
425            self.source_type,
426            self.change_case,
427            self.keep_unknown_fields,
428            self.dedups,
429            self.special_namings,
430            self.common_crate_name,
431            self.split,
432            self.with_descriptor,
433            self.with_field_mask,
434        );
435
436        cx.exec_plugin(BoxedPlugin);
437
438        cx.exec_plugin(AutoDerivePlugin::new(
439            Arc::from(["#[derive(PartialOrd)]".into()]),
440            |ty| {
441                let mut ty = ty;
442                while let ty::Vec(_ty) = &ty.kind {
443                    ty = _ty;
444                }
445                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_)) {
446                    PredicateResult::No
447                } else {
448                    PredicateResult::GoOn
449                }
450            },
451        ));
452
453        cx.exec_plugin(AutoDerivePlugin::new(
454            Arc::from(["#[derive(Hash, Eq, Ord)]".into()]),
455            |ty| {
456                let mut ty = ty;
457                while let ty::Vec(_ty) = &ty.kind {
458                    ty = _ty;
459                }
460                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_) | ty::F64 | ty::F32) {
461                    PredicateResult::No
462                } else {
463                    PredicateResult::GoOn
464                }
465            },
466        ));
467
468        self.plugins.into_iter().for_each(|p| cx.exec_plugin(p));
469
470        std::thread::scope(|scope| {
471            let pool = rayon::ThreadPoolBuilder::new();
472            let pool = pool
473                .spawn_handler(|thread| {
474                    let mut builder = std::thread::Builder::new();
475                    if let Some(name) = thread.name() {
476                        builder = builder.name(name.to_string());
477                    }
478                    if let Some(size) = thread.stack_size() {
479                        builder = builder.stack_size(size);
480                    }
481
482                    let cx = cx.clone();
483                    builder.spawn_scoped(scope, move || {
484                        CONTEXT.set(&cx, || thread.run());
485                    })?;
486                    Ok(())
487                })
488                .build()?;
489
490            pool.install(move || {
491                let cg = Codegen::new(self.mk_backend.make_backend(cx));
492                cg.r#gen().unwrap();
493            });
494
495            Ok::<_, rayon::ThreadPoolBuildError>(())
496        })
497        .unwrap();
498    }
499
500    // gen service_global_name and methods for certain service in IdlService
501    pub fn init_service(self, service: IdlService) -> anyhow::Result<(String, String)> {
502        let _ = tracing_subscriber::fmt::try_init();
503        let path = service.path.clone();
504        let cx = Self::build_cx(
505            vec![service],
506            None,
507            self.parser,
508            self.touches,
509            self.ignore_unused,
510            self.source_type,
511            self.change_case,
512            self.keep_unknown_fields,
513            self.dedups,
514            self.special_namings,
515            self.common_crate_name,
516            self.split,
517            self.with_descriptor,
518            self.with_field_mask,
519        );
520
521        std::thread::scope(|_scope| {
522            CONTEXT.set(&cx.clone(), move || {
523                Codegen::new(self.mk_backend.make_backend(cx)).pick_init_service(path)
524            })
525        })
526    }
527}
528
529mod test;