pilota_build/
lib.rs

1#![doc(
2    html_logo_url = "https://github.com/cloudwego/pilota/raw/main/.github/assets/logo.png?sanitize=true"
3)]
4#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
5#![allow(clippy::mutable_key_type)]
6
7mod util;
8
9pub mod codegen;
10pub mod db;
11pub(crate) mod errors;
12pub mod fmt;
13mod index;
14mod ir;
15pub mod middle;
16pub mod parser;
17mod resolve;
18mod symbol;
19
20use faststr::FastStr;
21pub use symbol::{ModPath, Symbol};
22use tempfile::tempdir;
23pub mod tags;
24use std::{path::PathBuf, sync::Arc};
25
26mod dedup;
27pub mod plugin;
28
29pub use codegen::{Codegen, thrift::ThriftBackend, traits::CodegenBackend};
30use db::{RirDatabase, RootDatabase};
31use middle::{
32    context::{CollectMode, ContextBuilder, Mode, WorkspaceInfo, tls::CONTEXT},
33    rir::NodeKind,
34    type_graph::TypeGraph,
35    workspace_graph::WorkspaceGraph,
36};
37pub use middle::{
38    context::{Context, SourceType},
39    rir, ty,
40};
41use parser::{ParseResult, Parser, protobuf::ProtobufParser, thrift::ThriftParser};
42use plugin::{AutoDerivePlugin, BoxedPlugin, ImplDefaultPlugin, PredicateResult, WithAttrsPlugin};
43pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
44use resolve::{ResolveResult, Resolver};
45pub use symbol::{DefId, IdentName};
46pub use tags::TagId;
47
48use crate::codegen::pb::ProtobufBackend;
49
50pub trait MakeBackend: Sized {
51    type Target: CodegenBackend;
52    fn make_backend(self, context: Context) -> Self::Target;
53}
54
55pub struct MkThriftBackend;
56
57impl MakeBackend for MkThriftBackend {
58    type Target = ThriftBackend;
59
60    fn make_backend(self, context: Context) -> Self::Target {
61        ThriftBackend::new(context)
62    }
63}
64
65pub struct MkPbBackend;
66
67impl MakeBackend for MkPbBackend {
68    type Target = ProtobufBackend;
69
70    fn make_backend(self, context: Context) -> Self::Target {
71        ProtobufBackend::new(context)
72    }
73}
74
75pub struct Builder<MkB, P> {
76    source_type: SourceType,
77    mk_backend: MkB,
78    parser: P,
79    plugins: Vec<Box<dyn Plugin>>,
80    ignore_unused: bool,
81    split: bool,
82    touches: Vec<(std::path::PathBuf, Vec<String>)>,
83    change_case: bool,
84    keep_unknown_fields: Vec<std::path::PathBuf>,
85    dedups: Vec<FastStr>,
86    special_namings: Vec<FastStr>,
87    common_crate_name: FastStr,
88    with_descriptor: bool,
89    with_field_mask: bool,
90    temp_dir: Option<tempfile::TempDir>,
91    with_comments: bool,
92}
93
94impl Builder<MkThriftBackend, ThriftParser> {
95    pub fn thrift() -> Self {
96        Builder {
97            source_type: SourceType::Thrift,
98            mk_backend: MkThriftBackend,
99            parser: ThriftParser::default(),
100            plugins: vec![
101                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
102                Box::new(ImplDefaultPlugin),
103            ],
104            touches: Vec::default(),
105            ignore_unused: true,
106            change_case: true,
107            keep_unknown_fields: Vec::default(),
108            dedups: Vec::default(),
109            special_namings: Vec::default(),
110            common_crate_name: "common".into(),
111            split: false,
112            with_descriptor: false,
113            with_field_mask: false,
114            temp_dir: None,
115            with_comments: false,
116        }
117    }
118}
119
120impl Builder<MkPbBackend, ProtobufParser> {
121    pub fn pb() -> Self {
122        let (out_dir, temp_dir) = match std::env::var("OUT_DIR") {
123            Ok(out_dir_str) => (PathBuf::from(out_dir_str), None),
124            _ => {
125                let temp_dir = tempdir().unwrap();
126                (temp_dir.path().to_path_buf(), Some(temp_dir))
127            }
128        };
129        let include_dir = out_dir.join("pilota_proto");
130
131        std::fs::create_dir_all(&include_dir).expect("Failed to create pilota_proto directory");
132
133        let pilota_proto_src = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("proto/pilota.proto");
134
135        std::fs::copy(&pilota_proto_src, include_dir.join("pilota.proto"))
136            .expect("Failed to copy pilota.proto");
137
138        let mut parser = ProtobufParser::default();
139        parser.include_dirs(vec![include_dir]);
140
141        Builder {
142            source_type: SourceType::Protobuf,
143            mk_backend: MkPbBackend,
144            parser,
145            plugins: vec![
146                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
147                Box::new(ImplDefaultPlugin),
148            ],
149            touches: Vec::default(),
150            ignore_unused: true,
151            change_case: true,
152            keep_unknown_fields: Vec::default(),
153            dedups: Vec::default(),
154            special_namings: Vec::default(),
155            common_crate_name: "common".into(),
156            split: false,
157            with_descriptor: false,
158            with_field_mask: false,
159            temp_dir,
160            with_comments: false,
161        }
162    }
163}
164
165impl<MkB, P> Builder<MkB, P>
166where
167    P: Parser,
168{
169    pub fn include_dirs(mut self, include_dirs: Vec<PathBuf>) -> Self {
170        self.parser.include_dirs(include_dirs);
171        self
172    }
173}
174
175impl<MkB, P> Builder<MkB, P> {
176    pub fn with_backend<B: MakeBackend>(self, mk_backend: B) -> Builder<B, P> {
177        Builder {
178            source_type: self.source_type,
179            mk_backend,
180            parser: self.parser,
181            plugins: self.plugins,
182            ignore_unused: self.ignore_unused,
183            touches: self.touches,
184            change_case: self.change_case,
185            keep_unknown_fields: self.keep_unknown_fields,
186            dedups: self.dedups,
187            special_namings: self.special_namings,
188            common_crate_name: self.common_crate_name,
189            split: self.split,
190            with_descriptor: self.with_descriptor,
191            with_field_mask: self.with_field_mask,
192            temp_dir: self.temp_dir,
193            with_comments: self.with_comments,
194        }
195    }
196
197    pub fn plugin<Plu: Plugin + 'static>(mut self, p: Plu) -> Self {
198        self.plugins.push(Box::new(p));
199
200        self
201    }
202
203    pub fn split_generated_files(mut self, split: bool) -> Self {
204        self.split = split;
205        self
206    }
207
208    pub fn change_case(mut self, change_case: bool) -> Self {
209        self.change_case = change_case;
210        self
211    }
212
213    /**
214     * Don't generate items which are unused by the main service
215     */
216    pub fn ignore_unused(mut self, flag: bool) -> Self {
217        self.ignore_unused = flag;
218        self
219    }
220
221    /**
222     * Generate items even them are not used.
223     *
224     * This is ignored if `ignore_unused` is false
225     */
226    pub fn touch(
227        mut self,
228        item: impl IntoIterator<Item = (PathBuf, Vec<impl Into<String>>)>,
229    ) -> Self {
230        self.touches.extend(
231            item.into_iter()
232                .map(|s| (s.0, s.1.into_iter().map(|s| s.into()).collect())),
233        );
234        self
235    }
236
237    pub fn keep_unknown_fields(mut self, item: impl IntoIterator<Item = PathBuf>) -> Self {
238        self.keep_unknown_fields.extend(item);
239        self
240    }
241
242    pub fn dedup(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
243        self.dedups.extend(item);
244        self
245    }
246
247    pub fn special_namings(mut self, item: impl IntoIterator<Item = FastStr>) -> Self {
248        self.special_namings.extend(item);
249        self
250    }
251
252    pub fn common_crate_name(mut self, name: FastStr) -> Self {
253        self.common_crate_name = name;
254        self
255    }
256
257    pub fn with_descriptor(mut self, on: bool) -> Self {
258        self.with_descriptor = on;
259        self
260    }
261
262    pub fn with_field_mask(mut self, on: bool) -> Self {
263        self.with_field_mask = on;
264        self
265    }
266
267    /**
268     * Generate comments for the generated code
269     */
270    pub fn with_comments(mut self, on: bool) -> Self {
271        self.with_comments = on;
272        self
273    }
274}
275
276pub enum Output {
277    Workspace(PathBuf),
278    File(PathBuf),
279}
280
281#[derive(serde::Deserialize, serde::Serialize)]
282pub struct IdlService {
283    pub path: PathBuf,
284    pub config: serde_yaml::Value,
285}
286
287impl IdlService {
288    pub fn from_path(p: PathBuf) -> Self {
289        IdlService {
290            path: p,
291            config: Default::default(),
292        }
293    }
294}
295
296impl<MkB, P> Builder<MkB, P>
297where
298    MkB: MakeBackend + Send,
299    MkB::Target: Send,
300    P: Parser,
301{
302    pub fn compile(
303        self,
304        services: impl IntoIterator<Item = impl AsRef<std::path::Path>>,
305        out: Output,
306    ) {
307        let services = services
308            .into_iter()
309            .map(|path| IdlService {
310                config: serde_yaml::Value::default(),
311                path: path.as_ref().to_owned(),
312            })
313            .collect();
314
315        self.compile_with_config(services, out)
316    }
317
318    #[allow(clippy::too_many_arguments)]
319    pub fn build_cx(
320        services: Vec<IdlService>,
321        out: Option<Output>,
322        mut parser: P,
323        touches: Vec<(PathBuf, Vec<String>)>,
324        ignore_unused: bool,
325        source_type: SourceType,
326        change_case: bool,
327        keep_unknown_fields: Vec<PathBuf>,
328        dedups: Vec<FastStr>,
329        special_namings: Vec<FastStr>,
330        common_crate_name: FastStr,
331        split: bool,
332        with_descriptor: bool,
333        with_field_mask: bool,
334        with_comments: bool,
335    ) -> Context {
336        parser.inputs(services.iter().map(|s| &s.path));
337        let ParseResult {
338            files,
339            input_files,
340            file_ids_map,
341            file_paths,
342            file_names,
343        } = parser.parse();
344
345        let ResolveResult {
346            files,
347            nodes,
348            tags,
349            args,
350            pb_ext_indexes,
351            pb_ext_indexes_used,
352        } = Resolver::default().resolve_files(&files);
353
354        let items = nodes.iter().filter_map(|(k, v)| match &v.kind {
355            NodeKind::Item(item) => Some((*k, item.clone())),
356            _ => None,
357        });
358
359        let type_graph = TypeGraph::from_items(items.clone());
360        let workspace_graph = WorkspaceGraph::from_items(items);
361
362        // Build the database using the builder pattern
363        let db = RootDatabase::default()
364            .with_file_ids_map(file_ids_map)
365            .with_file_paths(file_paths)
366            .with_file_names(file_names)
367            .with_files(files.into_iter())
368            .with_nodes(nodes)
369            .with_tags(tags, type_graph)
370            .with_args(args)
371            .with_workspace_graph(workspace_graph)
372            .with_input_files(input_files.clone())
373            .with_pb_ext_indexes(pb_ext_indexes)
374            .with_pb_exts_used(pb_ext_indexes_used);
375
376        let mut input = Vec::with_capacity(input_files.len());
377        for file_id in &input_files {
378            let file = db.file(*file_id).unwrap();
379            file.items.iter().for_each(|def_id| {
380                // Check if the node is an Item before calling item()
381                if let Some(node) = db.node(*def_id) {
382                    if let NodeKind::Item(item) = &node.kind {
383                        if matches!(&**item, rir::Item::Service(_)) {
384                            input.push(*def_id)
385                        }
386                    }
387                }
388            });
389        }
390
391        let mut cx = ContextBuilder::new(
392            db,
393            match out {
394                Some(Output::Workspace(dir)) => Mode::Workspace(WorkspaceInfo {
395                    dir,
396                    location_map: Default::default(),
397                }),
398                Some(Output::File(p)) => Mode::SingleFile { file_path: p },
399                None => Mode::SingleFile {
400                    file_path: Default::default(),
401                },
402            },
403            input,
404        );
405
406        cx.collect(if ignore_unused {
407            CollectMode::OnlyUsed { touches }
408        } else {
409            CollectMode::All
410        });
411
412        cx.keep(keep_unknown_fields);
413
414        cx.build(
415            Arc::from(services),
416            source_type,
417            change_case,
418            dedups,
419            special_namings,
420            common_crate_name,
421            split,
422            with_descriptor,
423            with_field_mask,
424            !ignore_unused,
425            with_comments,
426        )
427    }
428
429    pub fn compile_with_config(self, services: Vec<IdlService>, out: Output) {
430        let _ = tracing_subscriber::fmt::try_init();
431
432        let cx = Self::build_cx(
433            services,
434            Some(out),
435            self.parser,
436            self.touches,
437            self.ignore_unused,
438            self.source_type,
439            self.change_case,
440            self.keep_unknown_fields,
441            self.dedups,
442            self.special_namings,
443            self.common_crate_name,
444            self.split,
445            self.with_descriptor,
446            self.with_field_mask,
447            self.with_comments,
448        );
449
450        cx.exec_plugin(BoxedPlugin);
451
452        cx.exec_plugin(AutoDerivePlugin::new(
453            Arc::from(["#[derive(PartialOrd)]".into()]),
454            |ty| {
455                let mut ty = ty;
456                while let ty::Vec(_ty) = &ty.kind {
457                    ty = _ty;
458                }
459                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_)) {
460                    PredicateResult::No
461                } else {
462                    PredicateResult::GoOn
463                }
464            },
465        ));
466
467        cx.exec_plugin(AutoDerivePlugin::new(
468            Arc::from(["#[derive(Hash, Eq, Ord)]".into()]),
469            |ty| {
470                let mut ty = ty;
471                while let ty::Vec(_ty) = &ty.kind {
472                    ty = _ty;
473                }
474                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_) | ty::F64 | ty::F32) {
475                    PredicateResult::No
476                } else {
477                    PredicateResult::GoOn
478                }
479            },
480        ));
481
482        CONTEXT.set(&cx, || {
483            self.plugins.into_iter().for_each(|p| cx.exec_plugin(p));
484        });
485
486        std::thread::scope(|scope| {
487            let pool = rayon::ThreadPoolBuilder::new();
488            let pool = pool
489                .spawn_handler(|thread| {
490                    let mut builder = std::thread::Builder::new();
491                    if let Some(name) = thread.name() {
492                        builder = builder.name(name.to_string());
493                    }
494                    if let Some(size) = thread.stack_size() {
495                        builder = builder.stack_size(size);
496                    }
497
498                    let cx = cx.clone();
499                    builder.spawn_scoped(scope, move || {
500                        CONTEXT.set(&cx, || thread.run());
501                    })?;
502                    Ok(())
503                })
504                .build()?;
505
506            pool.install(move || {
507                let cg = Codegen::new(self.mk_backend.make_backend(cx));
508                cg.r#gen().unwrap();
509            });
510
511            Ok::<_, rayon::ThreadPoolBuildError>(())
512        })
513        .unwrap();
514    }
515
516    // gen service_global_name and methods for certain service in IdlService
517    pub fn init_service(self, service: IdlService) -> anyhow::Result<(String, String)> {
518        let _ = tracing_subscriber::fmt::try_init();
519        let path = service.path.clone();
520        let cx = Self::build_cx(
521            vec![service],
522            None,
523            self.parser,
524            self.touches,
525            self.ignore_unused,
526            self.source_type,
527            self.change_case,
528            self.keep_unknown_fields,
529            self.dedups,
530            self.special_namings,
531            self.common_crate_name,
532            self.split,
533            self.with_descriptor,
534            self.with_field_mask,
535            self.with_comments,
536        );
537
538        std::thread::scope(|_scope| {
539            CONTEXT.set(&cx.clone(), move || {
540                Codegen::new(self.mk_backend.make_backend(cx)).pick_init_service(path)
541            })
542        })
543    }
544}
545
546mod test;