pilota_build/
lib.rs

1#![doc(
2    html_logo_url = "https://github.com/cloudwego/pilota/raw/main/.github/assets/logo.png?sanitize=true"
3)]
4#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
5#![allow(clippy::mutable_key_type)]
6
7mod util;
8
9pub mod codegen;
10pub mod db;
11pub(crate) mod errors;
12pub mod fmt;
13mod index;
14mod ir;
15pub mod middle;
16pub mod parser;
17mod resolve;
18mod symbol;
19
20use faststr::FastStr;
21pub use symbol::Symbol;
22use tempfile::tempdir;
23pub mod tags;
24use std::{path::PathBuf, sync::Arc};
25
26mod dedup;
27pub mod plugin;
28
29pub use codegen::{
30    Codegen, protobuf::ProtobufBackend, thrift::ThriftBackend, traits::CodegenBackend,
31};
32use db::{RirDatabase, RootDatabase};
33use middle::{
34    context::{CollectMode, ContextBuilder, Mode, WorkspaceInfo, tls::CONTEXT},
35    rir::NodeKind,
36    type_graph::TypeGraph,
37    workspace_graph::WorkspaceGraph,
38};
39pub use middle::{
40    context::{Context, SourceType},
41    rir, ty,
42};
43use parser::{ParseResult, Parser, protobuf::ProtobufParser, thrift::ThriftParser};
44use plugin::{AutoDerivePlugin, BoxedPlugin, ImplDefaultPlugin, PredicateResult, WithAttrsPlugin};
45pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
46use resolve::{ResolveResult, Resolver};
47pub use symbol::{DefId, IdentName};
48pub use tags::TagId;
49
50pub trait MakeBackend: Sized {
51    type Target: CodegenBackend;
52    fn make_backend(self, context: Context) -> Self::Target;
53}
54
55pub struct MkThriftBackend;
56
57impl MakeBackend for MkThriftBackend {
58    type Target = ThriftBackend;
59
60    fn make_backend(self, context: Context) -> Self::Target {
61        ThriftBackend::new(context)
62    }
63}
64
65pub struct MkProtobufBackend;
66
67impl MakeBackend for MkProtobufBackend {
68    type Target = ProtobufBackend;
69
70    fn make_backend(self, context: Context) -> Self::Target {
71        ProtobufBackend::new(context)
72    }
73}
74
75pub struct MkPbBackend;
76
77impl MakeBackend for MkPbBackend {
78    type Target = codegen::pb::ProtobufBackend;
79
80    fn make_backend(self, context: Context) -> Self::Target {
81        codegen::pb::ProtobufBackend::new(context)
82    }
83}
84
85pub struct Builder<MkB, P> {
86    source_type: SourceType,
87    mk_backend: MkB,
88    parser: P,
89    plugins: Vec<Box<dyn Plugin>>,
90    ignore_unused: bool,
91    split: bool,
92    touches: Vec<(std::path::PathBuf, Vec<String>)>,
93    change_case: bool,
94    keep_unknown_fields: Vec<std::path::PathBuf>,
95    dedups: Vec<FastStr>,
96    special_namings: Vec<FastStr>,
97    common_crate_name: FastStr,
98    with_descriptor: bool,
99    with_field_mask: bool,
100    temp_dir: Option<tempfile::TempDir>,
101}
102
103impl Builder<MkThriftBackend, ThriftParser> {
104    pub fn thrift() -> Self {
105        Builder {
106            source_type: SourceType::Thrift,
107            mk_backend: MkThriftBackend,
108            parser: ThriftParser::default(),
109            plugins: vec![
110                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
111                Box::new(ImplDefaultPlugin),
112            ],
113            touches: Vec::default(),
114            ignore_unused: true,
115            change_case: true,
116            keep_unknown_fields: Vec::default(),
117            dedups: Vec::default(),
118            special_namings: Vec::default(),
119            common_crate_name: "common".into(),
120            split: false,
121            with_descriptor: false,
122            with_field_mask: false,
123            temp_dir: None,
124        }
125    }
126}
127
128impl Builder<MkProtobufBackend, ProtobufParser> {
129    pub fn protobuf() -> Self {
130        Builder {
131            source_type: SourceType::Protobuf,
132            mk_backend: MkProtobufBackend,
133            parser: ProtobufParser::default(),
134            plugins: vec![
135                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
136                Box::new(ImplDefaultPlugin),
137            ],
138            touches: Vec::default(),
139            ignore_unused: true,
140            change_case: true,
141            keep_unknown_fields: Vec::default(),
142            dedups: Vec::default(),
143            special_namings: Vec::default(),
144            common_crate_name: "common".into(),
145            split: false,
146            with_descriptor: false,
147            with_field_mask: false,
148            temp_dir: None,
149        }
150    }
151}
152
153impl Builder<MkPbBackend, ProtobufParser> {
154    pub fn pb() -> Self {
155        let (out_dir, temp_dir) = match std::env::var("OUT_DIR") {
156            Ok(out_dir_str) => (PathBuf::from(out_dir_str), None),
157            _ => {
158                let temp_dir = tempdir().unwrap();
159                (temp_dir.path().to_path_buf(), Some(temp_dir))
160            }
161        };
162        let include_dir = out_dir.join("pilota_proto");
163
164        std::fs::create_dir_all(&include_dir).expect("Failed to create pilota_proto directory");
165
166        let pilota_proto_src = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("proto/pilota.proto");
167
168        std::fs::copy(&pilota_proto_src, include_dir.join("pilota.proto"))
169            .expect("Failed to copy pilota.proto");
170
171        let mut parser = ProtobufParser::default();
172        parser.include_dirs(vec![include_dir]);
173
174        Builder {
175            source_type: SourceType::Protobuf,
176            mk_backend: MkPbBackend,
177            parser,
178            plugins: vec![
179                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
180                Box::new(ImplDefaultPlugin),
181            ],
182            touches: Vec::default(),
183            ignore_unused: true,
184            change_case: true,
185            keep_unknown_fields: Vec::default(),
186            dedups: Vec::default(),
187            special_namings: Vec::default(),
188            common_crate_name: "common".into(),
189            split: false,
190            with_descriptor: false,
191            with_field_mask: false,
192            temp_dir,
193        }
194    }
195}
196
197impl<MkB, P> Builder<MkB, P>
198where
199    P: Parser,
200{
201    pub fn include_dirs(mut self, include_dirs: Vec<PathBuf>) -> Self {
202        self.parser.include_dirs(include_dirs);
203        self
204    }
205}
206
207impl<MkB, P> Builder<MkB, P> {
208    pub fn with_backend<B: MakeBackend>(self, mk_backend: B) -> Builder<B, P> {
209        Builder {
210            source_type: self.source_type,
211            mk_backend,
212            parser: self.parser,
213            plugins: self.plugins,
214            ignore_unused: self.ignore_unused,
215            touches: self.touches,
216            change_case: self.change_case,
217            keep_unknown_fields: self.keep_unknown_fields,
218            dedups: self.dedups,
219            special_namings: self.special_namings,
220            common_crate_name: self.common_crate_name,
221            split: self.split,
222            with_descriptor: self.with_descriptor,
223            with_field_mask: self.with_field_mask,
224            temp_dir: self.temp_dir,
225        }
226    }
227
228    pub fn plugin<Plu: Plugin + 'static>(mut self, p: Plu) -> Self {
229        self.plugins.push(Box::new(p));
230
231        self
232    }
233
234    pub fn split_generated_files(mut self, split: bool) -> Self {
235        self.split = split;
236        self
237    }
238
239    pub fn change_case(mut self, change_case: bool) -> Self {
240        self.change_case = change_case;
241        self
242    }
243
244    /**
245     * Don't generate items which are unused by the main service
246     */
247    pub fn ignore_unused(mut self, flag: bool) -> Self {
248        self.ignore_unused = flag;
249        self
250    }
251
252    /**
253     * Generate items even them are not used.
254     *
255     * This is ignored if `ignore_unused` is false
256     */
257    pub fn touch(
258        mut self,
259        item: impl IntoIterator<Item = (PathBuf, Vec<impl Into<String>>)>,
260    ) -> Self {
261        self.touches.extend(
262            item.into_iter()
263                .map(|s| (s.0, s.1.into_iter().map(|s| s.into()).collect())),
264        );
265        self
266    }
267
268    pub fn keep_unknown_fields(mut self, item: impl IntoIterator<Item = PathBuf>) -> Self {
269        self.keep_unknown_fields.extend(item);
270        self
271    }
272
273    pub fn dedup(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
274        self.dedups.extend(item);
275        self
276    }
277
278    pub fn special_namings(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
279        self.special_namings.extend(item);
280        self
281    }
282
283    pub fn common_crate_name(mut self, name: FastStr) -> Self {
284        self.common_crate_name = name;
285        self
286    }
287
288    pub fn with_descriptor(mut self, on: bool) -> Self {
289        self.with_descriptor = on;
290        self
291    }
292
293    pub fn with_field_mask(mut self, on: bool) -> Self {
294        self.with_field_mask = on;
295        self
296    }
297}
298
299pub enum Output {
300    Workspace(PathBuf),
301    File(PathBuf),
302}
303
304#[derive(serde::Deserialize, serde::Serialize)]
305pub struct IdlService {
306    pub path: PathBuf,
307    pub config: serde_yaml::Value,
308}
309
310impl IdlService {
311    pub fn from_path(p: PathBuf) -> Self {
312        IdlService {
313            path: p,
314            config: Default::default(),
315        }
316    }
317}
318
319impl<MkB, P> Builder<MkB, P>
320where
321    MkB: MakeBackend + Send,
322    MkB::Target: Send,
323    P: Parser,
324{
325    pub fn compile(
326        self,
327        services: impl IntoIterator<Item = impl AsRef<std::path::Path>>,
328        out: Output,
329    ) {
330        let services = services
331            .into_iter()
332            .map(|path| IdlService {
333                config: serde_yaml::Value::default(),
334                path: path.as_ref().to_owned(),
335            })
336            .collect();
337
338        self.compile_with_config(services, out)
339    }
340
341    #[allow(clippy::too_many_arguments)]
342    pub fn build_cx(
343        services: Vec<IdlService>,
344        out: Option<Output>,
345        mut parser: P,
346        touches: Vec<(PathBuf, Vec<String>)>,
347        ignore_unused: bool,
348        source_type: SourceType,
349        change_case: bool,
350        keep_unknown_fields: Vec<PathBuf>,
351        dedups: Vec<FastStr>,
352        special_namings: Vec<FastStr>,
353        common_crate_name: FastStr,
354        split: bool,
355        with_descriptor: bool,
356        with_field_mask: bool,
357    ) -> Context {
358        parser.inputs(services.iter().map(|s| &s.path));
359        let ParseResult {
360            files,
361            input_files,
362            file_ids_map,
363            file_paths,
364        } = parser.parse();
365
366        let ResolveResult {
367            files,
368            nodes,
369            tags,
370            args,
371        } = Resolver::default().resolve_files(&files);
372
373        let items = nodes.iter().filter_map(|(k, v)| match &v.kind {
374            NodeKind::Item(item) => Some((*k, item.clone())),
375            _ => None,
376        });
377
378        let type_graph = TypeGraph::from_items(items.clone());
379        let workspace_graph = WorkspaceGraph::from_items(items);
380
381        // Build the database using the builder pattern
382        let db = RootDatabase::default()
383            .with_file_ids_map(file_ids_map)
384            .with_file_paths(file_paths)
385            .with_files(files.into_iter())
386            .with_nodes(nodes)
387            .with_tags(tags, type_graph)
388            .with_args(args)
389            .with_workspace_graph(workspace_graph)
390            .with_input_files(input_files.clone());
391
392        let mut input = Vec::with_capacity(input_files.len());
393        for file_id in &input_files {
394            let file = db.file(*file_id).unwrap();
395            file.items.iter().for_each(|def_id| {
396                // Check if the node is an Item before calling item()
397                if let Some(node) = db.node(*def_id) {
398                    if let NodeKind::Item(item) = &node.kind {
399                        if matches!(&**item, rir::Item::Service(_)) {
400                            input.push(*def_id)
401                        }
402                    }
403                }
404            });
405        }
406
407        let mut cx = ContextBuilder::new(
408            db,
409            match out {
410                Some(Output::Workspace(dir)) => Mode::Workspace(WorkspaceInfo {
411                    dir,
412                    location_map: Default::default(),
413                }),
414                Some(Output::File(p)) => Mode::SingleFile { file_path: p },
415                None => Mode::SingleFile {
416                    file_path: Default::default(),
417                },
418            },
419            input,
420        );
421
422        cx.collect(if ignore_unused {
423            CollectMode::OnlyUsed { touches }
424        } else {
425            CollectMode::All
426        });
427
428        cx.keep(keep_unknown_fields);
429
430        cx.build(
431            Arc::from(services),
432            source_type,
433            change_case,
434            dedups,
435            special_namings,
436            common_crate_name,
437            split,
438            with_descriptor,
439            with_field_mask,
440        )
441    }
442
443    pub fn compile_with_config(self, services: Vec<IdlService>, out: Output) {
444        let _ = tracing_subscriber::fmt::try_init();
445
446        let cx = Self::build_cx(
447            services,
448            Some(out),
449            self.parser,
450            self.touches,
451            self.ignore_unused,
452            self.source_type,
453            self.change_case,
454            self.keep_unknown_fields,
455            self.dedups,
456            self.special_namings,
457            self.common_crate_name,
458            self.split,
459            self.with_descriptor,
460            self.with_field_mask,
461        );
462
463        cx.exec_plugin(BoxedPlugin);
464
465        cx.exec_plugin(AutoDerivePlugin::new(
466            Arc::from(["#[derive(PartialOrd)]".into()]),
467            |ty| {
468                let mut ty = ty;
469                while let ty::Vec(_ty) = &ty.kind {
470                    ty = _ty;
471                }
472                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_)) {
473                    PredicateResult::No
474                } else {
475                    PredicateResult::GoOn
476                }
477            },
478        ));
479
480        cx.exec_plugin(AutoDerivePlugin::new(
481            Arc::from(["#[derive(Hash, Eq, Ord)]".into()]),
482            |ty| {
483                let mut ty = ty;
484                while let ty::Vec(_ty) = &ty.kind {
485                    ty = _ty;
486                }
487                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_) | ty::F64 | ty::F32) {
488                    PredicateResult::No
489                } else {
490                    PredicateResult::GoOn
491                }
492            },
493        ));
494
495        self.plugins.into_iter().for_each(|p| cx.exec_plugin(p));
496
497        std::thread::scope(|scope| {
498            let pool = rayon::ThreadPoolBuilder::new();
499            let pool = pool
500                .spawn_handler(|thread| {
501                    let mut builder = std::thread::Builder::new();
502                    if let Some(name) = thread.name() {
503                        builder = builder.name(name.to_string());
504                    }
505                    if let Some(size) = thread.stack_size() {
506                        builder = builder.stack_size(size);
507                    }
508
509                    let cx = cx.clone();
510                    builder.spawn_scoped(scope, move || {
511                        CONTEXT.set(&cx, || thread.run());
512                    })?;
513                    Ok(())
514                })
515                .build()?;
516
517            pool.install(move || {
518                let cg = Codegen::new(self.mk_backend.make_backend(cx));
519                cg.r#gen().unwrap();
520            });
521
522            Ok::<_, rayon::ThreadPoolBuildError>(())
523        })
524        .unwrap();
525    }
526
527    // gen service_global_name and methods for certain service in IdlService
528    pub fn init_service(self, service: IdlService) -> anyhow::Result<(String, String)> {
529        let _ = tracing_subscriber::fmt::try_init();
530        let path = service.path.clone();
531        let cx = Self::build_cx(
532            vec![service],
533            None,
534            self.parser,
535            self.touches,
536            self.ignore_unused,
537            self.source_type,
538            self.change_case,
539            self.keep_unknown_fields,
540            self.dedups,
541            self.special_namings,
542            self.common_crate_name,
543            self.split,
544            self.with_descriptor,
545            self.with_field_mask,
546        );
547
548        std::thread::scope(|_scope| {
549            CONTEXT.set(&cx.clone(), move || {
550                Codegen::new(self.mk_backend.make_backend(cx)).pick_init_service(path)
551            })
552        })
553    }
554}
555
556mod test;