1use proc_macro2::TokenStream;
19use prost_build::{Config, Method, ServiceGenerator};
20use quote::ToTokens;
21use std::path::{Path, PathBuf};
22
23use crate::{client, server, Attributes};
24
25const PACKAGE_HEADER: &str = "// @generated by apache/dubbo-rust.\n\n";
26
27pub fn compile_protos(proto: impl AsRef<Path>) -> std::io::Result<()> {
32 let proto_path: &Path = proto.as_ref();
33
34 let proto_dir = proto_path
36 .parent()
37 .expect("proto file should reside in a directory");
38
39 self::configure().compile(&[proto_path], &[proto_dir])?;
40
41 Ok(())
42}
43
44pub fn configure() -> Builder {
45 Builder {
46 build_client: true,
47 build_server: true,
48 proto_path: "super".to_string(),
49 protoc_args: Vec::new(),
50 compile_well_known_types: false,
51 include_file: None,
52 output_dir: None,
53 server_attributes: Attributes::default(),
54 client_attributes: Attributes::default(),
55 }
56}
57
58pub struct Builder {
59 build_client: bool,
60 build_server: bool,
61 proto_path: String,
62 compile_well_known_types: bool,
63 protoc_args: Vec<String>,
64 include_file: Option<PathBuf>,
65 output_dir: Option<PathBuf>,
66 server_attributes: Attributes,
67 client_attributes: Attributes,
68}
69
70impl Builder {
71 pub fn output_dir(mut self, output_dir: PathBuf) -> Self {
72 self.output_dir = Some(output_dir);
73 self
74 }
75
76 pub fn compile(
77 self,
78 protos: &[impl AsRef<Path>],
79 includes: &[impl AsRef<Path>],
80 ) -> std::io::Result<()> {
81 self.compile_with_config(Config::new(), protos, includes)
82 }
83
84 pub fn compile_with_config(
85 self,
86 mut config: Config,
87 protos: &[impl AsRef<Path>],
88 includes: &[impl AsRef<Path>],
89 ) -> std::io::Result<()> {
90 let out_dir = if let Some(out_dir) = self.output_dir.as_ref() {
91 out_dir.clone()
92 } else {
93 PathBuf::from(std::env::var("OUT_DIR").unwrap())
94 };
95 config.out_dir(out_dir);
96 config.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]");
97 config.message_attribute(".", "#[serde(default)]");
98
99 if self.compile_well_known_types {
100 config.compile_well_known_types();
101 }
102
103 if let Some(path) = self.include_file.as_ref() {
104 config.include_file(path);
105 }
106
107 for arg in self.protoc_args.iter() {
108 config.protoc_arg(arg);
109 }
110
111 config.service_generator(Box::new(SvcGenerator::new(self)));
112 config.compile_protos(protos, includes)?;
113
114 Ok(())
115 }
116}
117
118pub struct SvcGenerator {
119 builder: Builder,
120 clients: TokenStream,
121 servers: TokenStream,
122}
123
124impl SvcGenerator {
125 fn new(builder: Builder) -> Self {
126 SvcGenerator {
127 builder,
128 clients: TokenStream::new(),
129 servers: TokenStream::new(),
130 }
131 }
132}
133
134impl ServiceGenerator for SvcGenerator {
135 fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
136 let svc = DubboService::new(service);
137 if self.builder.build_server {
138 let server = server::generate(
139 &svc,
140 true,
141 &self.builder.proto_path,
142 self.builder.compile_well_known_types,
143 &self.builder.server_attributes,
144 );
145 self.servers.extend(server);
146 }
147
148 if self.builder.build_client {
149 let client = client::generate(
150 &svc,
151 true,
152 &self.builder.proto_path,
153 self.builder.compile_well_known_types,
154 &self.builder.client_attributes,
155 );
156 self.clients.extend(client);
157 }
158 }
159
160 fn finalize(&mut self, buf: &mut String) {
161 if self.builder.build_client && !self.clients.is_empty() {
162 let clients = &self.clients;
163
164 let client_services = quote::quote! {
165 #clients
166 };
167
168 let ast: syn::File = syn::parse2(client_services).expect("invalid tokenstream");
169 let code = prettyplease::unparse(&ast);
170 buf.push_str(&code);
171
172 self.clients = TokenStream::default();
173 }
174
175 if self.builder.build_server && !self.servers.is_empty() {
176 let servers = &self.servers;
177
178 let server_services = quote::quote! {
179 #servers
180 };
181
182 let ast: syn::File = syn::parse2(server_services).expect("invalid tokenstream");
183 let code = prettyplease::unparse(&ast);
184 buf.push_str(&code);
185
186 self.servers = TokenStream::default();
187 }
188 }
189
190 fn finalize_package(&mut self, _package: &str, buf: &mut String) {
191 buf.insert_str(0, PACKAGE_HEADER);
192 }
193}
194
195pub struct DubboService {
196 inner: prost_build::Service,
197}
198
199impl DubboService {
200 fn new(inner: prost_build::Service) -> DubboService {
201 Self { inner }
202 }
203}
204
205impl super::Service for DubboService {
206 type Comment = String;
207
208 type Method = DubboMethod;
209
210 fn name(&self) -> &str {
211 &self.inner.name
212 }
213
214 fn package(&self) -> &str {
215 &self.inner.package
216 }
217
218 fn identifier(&self) -> &str {
219 &self.inner.proto_name
220 }
221
222 fn methods(&self) -> Vec<Self::Method> {
223 let mut ms = Vec::new();
224 for m in &self.inner.methods[..] {
225 ms.push(DubboMethod::new(Method {
226 name: m.name.clone(),
227 proto_name: m.proto_name.clone(),
228 comments: prost_build::Comments {
229 leading_detached: m.comments.leading_detached.clone(),
230 leading: m.comments.leading.clone(),
231 trailing: m.comments.trailing.clone(),
232 },
233 input_type: m.input_type.clone(),
234 output_type: m.output_type.clone(),
235 input_proto_type: m.input_proto_type.clone(),
236 output_proto_type: m.output_proto_type.clone(),
237 options: m.options.clone(),
238 client_streaming: m.client_streaming,
239 server_streaming: m.server_streaming,
240 }))
241 }
242
243 ms
244 }
245
246 fn comment(&self) -> &[Self::Comment] {
247 &self.inner.comments.leading[..]
248 }
249}
250
251impl Clone for DubboService {
252 fn clone(&self) -> Self {
253 Self {
254 inner: prost_build::Service {
255 name: self.inner.name.clone(),
256 proto_name: self.inner.proto_name.clone(),
257 package: self.inner.package.clone(),
258 methods: {
259 let mut ms = Vec::new();
260 for m in &self.inner.methods[..] {
261 ms.push(Method {
262 name: m.name.clone(),
263 proto_name: m.proto_name.clone(),
264 comments: prost_build::Comments {
265 leading_detached: m.comments.leading_detached.clone(),
266 leading: m.comments.leading.clone(),
267 trailing: m.comments.trailing.clone(),
268 },
269 input_type: m.input_type.clone(),
270 output_type: m.output_type.clone(),
271 input_proto_type: m.input_proto_type.clone(),
272 output_proto_type: m.output_proto_type.clone(),
273 options: m.options.clone(),
274 client_streaming: m.client_streaming,
275 server_streaming: m.server_streaming,
276 })
277 }
278
279 ms
280 },
281 comments: prost_build::Comments {
282 leading_detached: self.inner.comments.leading_detached.clone(),
283 leading: self.inner.comments.leading.clone(),
284 trailing: self.inner.comments.trailing.clone(),
285 },
286 options: self.inner.options.clone(),
287 },
288 }
289 }
290}
291
292pub struct DubboMethod {
293 inner: Method,
294}
295
296impl DubboMethod {
297 fn new(m: Method) -> DubboMethod {
298 Self { inner: m }
299 }
300}
301
302impl super::Method for DubboMethod {
303 type Comment = String;
304
305 fn name(&self) -> &str {
306 &self.inner.name
307 }
308
309 fn identifier(&self) -> &str {
310 &self.inner.proto_name
311 }
312
313 fn codec_path(&self) -> &str {
314 "triple::codec::serde_codec::SerdeCodec"
315 }
316
317 fn client_streaming(&self) -> bool {
318 self.inner.client_streaming
319 }
320
321 fn server_streaming(&self) -> bool {
322 self.inner.server_streaming
323 }
324
325 fn comment(&self) -> &[Self::Comment] {
326 &self.inner.comments.leading[..]
327 }
328
329 fn request_response_name(
330 &self,
331 proto_path: &str,
332 compile_well_known_types: bool,
333 ) -> (TokenStream, TokenStream) {
334 let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
335 if (is_google_type(proto_type) && !compile_well_known_types)
336 || rust_type.starts_with("::")
337 || NON_PATH_TYPE_ALLOWLIST.iter().any(|t| *t == rust_type)
338 {
339 rust_type.parse::<TokenStream>().unwrap()
340 } else if rust_type.starts_with("crate::") {
341 syn::parse_str::<syn::Path>(rust_type)
342 .unwrap()
343 .to_token_stream()
344 } else {
345 syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
346 .unwrap()
347 .to_token_stream()
348 }
349 };
350
351 let req = convert_type(&self.inner.input_proto_type, &self.inner.input_type);
352 let resp = convert_type(&self.inner.output_proto_type, &self.inner.output_type);
353
354 (req, resp)
355 }
356}
357
358impl Clone for DubboMethod {
359 fn clone(&self) -> Self {
360 DubboMethod::new(Method {
361 name: self.inner.name.clone(),
362 proto_name: self.inner.proto_name.clone(),
363 comments: prost_build::Comments {
364 leading_detached: self.inner.comments.leading_detached.clone(),
365 leading: self.inner.comments.leading.clone(),
366 trailing: self.inner.comments.trailing.clone(),
367 },
368 input_type: self.inner.input_type.clone(),
369 output_type: self.inner.output_type.clone(),
370 input_proto_type: self.inner.input_proto_type.clone(),
371 output_proto_type: self.inner.output_proto_type.clone(),
372 options: self.inner.options.clone(),
373 client_streaming: self.inner.client_streaming,
374 server_streaming: self.inner.server_streaming,
375 })
376 }
377}
378
379const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
381
382fn is_google_type(proto_type: &str) -> bool {
383 proto_type.starts_with(".google.protobuf")
384}