dubbo_build/
prost.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use proc_macro2::TokenStream;
19use prost_build::{Config, Method, ServiceGenerator};
20use quote::ToTokens;
21use std::path::{Path, PathBuf};
22
23use crate::{client, server, Attributes};
24
25const PACKAGE_HEADER: &str = "// @generated by apache/dubbo-rust.\n\n";
26
27/// Simple `.proto` compiling. Use [`configure`] instead if you need more options.
28///
29/// The include directory will be the parent folder of the specified path.
30/// The package name will be the filename without the extension.
31pub fn compile_protos(proto: impl AsRef<Path>) -> std::io::Result<()> {
32    let proto_path: &Path = proto.as_ref();
33
34    // directory the main .proto file resides in
35    let proto_dir = proto_path
36        .parent()
37        .expect("proto file should reside in a directory");
38
39    self::configure().compile(&[proto_path], &[proto_dir])?;
40
41    Ok(())
42}
43
44pub fn configure() -> Builder {
45    Builder {
46        build_client: true,
47        build_server: true,
48        proto_path: "super".to_string(),
49        protoc_args: Vec::new(),
50        compile_well_known_types: false,
51        include_file: None,
52        output_dir: None,
53        server_attributes: Attributes::default(),
54        client_attributes: Attributes::default(),
55    }
56}
57
58pub struct Builder {
59    build_client: bool,
60    build_server: bool,
61    proto_path: String,
62    compile_well_known_types: bool,
63    protoc_args: Vec<String>,
64    include_file: Option<PathBuf>,
65    output_dir: Option<PathBuf>,
66    server_attributes: Attributes,
67    client_attributes: Attributes,
68}
69
70impl Builder {
71    pub fn output_dir(mut self, output_dir: PathBuf) -> Self {
72        self.output_dir = Some(output_dir);
73        self
74    }
75
76    pub fn compile(
77        self,
78        protos: &[impl AsRef<Path>],
79        includes: &[impl AsRef<Path>],
80    ) -> std::io::Result<()> {
81        self.compile_with_config(Config::new(), protos, includes)
82    }
83
84    pub fn compile_with_config(
85        self,
86        mut config: Config,
87        protos: &[impl AsRef<Path>],
88        includes: &[impl AsRef<Path>],
89    ) -> std::io::Result<()> {
90        let out_dir = if let Some(out_dir) = self.output_dir.as_ref() {
91            out_dir.clone()
92        } else {
93            PathBuf::from(std::env::var("OUT_DIR").unwrap())
94        };
95        config.out_dir(out_dir);
96        config.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]");
97        config.message_attribute(".", "#[serde(default)]");
98
99        if self.compile_well_known_types {
100            config.compile_well_known_types();
101        }
102
103        if let Some(path) = self.include_file.as_ref() {
104            config.include_file(path);
105        }
106
107        for arg in self.protoc_args.iter() {
108            config.protoc_arg(arg);
109        }
110
111        config.service_generator(Box::new(SvcGenerator::new(self)));
112        config.compile_protos(protos, includes)?;
113
114        Ok(())
115    }
116}
117
118pub struct SvcGenerator {
119    builder: Builder,
120    clients: TokenStream,
121    servers: TokenStream,
122}
123
124impl SvcGenerator {
125    fn new(builder: Builder) -> Self {
126        SvcGenerator {
127            builder,
128            clients: TokenStream::new(),
129            servers: TokenStream::new(),
130        }
131    }
132}
133
134impl ServiceGenerator for SvcGenerator {
135    fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
136        let svc = DubboService::new(service);
137        if self.builder.build_server {
138            let server = server::generate(
139                &svc,
140                true,
141                &self.builder.proto_path,
142                self.builder.compile_well_known_types,
143                &self.builder.server_attributes,
144            );
145            self.servers.extend(server);
146        }
147
148        if self.builder.build_client {
149            let client = client::generate(
150                &svc,
151                true,
152                &self.builder.proto_path,
153                self.builder.compile_well_known_types,
154                &self.builder.client_attributes,
155            );
156            self.clients.extend(client);
157        }
158    }
159
160    fn finalize(&mut self, buf: &mut String) {
161        if self.builder.build_client && !self.clients.is_empty() {
162            let clients = &self.clients;
163
164            let client_services = quote::quote! {
165                #clients
166            };
167
168            let ast: syn::File = syn::parse2(client_services).expect("invalid tokenstream");
169            let code = prettyplease::unparse(&ast);
170            buf.push_str(&code);
171
172            self.clients = TokenStream::default();
173        }
174
175        if self.builder.build_server && !self.servers.is_empty() {
176            let servers = &self.servers;
177
178            let server_services = quote::quote! {
179                #servers
180            };
181
182            let ast: syn::File = syn::parse2(server_services).expect("invalid tokenstream");
183            let code = prettyplease::unparse(&ast);
184            buf.push_str(&code);
185
186            self.servers = TokenStream::default();
187        }
188    }
189
190    fn finalize_package(&mut self, _package: &str, buf: &mut String) {
191        buf.insert_str(0, PACKAGE_HEADER);
192    }
193}
194
195pub struct DubboService {
196    inner: prost_build::Service,
197}
198
199impl DubboService {
200    fn new(inner: prost_build::Service) -> DubboService {
201        Self { inner }
202    }
203}
204
205impl super::Service for DubboService {
206    type Comment = String;
207
208    type Method = DubboMethod;
209
210    fn name(&self) -> &str {
211        &self.inner.name
212    }
213
214    fn package(&self) -> &str {
215        &self.inner.package
216    }
217
218    fn identifier(&self) -> &str {
219        &self.inner.proto_name
220    }
221
222    fn methods(&self) -> Vec<Self::Method> {
223        let mut ms = Vec::new();
224        for m in &self.inner.methods[..] {
225            ms.push(DubboMethod::new(Method {
226                name: m.name.clone(),
227                proto_name: m.proto_name.clone(),
228                comments: prost_build::Comments {
229                    leading_detached: m.comments.leading_detached.clone(),
230                    leading: m.comments.leading.clone(),
231                    trailing: m.comments.trailing.clone(),
232                },
233                input_type: m.input_type.clone(),
234                output_type: m.output_type.clone(),
235                input_proto_type: m.input_proto_type.clone(),
236                output_proto_type: m.output_proto_type.clone(),
237                options: m.options.clone(),
238                client_streaming: m.client_streaming,
239                server_streaming: m.server_streaming,
240            }))
241        }
242
243        ms
244    }
245
246    fn comment(&self) -> &[Self::Comment] {
247        &self.inner.comments.leading[..]
248    }
249}
250
251impl Clone for DubboService {
252    fn clone(&self) -> Self {
253        Self {
254            inner: prost_build::Service {
255                name: self.inner.name.clone(),
256                proto_name: self.inner.proto_name.clone(),
257                package: self.inner.package.clone(),
258                methods: {
259                    let mut ms = Vec::new();
260                    for m in &self.inner.methods[..] {
261                        ms.push(Method {
262                            name: m.name.clone(),
263                            proto_name: m.proto_name.clone(),
264                            comments: prost_build::Comments {
265                                leading_detached: m.comments.leading_detached.clone(),
266                                leading: m.comments.leading.clone(),
267                                trailing: m.comments.trailing.clone(),
268                            },
269                            input_type: m.input_type.clone(),
270                            output_type: m.output_type.clone(),
271                            input_proto_type: m.input_proto_type.clone(),
272                            output_proto_type: m.output_proto_type.clone(),
273                            options: m.options.clone(),
274                            client_streaming: m.client_streaming,
275                            server_streaming: m.server_streaming,
276                        })
277                    }
278
279                    ms
280                },
281                comments: prost_build::Comments {
282                    leading_detached: self.inner.comments.leading_detached.clone(),
283                    leading: self.inner.comments.leading.clone(),
284                    trailing: self.inner.comments.trailing.clone(),
285                },
286                options: self.inner.options.clone(),
287            },
288        }
289    }
290}
291
292pub struct DubboMethod {
293    inner: Method,
294}
295
296impl DubboMethod {
297    fn new(m: Method) -> DubboMethod {
298        Self { inner: m }
299    }
300}
301
302impl super::Method for DubboMethod {
303    type Comment = String;
304
305    fn name(&self) -> &str {
306        &self.inner.name
307    }
308
309    fn identifier(&self) -> &str {
310        &self.inner.proto_name
311    }
312
313    fn codec_path(&self) -> &str {
314        "triple::codec::serde_codec::SerdeCodec"
315    }
316
317    fn client_streaming(&self) -> bool {
318        self.inner.client_streaming
319    }
320
321    fn server_streaming(&self) -> bool {
322        self.inner.server_streaming
323    }
324
325    fn comment(&self) -> &[Self::Comment] {
326        &self.inner.comments.leading[..]
327    }
328
329    fn request_response_name(
330        &self,
331        proto_path: &str,
332        compile_well_known_types: bool,
333    ) -> (TokenStream, TokenStream) {
334        let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
335            if (is_google_type(proto_type) && !compile_well_known_types)
336                || rust_type.starts_with("::")
337                || NON_PATH_TYPE_ALLOWLIST.iter().any(|t| *t == rust_type)
338            {
339                rust_type.parse::<TokenStream>().unwrap()
340            } else if rust_type.starts_with("crate::") {
341                syn::parse_str::<syn::Path>(rust_type)
342                    .unwrap()
343                    .to_token_stream()
344            } else {
345                syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
346                    .unwrap()
347                    .to_token_stream()
348            }
349        };
350
351        let req = convert_type(&self.inner.input_proto_type, &self.inner.input_type);
352        let resp = convert_type(&self.inner.output_proto_type, &self.inner.output_type);
353
354        (req, resp)
355    }
356}
357
358impl Clone for DubboMethod {
359    fn clone(&self) -> Self {
360        DubboMethod::new(Method {
361            name: self.inner.name.clone(),
362            proto_name: self.inner.proto_name.clone(),
363            comments: prost_build::Comments {
364                leading_detached: self.inner.comments.leading_detached.clone(),
365                leading: self.inner.comments.leading.clone(),
366                trailing: self.inner.comments.trailing.clone(),
367            },
368            input_type: self.inner.input_type.clone(),
369            output_type: self.inner.output_type.clone(),
370            input_proto_type: self.inner.input_proto_type.clone(),
371            output_proto_type: self.inner.output_proto_type.clone(),
372            options: self.inner.options.clone(),
373            client_streaming: self.inner.client_streaming,
374            server_streaming: self.inner.server_streaming,
375        })
376    }
377}
378
379/// Non-path Rust types allowed for request/response types.
380const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
381
382fn is_google_type(proto_type: &str) -> bool {
383    proto_type.starts_with(".google.protobuf")
384}