ttrpc_compiler/
codegen.rs

1// Copyright (c) 2019 Ant Financial
2//
3// Copyright 2017 PingCAP, Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Copyright (c) 2016, Stepan Koltsov
17//
18// Permission is hereby granted, free of charge, to any person obtaining
19// a copy of this software and associated documentation files (the
20// "Software"), to deal in the Software without restriction, including
21// without limitation the rights to use, copy, modify, merge, publish,
22// distribute, sublicense, and/or sell copies of the Software, and to
23// permit persons to whom the Software is furnished to do so, subject to
24// the following conditions:
25//
26// The above copyright notice and this permission notice shall be
27// included in all copies or substantial portions of the Software.
28//
29// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
30// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
31// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
32// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
33// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
34// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
35// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
36
37#![allow(dead_code)]
38
39use std::{
40    collections::{HashMap, HashSet},
41    fs,
42    io::BufRead,
43};
44
45use crate::{
46    util::proto_path_to_rust_mod, util::scope::RootScope, util::writer::CodeWriter, Customize,
47};
48use protobuf::{
49    descriptor::*,
50    plugin::{
51        code_generator_response::Feature as CodeGeneratorResponse_Feature,
52        code_generator_response::File as CodeGeneratorResponse_File, CodeGeneratorRequest,
53        CodeGeneratorResponse,
54    },
55    Message,
56};
57use std::fs::File;
58use std::io::{self, stdin, stdout, Write};
59use std::path::Path;
60
61use super::util::{
62    self, async_on, def_async_fn, fq_grpc, pub_async_fn, to_camel_case, to_snake_case, MethodType,
63};
64
65struct MethodGen<'a> {
66    proto: &'a MethodDescriptorProto,
67    package_name: String,
68    service_name: String,
69    root_scope: &'a RootScope<'a>,
70    customize: &'a Customize,
71}
72
73impl<'a> MethodGen<'a> {
74    fn new(
75        proto: &'a MethodDescriptorProto,
76        package_name: String,
77        service_name: String,
78        root_scope: &'a RootScope<'a>,
79        customize: &'a Customize,
80    ) -> MethodGen<'a> {
81        MethodGen {
82            proto,
83            package_name,
84            service_name,
85            root_scope,
86            customize,
87        }
88    }
89
90    fn input(&self) -> String {
91        format!(
92            "super::{}",
93            self.root_scope
94                .find_message(self.proto.input_type())
95                .rust_fq_name()
96        )
97    }
98
99    fn output(&self) -> String {
100        format!(
101            "super::{}",
102            self.root_scope
103                .find_message(self.proto.output_type())
104                .rust_fq_name()
105        )
106    }
107
108    fn method_type(&self) -> (MethodType, String) {
109        match (self.proto.client_streaming(), self.proto.server_streaming()) {
110            (false, false) => (MethodType::Unary, fq_grpc("MethodType::Unary")),
111            (true, false) => (
112                MethodType::ClientStreaming,
113                fq_grpc("MethodType::ClientStreaming"),
114            ),
115            (false, true) => (
116                MethodType::ServerStreaming,
117                fq_grpc("MethodType::ServerStreaming"),
118            ),
119            (true, true) => (MethodType::Duplex, fq_grpc("MethodType::Duplex")),
120        }
121    }
122
123    fn service_name(&self) -> String {
124        to_snake_case(&self.service_name)
125    }
126
127    fn name(&self) -> String {
128        to_snake_case(self.proto.name())
129    }
130
131    fn struct_name(&self) -> String {
132        to_camel_case(self.proto.name())
133    }
134
135    fn const_method_name(&self) -> String {
136        format!(
137            "METHOD_{}_{}",
138            self.service_name().to_uppercase(),
139            self.name().to_uppercase()
140        )
141    }
142
143    fn write_handler(&self, w: &mut CodeWriter) {
144        w.block(
145            format!("struct {}Method {{", self.struct_name()),
146            "}",
147            |w| {
148                w.write_line(format!(
149                    "service: Arc<dyn {} + Send + Sync>,",
150                    self.service_name
151                ));
152            },
153        );
154        w.write_line("");
155        if async_on(self.customize, "server") {
156            self.write_handler_impl_async(w)
157        } else {
158            self.write_handler_impl(w)
159        }
160    }
161
162    fn write_handler_impl(&self, w: &mut CodeWriter) {
163        w.block(format!("impl ::ttrpc::MethodHandler for {}Method {{", self.struct_name()), "}",
164        |w| {
165            w.block("fn handler(&self, ctx: ::ttrpc::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<()> {", "}",
166            |w| {
167                w.write_line(format!("::ttrpc::request_handler!(self, ctx, req, {}, {}, {});",
168                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.input_type()).fd.name()),
169                                        self.root_scope.find_message(self.proto.input_type()).rust_name(),
170                                        self.name()));
171                w.write_line("Ok(())");
172            });
173        });
174    }
175
176    fn write_handler_impl_async(&self, w: &mut CodeWriter) {
177        w.write_line("#[async_trait]");
178        match self.method_type().0 {
179            MethodType::Unary => {
180                w.block(format!("impl ::ttrpc::r#async::MethodHandler for {}Method {{", self.struct_name()), "}",
181                |w| {
182                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, req: ::ttrpc::Request) -> ::ttrpc::Result<::ttrpc::Response> {", "}",
183                        |w| {
184                            w.write_line(format!("::ttrpc::async_request_handler!(self, ctx, req, {}, {}, {});",
185                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.input_type()).fd.name()),
186                                        self.root_scope.find_message(self.proto.input_type()).rust_name(),
187                                        self.name()));
188                    });
189            });
190            }
191            // only receive
192            MethodType::ClientStreaming => {
193                w.block(format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
194                |w| {
195                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
196                        |w| {
197                            w.write_line(format!("::ttrpc::async_client_streamimg_handler!(self, ctx, inner, {});",
198                                        self.name()));
199                    });
200            });
201            }
202            // only send
203            MethodType::ServerStreaming => {
204                w.block(format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
205                |w| {
206                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, mut inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
207                        |w| {
208                            w.write_line(format!("::ttrpc::async_server_streamimg_handler!(self, ctx, inner, {}, {}, {});",
209                                        proto_path_to_rust_mod(self.root_scope.find_message(self.proto.input_type()).fd.name()),
210                                        self.root_scope.find_message(self.proto.input_type()).rust_name(),
211                                        self.name()));
212                    });
213            });
214            }
215            // receive and send
216            MethodType::Duplex => {
217                w.block(format!("impl ::ttrpc::r#async::StreamHandler for {}Method {{", self.struct_name()), "}",
218                |w| {
219                    w.block("async fn handler(&self, ctx: ::ttrpc::r#async::TtrpcContext, inner: ::ttrpc::r#async::StreamInner) -> ::ttrpc::Result<Option<::ttrpc::Response>> {", "}",
220                        |w| {
221                            w.write_line(format!("::ttrpc::async_duplex_streamimg_handler!(self, ctx, inner, {});",
222                                        self.name()));
223                    });
224            });
225            }
226        }
227    }
228
229    // Method signatures
230    fn unary(&self, method_name: &str) -> String {
231        format!(
232            "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}>",
233            method_name,
234            self.input(),
235            fq_grpc("Result"),
236            self.output()
237        )
238    }
239
240    fn client_streaming(&self, method_name: &str) -> String {
241        format!(
242            "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>",
243            method_name,
244            fq_grpc("Result"),
245            fq_grpc("r#async::ClientStreamSender"),
246            self.input(),
247            self.output()
248        )
249    }
250
251    fn server_streaming(&self, method_name: &str) -> String {
252        format!(
253            "{}(&self, ctx: ttrpc::context::Context, req: &{}) -> {}<{}<{}>>",
254            method_name,
255            self.input(),
256            fq_grpc("Result"),
257            fq_grpc("r#async::ClientStreamReceiver"),
258            self.output()
259        )
260    }
261
262    fn duplex_streaming(&self, method_name: &str) -> String {
263        format!(
264            "{}(&self, ctx: ttrpc::context::Context) -> {}<{}<{}, {}>>",
265            method_name,
266            fq_grpc("Result"),
267            fq_grpc("r#async::ClientStream"),
268            self.input(),
269            self.output()
270        )
271    }
272
273    fn write_client(&self, w: &mut CodeWriter) {
274        let method_name = self.name();
275        if let MethodType::Unary = self.method_type().0 {
276            w.pub_fn(self.unary(&method_name), |w| {
277                w.write_line(format!("let mut cres = {}::new();", self.output()));
278                w.write_line(format!(
279                    "::ttrpc::client_request!(self, ctx, req, \"{}.{}\", \"{}\", cres);",
280                    self.package_name,
281                    self.service_name,
282                    &self.proto.name(),
283                ));
284                w.write_line("Ok(cres)");
285            });
286        }
287    }
288
289    fn write_async_client(&self, w: &mut CodeWriter) {
290        let method_name = self.name();
291        match self.method_type().0 {
292            // Unary RPC
293            MethodType::Unary => {
294                pub_async_fn(w, &self.unary(&method_name), |w| {
295                    w.write_line(format!("let mut cres = {}::new();", self.output()));
296                    w.write_line(format!(
297                        "::ttrpc::async_client_request!(self, ctx, req, \"{}.{}\", \"{}\", cres);",
298                        self.package_name,
299                        self.service_name,
300                        &self.proto.name(),
301                    ));
302                });
303            }
304            // Client Streaming RPC
305            MethodType::ClientStreaming => {
306                pub_async_fn(w, &self.client_streaming(&method_name), |w| {
307                    w.write_line(format!(
308                        "::ttrpc::async_client_stream_send!(self, ctx, \"{}.{}\", \"{}\");",
309                        self.package_name,
310                        self.service_name,
311                        &self.proto.name(),
312                    ));
313                });
314            }
315            // Server Streaming RPC
316            MethodType::ServerStreaming => {
317                pub_async_fn(w, &self.server_streaming(&method_name), |w| {
318                    w.write_line(format!(
319                        "::ttrpc::async_client_stream_receive!(self, ctx, req, \"{}.{}\", \"{}\");",
320                        self.package_name,
321                        self.service_name,
322                        &self.proto.name(),
323                    ));
324                });
325            }
326            // Bidirectional streaming RPC
327            MethodType::Duplex => {
328                pub_async_fn(w, &self.duplex_streaming(&method_name), |w| {
329                    w.write_line(format!(
330                        "::ttrpc::async_client_stream!(self, ctx, \"{}.{}\", \"{}\");",
331                        self.package_name,
332                        self.service_name,
333                        &self.proto.name(),
334                    ));
335                });
336            }
337        };
338    }
339
340    fn write_service(&self, w: &mut CodeWriter) {
341        let (_req, req_type, resp_type) = match self.method_type().0 {
342            MethodType::Unary => ("req", self.input(), self.output()),
343            MethodType::ClientStreaming => (
344                "stream",
345                format!("::ttrpc::r#async::ServerStreamReceiver<{}>", self.input()),
346                self.output(),
347            ),
348            MethodType::ServerStreaming => (
349                "req",
350                format!(
351                    "{}, _: {}<{}>",
352                    self.input(),
353                    "::ttrpc::r#async::ServerStreamSender",
354                    self.output()
355                ),
356                "()".to_string(),
357            ),
358            MethodType::Duplex => (
359                "stream",
360                format!(
361                    "{}<{}, {}>",
362                    "::ttrpc::r#async::ServerStream",
363                    self.output(),
364                    self.input(),
365                ),
366                "()".to_string(),
367            ),
368        };
369
370        let get_sig = |context_name| {
371            format!(
372                "{}(&self, _ctx: &{}, _: {}) -> ::ttrpc::Result<{}>",
373                self.name(),
374                fq_grpc(context_name),
375                req_type,
376                resp_type,
377            )
378        };
379
380        let cb = |w: &mut CodeWriter| {
381            w.write_line(format!("Err(::ttrpc::Error::RpcStatus(::ttrpc::get_status(::ttrpc::Code::NOT_FOUND, \"/{}.{}/{} is not supported\".to_string())))",
382            self.package_name,
383            self.service_name, self.proto.name(),));
384        };
385
386        if async_on(self.customize, "server") {
387            let sig = get_sig("r#async::TtrpcContext");
388            def_async_fn(w, &sig, cb);
389        } else {
390            let sig = get_sig("TtrpcContext");
391            w.def_fn(&sig, cb);
392        }
393    }
394
395    fn write_bind(&self, w: &mut CodeWriter) {
396        let method_handler_name = "::ttrpc::MethodHandler";
397
398        let s = format!(
399            "methods.insert(\"/{}.{}/{}\".to_string(),
400                    Box::new({}Method{{service: service.clone()}}) as Box<dyn {} + Send + Sync>);",
401            self.package_name,
402            self.service_name,
403            self.proto.name(),
404            self.struct_name(),
405            method_handler_name,
406        );
407        w.write_line(&s);
408    }
409
410    fn write_async_bind(&self, w: &mut CodeWriter) {
411        let s = if matches!(self.method_type().0, MethodType::Unary) {
412            format!(
413                "methods.insert(\"{}\".to_string(),
414                    Box::new({}Method{{service: service.clone()}}) as {});",
415                self.proto.name(),
416                self.struct_name(),
417                "Box<dyn ::ttrpc::r#async::MethodHandler + Send + Sync>"
418            )
419        } else {
420            format!(
421                "streams.insert(\"{}\".to_string(),
422                    Arc::new({}Method{{service: service.clone()}}) as {});",
423                self.proto.name(),
424                self.struct_name(),
425                "Arc<dyn ::ttrpc::r#async::StreamHandler + Send + Sync>"
426            )
427        };
428        w.write_line(&s);
429    }
430}
431
432struct ServiceGen<'a> {
433    proto: &'a ServiceDescriptorProto,
434    methods: Vec<MethodGen<'a>>,
435    customize: &'a Customize,
436    package_name: String,
437}
438
439impl<'a> ServiceGen<'a> {
440    fn new(
441        proto: &'a ServiceDescriptorProto,
442        file: &FileDescriptorProto,
443        root_scope: &'a RootScope<'a>,
444        customize: &'a Customize,
445    ) -> ServiceGen<'a> {
446        let methods = proto
447            .method
448            .iter()
449            .map(|m| {
450                MethodGen::new(
451                    m,
452                    file.package().to_string(),
453                    util::to_camel_case(proto.name()),
454                    root_scope,
455                    customize,
456                )
457            })
458            .collect();
459
460        ServiceGen {
461            proto,
462            methods,
463            customize,
464            package_name: file.package().to_string(),
465        }
466    }
467
468    fn service_name(&self) -> String {
469        util::to_camel_case(self.proto.name())
470    }
471
472    fn service_path(&self) -> String {
473        format!("{}.{}", self.package_name, self.service_name())
474    }
475
476    fn client_name(&self) -> String {
477        format!("{}Client", self.service_name())
478    }
479
480    fn has_stream_method(&self) -> bool {
481        self.methods
482            .iter()
483            .any(|method| !matches!(method.method_type().0, MethodType::Unary))
484    }
485
486    fn has_normal_method(&self) -> bool {
487        self.methods
488            .iter()
489            .any(|method| matches!(method.method_type().0, MethodType::Unary))
490    }
491
492    fn write_client(&self, w: &mut CodeWriter) {
493        if async_on(self.customize, "client") {
494            self.write_async_client(w)
495        } else {
496            self.write_sync_client(w)
497        }
498    }
499
500    fn write_sync_client(&self, w: &mut CodeWriter) {
501        w.write_line("#[derive(Clone)]");
502        w.pub_struct(self.client_name(), |w| {
503            w.field_decl("client", "::ttrpc::Client");
504        });
505
506        w.write_line("");
507
508        w.impl_self_block(self.client_name(), |w| {
509            w.pub_fn("new(client: ::ttrpc::Client) -> Self", |w| {
510                w.expr_block(self.client_name(), |w| {
511                    w.write_line("client,");
512                });
513            });
514
515            for method in &self.methods {
516                w.write_line("");
517                method.write_client(w);
518            }
519        });
520    }
521
522    fn write_async_client(&self, w: &mut CodeWriter) {
523        w.write_line("#[derive(Clone)]");
524        w.pub_struct(self.client_name(), |w| {
525            w.field_decl("client", "::ttrpc::r#async::Client");
526        });
527
528        w.write_line("");
529
530        w.impl_self_block(self.client_name(), |w| {
531            w.pub_fn("new(client: ::ttrpc::r#async::Client) -> Self", |w| {
532                w.expr_block(self.client_name(), |w| {
533                    w.write_line("client,");
534                });
535            });
536
537            for method in &self.methods {
538                w.write_line("");
539                method.write_async_client(w);
540            }
541        });
542    }
543
544    fn write_server(&self, w: &mut CodeWriter) {
545        let mut trait_name = self.service_name();
546        if async_on(self.customize, "server") {
547            w.write_line("#[async_trait]");
548            trait_name = format!("{}: Sync", &self.service_name());
549        }
550
551        w.pub_trait(&trait_name, |w| {
552            for method in &self.methods {
553                method.write_service(w);
554            }
555        });
556
557        w.write_line("");
558        if async_on(self.customize, "server") {
559            self.write_async_server_create(w);
560        } else {
561            self.write_sync_server_create(w);
562        }
563    }
564
565    fn write_sync_server_create(&self, w: &mut CodeWriter) {
566        let method_handler_name = "::ttrpc::MethodHandler";
567        let s = format!(
568            "create_{}(service: Arc<dyn {} + Send + Sync>) -> HashMap<String, Box<dyn {} + Send + Sync>>",
569            to_snake_case(&self.service_name()),
570            self.service_name(),
571            method_handler_name,
572        );
573
574        let has_normal_method = self.has_normal_method();
575        w.pub_fn(&s, |w| {
576            if has_normal_method {
577                w.write_line("let mut methods = HashMap::new();");
578            } else {
579                w.write_line("let methods = HashMap::new();");
580            }
581            for method in &self.methods[0..self.methods.len()] {
582                w.write_line("");
583                method.write_bind(w);
584            }
585            w.write_line("");
586            w.write_line("methods");
587        });
588    }
589
590    fn write_async_server_create(&self, w: &mut CodeWriter) {
591        let s = format!(
592            "create_{}(service: Arc<dyn {} + Send + Sync>) -> HashMap<String, {}>",
593            to_snake_case(&self.service_name()),
594            self.service_name(),
595            "::ttrpc::r#async::Service"
596        );
597
598        let has_stream_method = self.has_stream_method();
599        let has_normal_method = self.has_normal_method();
600        w.pub_fn(&s, |w| {
601            w.write_line("let mut ret = HashMap::new();");
602            if has_normal_method {
603                w.write_line("let mut methods = HashMap::new();");
604            } else {
605                w.write_line("let methods = HashMap::new();");
606            }
607            if has_stream_method {
608                w.write_line("let mut streams = HashMap::new();");
609            } else {
610                w.write_line("let streams = HashMap::new();");
611            }
612            for method in &self.methods[0..self.methods.len()] {
613                w.write_line("");
614                method.write_async_bind(w);
615            }
616            w.write_line("");
617            w.write_line(format!(
618                "ret.insert(\"{}\".to_string(), {});",
619                self.service_path(),
620                "::ttrpc::r#async::Service{ methods, streams }"
621            ));
622            w.write_line("ret");
623        });
624    }
625
626    fn write_method_handlers(&self, w: &mut CodeWriter) {
627        for (i, method) in self.methods.iter().enumerate() {
628            if i != 0 {
629                w.write_line("");
630            }
631
632            method.write_handler(w);
633        }
634    }
635
636    fn write(&self, w: &mut CodeWriter) {
637        self.write_client(w);
638        w.write_line("");
639        self.write_method_handlers(w);
640        w.write_line("");
641        self.write_server(w);
642    }
643}
644
645pub fn write_generated_by(w: &mut CodeWriter, pkg: &str, version: &str) {
646    w.write_line(format!(
647        "// This file is generated by {pkg} {version}. Do not edit",
648        pkg = pkg,
649        version = version
650    ));
651    write_generated_common(w);
652}
653
654fn write_generated_common(w: &mut CodeWriter) {
655    // https://secure.phabricator.com/T784
656    w.write_line("// @generated");
657
658    w.write_line("");
659    w.write_line("#![cfg_attr(rustfmt, rustfmt_skip)]");
660    w.write_line("#![allow(unknown_lints)]");
661    w.write_line("#![allow(clipto_camel_casepy)]");
662    w.write_line("#![allow(dead_code)]");
663    w.write_line("#![allow(missing_docs)]");
664    w.write_line("#![allow(non_camel_case_types)]");
665    w.write_line("#![allow(non_snake_case)]");
666    w.write_line("#![allow(non_upper_case_globals)]");
667    w.write_line("#![allow(trivial_casts)]");
668    w.write_line("#![allow(unsafe_code)]");
669    w.write_line("#![allow(unused_imports)]");
670    w.write_line("#![allow(unused_results)]");
671    w.write_line("#![allow(clippy::all)]");
672}
673
674fn gen_file(
675    file: &FileDescriptorProto,
676    root_scope: &RootScope<'_>,
677    customize: &Customize,
678) -> Option<CodeGeneratorResponse_File> {
679    if file.service.is_empty() {
680        return None;
681    }
682
683    let base = proto_path_to_rust_mod(file.name());
684
685    let mut w = CodeWriter::new();
686
687    write_generated_by(&mut w, "ttrpc-compiler", env!("CARGO_PKG_VERSION"));
688
689    w.write_line("use protobuf::{CodedInputStream, CodedOutputStream, Message};");
690    w.write_line("use std::collections::HashMap;");
691    w.write_line("use std::sync::Arc;");
692    if customize.async_all || customize.async_client || customize.async_server {
693        w.write_line("use async_trait::async_trait;");
694    }
695
696    for service in file.service.iter() {
697        w.write_line("");
698        ServiceGen::new(service, file, root_scope, customize).write(&mut w);
699    }
700
701    let mut file = CodeGeneratorResponse_File::new();
702    file.set_name(format!("{base}_ttrpc.rs"));
703    file.set_content(w.take_code());
704    Some(file)
705}
706
707pub fn gen(
708    file_descriptors: &[FileDescriptorProto],
709    files_to_generate: &[String],
710    customize: &Customize,
711) -> CodeGeneratorResponse {
712    let files_map: HashMap<&str, &FileDescriptorProto> =
713        file_descriptors.iter().map(|f| (f.name(), f)).collect();
714
715    let root_scope = RootScope { file_descriptors };
716
717    let mut results = CodeGeneratorResponse::new();
718    results.set_supported_features(CodeGeneratorResponse_Feature::FEATURE_PROTO3_OPTIONAL as _);
719
720    for file_name in files_to_generate {
721        let file = files_map[&file_name[..]];
722
723        if file.service.is_empty() {
724            continue;
725        }
726
727        results
728            .file
729            .extend(gen_file(file, &root_scope, customize).into_iter());
730    }
731
732    results
733}
734
735pub fn gen_and_write(
736    file_descriptors: &[FileDescriptorProto],
737    files_to_generate: &[String],
738    out_dir: &Path,
739    customize: &Customize,
740) -> io::Result<()> {
741    let results = gen(file_descriptors, files_to_generate, customize);
742
743    if customize.gen_mod {
744        let file_path = out_dir.join("mod.rs");
745        let mut set = HashSet::new();
746        //if mod file exists
747        if let Ok(file) = File::open(&file_path) {
748            let reader = io::BufReader::new(file);
749            reader.lines().for_each(|line| {
750                let _ = line.map(|r| set.insert(r));
751            });
752        }
753        let mut file_write = fs::OpenOptions::new()
754            .create(true)
755            .write(true)
756            .truncate(true)
757            .open(&file_path)?;
758        for r in &results.file {
759            let prefix_name: Vec<&str> = r.name().split('.').collect();
760            set.insert(format!("pub mod {};", prefix_name[0]));
761        }
762        for item in &set {
763            writeln!(file_write, "{}", item)?;
764        }
765        file_write.flush()?;
766    }
767
768    for r in &results.file {
769        let mut file_path = out_dir.to_owned();
770        file_path.push(r.name());
771        let mut file_writer = File::create(&file_path)?;
772        file_writer.write_all(r.content().as_bytes())?;
773        file_writer.flush()?;
774    }
775
776    Ok(())
777}
778
779pub fn protoc_gen_grpc_rust_main() {
780    plugin_main(|file_descriptors, files_to_generate| {
781        gen(
782            file_descriptors,
783            files_to_generate,
784            &Customize {
785                ..Default::default()
786            },
787        )
788    });
789}
790
791fn plugin_main<F>(gen: F)
792where
793    F: Fn(&[FileDescriptorProto], &[String]) -> CodeGeneratorResponse,
794{
795    plugin_main_2(|r| gen(&r.proto_file, &r.file_to_generate))
796}
797
798fn plugin_main_2<F>(gen: F)
799where
800    F: Fn(&CodeGeneratorRequest) -> CodeGeneratorResponse,
801{
802    let req = CodeGeneratorRequest::parse_from_reader(&mut stdin()).unwrap();
803    let result = gen(&req);
804    result.write_to_writer(&mut stdout()).unwrap();
805}