1#![recursion_limit = "256"]
2
3use proc_macro2::{Delimiter, Group, Ident, Literal, Punct, Spacing, Span, TokenStream};
4use quote::TokenStreamExt;
5mod prost;
6
7pub use crate::prost::{compile_protos, configure, Builder};
8
9use std::io::{self, Write};
10use std::process::Command;
11
12pub mod client;
14pub mod server;
16
17mod protos_codec_extension;
18
19pub trait Service {
26 type Comment: AsRef<str>;
28
29 type Method: Method;
31
32 fn name(&self) -> &str;
34 fn package(&self) -> &str;
36 fn identifier(&self) -> &str;
38 fn methods(&self) -> &[Self::Method];
40 fn comment(&self) -> &[Self::Comment];
42}
43
44pub trait Method {
51 type Comment: AsRef<str>;
53
54 fn name(&self) -> &str;
56 fn identifier(&self) -> &str;
58 fn client_streaming(&self) -> bool;
60 fn server_streaming(&self) -> bool;
62 fn comment(&self) -> &[Self::Comment];
64 fn request_response_name(
66 &self,
67 proto_path: &str,
68 compile_well_known_types: bool,
69 ) -> (Option<TokenStream>, TokenStream);
70}
71
72#[derive(Debug, Default, Clone)]
74pub struct Attributes {
75 module: Vec<(String, String)>,
77 structure: Vec<(String, String)>,
79}
80
81impl Attributes {
82 fn for_mod(&self, name: &str) -> Vec<syn::Attribute> {
83 generate_attributes(name, &self.module)
84 }
85
86 fn for_struct(&self, name: &str) -> Vec<syn::Attribute> {
87 generate_attributes(name, &self.structure)
88 }
89
90 pub fn push_mod(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
100 self.module.push((pattern.into(), attr.into()));
101 }
102
103 pub fn push_struct(&mut self, pattern: impl Into<String>, attr: impl Into<String>) {
113 self.structure.push((pattern.into(), attr.into()));
114 }
115}
116
117fn generate_attributes<'a>(
119 name: &str,
120 attrs: impl IntoIterator<Item = &'a (String, String)>,
121) -> Vec<syn::Attribute> {
122 attrs
123 .into_iter()
124 .filter(|(matcher, _)| match_name(matcher, name))
125 .flat_map(|(_, attr)| {
126 syn::parse_str::<syn::DeriveInput>(&format!("{attr}\nstruct fake;"))
128 .unwrap()
129 .attrs
130 })
131 .collect::<Vec<_>>()
132}
133
134pub fn fmt(out_dir: &str) {
136 let dir = std::fs::read_dir(out_dir).unwrap();
137
138 for entry in dir {
139 let file = entry.unwrap().file_name().into_string().unwrap();
140 if !file.ends_with(".rs") {
141 continue;
142 }
143 let result =
144 Command::new(std::env::var("RUSTFMT").unwrap_or_else(|_| "rustfmt".to_owned()))
145 .arg("--emit")
146 .arg("files")
147 .arg("--edition")
148 .arg("2018")
149 .arg(format!("{out_dir}/{file}"))
150 .output();
151
152 match result {
153 Err(e) => {
154 eprintln!("error running rustfmt: {e:?}");
155 }
157 Ok(output) => {
158 if !output.status.success() {
159 io::stdout().write_all(&output.stdout).unwrap();
160 io::stderr().write_all(&output.stderr).unwrap();
161 }
163 }
164 }
165 }
166}
167
168fn generate_doc_comment<S: AsRef<str>>(comment: S) -> TokenStream {
170 let mut doc_stream = TokenStream::new();
171
172 doc_stream.append(Ident::new("doc", Span::call_site()));
173 doc_stream.append(Punct::new('=', Spacing::Alone));
174 doc_stream.append(Literal::string(comment.as_ref()));
175
176 let group = Group::new(Delimiter::Bracket, doc_stream);
177
178 let mut stream = TokenStream::new();
179 stream.append(Punct::new('#', Spacing::Alone));
180 stream.append(group);
181 stream
182}
183
184fn generate_doc_comments<T: AsRef<str>>(comments: &[T]) -> TokenStream {
186 let mut stream = TokenStream::new();
187
188 for comment in comments {
189 stream.extend(generate_doc_comment(comment));
190 }
191
192 stream
193}
194
195pub(crate) fn match_name(pattern: &str, path: &str) -> bool {
197 if pattern.is_empty() {
198 false
199 } else if pattern == "." || pattern == path {
200 true
201 } else {
202 let pattern_segments = pattern.split('.').collect::<Vec<_>>();
203 let path_segments = path.split('.').collect::<Vec<_>>();
204
205 if &pattern[..1] == "." {
206 if pattern_segments.len() > path_segments.len() {
208 false
209 } else {
210 pattern_segments[..] == path_segments[..pattern_segments.len()]
211 }
212 } else if pattern_segments.len() > path_segments.len() {
214 false
215 } else {
216 pattern_segments[..] == path_segments[path_segments.len() - pattern_segments.len()..]
217 }
218 }
219}
220
221fn naive_snake_case(name: &str) -> String {
222 let mut s = String::new();
223 let mut it = name.chars().peekable();
224
225 while let Some(x) = it.next() {
226 s.push(x.to_ascii_lowercase());
227 if let Some(y) = it.peek() {
228 if y.is_uppercase() {
229 s.push('_');
230 }
231 }
232 }
233
234 s
235}
236
237fn join_path(config: &Builder, package: &str, service: &str, method: &str) -> String {
238 let mut parts = vec![];
239 if config.emit_package {
240 parts.push(package);
241 }
242
243 if config.emit_service_name || config.keep_service_names.contains(&service.to_string()) {
244 parts.push(service);
245 }
246
247 if !method.is_empty() {
248 parts.push(method);
249 }
250
251 parts.join(".")
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn test_match_name() {
260 assert!(match_name(".", ".my.protos"));
261 assert!(match_name(".", ".protos"));
262
263 assert!(match_name(".my", ".my"));
264 assert!(match_name(".my", ".my.protos"));
265 assert!(match_name(".my.protos.Service", ".my.protos.Service"));
266
267 assert!(match_name("Service", ".my.protos.Service"));
268
269 assert!(!match_name(".m", ".my.protos"));
270 assert!(!match_name(".p", ".protos"));
271
272 assert!(!match_name(".my", ".myy"));
273 assert!(!match_name(".protos", ".my.protos"));
274 assert!(!match_name(".Service", ".my.protos.Service"));
275
276 assert!(!match_name("service", ".my.protos.Service"));
277 }
278
279 #[test]
280 fn test_snake_case() {
281 for case in &[
282 ("Service", "service"),
283 ("ThatHasALongName", "that_has_a_long_name"),
284 ("greeter", "greeter"),
285 ("ABCServiceX", "a_b_c_service_x"),
286 ] {
287 assert_eq!(naive_snake_case(case.0), case.1)
288 }
289 }
290}