pilota_build2/
lib.rs

1#![doc(
2html_logo_url = "https://github.com/cloudwego/pilota/raw/main/.github/assets/logo.png?sanitize=true"
3)]
4#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
5#![allow(clippy::mutable_key_type)]
6#![allow(warnings)]
7
8use std::{path::PathBuf, sync::Arc};
9
10pub use codegen::{
11    Codegen, protobuf::ProtobufBackend, thrift::ThriftBackend, traits::CodegenBackend,
12};
13use db::{RirDatabase, RootDatabase};
14use middle::{
15    context::{CollectMode, ContextBuilder, Mode, tls::CONTEXT, WorkspaceInfo},
16    rir::NodeKind,
17    type_graph::TypeGraph,
18};
19pub use middle::{
20    adjust::Adjust,
21    context::{Context, SourceType},
22    rir, ty,
23};
24use parser::{Parser, ParseResult, protobuf::ProtobufParser, thrift::ThriftParser};
25use plugin::{
26    AutoDerivePlugin, BoxedPlugin, EnumNumPlugin, ImplDefaultPlugin, PredicateResult,
27    WithAttrsPlugin,
28};
29pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
30use resolve::{Resolver, ResolveResult};
31use salsa::Durability;
32pub use symbol::{DefId, IdentName};
33pub use symbol::Symbol;
34pub use tags::TagId;
35
36mod util;
37
38pub mod codegen;
39pub mod db;
40pub(crate) mod errors;
41pub mod fmt;
42mod index;
43pub mod ir;
44mod middle;
45pub mod parser;
46mod resolve;
47mod symbol;
48pub mod tags;
49// mod dedup;
50pub mod plugin;
51
52pub trait MakeBackend: Sized {
53    type Target: CodegenBackend;
54    fn make_backend(self, context: Context) -> Self::Target;
55}
56
57
58pub struct MkProtobufBackend;
59
60impl MakeBackend for MkProtobufBackend {
61    type Target = ProtobufBackend;
62
63    fn make_backend(self, context: Context) -> Self::Target {
64        ProtobufBackend::new(context)
65    }
66}
67
68pub struct Builder<MkB, P> {
69    source_type: SourceType,
70    mk_backend: MkB,
71    parser: P,
72    plugins: Vec<Box<dyn Plugin>>,
73    ignore_unused: bool,
74    touches: Vec<(std::path::PathBuf, Vec<String>)>,
75    change_case: bool,
76    doc_header: Option<String>,
77}
78
79impl<MkB> Builder<MkB, ThriftParser> {
80    pub fn thrift_with_backend(mk_backend: MkB) -> Self {
81        Builder {
82            source_type: SourceType::Thrift,
83            mk_backend,
84            parser: ThriftParser::default(),
85            plugins: vec![
86                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
87                Box::new(ImplDefaultPlugin),
88                Box::new(EnumNumPlugin),
89            ],
90            touches: Vec::default(),
91            ignore_unused: true,
92            change_case: true,
93            doc_header: None,
94        }
95    }
96}
97
98impl<MkB> Builder<MkB, ProtobufParser> {
99    pub fn protobuf_with_backend(mk_backend: MkB) -> Self {
100        Builder {
101            source_type: SourceType::Protobuf,
102            mk_backend: mk_backend,
103            parser: ProtobufParser::default(),
104            plugins: vec![
105                Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
106                Box::new(ImplDefaultPlugin),
107                Box::new(EnumNumPlugin),
108            ],
109            touches: Vec::default(),
110            ignore_unused: true,
111            change_case: true,
112            doc_header: None,
113        }
114    }
115}
116
117impl<MkB, P> Builder<MkB, P>
118    where
119        P: Parser,
120{
121    pub fn include_dirs(mut self, include_dirs: Vec<PathBuf>) -> Self {
122        self.parser.include_dirs(include_dirs);
123        self
124    }
125}
126
127impl<MkB, P> Builder<MkB, P> {
128    pub fn with_backend<B: MakeBackend>(self, mk_backend: B) -> Builder<B, P> {
129        Builder {
130            source_type: self.source_type,
131            mk_backend,
132            parser: self.parser,
133            plugins: self.plugins,
134            ignore_unused: self.ignore_unused,
135            touches: self.touches,
136            change_case: self.change_case,
137            doc_header: None,
138        }
139    }
140    pub fn doc_header(mut self, doc_header: String) -> Self {
141        self.doc_header = Some(doc_header);
142        self
143    }
144    pub fn plugin<Plu: Plugin + 'static>(mut self, p: Plu) -> Self {
145        self.plugins.push(Box::new(p));
146
147        self
148    }
149
150    pub fn change_case(mut self, change_case: bool) -> Self {
151        self.change_case = change_case;
152        self
153    }
154
155    /**
156     * Don't generate items which are unused by the main service
157     */
158    pub fn ignore_unused(mut self, flag: bool) -> Self {
159        self.ignore_unused = flag;
160        self
161    }
162
163    /**
164     * Generate items even them are not used.
165     *
166     * This is ignored if `ignore_unused` is false
167     */
168    pub fn touch(
169        mut self,
170        item: impl IntoIterator<Item=(PathBuf, Vec<impl Into<String>>)>,
171    ) -> Self {
172        self.touches.extend(
173            item.into_iter()
174                .map(|s| (s.0, s.1.into_iter().map(|s| s.into()).collect())),
175        );
176        self
177    }
178}
179
180pub enum Output {
181    Workspace(PathBuf),
182    File(PathBuf),
183}
184
185#[derive(serde::Deserialize, serde::Serialize)]
186pub struct IdlService {
187    pub path: PathBuf,
188    pub config: serde_yaml::Value,
189}
190
191impl IdlService {
192    pub fn from_path(p: PathBuf) -> Self {
193        IdlService {
194            path: p,
195            config: Default::default(),
196        }
197    }
198}
199
200impl<MkB, P> Builder<MkB, P>
201    where
202        MkB: MakeBackend + Send,
203        MkB::Target: Send,
204        P: Parser,
205{
206    pub fn compile(
207        self,
208        services: impl IntoIterator<Item=impl AsRef<std::path::Path>>,
209        out: Output,
210    ) {
211        let services = services
212            .into_iter()
213            .map(|path| IdlService {
214                config: serde_yaml::Value::default(),
215                path: path.as_ref().to_owned(),
216            })
217            .collect();
218
219        self.compile_with_config(services, out)
220    }
221
222    pub fn compile_with_config(mut self, services: Vec<IdlService>, out: Output) {
223        let _ = tracing_subscriber::fmt::try_init();
224
225        let mut db = RootDatabase::default();
226        self.parser.inputs(services.iter().map(|s| &s.path));
227        let ParseResult {
228            files,
229            input_files,
230            file_ids_map,
231        } = self.parser.parse();
232        db.set_file_ids_map_with_durability(Arc::new(file_ids_map), Durability::HIGH);
233
234        let ResolveResult { files, nodes, tags } = Resolver::default().resolve_files(&files);
235
236        // discard duplicated items
237        // let mods = nodes
238        //     .iter()
239        //     .into_group_map_by(|(_, node)|
240        // files.get(&node.file_id).unwrap().package.clone());
241
242        // for (_, m) in mods {
243        //     m.iter().unique_by(f);
244        // }
245
246        db.set_files_with_durability(Arc::new(files), Durability::HIGH);
247        let items = nodes.iter().filter_map(|(k, v)| {
248            if let NodeKind::Item(item) = &v.kind {
249                Some((*k, item.clone()))
250            } else {
251                None
252            }
253        });
254
255        let type_graph = Arc::from(TypeGraph::from_items(items));
256        db.set_type_graph_with_durability(type_graph, Durability::HIGH);
257        db.set_nodes_with_durability(Arc::new(nodes), Durability::HIGH);
258        db.set_tags_map_with_durability(Arc::new(tags), Durability::HIGH);
259
260        let mut input = Vec::with_capacity(input_files.len());
261        for file_id in &input_files {
262            let file = db.file(*file_id).unwrap();
263            file.items.iter().for_each(|def_id| {
264                if matches!(&*db.item(*def_id).unwrap(), rir::Item::Service(_)) {
265                    input.push(*def_id)
266                }
267            });
268        }
269        db.set_input_files_with_durability(Arc::new(input_files), Durability::HIGH);
270
271        let mut cx = ContextBuilder::new(
272            db,
273            match out {
274                Output::Workspace(dir) => Mode::Workspace(WorkspaceInfo {
275                    dir,
276                    location_map: Default::default(),
277                }),
278                Output::File(p) => Mode::SingleFile { file_path: p },
279            },
280            input,
281        );
282
283        cx.collect(if self.ignore_unused {
284            CollectMode::OnlyUsed {
285                touches: self.touches,
286            }
287        } else {
288            CollectMode::All
289        });
290
291        let cx = cx.build(Arc::from(services), self.source_type, self.change_case, Arc::from(self.doc_header.unwrap_or_default()));
292
293        cx.exec_plugin(BoxedPlugin);
294
295        cx.exec_plugin(AutoDerivePlugin::new(
296            Arc::from(["#[derive(PartialOrd)]".into()]),
297            |ty| {
298                let ty = match &ty.kind {
299                    ty::Vec(ty) => ty,
300                    _ => ty,
301                };
302                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_)) {
303                    PredicateResult::No
304                } else {
305                    PredicateResult::GoOn
306                }
307            },
308        ));
309
310        cx.exec_plugin(AutoDerivePlugin::new(
311            Arc::from(["#[derive(Hash, Eq, Ord)]".into()]),
312            |ty| {
313                let ty = match &ty.kind {
314                    ty::Vec(ty) => ty,
315                    _ => ty,
316                };
317                if matches!(ty.kind, ty::Map(_, _) | ty::Set(_) | ty::F64 | ty::F32) {
318                    PredicateResult::No
319                } else {
320                    PredicateResult::GoOn
321                }
322            },
323        ));
324
325        self.plugins.into_iter().for_each(|p| cx.exec_plugin(p));
326
327        std::thread::scope(|scope| {
328            let pool = rayon::ThreadPoolBuilder::new();
329            let pool = pool
330                .spawn_handler(|thread| {
331                    let mut builder = std::thread::Builder::new();
332                    if let Some(name) = thread.name() {
333                        builder = builder.name(name.to_string());
334                    }
335                    if let Some(size) = thread.stack_size() {
336                        builder = builder.stack_size(size);
337                    }
338
339                    let cx = cx.clone();
340                    builder.spawn_scoped(scope, move || {
341                        CONTEXT.set(&cx, || thread.run());
342                    })?;
343                    Ok(())
344                })
345                .build()?;
346
347            pool.install(move || {
348                let cg = Codegen::new(self.mk_backend.make_backend(cx));
349                cg.gen().unwrap();
350            });
351
352            Ok::<_, rayon::ThreadPoolBuildError>(())
353        })
354            .unwrap();
355    }
356}
357
358mod test;