prpc_build/
lib.rs

1#![recursion_limit = "256"]
2
3use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream};
4use quote::TokenStreamExt;
5mod prost;
6
7pub use crate::prost::{compile_protos, configure, Builder};
8
9use std::io::{self, Write};
10use std::process::Command;
11
12/// Service code generation for client
13pub mod client;
14/// Service code generation for Server
15pub mod server;
16
17mod protos_codec_extension;
18
19/// Service generation trait.
20///
21/// This trait can be implemented and consumed
22/// by `client::generate` and `server::generate`
23/// to allow any codegen module to generate service
24/// abstractions.
25pub trait Service {
26    /// Comment type.
27    type Comment: AsRef<str>;
28
29    /// Method type.
30    type Method: Method;
31
32    /// Name of service.
33    fn name(&self) -> &str;
34    /// Package name of service.
35    fn package(&self) -> &str;
36    /// Identifier used to generate type name.
37    fn identifier(&self) -> &str;
38    /// Methods provided by service.
39    fn methods(&self) -> &[Self::Method];
40    /// Get comments about this item.
41    fn comment(&self) -> &[Self::Comment];
42}
43
44/// Method generation trait.
45///
46/// Each service contains a set of generic
47/// `Methods`'s that will be used by codegen
48/// to generate abstraction implementations for
49/// the provided methods.
50pub trait Method {
51    /// Comment type.
52    type Comment: AsRef<str>;
53
54    /// Name of method.
55    fn name(&self) -> &str;
56    /// Identifier used to generate type name.
57    fn identifier(&self) -> &str;
58    /// Method is streamed by client.
59    fn client_streaming(&self) -> bool;
60    /// Method is streamed by server.
61    fn server_streaming(&self) -> bool;
62    /// Get comments about this item.
63    fn comment(&self) -> &[Self::Comment];
64    /// Type name of request and response.
65    fn request_response_name(
66        &self,
67        proto_path: &str,
68        compile_well_known_types: bool,
69    ) -> (Option<TokenStream>, TokenStream);
70}
71
72/// Attributes that will be added to `mod` and `struct` items.
73#[derive(Debug, Default, Clone)]
74pub struct Attributes {
75    /// `mod` attributes.
76    module: Vec<(String, String)>,
77    /// `struct` attributes.
78    structure: Vec<(String, String)>,
79}
80
81impl Attributes {
82    fn for_mod(&self, name: &str) -> Vec<syn::Attribute> {
83        generate_attributes(name, &self.module)
84    }
85
86    fn for_struct(&self, name: &str) -> Vec<syn::Attribute> {
87        generate_attributes(name, &self.structure)
88    }
89
90    /// Add an attribute that will be added to `mod` items matching the given pattern.
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// # use prpc_build::*;
96    /// let mut attributes = Attributes::default();
97    /// attributes.push_mod("my.proto.package", r#"#[cfg(feature = "server")]"#);
98    /// ```
99    pub fn push_mod(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
100        self.module.push((pattern.into(), attr.into()));
101    }
102
103    /// Add an attribute that will be added to `struct` items matching the given pattern.
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// # use prpc_build::*;
109    /// let mut attributes = Attributes::default();
110    /// attributes.push_struct("EchoService", "#[derive(PartialEq)]");
111    /// ```
112    pub fn push_struct(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
113        self.structure.push((pattern.into(), attr.into()));
114    }
115}
116
117// Generates attributes given a list of (`pattern`, `attribute`) pairs. If `pattern` matches `name`, `attribute` will be included.
118fn generate_attributes<'a>(
119    name: &str,
120    attrs: impl IntoIterator<Item = &'a (String, String)>,
121) -> Vec<syn::Attribute> {
122    attrs
123        .into_iter()
124        .filter(|(matcher, _)| match_name(matcher, name))
125        .flat_map(|(_, attr)| {
126            // attributes cannot be parsed directly, so we pretend they're on a struct
127            syn::parse_str::<syn::DeriveInput>(&format!("{attr}\nstruct fake;"))
128                .unwrap()
129                .attrs
130        })
131        .collect::<Vec<_>>()
132}
133
134/// Format files under the out_dir with rustfmt
135pub fn fmt(out_dir: &str) {
136    let dir = std::fs::read_dir(out_dir).unwrap();
137
138    for entry in dir {
139        let file = entry.unwrap().file_name().into_string().unwrap();
140        if !file.ends_with(".rs") {
141            continue;
142        }
143        let result =
144            Command::new(std::env::var("RUSTFMT").unwrap_or_else(|_| "rustfmt".to_owned()))
145                .arg("--emit")
146                .arg("files")
147                .arg("--edition")
148                .arg("2018")
149                .arg(format!("{out_dir}/{file}"))
150                .output();
151
152        match result {
153            Err(e) => {
154                eprintln!("error running rustfmt: {e:?}");
155                // exit(1)
156            }
157            Ok(output) => {
158                if !output.status.success() {
159                    io::stdout().write_all(&output.stdout).unwrap();
160                    io::stderr().write_all(&output.stderr).unwrap();
161                    //exit(output.status.code().unwrap_or(1))
162                }
163            }
164        }
165    }
166}
167
168// Generate a singular line of a doc comment
169fn generate_doc_comment<S: AsRef<str>>(comment: S) -> TokenStream {
170    let mut doc_stream = TokenStream::new();
171
172    doc_stream.append(Ident::new("doc", Span::call_site()));
173    doc_stream.append(Punct::new('=', Spacing::Alone));
174    doc_stream.append(Literal::string(comment.as_ref()));
175
176    let group = Group::new(Delimiter::Bracket, doc_stream);
177
178    let mut stream = TokenStream::new();
179    stream.append(Punct::new('#', Spacing::Alone));
180    stream.append(group);
181    stream
182}
183
184// Generate a larger doc comment composed of many lines of doc comments
185fn generate_doc_comments<T: AsRef<str>>(comments: &[T]) -> TokenStream {
186    let mut stream = TokenStream::new();
187
188    for comment in comments {
189        stream.extend(generate_doc_comment(comment));
190    }
191
192    stream
193}
194
195// Checks whether a path pattern matches a given path.
196pub(crate) fn match_name(pattern: &str, path: &str) -> bool {
197    if pattern.is_empty() {
198        false
199    } else if pattern == "." || pattern == path {
200        true
201    } else {
202        let pattern_segments = pattern.split('.').collect::<Vec<_>>();
203        let path_segments = path.split('.').collect::<Vec<_>>();
204
205        if &pattern[..1] == "." {
206            // prefix match
207            if pattern_segments.len() > path_segments.len() {
208                false
209            } else {
210                pattern_segments[..] == path_segments[..pattern_segments.len()]
211            }
212        // suffix match
213        } else if pattern_segments.len() > path_segments.len() {
214            false
215        } else {
216            pattern_segments[..] == path_segments[path_segments.len() - pattern_segments.len()..]
217        }
218    }
219}
220
221fn naive_snake_case(name: &str) -> String {
222    let mut s = String::new();
223    let mut it = name.chars().peekable();
224
225    while let Some(x) = it.next() {
226        s.push(x.to_ascii_lowercase());
227        if let Some(y) = it.peek() {
228            if y.is_uppercase() {
229                s.push('_');
230            }
231        }
232    }
233
234    s
235}
236
237fn join_path(config: &Builder, package: &str, service: &str, method: &str) -> String {
238    let mut parts = vec![];
239    if config.emit_package {
240        parts.push(package);
241    }
242
243    if config.emit_service_name || config.keep_service_names.contains(&service.to_string()) {
244        parts.push(service);
245    }
246
247    if !method.is_empty() {
248        parts.push(method);
249    }
250
251    parts.join(".")
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn test_match_name() {
260        assert!(match_name(".", ".my.protos"));
261        assert!(match_name(".", ".protos"));
262
263        assert!(match_name(".my", ".my"));
264        assert!(match_name(".my", ".my.protos"));
265        assert!(match_name(".my.protos.Service", ".my.protos.Service"));
266
267        assert!(match_name("Service", ".my.protos.Service"));
268
269        assert!(!match_name(".m", ".my.protos"));
270        assert!(!match_name(".p", ".protos"));
271
272        assert!(!match_name(".my", ".myy"));
273        assert!(!match_name(".protos", ".my.protos"));
274        assert!(!match_name(".Service", ".my.protos.Service"));
275
276        assert!(!match_name("service", ".my.protos.Service"));
277    }
278
279    #[test]
280    fn test_snake_case() {
281        for case in &[
282            ("Service", "service"),
283            ("ThatHasALongName", "that_has_a_long_name"),
284            ("greeter", "greeter"),
285            ("ABCServiceX", "a_b_c_service_x"),
286        ] {
287            assert_eq!(naive_snake_case(case.0), case.1)
288        }
289    }
290}