1#![doc(
2html_logo_url = "https://github.com/cloudwego/pilota/raw/main/.github/assets/logo.png?sanitize=true"
3)]
4#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
5#![allow(clippy::mutable_key_type)]
6#![allow(warnings)]
7
8use std::{path::PathBuf, sync::Arc};
9
10pub use codegen::{
11 Codegen, protobuf::ProtobufBackend, thrift::ThriftBackend, traits::CodegenBackend,
12};
13use db::{RirDatabase, RootDatabase};
14use middle::{
15 context::{CollectMode, ContextBuilder, Mode, tls::CONTEXT, WorkspaceInfo},
16 rir::NodeKind,
17 type_graph::TypeGraph,
18};
19pub use middle::{
20 adjust::Adjust,
21 context::{Context, SourceType},
22 rir, ty,
23};
24use parser::{Parser, ParseResult, protobuf::ProtobufParser, thrift::ThriftParser};
25use plugin::{
26 AutoDerivePlugin, BoxedPlugin, EnumNumPlugin, ImplDefaultPlugin, PredicateResult,
27 WithAttrsPlugin,
28};
29pub use plugin::{BoxClonePlugin, ClonePlugin, Plugin};
30use resolve::{Resolver, ResolveResult};
31use salsa::Durability;
32pub use symbol::{DefId, IdentName};
33pub use symbol::Symbol;
34pub use tags::TagId;
35
36mod util;
37
38pub mod codegen;
39pub mod db;
40pub(crate) mod errors;
41pub mod fmt;
42mod index;
43pub mod ir;
44mod middle;
45pub mod parser;
46mod resolve;
47mod symbol;
48pub mod tags;
49pub mod plugin;
51
52pub trait MakeBackend: Sized {
53 type Target: CodegenBackend;
54 fn make_backend(self, context: Context) -> Self::Target;
55}
56
57
58pub struct MkProtobufBackend;
59
60impl MakeBackend for MkProtobufBackend {
61 type Target = ProtobufBackend;
62
63 fn make_backend(self, context: Context) -> Self::Target {
64 ProtobufBackend::new(context)
65 }
66}
67
68pub struct Builder<MkB, P> {
69 source_type: SourceType,
70 mk_backend: MkB,
71 parser: P,
72 plugins: Vec<Box<dyn Plugin>>,
73 ignore_unused: bool,
74 touches: Vec<(std::path::PathBuf, Vec<String>)>,
75 change_case: bool,
76 doc_header: Option<String>,
77}
78
79impl<MkB> Builder<MkB, ThriftParser> {
80 pub fn thrift_with_backend(mk_backend: MkB) -> Self {
81 Builder {
82 source_type: SourceType::Thrift,
83 mk_backend,
84 parser: ThriftParser::default(),
85 plugins: vec![
86 Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
87 Box::new(ImplDefaultPlugin),
88 Box::new(EnumNumPlugin),
89 ],
90 touches: Vec::default(),
91 ignore_unused: true,
92 change_case: true,
93 doc_header: None,
94 }
95 }
96}
97
98impl<MkB> Builder<MkB, ProtobufParser> {
99 pub fn protobuf_with_backend(mk_backend: MkB) -> Self {
100 Builder {
101 source_type: SourceType::Protobuf,
102 mk_backend: mk_backend,
103 parser: ProtobufParser::default(),
104 plugins: vec![
105 Box::new(WithAttrsPlugin(Arc::from(["#[derive(Debug)]".into()]))),
106 Box::new(ImplDefaultPlugin),
107 Box::new(EnumNumPlugin),
108 ],
109 touches: Vec::default(),
110 ignore_unused: true,
111 change_case: true,
112 doc_header: None,
113 }
114 }
115}
116
117impl<MkB, P> Builder<MkB, P>
118 where
119 P: Parser,
120{
121 pub fn include_dirs(mut self, include_dirs: Vec<PathBuf>) -> Self {
122 self.parser.include_dirs(include_dirs);
123 self
124 }
125}
126
127impl<MkB, P> Builder<MkB, P> {
128 pub fn with_backend<B: MakeBackend>(self, mk_backend: B) -> Builder<B, P> {
129 Builder {
130 source_type: self.source_type,
131 mk_backend,
132 parser: self.parser,
133 plugins: self.plugins,
134 ignore_unused: self.ignore_unused,
135 touches: self.touches,
136 change_case: self.change_case,
137 doc_header: None,
138 }
139 }
140 pub fn doc_header(mut self, doc_header: String) -> Self {
141 self.doc_header = Some(doc_header);
142 self
143 }
144 pub fn plugin<Plu: Plugin + 'static>(mut self, p: Plu) -> Self {
145 self.plugins.push(Box::new(p));
146
147 self
148 }
149
150 pub fn change_case(mut self, change_case: bool) -> Self {
151 self.change_case = change_case;
152 self
153 }
154
155 pub fn ignore_unused(mut self, flag: bool) -> Self {
159 self.ignore_unused = flag;
160 self
161 }
162
163 pub fn touch(
169 mut self,
170 item: impl IntoIterator<Item=(PathBuf, Vec<impl Into<String>>)>,
171 ) -> Self {
172 self.touches.extend(
173 item.into_iter()
174 .map(|s| (s.0, s.1.into_iter().map(|s| s.into()).collect())),
175 );
176 self
177 }
178}
179
180pub enum Output {
181 Workspace(PathBuf),
182 File(PathBuf),
183}
184
185#[derive(serde::Deserialize, serde::Serialize)]
186pub struct IdlService {
187 pub path: PathBuf,
188 pub config: serde_yaml::Value,
189}
190
191impl IdlService {
192 pub fn from_path(p: PathBuf) -> Self {
193 IdlService {
194 path: p,
195 config: Default::default(),
196 }
197 }
198}
199
200impl<MkB, P> Builder<MkB, P>
201 where
202 MkB: MakeBackend + Send,
203 MkB::Target: Send,
204 P: Parser,
205{
206 pub fn compile(
207 self,
208 services: impl IntoIterator<Item=impl AsRef<std::path::Path>>,
209 out: Output,
210 ) {
211 let services = services
212 .into_iter()
213 .map(|path| IdlService {
214 config: serde_yaml::Value::default(),
215 path: path.as_ref().to_owned(),
216 })
217 .collect();
218
219 self.compile_with_config(services, out)
220 }
221
222 pub fn compile_with_config(mut self, services: Vec<IdlService>, out: Output) {
223 let _ = tracing_subscriber::fmt::try_init();
224
225 let mut db = RootDatabase::default();
226 self.parser.inputs(services.iter().map(|s| &s.path));
227 let ParseResult {
228 files,
229 input_files,
230 file_ids_map,
231 } = self.parser.parse();
232 db.set_file_ids_map_with_durability(Arc::new(file_ids_map), Durability::HIGH);
233
234 let ResolveResult { files, nodes, tags } = Resolver::default().resolve_files(&files);
235
236 db.set_files_with_durability(Arc::new(files), Durability::HIGH);
247 let items = nodes.iter().filter_map(|(k, v)| {
248 if let NodeKind::Item(item) = &v.kind {
249 Some((*k, item.clone()))
250 } else {
251 None
252 }
253 });
254
255 let type_graph = Arc::from(TypeGraph::from_items(items));
256 db.set_type_graph_with_durability(type_graph, Durability::HIGH);
257 db.set_nodes_with_durability(Arc::new(nodes), Durability::HIGH);
258 db.set_tags_map_with_durability(Arc::new(tags), Durability::HIGH);
259
260 let mut input = Vec::with_capacity(input_files.len());
261 for file_id in &input_files {
262 let file = db.file(*file_id).unwrap();
263 file.items.iter().for_each(|def_id| {
264 if matches!(&*db.item(*def_id).unwrap(), rir::Item::Service(_)) {
265 input.push(*def_id)
266 }
267 });
268 }
269 db.set_input_files_with_durability(Arc::new(input_files), Durability::HIGH);
270
271 let mut cx = ContextBuilder::new(
272 db,
273 match out {
274 Output::Workspace(dir) => Mode::Workspace(WorkspaceInfo {
275 dir,
276 location_map: Default::default(),
277 }),
278 Output::File(p) => Mode::SingleFile { file_path: p },
279 },
280 input,
281 );
282
283 cx.collect(if self.ignore_unused {
284 CollectMode::OnlyUsed {
285 touches: self.touches,
286 }
287 } else {
288 CollectMode::All
289 });
290
291 let cx = cx.build(Arc::from(services), self.source_type, self.change_case, Arc::from(self.doc_header.unwrap_or_default()));
292
293 cx.exec_plugin(BoxedPlugin);
294
295 cx.exec_plugin(AutoDerivePlugin::new(
296 Arc::from(["#[derive(PartialOrd)]".into()]),
297 |ty| {
298 let ty = match &ty.kind {
299 ty::Vec(ty) => ty,
300 _ => ty,
301 };
302 if matches!(ty.kind, ty::Map(_, _) | ty::Set(_)) {
303 PredicateResult::No
304 } else {
305 PredicateResult::GoOn
306 }
307 },
308 ));
309
310 cx.exec_plugin(AutoDerivePlugin::new(
311 Arc::from(["#[derive(Hash, Eq, Ord)]".into()]),
312 |ty| {
313 let ty = match &ty.kind {
314 ty::Vec(ty) => ty,
315 _ => ty,
316 };
317 if matches!(ty.kind, ty::Map(_, _) | ty::Set(_) | ty::F64 | ty::F32) {
318 PredicateResult::No
319 } else {
320 PredicateResult::GoOn
321 }
322 },
323 ));
324
325 self.plugins.into_iter().for_each(|p| cx.exec_plugin(p));
326
327 std::thread::scope(|scope| {
328 let pool = rayon::ThreadPoolBuilder::new();
329 let pool = pool
330 .spawn_handler(|thread| {
331 let mut builder = std::thread::Builder::new();
332 if let Some(name) = thread.name() {
333 builder = builder.name(name.to_string());
334 }
335 if let Some(size) = thread.stack_size() {
336 builder = builder.stack_size(size);
337 }
338
339 let cx = cx.clone();
340 builder.spawn_scoped(scope, move || {
341 CONTEXT.set(&cx, || thread.run());
342 })?;
343 Ok(())
344 })
345 .build()?;
346
347 pool.install(move || {
348 let cg = Codegen::new(self.mk_backend.make_backend(cx));
349 cg.gen().unwrap();
350 });
351
352 Ok::<_, rayon::ThreadPoolBuildError>(())
353 })
354 .unwrap();
355 }
356}
357
358mod test;