grpc_compiler/
codegen.rs

1use std::collections::HashMap;
2
3use protobuf;
4use protobuf::compiler_plugin;
5use protobuf::descriptor::*;
6use protobuf::descriptorx::*;
7
8use std::io::Write;
9
10struct CodeWriter<'a> {
11    writer: &'a mut (dyn Write + 'a),
12    indent: String,
13}
14
15impl<'a> CodeWriter<'a> {
16    pub fn new(writer: &'a mut dyn Write) -> CodeWriter<'a> {
17        CodeWriter {
18            writer,
19            indent: "".to_string(),
20        }
21    }
22
23    pub fn write_line<S: AsRef<str>>(&mut self, line: S) {
24        (if line.as_ref().is_empty() {
25            self.writer.write_all("\n".as_bytes())
26        } else {
27            let s: String = [self.indent.as_ref(), line.as_ref(), "\n"].concat();
28            self.writer.write_all(s.as_bytes())
29        })
30        .unwrap();
31    }
32
33    pub fn write_generated(&mut self) {
34        self.write_line("// This file is generated. Do not edit");
35        self.write_generated_common();
36    }
37
38    fn write_generated_common(&mut self) {
39        // https://secure.phabricator.com/T784
40        self.write_line("// @generated");
41
42        self.write_line("");
43        self.comment("https://github.com/Manishearth/rust-clippy/issues/702");
44        self.write_line("#![allow(unknown_lints)]");
45        self.write_line("#![allow(clippy::all)]");
46        self.write_line("");
47        self.write_line("#![cfg_attr(rustfmt, rustfmt_skip)]");
48        self.write_line("");
49        self.write_line("#![allow(box_pointers)]");
50        self.write_line("#![allow(dead_code)]");
51        self.write_line("#![allow(missing_docs)]");
52        self.write_line("#![allow(non_camel_case_types)]");
53        self.write_line("#![allow(non_snake_case)]");
54        self.write_line("#![allow(non_upper_case_globals)]");
55        self.write_line("#![allow(trivial_casts)]");
56        self.write_line("#![allow(unsafe_code)]");
57        self.write_line("#![allow(unused_imports)]");
58        self.write_line("#![allow(unused_results)]");
59    }
60
61    pub fn indented<F>(&mut self, cb: F)
62    where
63        F: Fn(&mut CodeWriter),
64    {
65        cb(&mut CodeWriter {
66            writer: self.writer,
67            indent: format!("{}    ", self.indent),
68        });
69    }
70
71    #[allow(dead_code)]
72    pub fn commented<F>(&mut self, cb: F)
73    where
74        F: Fn(&mut CodeWriter),
75    {
76        cb(&mut CodeWriter {
77            writer: self.writer,
78            indent: format!("// {}", self.indent),
79        });
80    }
81
82    pub fn block<F>(&mut self, first_line: &str, last_line: &str, cb: F)
83    where
84        F: Fn(&mut CodeWriter),
85    {
86        self.write_line(first_line);
87        self.indented(cb);
88        self.write_line(last_line);
89    }
90
91    pub fn expr_block<F>(&mut self, prefix: &str, cb: F)
92    where
93        F: Fn(&mut CodeWriter),
94    {
95        self.block(&format!("{} {{", prefix), "}", cb);
96    }
97
98    pub fn impl_self_block<S: AsRef<str>, F>(&mut self, name: S, cb: F)
99    where
100        F: Fn(&mut CodeWriter),
101    {
102        self.expr_block(&format!("impl {}", name.as_ref()), cb);
103    }
104
105    pub fn impl_for_block<S1: AsRef<str>, S2: AsRef<str>, F>(&mut self, tr: S1, ty: S2, cb: F)
106    where
107        F: Fn(&mut CodeWriter),
108    {
109        self.expr_block(&format!("impl {} for {}", tr.as_ref(), ty.as_ref()), cb);
110    }
111
112    pub fn pub_struct<S: AsRef<str>, F>(&mut self, name: S, cb: F)
113    where
114        F: Fn(&mut CodeWriter),
115    {
116        self.expr_block(&format!("pub struct {}", name.as_ref()), cb);
117    }
118
119    pub fn pub_trait<F>(&mut self, name: &str, cb: F)
120    where
121        F: Fn(&mut CodeWriter),
122    {
123        self.expr_block(&format!("pub trait {}", name), cb);
124    }
125
126    pub fn field_entry(&mut self, name: &str, value: &str) {
127        self.write_line(&format!("{}: {},", name, value));
128    }
129
130    pub fn field_decl(&mut self, name: &str, field_type: &str) {
131        self.write_line(&format!("{}: {},", name, field_type));
132    }
133
134    pub fn comment(&mut self, comment: &str) {
135        if comment.is_empty() {
136            self.write_line("//");
137        } else {
138            self.write_line(&format!("// {}", comment));
139        }
140    }
141
142    pub fn fn_def(&mut self, sig: &str) {
143        self.write_line(&format!("fn {};", sig));
144    }
145
146    pub fn fn_block<F>(&mut self, public: bool, sig: &str, cb: F)
147    where
148        F: Fn(&mut CodeWriter),
149    {
150        if public {
151            self.expr_block(&format!("pub fn {}", sig), cb);
152        } else {
153            self.expr_block(&format!("fn {}", sig), cb);
154        }
155    }
156
157    pub fn pub_fn<F>(&mut self, sig: &str, cb: F)
158    where
159        F: Fn(&mut CodeWriter),
160    {
161        self.fn_block(true, sig, cb);
162    }
163
164    pub fn def_fn<F>(&mut self, sig: &str, cb: F)
165    where
166        F: Fn(&mut CodeWriter),
167    {
168        self.fn_block(false, sig, cb);
169    }
170}
171
172/// Adjust method name to follow the rust's style.
173fn snake_name(name: &str) -> String {
174    let mut snake_method_name = String::with_capacity(name.len());
175    let mut chars = name.chars();
176    // initial char can be any char except '_'.
177    let mut last_char = '.';
178    'outer: while let Some(c) = chars.next() {
179        // Please note that '_' is neither uppercase nor lowercase.
180        if !c.is_uppercase() {
181            last_char = c;
182            snake_method_name.push(c);
183            continue;
184        }
185        let mut can_append_underscore = false;
186        // check if it's the first char.
187        if !snake_method_name.is_empty() && last_char != '_' {
188            snake_method_name.push('_');
189        }
190        last_char = c;
191        // find all continous upper case char and append an underscore before
192        // last upper case char if neccessary.
193        while let Some(c) = chars.next() {
194            if !c.is_uppercase() {
195                if can_append_underscore && c != '_' {
196                    snake_method_name.push('_');
197                }
198                snake_method_name.extend(last_char.to_lowercase());
199                snake_method_name.push(c);
200                last_char = c;
201                continue 'outer;
202            }
203            snake_method_name.extend(last_char.to_lowercase());
204            last_char = c;
205            can_append_underscore = true;
206        }
207        snake_method_name.extend(last_char.to_lowercase());
208    }
209    snake_method_name
210}
211
212struct MethodGen<'a> {
213    proto: &'a MethodDescriptorProto,
214    service_path: String,
215    root_scope: &'a RootScope<'a>,
216}
217
218impl<'a> MethodGen<'a> {
219    fn new(
220        proto: &'a MethodDescriptorProto,
221        service_path: String,
222        root_scope: &'a RootScope<'a>,
223    ) -> MethodGen<'a> {
224        MethodGen {
225            proto: proto,
226            service_path: service_path,
227            root_scope: root_scope,
228        }
229    }
230
231    fn snake_name(&self) -> String {
232        snake_name(self.proto.get_name())
233    }
234
235    fn input_message(&self) -> String {
236        format!(
237            "super::{}",
238            self.root_scope
239                .find_message(self.proto.get_input_type())
240                .rust_fq_name()
241        )
242    }
243
244    fn output_message(&self) -> String {
245        format!(
246            "super::{}",
247            self.root_scope
248                .find_message(self.proto.get_output_type())
249                .rust_fq_name()
250        )
251    }
252
253    fn client_resp_type(&self) -> String {
254        match self.proto.get_server_streaming() {
255            false => format!("::grpc::SingleResponse<{}>", self.output_message()),
256            true => format!("::grpc::StreamingResponse<{}>", self.output_message()),
257        }
258    }
259
260    fn client_sig(&self) -> String {
261        let req_param = match self.proto.get_client_streaming() {
262            false => format!(", req: {}", self.input_message()),
263            true => format!(""),
264        };
265        let resp_type = self.client_resp_type();
266        let return_type = match self.proto.get_client_streaming() {
267            false => resp_type,
268            true => format!(
269                "impl ::std::future::Future<Output=::grpc::Result<(::grpc::ClientRequestSink<{}>, {})>>",
270                self.input_message(),
271                resp_type
272            ),
273        };
274        format!(
275            "{}(&self, o: ::grpc::RequestOptions{}) -> {}",
276            self.snake_name(),
277            req_param,
278            return_type,
279        )
280    }
281
282    fn server_req_type(&self) -> String {
283        match self.proto.get_client_streaming() {
284            false => format!("::grpc::ServerRequestSingle<{}>", self.input_message()),
285            true => format!("::grpc::ServerRequest<{}>", self.input_message()),
286        }
287    }
288
289    fn server_resp_type(&self) -> String {
290        match self.proto.get_server_streaming() {
291            false => format!("::grpc::ServerResponseUnarySink<{}>", self.output_message()),
292            true => format!("::grpc::ServerResponseSink<{}>", self.output_message()),
293        }
294    }
295
296    fn server_sig(&self) -> String {
297        format!(
298            "{}(&self, o: ::grpc::ServerHandlerContext, req: {}, resp: {}) -> ::grpc::Result<()>",
299            self.snake_name(),
300            self.server_req_type(),
301            self.server_resp_type(),
302        )
303    }
304
305    fn write_server_intf(&self, w: &mut CodeWriter) {
306        w.fn_def(&self.server_sig())
307    }
308
309    fn streaming_upper(&self) -> &'static str {
310        match (
311            self.proto.get_client_streaming(),
312            self.proto.get_server_streaming(),
313        ) {
314            (false, false) => "Unary",
315            (false, true) => "ServerStreaming",
316            (true, false) => "ClientStreaming",
317            (true, true) => "Bidi",
318        }
319    }
320
321    fn streaming_lower(&self) -> &'static str {
322        match (
323            self.proto.get_client_streaming(),
324            self.proto.get_server_streaming(),
325        ) {
326            (false, false) => "unary",
327            (false, true) => "server_streaming",
328            (true, false) => "client_streaming",
329            (true, true) => "bidi",
330        }
331    }
332
333    fn write_client(&self, w: &mut CodeWriter) {
334        w.pub_fn(&self.client_sig(), |w| {
335            self.write_descriptor(
336                w,
337                "let descriptor = ::grpc::rt::ArcOrStatic::Static(&",
338                ");",
339            );
340
341            let req = match self.proto.get_client_streaming() {
342                false => ", req",
343                true => "",
344            };
345            w.write_line(&format!(
346                "self.grpc_client.call_{}(o{}, descriptor)",
347                self.streaming_lower(),
348                req,
349            ))
350        });
351    }
352
353    fn write_descriptor(&self, w: &mut CodeWriter, before: &str, after: &str) {
354        w.block(
355            &format!("{}{}", before, "::grpc::rt::MethodDescriptor {"),
356            &format!("{}{}", "}", after),
357            |w| {
358                w.field_entry(
359                    "name",
360                    &format!(
361                        "::grpc::rt::StringOrStatic::Static(\"{}/{}\")",
362                        self.service_path,
363                        self.proto.get_name()
364                    ),
365                );
366                w.field_entry(
367                    "streaming",
368                    &format!("::grpc::rt::GrpcStreaming::{}", self.streaming_upper()),
369                );
370                w.field_entry(
371                    "req_marshaller",
372                    "::grpc::rt::ArcOrStatic::Static(&::grpc_protobuf::MarshallerProtobuf)",
373                );
374                w.field_entry(
375                    "resp_marshaller",
376                    "::grpc::rt::ArcOrStatic::Static(&::grpc_protobuf::MarshallerProtobuf)",
377                );
378            },
379        );
380    }
381}
382
383struct ServiceGen<'a> {
384    proto: &'a ServiceDescriptorProto,
385    _root_scope: &'a RootScope<'a>,
386    methods: Vec<MethodGen<'a>>,
387    service_path: String,
388    _package: String,
389}
390
391impl<'a> ServiceGen<'a> {
392    fn new(
393        proto: &'a ServiceDescriptorProto,
394        file: &FileDescriptorProto,
395        root_scope: &'a RootScope,
396    ) -> ServiceGen<'a> {
397        let service_path = if file.get_package().is_empty() {
398            format!("/{}", proto.get_name())
399        } else {
400            format!("/{}.{}", file.get_package(), proto.get_name())
401        };
402        let methods = proto
403            .get_method()
404            .into_iter()
405            .map(|m| MethodGen::new(m, service_path.clone(), root_scope))
406            .collect();
407
408        ServiceGen {
409            proto,
410            _root_scope: root_scope,
411            methods,
412            service_path,
413            _package: file.get_package().to_string(),
414        }
415    }
416
417    // trait name
418    fn server_intf_name(&self) -> &str {
419        self.proto.get_name()
420    }
421
422    // client struct name
423    fn client_name(&self) -> String {
424        format!("{}Client", self.proto.get_name())
425    }
426
427    // server struct name
428    fn server_name(&self) -> String {
429        format!("{}Server", self.proto.get_name())
430    }
431
432    fn write_server_intf(&self, w: &mut CodeWriter) {
433        w.pub_trait(&self.server_intf_name(), |w| {
434            for (i, method) in self.methods.iter().enumerate() {
435                if i != 0 {
436                    w.write_line("");
437                }
438
439                method.write_server_intf(w);
440            }
441        });
442    }
443
444    fn write_client_object(&self, grpc_client: &str, w: &mut CodeWriter) {
445        w.expr_block(&self.client_name(), |w| {
446            w.field_entry("grpc_client", grpc_client);
447        });
448    }
449
450    fn write_client(&self, w: &mut CodeWriter) {
451        w.pub_struct(&self.client_name(), |w| {
452            w.field_decl("grpc_client", "::std::sync::Arc<::grpc::Client>");
453        });
454
455        w.write_line("");
456
457        w.impl_for_block("::grpc::ClientStub", &self.client_name(), |w| {
458            let sig = "with_client(grpc_client: ::std::sync::Arc<::grpc::Client>) -> Self";
459            w.def_fn(sig, |w| {
460                self.write_client_object("grpc_client", w);
461            });
462        });
463
464        w.write_line("");
465
466        w.impl_self_block(&self.client_name(), |w| {
467            for (i, method) in self.methods.iter().enumerate() {
468                if i != 0 {
469                    w.write_line("");
470                }
471
472                method.write_client(w);
473            }
474        });
475    }
476
477    fn write_service_definition(
478        &self,
479        before: &str,
480        after: &str,
481        handler: &str,
482        w: &mut CodeWriter,
483    ) {
484        w.block(
485            &format!("{}::grpc::rt::ServerServiceDefinition::new(\"{}\",",
486                before, self.service_path),
487            &format!("){}", after),
488            |w| {
489                w.block("vec![", "],", |w| {
490                    for method in &self.methods {
491                        w.block("::grpc::rt::ServerMethod::new(", "),", |w| {
492                            method.write_descriptor(w, "::grpc::rt::ArcOrStatic::Static(&", "),");
493                            w.block("{", "},", |w| {
494                                w.write_line(&format!("let handler_copy = {}.clone();", handler));
495                                w.write_line(&format!("::grpc::rt::MethodHandler{}::new(move |ctx, req, resp| (*handler_copy).{}(ctx, req, resp))",
496                                    method.streaming_upper(),
497                                    method.snake_name()));
498                            });
499                        });
500                    }
501                });
502            });
503    }
504
505    fn write_server(&self, w: &mut CodeWriter) {
506        w.write_line(&format!("pub struct {};", self.server_name()));
507
508        w.write_line("");
509
510        w.write_line("");
511
512        w.impl_self_block(&self.server_name(), |w| {
513            w.pub_fn(&format!("new_service_def<H : {} + 'static + Sync + Send + 'static>(handler: H) -> ::grpc::rt::ServerServiceDefinition", self.server_intf_name()), |w| {
514                w.write_line("let handler_arc = ::std::sync::Arc::new(handler);");
515
516                self.write_service_definition("", "", "handler_arc", w);
517            });
518        });
519    }
520
521    fn write(&self, w: &mut CodeWriter) {
522        w.comment("server interface");
523        w.write_line("");
524        self.write_server_intf(w);
525        w.write_line("");
526        w.comment("client");
527        w.write_line("");
528        self.write_client(w);
529        w.write_line("");
530        w.comment("server");
531        w.write_line("");
532        self.write_server(w);
533    }
534}
535
536fn gen_file(
537    file: &FileDescriptorProto,
538    root_scope: &RootScope,
539) -> Option<compiler_plugin::GenResult> {
540    if file.get_service().is_empty() {
541        return None;
542    }
543
544    let base = protobuf::descriptorx::proto_path_to_rust_mod(file.get_name());
545
546    let mut v = Vec::new();
547    {
548        let mut w = CodeWriter::new(&mut v);
549        w.write_generated();
550        w.write_line("");
551
552        for service in file.get_service() {
553            w.write_line("");
554            ServiceGen::new(service, file, root_scope).write(&mut w);
555        }
556    }
557
558    Some(compiler_plugin::GenResult {
559        name: base + "_grpc.rs",
560        content: v,
561    })
562}
563
564pub fn gen(
565    file_descriptors: &[FileDescriptorProto],
566    files_to_generate: &[String],
567) -> Vec<compiler_plugin::GenResult> {
568    let files_map: HashMap<&str, &FileDescriptorProto> =
569        file_descriptors.iter().map(|f| (f.get_name(), f)).collect();
570
571    let root_scope = RootScope {
572        file_descriptors: file_descriptors,
573    };
574
575    let mut results = Vec::new();
576
577    for file_name in files_to_generate {
578        let file = files_map[&file_name[..]];
579
580        if file.get_service().is_empty() {
581            continue;
582        }
583
584        results.extend(gen_file(file, &root_scope).into_iter());
585    }
586
587    results
588}
589
590pub fn protoc_gen_grpc_rust_main() {
591    compiler_plugin::plugin_main(gen);
592}
593
594#[cfg(test)]
595mod test {
596    #[test]
597    fn test_snake_name() {
598        let cases = vec![
599            ("AsyncRequest", "async_request"),
600            ("asyncRequest", "async_request"),
601            ("async_request", "async_request"),
602            ("createID", "create_id"),
603            ("CreateIDForReq", "create_id_for_req"),
604            ("Create_ID_For_Req", "create_id_for_req"),
605            ("ID", "id"),
606            ("id", "id"),
607        ];
608
609        for (origin, exp) in cases {
610            let res = super::snake_name(&origin);
611            assert_eq!(res, exp);
612        }
613    }
614}