power_protobuf_lib/
dep.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fs::{self, read_to_string},
4    ops::Range,
5    path::{Path, PathBuf},
6};
7
8use cargo_toml::Manifest;
9use proc_macro2::Span;
10use syn::parse::ParseStream;
11use syn_prelude::{ToErr, ToIdent, ToSynError};
12
13use crate::{model::Import, resolve::PathMod, FilePath};
14
15#[derive(Debug, Clone)]
16pub struct ExternalTypeRef {
17    pub is_message: bool,
18    pub prost_type: bool,
19    pub import_index: usize,
20}
21
22#[derive(Debug, Clone)]
23pub struct Deps {
24    pub current_source_range: Range<usize>,
25    pub scopes: HashMap<String, HashMap<String, ExternalTypeRef>>,
26    pub project_root: String,
27    pub project_root_path: PathBuf,
28    pub bin_paths: HashSet<String>,
29    pub example_paths: HashSet<String>,
30    pub lib_path: String,
31}
32
33impl Deps {
34    pub fn new(call_site_path: &PathBuf, input: ParseStream) -> syn::Result<Self> {
35        fn find_cargo_toml(path: impl AsRef<Path>) -> Option<PathBuf> {
36            if let Some(parent) = path.as_ref().parent() {
37                let p = parent.join("Cargo.toml");
38                if p.exists() && p.is_file() {
39                    Some(p)
40                } else {
41                    find_cargo_toml(parent)
42                }
43            } else {
44                None
45            }
46        }
47        let project_manifest_path = find_cargo_toml(call_site_path)
48            .ok_or(input.span().to_syn_error("cannot find project Cargo.toml"))?;
49        let project_manifest = Manifest::from_path(&project_manifest_path)
50            .map_err(|err| input.span().to_syn_error(err.to_string()))?;
51
52        let mut lib_path = "src/lib.rs".to_owned();
53        if let Some(lib) = project_manifest.lib {
54            if let Some(path) = lib.path {
55                lib_path = path;
56            }
57        }
58        let bin_paths = project_manifest
59            .bin
60            .iter()
61            .filter_map(|bin| bin.path.clone())
62            .collect::<HashSet<_>>();
63        let example_paths = project_manifest
64            .example
65            .iter()
66            .filter_map(|example| {
67                if let Some(path) = example.path.as_ref() {
68                    Some(path.clone())
69                } else if let Some(name) = example.name.as_ref() {
70                    Some(format!("examples/{name}.rs"))
71                } else {
72                    None
73                }
74            })
75            .collect::<HashSet<_>>();
76
77        let project_root_path = project_manifest_path.parent().unwrap().to_path_buf();
78        let mut slf = Self {
79            current_source_range: Span::call_site().byte_range(),
80            scopes: Default::default(),
81            project_root: project_root_path.to_string_lossy().to_string(),
82            project_root_path,
83            lib_path,
84            bin_paths,
85            example_paths,
86        };
87        let contents = fs::read_to_string(call_site_path)
88            .map_err(|_err| input.span().to_syn_error("cannot read source file"))?;
89        slf.scan_with_contents(true, &contents, 0)?;
90
91        Ok(slf)
92    }
93    pub fn scan(&mut self, import_index: usize, import: &Import) -> syn::Result<()> {
94        if import.builtin {
95            match import.path.value().as_str() {
96                "google/protobuf/any.proto" => {
97                    self.add_well_known_type("google.protobuf.Any", true, import_index);
98                }
99                "google/protobuf/api.proto" => {
100                    self.add_well_known_type("google.protobuf.Api", true, import_index);
101                    self.add_well_known_type("google.protobuf.Method", true, import_index);
102                    self.add_well_known_type("google.protobuf.Mixin", true, import_index);
103                }
104                "google/protobuf/duration.proto" => {
105                    self.add_well_known_type("google.protobuf.Duration", true, import_index);
106                }
107                "google/protobuf/empty.proto" => {
108                    self.add_well_known_type("google.protobuf.Empty", true, import_index);
109                }
110                "google/protobuf/field_mask.proto" => {
111                    self.add_well_known_type("google.protobuf.FieldMask", true, import_index);
112                }
113                "google/protobuf/source_context.proto" => {
114                    self.add_well_known_type("google.protobuf.SourceContext", true, import_index);
115                }
116                "google/protobuf/struct.proto" => {
117                    self.add_well_known_type("google.protobuf.Struct", true, import_index);
118                    self.add_well_known_type("google.protobuf.Value", true, import_index);
119                    self.add_well_known_type("google.protobuf.NullValue", false, import_index);
120                    self.add_well_known_type("google.protobuf.ListValue", true, import_index);
121                }
122                "google/protobuf/timestamp.proto" => {
123                    self.add_well_known_type("google.protobuf.Timestamp", true, import_index);
124                }
125                "google/protobuf/type.proto" => {
126                    self.add_well_known_type("google.protobuf.Type", true, import_index);
127                    self.add_well_known_type("google.protobuf.Field", true, import_index);
128                    self.add_well_known_type("google.protobuf.Field.Kind", false, import_index);
129                    self.add_well_known_type(
130                        "google.protobuf.Field.Cardinality",
131                        false,
132                        import_index,
133                    );
134                    self.add_well_known_type("google.protobuf.Enum", true, import_index);
135                    self.add_well_known_type("google.protobuf.EnumValue", true, import_index);
136                    self.add_well_known_type("google.protobuf.Option", true, import_index);
137                    self.add_well_known_type("google.protobuf.Syntax", false, import_index);
138                }
139                "google/protobuf/wrappers.proto" => {
140                    self.add_well_known_type("google.protobuf.DoubleValue", true, import_index);
141                    self.add_well_known_type("google.protobuf.FloatValue", true, import_index);
142                    self.add_well_known_type("google.protobuf.Int64Value", true, import_index);
143                    self.add_well_known_type("google.protobuf.UInt64Value", true, import_index);
144                    self.add_well_known_type("google.protobuf.Int32Value", true, import_index);
145                    self.add_well_known_type("google.protobuf.UInt32Value", true, import_index);
146                    self.add_well_known_type("google.protobuf.BoolValue", true, import_index);
147                    self.add_well_known_type("google.protobuf.StringValue", true, import_index);
148                    self.add_well_known_type("google.protobuf.BytesValue", true, import_index);
149                }
150                _ => import
151                    .path
152                    .to_syn_error("unsupported .proto as for importing well known types.")
153                    .to_err()?,
154            };
155            Ok(())
156        } else if let Some(file_path) = &import.file_path {
157            let contents = read_to_string(&file_path.path).map_err(|_| {
158                import.path.span().to_syn_error(
159                    "fails to read contents while scanning potential types for importing",
160                )
161            })?;
162            self.scan_with_contents(false, &contents, import_index)
163        } else {
164            // unreachable!()
165            Ok(())
166        }
167    }
168
169    fn add_well_known_type(&mut self, type_name: &str, is_message: bool, import_index: usize) {
170        if let Some(scope) = self.scopes.get_mut("") {
171            scope.insert(
172                type_name.to_owned(),
173                ExternalTypeRef {
174                    is_message,
175                    prost_type: true,
176                    import_index,
177                },
178            );
179        } else {
180            let mut scope = HashMap::new();
181            scope.insert(
182                type_name.to_owned(),
183                ExternalTypeRef {
184                    is_message,
185                    prost_type: true,
186                    import_index,
187                },
188            );
189            self.scopes.insert("".to_owned(), scope);
190        }
191    }
192
193    fn scan_with_contents(
194        &mut self,
195        current_source: bool,
196        contents: &str,
197        import_index: usize,
198    ) -> syn::Result<()> {
199        let mut source = contents;
200        let mut source_pos = 0usize;
201        while !source.is_empty() {
202            if let Some(macro_pos) = source.find("protobuf!") {
203                source = if let Some((_, next)) = source.split_at_checked(macro_pos + 10) {
204                    next
205                } else {
206                    break;
207                };
208                source_pos += 10 + macro_pos;
209                if current_source && self.current_source_range.contains(&source_pos) {
210                    source_pos = self.current_source_range.end;
211                    if let Some(checked) = contents.split_at_checked(self.current_source_range.end)
212                    {
213                        source = checked.1;
214                    } else {
215                        break;
216                    }
217                    continue;
218                } else {
219                    if let Some((before, _)) = source.split_at_checked(macro_pos) {
220                        if let Some(nl) = before.rfind("\n") {
221                            // check is commented to this macro
222                            if before.split_at(nl).1.contains("//") {
223                                // has comment before protobuf!
224                                continue;
225                            }
226                        }
227                    } else {
228                        break;
229                    }
230                }
231            } else {
232                break;
233            }
234
235            if let Ok((rest, (package, types))) = fast_pb_parser::pb_proto(source) {
236                let eaten = source.len() - rest.len();
237                source_pos += eaten;
238                source = rest;
239
240                let package = package.unwrap_or_default().to_owned();
241                if let Some(map) = self.scopes.get_mut(&package) {
242                    types
243                        .iter()
244                        .for_each(|t| t.register_type(import_index, map, ""));
245                } else {
246                    let mut map = HashMap::<String, ExternalTypeRef>::new();
247                    types
248                        .iter()
249                        .for_each(|t| t.register_type(import_index, &mut map, ""));
250                    self.scopes.insert(package.to_owned(), map);
251                }
252            } else {
253                // broken but can still parse next protobuf!
254            }
255        }
256        Ok(())
257    }
258}
259
260impl Deps {
261    pub fn resolve_path(&self, path: impl AsRef<str>, span: Span) -> syn::Result<FilePath> {
262        let p = path
263            .as_ref()
264            .trim_start_matches(&self.project_root)
265            .trim_start_matches("/");
266        if p.eq(&self.lib_path) {
267            let mut mod_path = syn::Path::new();
268            mod_path.push_ident(("crate", span).to_ident());
269            Ok(FilePath {
270                root: true,
271                bin: false,
272                example: false,
273                is_mod: false,
274                path: self.project_root_path.join(&self.lib_path),
275                mod_path,
276            })
277        } else if self.bin_paths.contains(p) {
278            let mut mod_path = syn::Path::new();
279            mod_path.push_ident(("crate", span).to_ident());
280            Ok(FilePath {
281                root: true,
282                bin: true,
283                example: false,
284                is_mod: false,
285                path: self.project_root_path.join(p),
286                mod_path,
287            })
288        } else if self.example_paths.contains(p) {
289            let mut mod_path = syn::Path::new();
290            mod_path.push_ident(("crate", span).to_ident());
291            Ok(FilePath {
292                root: true,
293                bin: false,
294                example: true,
295                is_mod: false,
296                path: self.project_root_path.join(p),
297                mod_path,
298            })
299        } else {
300            let p = p.trim_start_matches("src/");
301            let path = self.project_root_path.join("src").join(p);
302            if !path.exists() {
303                return span
304                    .to_syn_error(format!("path {:?} does not exist", &path))
305                    .to_err();
306            }
307            let mut is_mod = false;
308            let p = if p.ends_with("/mod.rs") {
309                is_mod = true;
310                p.trim_end_matches("/mod.rs")
311            } else {
312                p.trim_end_matches(".rs")
313            };
314            let mod_path =
315                syn::Path::from_idents(p.split("/").into_iter().map(|p| (p, span).to_ident()));
316            Ok(FilePath {
317                root: false,
318                bin: false,
319                is_mod,
320                example: false,
321                path,
322                mod_path,
323            })
324        }
325    }
326}
327
328trait RegisterType {
329    fn register_type(
330        &self,
331        import_index: usize,
332        map: &mut HashMap<String, ExternalTypeRef>,
333        parent: &str,
334    );
335}
336
337mod fast_pb_parser {
338    use nom::{
339        branch::alt,
340        bytes::complete::{is_not, tag, take_until, take_while1},
341        character::{complete::digit1, is_newline},
342        combinator::{map, opt},
343        multi::{many0, separated_list1},
344        sequence::{delimited, preceded, terminated, tuple},
345        IResult,
346    };
347
348    use super::RegisterType;
349
350    #[derive(Debug, PartialEq)]
351    pub enum Element<'a> {
352        Message(Message<'a>),
353        Enum(Enum<'a>),
354    }
355
356    impl RegisterType for Element<'_> {
357        fn register_type(
358            &self,
359            import_index: usize,
360            map: &mut std::collections::HashMap<String, super::ExternalTypeRef>,
361            parent: &str,
362        ) {
363            match self {
364                Element::Message(m) => m.register_type(import_index, map, parent),
365                Element::Enum(e) => e.register_type(import_index, map, parent),
366            }
367        }
368    }
369
370    fn line_comment(input: &str) -> IResult<&str, ()> {
371        let (rest, _) = tag("//")(input)?;
372        if let Some(c) = rest.chars().nth(0) {
373            if is_newline(c as u8) {
374                Ok((rest, ()))
375            } else {
376                map(is_not("\n\r"), |_| ())(rest)
377            }
378        } else {
379            Ok((rest, ()))
380        }
381    }
382
383    fn block_commnet(input: &str) -> IResult<&str, ()> {
384        map(delimited(tag("/*"), take_until("*/"), tag("*/")), |_| ())(input)
385    }
386
387    fn ws(input: &str) -> IResult<&str, ()> {
388        map(
389            many0(alt((
390                map(take_while1(|c: char| c.is_whitespace()), |_| ()),
391                line_comment,
392                block_commnet,
393            ))),
394            |_| (),
395        )(input)
396    }
397
398    macro_rules! tag_ws_around {
399        ($tag:expr) => {
400            tuple((ws, tag($tag), ws))
401        };
402        ($tag1:expr, $tag2:expr) => {
403            tuple((ws, tag($tag1), ws, tag($tag2), ws))
404        };
405    }
406
407    fn ident(input: &str) -> IResult<&str, &str> {
408        take_while1(|c: char| c.is_alphanumeric() || c == '_')(input)
409    }
410
411    pub fn pb_proto(input: &str) -> IResult<&str, (Option<&str>, Vec<Element>)> {
412        let mut package = None;
413        let (rest, elements) = delimited(
414            tag_ws_around!("{"),
415            map(many0(pb_decl), |result| {
416                result
417                    .into_iter()
418                    .map(|(p, r)| {
419                        if p.is_some() {
420                            package = p;
421                        }
422                        r
423                    })
424                    .flatten()
425                    .flatten()
426                    .collect::<Vec<_>>()
427            }),
428            tag_ws_around!("}"),
429        )(input)?;
430        Ok((rest, (package, elements)))
431    }
432
433    fn pb_decl(input: &str) -> IResult<&str, (Option<&str>, Option<Vec<Element>>)> {
434        let (rest, _) = ws(input)?;
435        let (_, next) = alt((
436            tag("message"),
437            tag("enum"),
438            tag("option"),
439            tag("import"),
440            tag("service"),
441            tag("syntax"),
442            tag("package"),
443        ))(rest)?;
444        let mut package = None;
445        match next {
446            "message" => map(pb_message, |m| Some(vec![Element::Message(m)]))(rest),
447            "enum" => map(pb_enum, |e| Some(vec![Element::Enum(e)]))(rest),
448            "option" => map(pb_option, |_| None)(rest),
449            "import" => map(pb_import, |_| None)(rest),
450            "service" => map(pb_service, |messages| {
451                Some(
452                    messages
453                        .into_iter()
454                        .map(|m| Element::Message(m))
455                        .collect::<Vec<_>>(),
456                )
457            })(rest),
458            "syntax" => map(pb_syntax, |_| None)(rest),
459            "package" => map(pb_package, |name| {
460                package = Some(name);
461                None
462            })(rest),
463            _ => unreachable!(),
464        }
465        .map(|(rest, result)| (rest, (package, result)))
466    }
467
468    fn pb_syntax(input: &str) -> IResult<&str, &str> {
469        delimited(
470            tag_ws_around!("syntax", "="),
471            delimited(tag("\""), ident, tag("\"")),
472            tag_ws_around!(";"),
473        )(input)
474    }
475
476    fn pb_import(input: &str) -> IResult<&str, &str> {
477        delimited(
478            tag_ws_around!("import"),
479            delimited(tag("\""), take_until("\""), tag("\"")),
480            tag_ws_around!(";"),
481        )(input)
482    }
483
484    fn pb_option(input: &str) -> IResult<&str, ()> {
485        map(
486            tuple((
487                tag_ws_around!("option"),
488                take_while1(|c: char| c != ';'),
489                tag_ws_around!(";"),
490            )),
491            |_| (),
492        )(input)
493    }
494
495    fn pb_package(input: &str) -> IResult<&str, &str> {
496        delimited(tag_ws_around!("package"), ident, tag_ws_around!(";"))(input)
497    }
498
499    fn pb_service(input: &str) -> IResult<&str, Vec<Message>> {
500        let (rest, _name) =
501            delimited(tag_ws_around!("service"), ident, tag_ws_around!("{"))(input)?;
502        terminated(
503            map(many0(pb_rpc), |messages| {
504                messages
505                    .into_iter()
506                    .map(|rpc| match (rpc.gen_request, rpc.gen_response) {
507                        (None, None) => vec![],
508                        (None, Some((name, suffix, inner_types))) => vec![Message {
509                            name,
510                            suffix,
511                            inner_types,
512                        }],
513                        (Some((name, suffix, inner_types)), None) => {
514                            vec![Message {
515                                name,
516                                suffix,
517                                inner_types,
518                            }]
519                        }
520                        (Some(r), Some(p)) => vec![
521                            Message {
522                                name: r.0,
523                                suffix: r.1,
524                                inner_types: r.2,
525                            },
526                            Message {
527                                name: p.0,
528                                suffix: p.1,
529                                inner_types: p.2,
530                            },
531                        ],
532                    })
533                    .flatten()
534                    .collect()
535            }),
536            tag_ws_around!("}"),
537        )(rest)
538    }
539
540    struct Rpc<'a> {
541        gen_request: Option<(&'a str, &'static str, Vec<Element<'a>>)>,
542        gen_response: Option<(&'a str, &'static str, Vec<Element<'a>>)>,
543    }
544
545    fn pb_rpc(input: &str) -> IResult<&str, Rpc> {
546        let (rest, _) = tag_ws_around!("rpc")(input)?;
547        let (rest, rpc_name) = ident(rest)?;
548        let (rest, gen_request) = delimited(
549            tag_ws_around!("("),
550            preceded(
551                opt(tag_ws_around!("stream")),
552                alt((
553                    map(pb_message_body, |inner_types| {
554                        Some((rpc_name, "Request", inner_types))
555                    }),
556                    map(tuple((ident, pb_message_body)), |(name, types)| {
557                        Some((name, "", types))
558                    }),
559                    map(ident, |_| None),
560                )),
561            ),
562            tag_ws_around!(")"),
563        )(rest)?;
564        let (rest, _) = tag_ws_around!("returns")(rest)?;
565        let (rest, gen_response) = delimited(
566            tag_ws_around!("("),
567            preceded(
568                opt(tag_ws_around!("stream")),
569                alt((
570                    map(pb_message_body, |inner_types| {
571                        Some((rpc_name, "Response", inner_types))
572                    }),
573                    map(tuple((ident, pb_message_body)), |(name, inner_types)| {
574                        Some((name, "", inner_types))
575                    }),
576                    map(ident, |_| None),
577                )),
578            ),
579            tag_ws_around!(")"),
580        )(rest)?;
581        // rpc options
582        let (rest, _) = opt(tuple((
583            tag_ws_around!("{"),
584            many0(tuple((pb_option, opt(tag_ws_around!(";"))))),
585            tag_ws_around!("}"),
586        )))(rest)?;
587        let (rest, _) = opt(tag_ws_around!(";"))(rest)?;
588        Ok((
589            rest,
590            Rpc {
591                gen_request,
592                gen_response,
593            },
594        ))
595    }
596
597    fn pb_message_body(input: &str) -> IResult<&str, Vec<Element>> {
598        map(
599            delimited(
600                tag_ws_around!("{"),
601                many0(alt((
602                    map(pb_reserved, |_| None),
603                    map(pb_extend, |_| None),
604                    map(pb_message, |m| Some(Element::Message(m))),
605                    map(pb_enum, |e| Some(Element::Enum(e))),
606                    map(pb_field, |_| None),
607                ))),
608                tag_ws_around!("}"),
609            ),
610            |results| results.into_iter().flatten().collect::<Vec<_>>(),
611        )(input)
612    }
613
614    fn pb_path(input: &str) -> IResult<&str, ()> {
615        separated_list1(tag("."), ident)(input).map(|(rest, _)| (rest, ()))
616    }
617
618    fn pb_extend(input: &str) -> IResult<&str, ()> {
619        let (rest, _) = tuple((ws, tag("extend"), ws, pb_path))(input)?;
620        pb_message_body(rest).map(|(rest, _)| (rest, ()))
621    }
622
623    fn pb_reserved(input: &str) -> IResult<&str, ()> {
624        preceded(
625            tag_ws_around!("reserved"),
626            map(
627                many0(terminated(
628                    alt((
629                        map(digit1, |_| ()),
630                        map(tuple((digit1, ws, tag("to"), ws, digit1)), |_| ()),
631                        map(delimited(tag("\""), ident, tag("\"")), |_| ()),
632                    )),
633                    opt(tag_ws_around!(";")),
634                )),
635                |_| (),
636            ),
637        )(input)
638    }
639
640    fn pb_field(input: &str) -> IResult<&str, ()> {
641        let (rest, _) = alt((
642            map(tuple((tag_ws_around!("group"), pb_message_body)), |_| ()),
643            map(
644                tuple((
645                    tag_ws_around!("map", "<"),
646                    pb_path,
647                    tag_ws_around!(","),
648                    pb_path,
649                    tag_ws_around!(">"),
650                    ident,
651                    ws,
652                    opt(tuple((tag_ws_around!("="), digit1, ws))),
653                    opt(tuple((tag("["), is_not("]"), tag("]"), ws))),
654                )),
655                |_| (),
656            ),
657            map(
658                tuple((tag_ws_around!("oneof"), ident, ws, pb_message_body)),
659                |_| (),
660            ),
661            map(
662                tuple((
663                    ws,
664                    opt(tag("repeated")),
665                    ws,
666                    pb_path,
667                    ws,
668                    ident,
669                    ws,
670                    opt(tuple((tag_ws_around!("="), digit1, ws))),
671                    opt(tuple((tag("["), is_not("]"), tag("]"), ws))),
672                )),
673                |_| (),
674            ),
675        ))(input)?;
676        let (rest, _) = opt(tag_ws_around!(";"))(rest)?;
677        Ok((rest, ()))
678    }
679
680    #[derive(Debug, PartialEq)]
681    pub struct Message<'a> {
682        pub name: &'a str,
683        pub suffix: &'static str,
684        pub inner_types: Vec<Element<'a>>,
685    }
686
687    impl RegisterType for Message<'_> {
688        fn register_type(
689            &self,
690            import_index: usize,
691            map: &mut std::collections::HashMap<String, super::ExternalTypeRef>,
692            parent: &str,
693        ) {
694            let type_name = format!("{}{}{}", parent, self.name, self.suffix);
695
696            for el in self.inner_types.iter() {
697                el.register_type(import_index, map, &format!("{}.", &type_name));
698            }
699
700            map.insert(
701                type_name,
702                super::ExternalTypeRef {
703                    is_message: true,
704                    prost_type: false,
705                    import_index,
706                },
707            );
708        }
709    }
710
711    fn pb_message(input: &str) -> IResult<&str, Message> {
712        let (rest, message_name) = delimited(tag_ws_around!("message"), ident, ws)(input)?;
713        let (rest, inner_types) = pb_message_body(rest)?;
714        Ok((
715            rest,
716            Message {
717                name: message_name,
718                suffix: "",
719                inner_types,
720            },
721        ))
722    }
723
724    #[derive(Debug, PartialEq)]
725    pub struct Enum<'a> {
726        pub name: &'a str,
727    }
728
729    impl RegisterType for Enum<'_> {
730        fn register_type(
731            &self,
732            import_index: usize,
733            map: &mut std::collections::HashMap<String, super::ExternalTypeRef>,
734            parent: &str,
735        ) {
736            map.insert(
737                format!("{}{}", parent, self.name),
738                super::ExternalTypeRef {
739                    is_message: false,
740                    prost_type: false,
741                    import_index,
742                },
743            );
744        }
745    }
746
747    fn pb_enum(input: &str) -> IResult<&str, Enum> {
748        map(
749            delimited(
750                tag_ws_around!("enum"),
751                ident,
752                tuple((
753                    ws,
754                    tag("{"),
755                    ws,
756                    many0(tuple((
757                        alt((
758                            pb_reserved,
759                            pb_option,
760                            map(
761                                tuple((
762                                    ident,
763                                    opt(tuple((tag_ws_around!("="), digit1, ws))),
764                                    opt(tuple((tag("["), is_not("]"), tag("]"), ws))),
765                                    ws,
766                                )),
767                                |_| (),
768                            ),
769                        )),
770                        opt(tag_ws_around!(";")),
771                    ))),
772                    tag("}"),
773                    ws,
774                )),
775            ),
776            |name| Enum { name },
777        )(input)
778    }
779
780    #[cfg(test)]
781    mod test_fast_pb_parser {
782        use crate::dep::fast_pb_parser::{self, Element, Message};
783
784        fn test_parser(
785            pb_text: &str,
786            expected_package: Option<&str>,
787            expected_elements: Vec<Element>,
788        ) {
789            let (rest, (package, result)) = fast_pb_parser::pb_proto(pb_text).unwrap();
790            if let Some(pos) = pb_text.rfind("}") {
791                assert_eq!(
792                    rest,
793                    pb_text.split_at(pos + 1).1,
794                    "unexpected rest text: {rest}"
795                );
796            }
797            assert!(
798                package.eq(&expected_package),
799                "unexpected package {:#?}",
800                package
801            );
802            assert!(
803                result.eq(&expected_elements),
804                "unexpected result {:#?}",
805                result
806            );
807        }
808
809        #[test]
810        fn test_parse_protocol() {
811            test_parser(
812                r#"{
813                    message A {
814                        message AInner {
815                        }
816                    }
817                }"#,
818                None,
819                vec![Element::Message(Message {
820                    name: "A",
821                    suffix: "",
822                    inner_types: vec![Element::Message(Message {
823                        name: "AInner",
824                        suffix: "",
825                        inner_types: vec![],
826                    })],
827                })],
828            );
829
830            test_parser(
831                r#"{
832                    package abc;
833                    service SomeService {
834                        rpc hello({
835                            string name
836                            message ReqInner {
837                            }
838                        }) returns({
839                            string words
840                        })
841                    }
842                }"#,
843                Some("abc"),
844                vec![
845                    Element::Message(Message {
846                        name: "hello",
847                        suffix: "Request",
848                        inner_types: vec![Element::Message(Message {
849                            name: "ReqInner",
850                            suffix: "",
851                            inner_types: vec![],
852                        })],
853                    }),
854                    Element::Message(Message {
855                        name: "hello",
856                        suffix: "Response",
857                        inner_types: vec![],
858                    }),
859                ],
860            );
861        }
862    }
863}