pb_rs/
types.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::fs::File;
4use std::io::{BufReader, BufWriter, Write};
5use std::path::{Path, PathBuf};
6
7use log::{debug, warn};
8
9use crate::errors::{Error, Result};
10use crate::keywords::sanitize_keyword;
11use crate::parser::file_descriptor;
12
13fn sizeof_varint(v: u32) -> usize {
14    match v {
15        0x0..=0x7F => 1,
16        0x80..=0x3FFF => 2,
17        0x4000..=0x1F_FFFF => 3,
18        0x20_0000..=0xFFF_FFFF => 4,
19        _ => 5,
20    }
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum Syntax {
25    Proto2,
26    Proto3,
27}
28
29impl Default for Syntax {
30    fn default() -> Syntax {
31        Syntax::Proto2
32    }
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum Frequency {
37    Optional,
38    Repeated,
39    Required,
40}
41
42#[derive(Clone, PartialEq, Eq, Hash, Default)]
43pub struct MessageIndex {
44    indexes: Vec<usize>,
45}
46
47impl fmt::Debug for MessageIndex {
48    fn fmt(&self, f: &mut fmt::Formatter) -> ::std::result::Result<(), fmt::Error> {
49        f.debug_set().entries(self.indexes.iter()).finish()
50    }
51}
52
53impl MessageIndex {
54    pub fn get_message<'a>(&self, desc: &'a FileDescriptor) -> &'a Message {
55        let first_message = self.indexes.first().and_then(|i| desc.messages.get(*i));
56        self.indexes
57            .iter()
58            .skip(1)
59            .fold(first_message, |cur, next| {
60                cur.and_then(|msg| msg.messages.get(*next))
61            })
62            .expect("Message index not found")
63    }
64
65    fn get_message_mut<'a>(&self, desc: &'a mut FileDescriptor) -> &'a mut Message {
66        let first_message = self
67            .indexes
68            .first()
69            .and_then(move |i| desc.messages.get_mut(*i));
70        self.indexes
71            .iter()
72            .skip(1)
73            .fold(first_message, |cur, next| {
74                cur.and_then(|msg| msg.messages.get_mut(*next))
75            })
76            .expect("Message index not found")
77    }
78
79    fn push(&mut self, i: usize) {
80        self.indexes.push(i);
81    }
82
83    fn pop(&mut self) {
84        self.indexes.pop();
85    }
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
89pub struct EnumIndex {
90    msg_index: MessageIndex,
91    index: usize,
92}
93
94impl EnumIndex {
95    pub fn get_enum<'a>(&self, desc: &'a FileDescriptor) -> &'a Enumerator {
96        let enums = if self.msg_index.indexes.is_empty() {
97            &desc.enums
98        } else {
99            &self.msg_index.get_message(desc).enums
100        };
101        enums.get(self.index).expect("Enum index not found")
102    }
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Hash)]
106pub enum FieldType {
107    Int32,
108    Int64,
109    Uint32,
110    Uint64,
111    Sint32,
112    Sint64,
113    Bool,
114    Enum(EnumIndex),
115    Fixed64,
116    Sfixed64,
117    Double,
118    StringCow,
119    BytesCow,
120    String_,
121    Bytes_,
122    Message(MessageIndex),
123    MessageOrEnum(String),
124    Fixed32,
125    Sfixed32,
126    Float,
127    Map(Box<FieldType>, Box<FieldType>),
128}
129
130impl FieldType {
131    pub fn is_primitive(&self) -> bool {
132        match *self {
133            FieldType::Message(_)
134            | FieldType::Map(_, _)
135            | FieldType::StringCow
136            | FieldType::BytesCow
137            | FieldType::String_
138            | FieldType::Bytes_ => false,
139            _ => true,
140        }
141    }
142
143    fn has_cow(&self) -> bool {
144        match *self {
145            FieldType::BytesCow | FieldType::StringCow => true,
146            FieldType::Map(ref k, ref v) => k.has_cow() || v.has_cow(),
147            _ => false,
148        }
149    }
150
151    fn has_bytes_and_string(&self) -> bool {
152        match *self {
153            FieldType::Bytes_ | FieldType::String_ => true,
154            _ => false,
155        }
156    }
157
158    fn is_map(&self) -> bool {
159        match *self {
160            FieldType::Map(_, _) => true,
161            _ => false,
162        }
163    }
164
165    fn wire_type_num(&self, packed: bool) -> u32 {
166        if packed {
167            2
168        } else {
169            self.wire_type_num_non_packed()
170        }
171    }
172
173    fn wire_type_num_non_packed(&self) -> u32 {
174        /*
175        0	Varint	int32, int64, uint32, uint64, sint32, sint64, bool, enum
176        1	64-bit	fixed64, sfixed64, double
177        2	Length-delimited	string, bytes, embedded messages, packed repeated fields
178        3	Start group	groups (deprecated)
179        4	End group	groups (deprecated)
180        5	32-bit	fixed32, sfixed32, float
181        */
182        match *self {
183            FieldType::Int32
184            | FieldType::Sint32
185            | FieldType::Int64
186            | FieldType::Sint64
187            | FieldType::Uint32
188            | FieldType::Uint64
189            | FieldType::Bool
190            | FieldType::Enum(_) => 0,
191            FieldType::Fixed64 | FieldType::Sfixed64 | FieldType::Double => 1,
192            FieldType::StringCow
193            | FieldType::BytesCow
194            | FieldType::String_
195            | FieldType::Bytes_
196            | FieldType::Message(_)
197            | FieldType::Map(_, _) => 2,
198            FieldType::Fixed32 | FieldType::Sfixed32 | FieldType::Float => 5,
199            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
200        }
201    }
202
203    fn proto_type(&self) -> &str {
204        match *self {
205            FieldType::Int32 => "int32",
206            FieldType::Sint32 => "sint32",
207            FieldType::Int64 => "int64",
208            FieldType::Sint64 => "sint64",
209            FieldType::Uint32 => "uint32",
210            FieldType::Uint64 => "uint64",
211            FieldType::Bool => "bool",
212            FieldType::Enum(_) => "enum",
213            FieldType::Fixed32 => "fixed32",
214            FieldType::Sfixed32 => "sfixed32",
215            FieldType::Float => "float",
216            FieldType::Fixed64 => "fixed64",
217            FieldType::Sfixed64 => "sfixed64",
218            FieldType::Double => "double",
219            FieldType::String_ => "string",
220            FieldType::Bytes_ => "bytes",
221            FieldType::StringCow => "string",
222            FieldType::BytesCow => "bytes",
223            FieldType::Message(_) => "message",
224            FieldType::Map(_, _) => "map",
225            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
226        }
227    }
228
229    fn is_fixed_size(&self) -> bool {
230        match self.wire_type_num_non_packed() {
231            1 | 5 => true,
232            _ => false,
233        }
234    }
235
236    fn regular_default<'a, 'b>(&'a self, desc: &'b FileDescriptor) -> Option<&'b str> {
237        match *self {
238            FieldType::Int32 => Some("0i32"),
239            FieldType::Sint32 => Some("0i32"),
240            FieldType::Int64 => Some("0i64"),
241            FieldType::Sint64 => Some("0i64"),
242            FieldType::Uint32 => Some("0u32"),
243            FieldType::Uint64 => Some("0u64"),
244            FieldType::Bool => Some("false"),
245            FieldType::Fixed32 => Some("0u32"),
246            FieldType::Sfixed32 => Some("0i32"),
247            FieldType::Float => Some("0f32"),
248            FieldType::Fixed64 => Some("0u64"),
249            FieldType::Sfixed64 => Some("0i64"),
250            FieldType::Double => Some("0f64"),
251            FieldType::StringCow => Some("\"\""),
252            FieldType::BytesCow => Some("Cow::Borrowed(b\"\")"),
253            FieldType::String_ => Some("String::default()"),
254            FieldType::Bytes_ => Some("vec![]"),
255            FieldType::Enum(ref e) => {
256                let e = e.get_enum(desc);
257                Some(&*e.fully_qualified_fields[0].0)
258            }
259            FieldType::Message(_) => None,
260            FieldType::Map(_, _) => None,
261            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
262        }
263    }
264
265    pub fn message(&self) -> Option<&MessageIndex> {
266        if let FieldType::Message(ref m) = self {
267            Some(m)
268        } else {
269            None
270        }
271    }
272
273    fn has_lifetime(
274        &self,
275        desc: &FileDescriptor,
276        config: &Config,
277        packed: bool,
278        ignore: &mut Vec<MessageIndex>,
279    ) -> bool {
280        match *self {
281            FieldType::StringCow | FieldType::BytesCow => true, // Cow<[u8]>
282            FieldType::Message(ref m) => m.get_message(desc).has_lifetime(desc, config, ignore),
283            FieldType::Fixed64
284            | FieldType::Sfixed64
285            | FieldType::Double
286            | FieldType::Fixed32
287            | FieldType::Sfixed32
288            | FieldType::String_
289            | FieldType::Bytes_
290            | FieldType::Float => packed, // Cow<[M]>
291            FieldType::Map(ref key, ref value) => {
292                key.has_lifetime(desc, config, false, ignore) || value.has_lifetime(desc, config, false, ignore)
293            }
294            _ => false,
295        }
296    }
297
298    fn rust_type(&self, desc: &FileDescriptor, config: &Config) -> Result<String> {
299        Ok(match *self {
300            FieldType::Int32 | FieldType::Sint32 | FieldType::Sfixed32 => "i32".to_string(),
301            FieldType::Int64 | FieldType::Sint64 | FieldType::Sfixed64 => "i64".to_string(),
302            FieldType::Uint32 | FieldType::Fixed32 => "u32".to_string(),
303            FieldType::Uint64 | FieldType::Fixed64 => "u64".to_string(),
304            FieldType::Double => "f64".to_string(),
305            FieldType::Float => "f32".to_string(),
306            FieldType::StringCow => "Cow<'a, str>".to_string(),
307            FieldType::BytesCow => "Cow<'a, [u8]>".to_string(),
308            FieldType::String_ => "String".to_string(),
309            FieldType::Bytes_ => "Vec<u8>".to_string(),
310            FieldType::Bool => "bool".to_string(),
311            FieldType::Enum(ref e) => {
312                let e = e.get_enum(desc);
313                format!("{}{}", e.get_modules(desc), e.name)
314            }
315            FieldType::Message(ref msg) => {
316                let m = msg.get_message(desc);
317                let lifetime = if m.has_lifetime(desc, config, &mut Vec::new()) {
318                    "<'a>"
319                } else {
320                    ""
321                };
322                format!("{}{}{}", m.get_modules(desc), m.name, lifetime)
323            }
324            FieldType::Map(ref key, ref value) => format!(
325                "KVMap<{}, {}>",
326                key.rust_type(desc, config)?,
327                value.rust_type(desc, config)?
328            ),
329            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
330        })
331    }
332
333    /// Returns the relevant function to read the data, both for regular and Cow wrapped
334    fn read_fn(&self, desc: &FileDescriptor) -> Result<(String, String)> {
335        Ok(match *self {
336            FieldType::Message(ref msg) => {
337                let m = msg.get_message(desc);
338                let m = format!(
339                    "r.read_message::<{}{}>(bytes)?",
340                    m.get_modules(desc),
341                    m.name
342                );
343                (m.clone(), m)
344            }
345            FieldType::Map(_, _) => return Err(Error::ReadFnMap),
346            FieldType::StringCow | FieldType::BytesCow => {
347                let m = format!("r.read_{}(bytes)", self.proto_type());
348                let cow = format!("{}.map(Cow::Borrowed)?", m);
349                (m, cow)
350            }
351            FieldType::String_ => {
352                let m = format!("r.read_{}(bytes)", self.proto_type());
353                let vec = format!("{}?.to_owned()", m);
354                (m, vec)
355            }
356            FieldType::Bytes_ => {
357                let m = format!("r.read_{}(bytes)", self.proto_type());
358                let vec = format!("{}?.to_owned()", m);
359                (m, vec)
360            }
361            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
362            _ => {
363                let m = format!("r.read_{}(bytes)?", self.proto_type());
364                (m.clone(), m)
365            }
366        })
367    }
368
369    fn get_size(&self, s: &str) -> String {
370        match *self {
371            FieldType::Int32
372            | FieldType::Int64
373            | FieldType::Uint32
374            | FieldType::Uint64
375            | FieldType::Bool
376            | FieldType::Enum(_) => format!("sizeof_varint(*({}) as u64)", s),
377            FieldType::Sint32 => format!("sizeof_sint32(*({}))", s),
378            FieldType::Sint64 => format!("sizeof_sint64(*({}))", s),
379
380            FieldType::Fixed64 | FieldType::Sfixed64 | FieldType::Double => "8".to_string(),
381            FieldType::Fixed32 | FieldType::Sfixed32 | FieldType::Float => "4".to_string(),
382
383            FieldType::StringCow | FieldType::BytesCow => format!("sizeof_len(({}).len())", s),
384
385            FieldType::String_ | FieldType::Bytes_ => format!("sizeof_len(({}).len())", s),
386
387            FieldType::Message(_) => format!("sizeof_len(({}).get_size())", s),
388
389            FieldType::Map(ref k, ref v) => {
390                format!("2 + {} + {}", k.get_size("k"), v.get_size("v"))
391            }
392            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
393        }
394    }
395
396    fn get_write(&self, s: &str, boxed: bool) -> String {
397        match *self {
398            FieldType::Enum(_) => format!("write_enum(*{} as i32)", s),
399
400            FieldType::Int32
401            | FieldType::Sint32
402            | FieldType::Int64
403            | FieldType::Sint64
404            | FieldType::Uint32
405            | FieldType::Uint64
406            | FieldType::Bool
407            | FieldType::Fixed64
408            | FieldType::Sfixed64
409            | FieldType::Double
410            | FieldType::Fixed32
411            | FieldType::Sfixed32
412            | FieldType::Float => format!("write_{}(*{})", self.proto_type(), s),
413
414            FieldType::StringCow => format!("write_string(&**{})", s),
415            FieldType::BytesCow => format!("write_bytes(&**{})", s),
416            FieldType::String_ => format!("write_string(&**{})", s),
417            FieldType::Bytes_ => format!("write_bytes(&**{})", s),
418
419            FieldType::Message(_) if boxed => format!("write_message(&**{})", s),
420            FieldType::Message(_) => format!("write_message({})", s),
421
422            FieldType::Map(ref k, ref v) => format!(
423                "write_map({}, {}, |w| w.{}, {}, |w| w.{})",
424                self.get_size(""),
425                tag(1, k, false),
426                k.get_write("k", false),
427                tag(2, v, false),
428                v.get_write("v", false)
429            ),
430            FieldType::MessageOrEnum(_) => unreachable!("Message / Enum not resolved"),
431        }
432    }
433}
434
435#[derive(Debug, Clone)]
436pub struct Field {
437    pub name: String,
438    pub frequency: Frequency,
439    pub typ: FieldType,
440    pub number: i32,
441    pub default: Option<String>,
442    pub packed: Option<bool>,
443    pub boxed: bool,
444    pub deprecated: bool,
445}
446
447impl Field {
448    fn packed(&self) -> bool {
449        self.packed.unwrap_or(false)
450    }
451
452    fn sanitize_default(&mut self, desc: &FileDescriptor, config: &Config) -> Result<()> {
453        if let Some(ref mut d) = self.default {
454            *d = match &*self.typ.rust_type(desc, config)? {
455                "u32" => format!("{}u32", *d),
456                "u64" => format!("{}u64", *d),
457                "i32" => format!("{}i32", *d),
458                "i64" => format!("{}i64", *d),
459                "f32" => match &*d.to_lowercase() {
460                    "inf" => "::core::f32::INFINITY".to_string(),
461                    "-inf" => "::core::f32::NEG_INFINITY".to_string(),
462                    "nan" => "::core::f32::NAN".to_string(),
463                    _ => format!("{}f32", *d),
464                },
465                "f64" => match &*d.to_lowercase() {
466                    "inf" => "::core::f64::INFINITY".to_string(),
467                    "-inf" => "::core::f64::NEG_INFINITY".to_string(),
468                    "nan" => "::core::f64::NAN".to_string(),
469                    _ => format!("{}f64", *d),
470                },
471                "Cow<'a, str>" => format!("Cow::Borrowed({})", d),
472                "Cow<'a, [u8]>" => format!("Cow::Borrowed(b{})", d),
473                "String" => format!("String::from({})", d),
474                "Bytes" => format!(r#"b{}"#, d),
475                "Vec<u8>" => format!("b{}.to_vec()", d),
476                "bool" => format!("{}", d.parse::<bool>().unwrap()),
477                e => format!("{}::{}", e, d), // enum, as message and map do not have defaults
478            }
479        }
480        Ok(())
481    }
482
483    fn has_regular_default(&self, desc: &FileDescriptor) -> bool {
484        self.default.is_none()
485            || self.default.as_ref().map(|d| &**d) == self.typ.regular_default(desc)
486    }
487
488    fn tag(&self) -> u32 {
489        tag(self.number as u32, &self.typ, self.packed())
490    }
491
492    fn write_definition<W: Write>(
493        &self,
494        w: &mut W,
495        desc: &FileDescriptor,
496        config: &Config,
497    ) -> Result<()> {
498        if self.deprecated {
499            if config.add_deprecated_fields {
500                writeln!(w, "    #[deprecated]")?;
501            } else {
502                return Ok(());
503            }
504        }
505        write!(w, "    pub {}: ", self.name)?;
506        let rust_type = self.typ.rust_type(desc, config)?;
507        match self.frequency {
508            _ if self.boxed => writeln!(w, "Option<Box<{}>>,", rust_type)?,
509            Frequency::Optional
510                if desc.syntax == Syntax::Proto2 && self.default.is_none()
511                    || self.typ.message().is_some() =>
512            {
513                writeln!(w, "Option<{}>,", rust_type)?
514            }
515            Frequency::Repeated
516                if self.packed() && self.typ.is_fixed_size() && !config.dont_use_cow =>
517            {
518                writeln!(w, "Cow<'a, [{}]>,", rust_type)?;
519            }
520            Frequency::Repeated => writeln!(w, "Vec<{}>,", rust_type)?,
521            Frequency::Required | Frequency::Optional => writeln!(w, "{},", rust_type)?,
522        }
523        Ok(())
524    }
525
526    fn write_match_tag<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
527        if self.deprecated && !config.add_deprecated_fields {
528            return Ok(());
529        }
530
531        // special case for FieldType::Map: destructure tuple before inserting in HashMap
532        if let FieldType::Map(ref key, ref value) = self.typ {
533            writeln!(w, "                Ok({}) => {{", self.tag())?;
534            writeln!(
535                w,
536                "                    let (key, value) = \
537                 r.read_map(bytes, |r, bytes| Ok({}), |r, bytes| Ok({}))?;",
538                key.read_fn(desc)?.1,
539                value.read_fn(desc)?.1
540            )?;
541            writeln!(
542                w,
543                "                    msg.{}.insert(key, value);",
544                self.name
545            )?;
546            writeln!(w, "                }}")?;
547            return Ok(());
548        }
549
550        let (val, val_cow) = self.typ.read_fn(desc)?;
551        let name = &self.name;
552        write!(w, "                Ok({}) => ", self.tag())?;
553        match self.frequency {
554            _ if self.boxed => writeln!(w, "msg.{} = Some(Box::new({})),", name, val)?,
555            Frequency::Optional
556                if desc.syntax == Syntax::Proto2 && self.default.is_none()
557                    || self.typ.message().is_some() =>
558            {
559                writeln!(w, "msg.{} = Some({}),", name, val_cow)?
560            }
561            Frequency::Required | Frequency::Optional => {
562                writeln!(w, "msg.{} = {},", name, val_cow)?
563            }
564            Frequency::Repeated if self.packed() && self.typ.is_fixed_size() => {
565                writeln!(w, "msg.{} = r.read_packed_fixed(bytes)?.into(),", name)?;
566            }
567            Frequency::Repeated if self.packed() => {
568                writeln!(
569                    w,
570                    "msg.{} = r.read_packed(bytes, |r, bytes| Ok({}))?,",
571                    name, val_cow
572                )?;
573            }
574            Frequency::Repeated => writeln!(w, "msg.{}.push({}),", name, val_cow)?,
575        }
576        Ok(())
577    }
578
579    fn write_get_size<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
580        if self.deprecated && !config.add_deprecated_fields {
581            return Ok(());
582        }
583
584        write!(w, "        + ")?;
585        let tag_size = sizeof_varint(self.tag());
586        match self.frequency {
587            Frequency::Optional
588                if desc.syntax == Syntax::Proto2 || self.typ.message().is_some() =>
589            {
590                // TODO this might be incorrect behavior for proto2
591                match self.default.as_ref() {
592                    None => {
593                        write!(w, "self.{}.as_ref().map_or(0, ", self.name)?;
594                        if self.typ.is_fixed_size() {
595                            writeln!(w, "|_| {} + {})", tag_size, self.typ.get_size(""))?;
596                        } else {
597                            writeln!(w, "|m| {} + {})", tag_size, self.typ.get_size("m"))?;
598                        }
599                    }
600                    Some(d) => {
601                        writeln!(
602                            w,
603                            "if self.{} == {} {{ 0 }} else {{ {} + {} }}",
604                            self.name,
605                            d,
606                            tag_size,
607                            self.typ.get_size(&format!("&self.{}", self.name))
608                        )?;
609                    }
610                }
611            }
612            Frequency::Required if self.typ.is_map() => {
613                writeln!(
614                    w,
615                    "self.{}.iter().map(|(k, v)| {} + sizeof_len({})).sum::<usize>()",
616                    self.name,
617                    tag_size,
618                    self.typ.get_size("")
619                )?;
620            }
621            Frequency::Optional => match self.typ {
622                FieldType::Bytes_ => writeln!(
623                    w,
624                    "if self.{}.is_empty() {{ 0 }} else {{ {} + {} }}",
625                    self.name,
626                    tag_size,
627                    self.typ.get_size(&format!("&self.{}", self.name))
628                )?,
629                _ => writeln!(
630                    w,
631                    "if self.{} == {} {{ 0 }} else {{ {} + {} }}",
632                    self.name,
633                    self.default.as_ref().map_or_else(
634                        || self.typ.regular_default(desc).unwrap_or("None"),
635                        |s| s.as_str()
636                    ),
637                    tag_size,
638                    self.typ.get_size(&format!("&self.{}", self.name))
639                )?,
640            },
641            Frequency::Required => writeln!(
642                w,
643                "{} + {}",
644                tag_size,
645                self.typ.get_size(&format!("&self.{}", self.name))
646            )?,
647            Frequency::Repeated => {
648                if self.packed() {
649                    write!(
650                        w,
651                        "if self.{}.is_empty() {{ 0 }} else {{ {} + ",
652                        self.name, tag_size
653                    )?;
654                    match self.typ.wire_type_num_non_packed() {
655                        1 => writeln!(w, "sizeof_len(self.{}.len() * 8) }}", self.name)?,
656                        5 => writeln!(w, "sizeof_len(self.{}.len() * 4) }}", self.name)?,
657                        _ => writeln!(
658                            w,
659                            "sizeof_len(self.{}.iter().map(|s| {}).sum::<usize>()) }}",
660                            self.name,
661                            self.typ.get_size("s")
662                        )?,
663                    }
664                } else {
665                    match self.typ.wire_type_num_non_packed() {
666                        1 => writeln!(w, "({} + 8) * self.{}.len()", tag_size, self.name)?,
667                        5 => writeln!(w, "({} + 4) * self.{}.len()", tag_size, self.name)?,
668                        _ => writeln!(
669                            w,
670                            "self.{}.iter().map(|s| {} + {}).sum::<usize>()",
671                            self.name,
672                            tag_size,
673                            self.typ.get_size("s")
674                        )?,
675                    }
676                }
677            }
678        }
679        Ok(())
680    }
681
682    fn write_write<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
683        if self.deprecated && !config.add_deprecated_fields {
684            return Ok(());
685        }
686
687        match self.frequency {
688            Frequency::Optional
689                if desc.syntax == Syntax::Proto2 || self.typ.message().is_some() =>
690            {
691                match self.default.as_ref() {
692                    None => {
693                        writeln!(
694                            w,
695                            "        if let Some(ref s) = \
696                             self.{} {{ w.write_with_tag({}, |w| w.{})?; }}",
697                            self.name,
698                            self.tag(),
699                            self.typ.get_write("s", self.boxed)
700                        )?;
701                    }
702                    Some(d) => {
703                        writeln!(
704                            w,
705                            "        if self.{} != {} {{ w.write_with_tag({}, |w| w.{})?; }}",
706                            self.name,
707                            d,
708                            self.tag(),
709                            self.typ
710                                .get_write(&format!("&self.{}", self.name), self.boxed)
711                        )?;
712                    }
713                }
714            }
715            Frequency::Optional => match self.typ {
716                FieldType::Bytes_ => {
717                    writeln!(
718                        w,
719                        "        if !self.{}.is_empty() {{ w.write_with_tag({}, |w| w.{})?; }}",
720                        self.name,
721                        self.tag(),
722                        self.typ
723                            .get_write(&format!("&self.{}", self.name), self.boxed)
724                    )?;
725                }
726                _ => {
727                    writeln!(
728                        w,
729                        "        if self.{} != {} {{ w.write_with_tag({}, |w| w.{})?; }}",
730                        self.name,
731                        self.default.as_ref().map_or_else(
732                            || self.typ.regular_default(desc).unwrap_or("None"),
733                            |s| s.as_str()
734                        ),
735                        self.tag(),
736                        self.typ
737                            .get_write(&format!("&self.{}", self.name), self.boxed)
738                    )?;
739                }
740            },
741            Frequency::Required if self.typ.is_map() => {
742                writeln!(
743                    w,
744                    "        for (k, v) in self.{}.iter() {{ w.write_with_tag({}, |w| w.{})?; }}",
745                    self.name,
746                    self.tag(),
747                    self.typ.get_write("", false)
748                )?;
749            }
750            Frequency::Required => {
751                writeln!(
752                    w,
753                    "        w.write_with_tag({}, |w| w.{})?;",
754                    self.tag(),
755                    self.typ
756                        .get_write(&format!("&self.{}", self.name), self.boxed)
757                )?;
758            }
759            Frequency::Repeated if self.packed() && self.typ.is_fixed_size() => writeln!(
760                w,
761                "        w.write_packed_fixed_with_tag({}, &self.{})?;",
762                self.tag(),
763                self.name
764            )?,
765            Frequency::Repeated if self.packed() => writeln!(
766                w,
767                "        w.write_packed_with_tag({}, &self.{}, |w, m| w.{}, &|m| {})?;",
768                self.tag(),
769                self.name,
770                self.typ.get_write("m", self.boxed),
771                self.typ.get_size("m")
772            )?,
773            Frequency::Repeated => {
774                writeln!(
775                    w,
776                    "        for s in &self.{} {{ w.write_with_tag({}, |w| w.{})?; }}",
777                    self.name,
778                    self.tag(),
779                    self.typ.get_write("s", self.boxed)
780                )?;
781            }
782        }
783        Ok(())
784    }
785}
786
787fn get_modules(module: &str, imported: bool, desc: &FileDescriptor) -> String {
788    let skip = if desc.package.is_empty() && !imported {
789        1
790    } else {
791        0
792    };
793    module
794        .split('.')
795        .filter(|p| !p.is_empty())
796        .skip(skip)
797        .map(|p| format!("{}::", p))
798        .collect()
799}
800
801#[derive(Debug, Clone, Default)]
802pub struct Message {
803    pub name: String,
804    pub fields: Vec<Field>,
805    pub oneofs: Vec<OneOf>,
806    pub reserved_nums: Option<Vec<i32>>,
807    pub reserved_names: Option<Vec<String>>,
808    pub imported: bool,
809    pub package: String,        // package from imports + nested items
810    pub messages: Vec<Message>, // nested messages
811    pub enums: Vec<Enumerator>, // nested enums
812    pub module: String,         // 'package' corresponding to actual generated Rust module
813    pub path: PathBuf,
814    pub import: PathBuf,
815    pub index: MessageIndex,
816}
817
818impl Message {
819    fn convert_field_types(&mut self, from: &FieldType, to: &FieldType) {
820        for f in self.all_fields_mut().filter(|f| f.typ == *from) {
821            f.typ = to.clone();
822        }
823
824        // If that type is a map with the fieldtype, it must also be converted.
825        for f in self.all_fields_mut() {
826            let new_type: FieldType = match f.typ {
827                FieldType::Map(ref mut key, ref mut value)
828                    if **key == *from && **value == *from =>
829                {
830                    FieldType::Map(Box::new(to.clone()), Box::new(to.clone()))
831                }
832                FieldType::Map(ref mut key, ref mut value) if **key == *from => {
833                    FieldType::Map(Box::new(to.clone()), value.clone())
834                }
835                FieldType::Map(ref mut key, ref mut value) if **value == *from => {
836                    FieldType::Map(key.clone(), Box::new(to.clone()))
837                }
838                ref other => other.clone(),
839            };
840            f.typ = new_type;
841        }
842
843        for message in &mut self.messages {
844            message.convert_field_types(from, to);
845        }
846    }
847
848    fn has_lifetime(&self, desc: &FileDescriptor, config: &Config, ignore: &mut Vec<MessageIndex>) -> bool {
849        if ignore.contains(&&self.index) {
850            return false;
851        }
852        ignore.push(self.index.clone());
853        let res = self.all_fields().any(|f| {
854            f.typ.has_lifetime(desc, config, f.packed(), ignore)
855                && (!f.deprecated || config.add_deprecated_fields)
856        });
857        ignore.pop();
858        res
859    }
860
861    fn set_imported(&mut self) {
862        self.imported = true;
863        for o in self.oneofs.iter_mut() {
864            o.imported = true;
865        }
866        for m in self.messages.iter_mut() {
867            m.set_imported();
868        }
869        for e in self.enums.iter_mut() {
870            e.imported = true;
871        }
872    }
873
874    fn get_modules(&self, desc: &FileDescriptor) -> String {
875        get_modules(&self.module, self.imported, desc)
876    }
877
878    fn is_unit(&self) -> bool {
879        self.fields.is_empty()
880            && self.oneofs.is_empty()
881            && self.messages.iter().all(|m| m.is_unit())
882    }
883
884    fn write_common_uses<W: Write>(
885        w: &mut W,
886        messages: &Vec<Message>,
887        desc: &FileDescriptor,
888        config: &Config,
889    ) -> Result<()> {
890        if config.nostd {
891            writeln!(w, "use alloc::vec::Vec;")?;
892        }
893
894        if !config.dont_use_cow {
895            if messages.iter().any(|m| {
896                m.all_fields()
897                    .any(|f| (f.typ.has_cow() || (f.packed() && f.typ.is_fixed_size())))
898            }) {
899                if config.nostd {
900                    writeln!(w, "use alloc::borrow::Cow;")?;
901                } else {
902                    writeln!(w, "use std::borrow::Cow;")?;
903                }
904            }
905        } else if config.nostd {
906            if messages
907                .iter()
908                .any(|m| m.all_fields().any(|f| (f.typ.has_bytes_and_string())))
909            {
910                writeln!(w, "use alloc::borrow::ToOwned;")?;
911            }
912        }
913
914        if config.nostd
915            && messages.iter().any(|m| {
916                desc.owned && m.has_lifetime(desc, config, &mut Vec::new())
917                    || m.all_fields().any(|f| f.boxed)
918            })
919        {
920            writeln!(w)?;
921            writeln!(w, "use alloc::boxed::Box;")?;
922        }
923
924        if messages
925            .iter()
926            .filter(|m| !m.imported)
927            .any(|m| m.all_fields().any(|f| f.typ.is_map()))
928        {
929            if config.hashbrown {
930                writeln!(w, "use hashbrown::HashMap;")?;
931                writeln!(w, "type KVMap<K, V> = HashMap<K, V>;")?;
932            } else if config.nostd {
933                writeln!(w, "use alloc::collections::BTreeMap;")?;
934                writeln!(w, "type KVMap<K, V> = BTreeMap<K, V>;")?;
935            } else {
936                writeln!(w, "use std::collections::HashMap;")?;
937                writeln!(w, "type KVMap<K, V> = HashMap<K, V>;")?;
938            }
939        }
940
941        Ok(())
942    }
943
944    fn write<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
945        println!("Writing message {}{}", self.get_modules(desc), self.name);
946        writeln!(w)?;
947
948        self.write_definition(w, desc, config)?;
949        writeln!(w)?;
950        self.write_impl_message_read(w, desc, config)?;
951        writeln!(w)?;
952        self.write_impl_message_write(w, desc, config)?;
953
954        if config.gen_info {
955            self.write_impl_message_info(w, desc, config)?;
956            writeln!(w)?;
957        }
958
959        if desc.owned {
960            writeln!(w)?;
961
962            if self.has_lifetime(desc, config, &mut Vec::new()) {
963                self.write_impl_owned(w, config)?;
964            } else {
965                self.write_impl_try_from(w)?;
966            }
967        }
968
969        if !(self.messages.is_empty() && self.enums.is_empty() && self.oneofs.is_empty()) {
970            writeln!(w)?;
971            writeln!(w, "pub mod mod_{} {{", self.name)?;
972            writeln!(w)?;
973
974            Self::write_common_uses(w, &self.messages, desc, config)?;
975
976            if !self.messages.is_empty() || !self.oneofs.is_empty() {
977                writeln!(w, "use super::*;")?;
978            }
979            for m in &self.messages {
980                m.write(w, desc, config)?;
981            }
982            for e in &self.enums {
983                e.write(w)?;
984            }
985            for o in &self.oneofs {
986                o.write(w, desc, config)?;
987            }
988
989            writeln!(w)?;
990            writeln!(w, "}}")?;
991        }
992
993        Ok(())
994    }
995
996    fn write_definition<W: Write>(
997        &self,
998        w: &mut W,
999        desc: &FileDescriptor,
1000        config: &Config,
1001    ) -> Result<()> {
1002        let mut custom_struct_derive = config.custom_struct_derive.join(", ");
1003        if !custom_struct_derive.is_empty() {
1004            custom_struct_derive += ", ";
1005        }
1006
1007        writeln!(w, "#[allow(clippy::derive_partial_eq_without_eq)]")?;
1008
1009        writeln!(
1010            w,
1011            "#[derive({}Debug, Default, PartialEq, Clone)]",
1012            custom_struct_derive
1013        )?;
1014
1015        if let Some(repr) = &config.custom_repr {
1016            writeln!(w, "#[repr({})]", repr)?;
1017        }
1018
1019        if self.is_unit() {
1020            writeln!(w, "pub struct {} {{ }}", self.name)?;
1021            return Ok(());
1022        }
1023
1024        let mut ignore = Vec::new();
1025        if config.dont_use_cow {
1026            ignore.push(self.index.clone());
1027        }
1028        if self.has_lifetime(desc, config, &mut ignore) {
1029            writeln!(w, "pub struct {}<'a> {{", self.name)?;
1030        } else {
1031            writeln!(w, "pub struct {} {{", self.name)?;
1032        }
1033        for f in &self.fields {
1034            f.write_definition(w, desc, config)?;
1035        }
1036        for o in &self.oneofs {
1037            o.write_message_definition(w, desc, config)?;
1038        }
1039        writeln!(w, "}}")?;
1040        Ok(())
1041    }
1042
1043    fn write_impl_message_info<W: Write>(
1044        &self,
1045        w: &mut W,
1046        desc: &FileDescriptor,
1047        config: &Config,
1048    ) -> Result<()> {
1049        let mut ignore = Vec::new();
1050        if config.dont_use_cow {
1051            ignore.push(self.index.clone());
1052        }
1053        if self.has_lifetime(desc, config, &mut ignore) {
1054            writeln!(w, "impl<'a> MessageInfo for {}<'a> {{", self.name)?;
1055        } else {
1056            writeln!(w, "impl MessageInfo for {} {{", self.name)?;
1057        }
1058        writeln!(
1059            w,
1060            "    const PATH : &'static str = \"{}.{}\";",
1061            self.module, self.name
1062        )?;
1063        writeln!(w, "}}")?;
1064        Ok(())
1065    }
1066
1067    fn write_impl_message_read<W: Write>(
1068        &self,
1069        w: &mut W,
1070        desc: &FileDescriptor,
1071        config: &Config,
1072    ) -> Result<()> {
1073        if self.is_unit() {
1074            writeln!(w, "impl<'a> MessageRead<'a> for {} {{", self.name)?;
1075            writeln!(
1076                w,
1077                "    fn from_reader(r: &mut BytesReader, _: &[u8]) -> Result<Self> {{"
1078            )?;
1079            writeln!(w, "        r.read_to_end();")?;
1080            writeln!(w, "        Ok(Self::default())")?;
1081            writeln!(w, "    }}")?;
1082            writeln!(w, "}}")?;
1083            return Ok(());
1084        }
1085
1086        let mut ignore = Vec::new();
1087        if config.dont_use_cow {
1088            ignore.push(self.index.clone());
1089        }
1090        if self.has_lifetime(desc, config, &mut ignore) {
1091            writeln!(w, "impl<'a> MessageRead<'a> for {}<'a> {{", self.name)?;
1092            writeln!(
1093                w,
1094                "    fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result<Self> {{"
1095            )?;
1096        } else {
1097            writeln!(w, "impl<'a> MessageRead<'a> for {} {{", self.name)?;
1098            writeln!(
1099                w,
1100                "    fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result<Self> {{"
1101            )?;
1102        }
1103
1104        let unregular_defaults = self
1105            .fields
1106            .iter()
1107            .filter(|f| !f.has_regular_default(desc))
1108            .collect::<Vec<_>>();
1109        if unregular_defaults.is_empty() {
1110            writeln!(w, "        let mut msg = Self::default();")?;
1111        } else {
1112            writeln!(w, "        let mut msg = {} {{", self.name)?;
1113            for f in unregular_defaults {
1114                writeln!(
1115                    w,
1116                    "            {}: {},",
1117                    f.name,
1118                    f.default.as_ref().unwrap()
1119                )?;
1120            }
1121            writeln!(w, "            ..Self::default()")?;
1122            writeln!(w, "        }};")?;
1123        }
1124        writeln!(w, "        while !r.is_eof() {{")?;
1125        writeln!(w, "            match r.next_tag(bytes) {{")?;
1126        for f in &self.fields {
1127            f.write_match_tag(w, desc, config)?;
1128        }
1129        for o in &self.oneofs {
1130            o.write_match_tag(w, desc, config)?;
1131        }
1132        writeln!(
1133            w,
1134            "                Ok(t) => {{ r.read_unknown(bytes, t)?; }}"
1135        )?;
1136        writeln!(w, "                Err(e) => return Err(e),")?;
1137        writeln!(w, "            }}")?;
1138        writeln!(w, "        }}")?;
1139        writeln!(w, "        Ok(msg)")?;
1140        writeln!(w, "    }}")?;
1141        writeln!(w, "}}")?;
1142
1143        // TODO: write impl default when special default?
1144        // alternatively set the default value directly when reading
1145
1146        Ok(())
1147    }
1148
1149    fn write_impl_message_write<W: Write>(
1150        &self,
1151        w: &mut W,
1152        desc: &FileDescriptor,
1153        config: &Config,
1154    ) -> Result<()> {
1155        if self.is_unit() {
1156            writeln!(w, "impl MessageWrite for {} {{ }}", self.name)?;
1157            return Ok(());
1158        }
1159
1160        let mut ignore = Vec::new();
1161        if config.dont_use_cow {
1162            ignore.push(self.index.clone());
1163        }
1164        if self.has_lifetime(desc, config, &mut ignore) {
1165            writeln!(w, "impl<'a> MessageWrite for {}<'a> {{", self.name)?;
1166        } else {
1167            writeln!(w, "impl MessageWrite for {} {{", self.name)?;
1168        }
1169        self.write_get_size(w, desc, config)?;
1170        writeln!(w)?;
1171        self.write_write_message(w, desc, config)?;
1172        writeln!(w, "}}")?;
1173        Ok(())
1174    }
1175
1176    fn write_impl_owned<W: Write>(&self, w: &mut W, config: &Config) -> Result<()> {
1177        write!(
1178            w,
1179            r#"
1180            // IMPORTANT: For any future changes, note that the lifetime parameter
1181            // of the `proto` field is set to 'static!!!
1182            //
1183            // This means that the internals of `proto` should at no point create a
1184            // mutable reference to something using that lifetime parameter, on pain
1185            // of UB. This applies even though it may be transmuted to a smaller
1186            // lifetime later (through `proto()` or `proto_mut()`).
1187            //
1188            // At the time of writing, the only possible thing that uses the
1189            // lifetime parameter is `Cow<'a, T>`, which never does this, so it's
1190            // not UB.
1191            //
1192            #[derive(Debug)]
1193            struct {name}OwnedInner {{
1194                buf: Vec<u8>,
1195                proto: Option<{name}<'static>>,
1196                _pin: core::marker::PhantomPinned,
1197            }}
1198
1199            impl {name}OwnedInner {{
1200                fn new(buf: Vec<u8>) -> Result<core::pin::Pin<Box<Self>>> {{
1201                    let inner = Self {{
1202                        buf,
1203                        proto: None,
1204                        _pin: core::marker::PhantomPinned,
1205                    }};
1206                    let mut pinned = Box::pin(inner);
1207
1208                    let mut reader = BytesReader::from_bytes(&pinned.buf);
1209                    let proto = {name}::from_reader(&mut reader, &pinned.buf)?;
1210
1211                    unsafe {{
1212                        let proto = core::mem::transmute::<_, {name}<'_>>(proto);
1213                        pinned.as_mut().get_unchecked_mut().proto = Some(proto);
1214                    }}
1215                    Ok(pinned)
1216                }}
1217            }}
1218
1219            pub struct {name}Owned {{
1220                inner: core::pin::Pin<Box<{name}OwnedInner>>,
1221            }}
1222
1223            #[allow(dead_code)]
1224            impl {name}Owned {{
1225                pub fn buf(&self) -> &[u8] {{
1226                    &self.inner.buf
1227                }}
1228
1229                pub fn proto<'a>(&'a self) -> &'a {name}<'a> {{
1230                    let proto = self.inner.proto.as_ref().unwrap();
1231                    unsafe {{ core::mem::transmute::<&{name}<'static>, &{name}<'a>>(proto) }}
1232                }}
1233
1234                pub fn proto_mut<'a>(&'a mut self) -> &'a mut {name}<'a> {{
1235                    let inner = self.inner.as_mut();
1236                    let inner = unsafe {{ inner.get_unchecked_mut() }};
1237                    let proto = inner.proto.as_mut().unwrap();
1238                    unsafe {{ core::mem::transmute::<_, &mut {name}<'a>>(proto) }}
1239                }}
1240            }}
1241
1242            impl core::fmt::Debug for {name}Owned {{
1243                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {{
1244                    self.inner.proto.as_ref().unwrap().fmt(f)
1245                }}
1246            }}
1247
1248            impl TryFrom<Vec<u8>> for {name}Owned {{
1249                type Error=quick_protobuf::Error;
1250
1251                fn try_from(buf: Vec<u8>) -> Result<Self> {{
1252                    Ok(Self {{ inner: {name}OwnedInner::new(buf)? }})
1253                }}
1254            }}
1255
1256            impl TryInto<Vec<u8>> for {name}Owned {{
1257                type Error=quick_protobuf::Error;
1258
1259                fn try_into(self) -> Result<Vec<u8>> {{
1260                    let mut buf = Vec::new();
1261                    let mut writer = Writer::new(&mut buf);
1262                    self.inner.proto.as_ref().unwrap().write_message(&mut writer)?;
1263                    Ok(buf)
1264                }}
1265            }}
1266
1267            impl From<{name}<'static>> for {name}Owned {{
1268                fn from(proto: {name}<'static>) -> Self {{
1269                    Self {{
1270                        inner: Box::pin({name}OwnedInner {{
1271                            buf: Vec::new(),
1272                            proto: Some(proto),
1273                            _pin: core::marker::PhantomPinned,
1274                        }})
1275                    }}
1276                }}
1277            }}
1278            "#,
1279            name = self.name
1280        )?;
1281
1282        if config.gen_info {
1283            write!(w, r#"
1284            impl MessageInfo for {name}Owned {{
1285                const PATH: &'static str = "{module}.{name}";
1286            }}
1287            "#, name = self.name, module = self.module)?;
1288        }
1289        Ok(())
1290    }
1291
1292    fn write_get_size<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1293        writeln!(w, "    fn get_size(&self) -> usize {{")?;
1294        writeln!(w, "        0")?;
1295        for f in &self.fields {
1296            f.write_get_size(w, desc, config)?;
1297        }
1298        for o in self.oneofs.iter() {
1299            o.write_get_size(w, desc, config)?;
1300        }
1301        writeln!(w, "    }}")?;
1302        Ok(())
1303    }
1304
1305    fn write_impl_try_from<W: Write>(&self, w: &mut W) -> Result<()> {
1306        write!(
1307            w,
1308            r#"
1309            impl TryFrom<&[u8]> for {name} {{
1310                type Error=quick_protobuf::Error;
1311
1312                fn try_from(buf: &[u8]) -> Result<Self> {{
1313                    let mut reader = BytesReader::from_bytes(&buf);
1314                    Ok({name}::from_reader(&mut reader, &buf)?)
1315                }}
1316            }}
1317            "#,
1318            name = self.name
1319        )?;
1320        Ok(())
1321    }
1322
1323    fn write_write_message<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1324        writeln!(
1325            w,
1326            "    fn write_message<W: WriterBackend>(&self, w: &mut Writer<W>) -> Result<()> {{"
1327        )?;
1328        for f in &self.fields {
1329            f.write_write(w, desc, config)?;
1330        }
1331        for o in &self.oneofs {
1332            o.write_write(w, desc, config)?;
1333        }
1334        writeln!(w, "        Ok(())")?;
1335        writeln!(w, "    }}")?;
1336        Ok(())
1337    }
1338
1339    fn sanity_checks(&self, desc: &FileDescriptor) -> Result<()> {
1340        for f in self.all_fields() {
1341            // check reserved
1342            if self
1343                .reserved_names
1344                .as_ref()
1345                .map_or(false, |names| names.contains(&f.name))
1346                || self
1347                    .reserved_nums
1348                    .as_ref()
1349                    .map_or(false, |nums| nums.contains(&f.number))
1350            {
1351                return Err(Error::InvalidMessage(format!(
1352                    "Error in message {}\n\
1353                     Field {:?} conflict with reserved fields",
1354                    self.name, f
1355                )));
1356            }
1357
1358            // check default enums
1359            if let Some(var) = f.default.as_ref() {
1360                if let FieldType::Enum(ref e) = f.typ {
1361                    let e = e.get_enum(desc);
1362                    e.fields.iter().find(|&(ref name, _)| name == var)
1363                    .ok_or_else(|| Error::InvalidDefaultEnum(format!(
1364                                "Error in message {}\n\
1365                                Enum field {:?} has a default value '{}' which is not valid for enum index {:?}",
1366                                self.name, f, var, e)))?;
1367                }
1368            }
1369        }
1370        Ok(())
1371    }
1372
1373    fn set_package(&mut self, package: &str, module: &str) {
1374        // The complication here is that the _package_ (as declared in the proto file) does
1375        // not directly map to the _module_. For example, the package 'a.A' where A is a
1376        // message will be the module 'a.mod_A', since we can't reuse the message name A as
1377        // the submodule containing nested items. Also, protos with empty packages always
1378        // have a module corresponding to the file name.
1379        let (child_package, child_module) = if package.is_empty() {
1380            self.module = module.to_string();
1381            (self.name.clone(), format!("{}.mod_{}", module, self.name))
1382        } else {
1383            self.package = package.to_string();
1384            self.module = module.to_string();
1385            (
1386                format!("{}.{}", package, self.name),
1387                format!("{}.mod_{}", module, self.name),
1388            )
1389        };
1390
1391        for m in &mut self.messages {
1392            m.set_package(&child_package, &child_module);
1393        }
1394        for m in &mut self.enums {
1395            m.set_package(&child_package, &child_module);
1396        }
1397        for m in &mut self.oneofs {
1398            m.set_package(&child_package, &child_module);
1399        }
1400    }
1401
1402    fn set_map_required(&mut self) {
1403        for f in self.all_fields_mut() {
1404            if let FieldType::Map(_, _) = f.typ {
1405                f.frequency = Frequency::Required;
1406            }
1407        }
1408        for m in &mut self.messages {
1409            m.set_map_required();
1410        }
1411    }
1412
1413    fn set_repeated_as_packed(&mut self) {
1414        for f in self.all_fields_mut() {
1415            if f.packed.is_none() {
1416                if let Frequency::Repeated = f.frequency {
1417                    f.packed = Some(true);
1418                }
1419            }
1420        }
1421    }
1422
1423    fn unset_packed_non_primitives(&mut self) {
1424        for f in self.all_fields_mut() {
1425            if !f.typ.is_primitive() && f.packed.is_some() {
1426                f.packed = None;
1427            }
1428        }
1429    }
1430
1431    fn sanitize_defaults(&mut self, desc: &FileDescriptor, config: &Config) -> Result<()> {
1432        for f in self.all_fields_mut() {
1433            f.sanitize_default(desc, config)?;
1434        }
1435        for m in &mut self.messages {
1436            m.sanitize_defaults(desc, config)?;
1437        }
1438        Ok(())
1439    }
1440
1441    fn sanitize_names(&mut self) {
1442        sanitize_keyword(&mut self.name);
1443        sanitize_keyword(&mut self.package);
1444        for f in self.fields.iter_mut() {
1445            sanitize_keyword(&mut f.name);
1446        }
1447        for m in &mut self.messages {
1448            m.sanitize_names();
1449        }
1450        for e in &mut self.enums {
1451            e.sanitize_names();
1452        }
1453        for o in &mut self.oneofs {
1454            o.sanitize_names();
1455        }
1456    }
1457
1458    /// Return an iterator producing references to all the `Field`s of `self`,
1459    /// including both direct and `oneof` fields.
1460    pub fn all_fields(&self) -> impl Iterator<Item = &Field> {
1461        self.fields
1462            .iter()
1463            .chain(self.oneofs.iter().flat_map(|o| o.fields.iter()))
1464    }
1465
1466    /// Return an iterator producing mutable references to all the `Field`s of
1467    /// `self`, including both direct and `oneof` fields.
1468    fn all_fields_mut(&mut self) -> impl Iterator<Item = &mut Field> {
1469        self.fields
1470            .iter_mut()
1471            .chain(self.oneofs.iter_mut().flat_map(|o| o.fields.iter_mut()))
1472    }
1473}
1474
1475#[derive(Debug, Clone, Default)]
1476pub struct RpcFunctionDeclaration {
1477    pub name: String,
1478    pub arg: String,
1479    pub ret: String,
1480}
1481
1482#[derive(Debug, Clone, Default)]
1483pub struct RpcService {
1484    pub service_name: String,
1485    pub functions: Vec<RpcFunctionDeclaration>,
1486}
1487
1488impl RpcService {
1489    fn write_definition<W: Write>(&self, w: &mut W, config: &Config) -> Result<()> {
1490        (config.custom_rpc_generator)(self, w)
1491    }
1492}
1493
1494pub type RpcGeneratorFunction = Box<dyn Fn(&RpcService, &mut dyn Write) -> Result<()>>;
1495
1496#[derive(Debug, Clone, Default)]
1497pub struct Enumerator {
1498    pub name: String,
1499    pub fields: Vec<(String, i32)>,
1500    pub fully_qualified_fields: Vec<(String, i32)>,
1501    pub partially_qualified_fields: Vec<(String, i32)>,
1502    pub imported: bool,
1503    pub package: String,
1504    pub module: String,
1505    pub path: PathBuf,
1506    pub import: PathBuf,
1507    pub index: EnumIndex,
1508}
1509
1510impl Enumerator {
1511    fn set_package(&mut self, package: &str, module: &str) {
1512        self.package = package.to_string();
1513        self.module = module.to_string();
1514        self.partially_qualified_fields = self
1515            .fields
1516            .iter()
1517            .map(|f| (format!("{}::{}", &self.name, f.0), f.1))
1518            .collect();
1519        self.fully_qualified_fields = self
1520            .partially_qualified_fields
1521            .iter()
1522            .map(|pqf| {
1523                let fqf = if self.module.is_empty() {
1524                    pqf.0.clone()
1525                } else {
1526                    format!("{}::{}", self.module.replace(".", "::"), pqf.0)
1527                };
1528                (fqf, pqf.1)
1529            })
1530            .collect();
1531    }
1532
1533    fn sanitize_names(&mut self) {
1534        sanitize_keyword(&mut self.name);
1535        sanitize_keyword(&mut self.package);
1536        for f in self.fields.iter_mut() {
1537            sanitize_keyword(&mut f.0);
1538        }
1539    }
1540
1541    fn get_modules(&self, desc: &FileDescriptor) -> String {
1542        get_modules(&self.module, self.imported, desc)
1543    }
1544
1545    fn write<W: Write>(&self, w: &mut W) -> Result<()> {
1546        println!("Writing enum {}", self.name);
1547        writeln!(w)?;
1548        self.write_definition(w)?;
1549        writeln!(w)?;
1550        if self.fields.is_empty() {
1551            Ok(())
1552        } else {
1553            self.write_impl_default(w)?;
1554            writeln!(w)?;
1555            self.write_from_i32(w)?;
1556            writeln!(w)?;
1557            self.write_from_str(w)
1558        }
1559    }
1560
1561    fn write_definition<W: Write>(&self, w: &mut W) -> Result<()> {
1562        writeln!(w, "#[derive(Debug, PartialEq, Eq, Clone, Copy)]")?;
1563        writeln!(w, "pub enum {} {{", self.name)?;
1564        for &(ref f, ref number) in &self.fields {
1565            writeln!(w, "    {} = {},", f, number)?;
1566        }
1567        writeln!(w, "}}")?;
1568        Ok(())
1569    }
1570
1571    fn write_impl_default<W: Write>(&self, w: &mut W) -> Result<()> {
1572        writeln!(w, "impl Default for {} {{", self.name)?;
1573        writeln!(w, "    fn default() -> Self {{")?;
1574        // TODO: check with default field and return error if there is no field
1575        writeln!(w, "        {}", self.partially_qualified_fields[0].0)?;
1576        writeln!(w, "    }}")?;
1577        writeln!(w, "}}")?;
1578        Ok(())
1579    }
1580
1581    fn write_from_i32<W: Write>(&self, w: &mut W) -> Result<()> {
1582        writeln!(w, "impl From<i32> for {} {{", self.name)?;
1583        writeln!(w, "    fn from(i: i32) -> Self {{")?;
1584        writeln!(w, "        match i {{")?;
1585        for &(ref f, ref number) in &self.fields {
1586            writeln!(w, "            {} => {}::{},", number, self.name, f)?;
1587        }
1588        writeln!(w, "            _ => Self::default(),")?;
1589        writeln!(w, "        }}")?;
1590        writeln!(w, "    }}")?;
1591        writeln!(w, "}}")?;
1592        Ok(())
1593    }
1594
1595    fn write_from_str<W: Write>(&self, w: &mut W) -> Result<()> {
1596        writeln!(w, "impl<'a> From<&'a str> for {} {{", self.name)?;
1597        writeln!(w, "    fn from(s: &'a str) -> Self {{")?;
1598        writeln!(w, "        match s {{")?;
1599        for &(ref f, _) in &self.fields {
1600            writeln!(w, "            {:?} => {}::{},", f, self.name, f)?;
1601        }
1602        writeln!(w, "            _ => Self::default(),")?;
1603        writeln!(w, "        }}")?;
1604        writeln!(w, "    }}")?;
1605        writeln!(w, "}}")?;
1606        Ok(())
1607    }
1608}
1609
1610#[derive(Debug, Clone, Default)]
1611pub struct OneOf {
1612    pub name: String,
1613    pub fields: Vec<Field>,
1614    pub package: String,
1615    pub module: String,
1616    pub imported: bool,
1617}
1618
1619impl OneOf {
1620    fn has_lifetime(&self, desc: &FileDescriptor, config: &Config) -> bool {
1621        self.fields.iter().any(|f| {
1622            f.typ
1623                .has_lifetime(desc, config, f.packed(), &mut Vec::new())
1624                && (!f.deprecated || config.add_deprecated_fields)
1625        })
1626    }
1627
1628    fn set_package(&mut self, package: &str, module: &str) {
1629        self.package = package.to_string();
1630        self.module = module.to_string();
1631    }
1632
1633    fn sanitize_names(&mut self) {
1634        sanitize_keyword(&mut self.name);
1635        sanitize_keyword(&mut self.package);
1636        for f in self.fields.iter_mut() {
1637            sanitize_keyword(&mut f.name);
1638        }
1639    }
1640
1641    fn get_modules(&self, desc: &FileDescriptor) -> String {
1642        get_modules(&self.module, self.imported, desc)
1643    }
1644
1645    fn write<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1646        writeln!(w)?;
1647        self.write_definition(w, desc, config)?;
1648        writeln!(w)?;
1649        self.write_impl_default(w, desc, config)?;
1650        Ok(())
1651    }
1652
1653    fn write_definition<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1654        writeln!(w, "#[derive(Debug, PartialEq, Clone)]")?;
1655        if self.has_lifetime(desc, config) {
1656            writeln!(w, "pub enum OneOf{}<'a> {{", self.name)?;
1657        } else {
1658            writeln!(w, "pub enum OneOf{} {{", self.name)?;
1659        }
1660        for f in &self.fields {
1661            if f.deprecated {
1662                if config.add_deprecated_fields {
1663                    writeln!(w, "    #[deprecated]")?;
1664                } else {
1665                    continue;
1666                }
1667            }
1668
1669            let rust_type = f.typ.rust_type(desc, config)?;
1670            if f.boxed {
1671                writeln!(w, "    {}(Box<{}>),", f.name, rust_type)?;
1672            } else {
1673                writeln!(w, "    {}({}),", f.name, rust_type)?;
1674            }
1675        }
1676        writeln!(w, "    None,")?;
1677        writeln!(w, "}}")?;
1678
1679        if cfg!(feature = "generateImplFromForEnums") {
1680            self.generate_impl_from_for_enums(w, desc, config)
1681        } else {
1682            Ok(())
1683        }
1684    }
1685
1686    fn generate_impl_from_for_enums<W: Write>(
1687        &self,
1688        w: &mut W,
1689        desc: &FileDescriptor,
1690        config: &Config,
1691    ) -> Result<()> {
1692        // For the first of each enumeration type, generate an impl From<> for it.
1693        let mut handled_fields = Vec::new();
1694        for f in self.fields.iter().filter(|f| !f.deprecated || config.add_deprecated_fields) {
1695            let rust_type = f.typ.rust_type(desc, config)?;
1696            if handled_fields.contains(&rust_type) {
1697                continue;
1698            }
1699            writeln!(w, "impl From<{}> for OneOf{} {{", rust_type, self.name)?; // TODO: lifetime.
1700            writeln!(w, "   fn from(f: {}) -> OneOf{} {{", rust_type, self.name)?;
1701            writeln!(w, "      OneOf{}::{}(f)", self.name, f.name)?;
1702            writeln!(w, "   }}")?;
1703            writeln!(w, "}}")?;
1704
1705            handled_fields.push(rust_type);
1706        }
1707
1708        Ok(())
1709    }
1710
1711    fn write_impl_default<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1712        if self.has_lifetime(desc, config) {
1713            writeln!(w, "impl<'a> Default for OneOf{}<'a> {{", self.name)?;
1714        } else {
1715            writeln!(w, "impl Default for OneOf{} {{", self.name)?;
1716        }
1717        writeln!(w, "    fn default() -> Self {{")?;
1718        writeln!(w, "        OneOf{}::None", self.name)?;
1719        writeln!(w, "    }}")?;
1720        writeln!(w, "}}")?;
1721        Ok(())
1722    }
1723
1724    fn write_message_definition<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1725        if self.has_lifetime(desc, config) {
1726            writeln!(
1727                w,
1728                "    pub {}: {}OneOf{}<'a>,",
1729                self.name,
1730                self.get_modules(desc),
1731                self.name
1732            )?;
1733        } else {
1734            writeln!(
1735                w,
1736                "    pub {}: {}OneOf{},",
1737                self.name,
1738                self.get_modules(desc),
1739                self.name
1740            )?;
1741        }
1742        Ok(())
1743    }
1744
1745    fn write_match_tag<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1746        for f in self.fields.iter().filter(|f| !f.deprecated || config.add_deprecated_fields) {
1747            let (val, val_cow) = f.typ.read_fn(desc)?;
1748            if f.boxed {
1749                writeln!(
1750                    w,
1751                    "                Ok({}) => msg.{} = {}OneOf{}::{}(Box::new({})),",
1752                    f.tag(),
1753                    self.name,
1754                    self.get_modules(desc),
1755                    self.name,
1756                    f.name,
1757                    val
1758                )?;
1759            } else {
1760                writeln!(
1761                    w,
1762                    "                Ok({}) => msg.{} = {}OneOf{}::{}({}),",
1763                    f.tag(),
1764                    self.name,
1765                    self.get_modules(desc),
1766                    self.name,
1767                    f.name,
1768                    val_cow
1769                )?;
1770            }
1771        }
1772        Ok(())
1773    }
1774
1775    fn write_get_size<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1776        writeln!(w, "        + match self.{} {{", self.name)?;
1777        for f in self.fields.iter().filter(|f| !f.deprecated || config.add_deprecated_fields) {
1778            let tag_size = sizeof_varint(f.tag());
1779            if f.typ.is_fixed_size() {
1780                writeln!(
1781                    w,
1782                    "            {}OneOf{}::{}(_) => {} + {},",
1783                    self.get_modules(desc),
1784                    self.name,
1785                    f.name,
1786                    tag_size,
1787                    f.typ.get_size("")
1788                )?;
1789            } else {
1790                writeln!(
1791                    w,
1792                    "            {}OneOf{}::{}(ref m) => {} + {},",
1793                    self.get_modules(desc),
1794                    self.name,
1795                    f.name,
1796                    tag_size,
1797                    f.typ.get_size("m")
1798                )?;
1799            }
1800        }
1801        writeln!(
1802            w,
1803            "            {}OneOf{}::None => 0,",
1804            self.get_modules(desc),
1805            self.name
1806        )?;
1807        write!(w, "    }}")?;
1808        Ok(())
1809    }
1810
1811    fn write_write<W: Write>(&self, w: &mut W, desc: &FileDescriptor, config: &Config) -> Result<()> {
1812        write!(w, "        match self.{} {{", self.name)?;
1813        for f in self.fields.iter().filter(|f| !f.deprecated || config.add_deprecated_fields) {
1814            writeln!(
1815                w,
1816                "            {}OneOf{}::{}(ref m) => {{ w.write_with_tag({}, |w| w.{})? }},",
1817                self.get_modules(desc),
1818                self.name,
1819                f.name,
1820                f.tag(),
1821                f.typ.get_write("m", f.boxed)
1822            )?;
1823        }
1824        writeln!(
1825            w,
1826            "            {}OneOf{}::None => {{}},",
1827            self.get_modules(desc),
1828            self.name
1829        )?;
1830        write!(w, "    }}")?;
1831        Ok(())
1832    }
1833}
1834
1835pub struct Config {
1836    pub in_file: PathBuf,
1837    pub out_file: PathBuf,
1838    pub single_module: bool,
1839    pub import_search_path: Vec<PathBuf>,
1840    pub no_output: bool,
1841    pub error_cycle: bool,
1842    pub headers: bool,
1843    pub dont_use_cow: bool,
1844    pub custom_struct_derive: Vec<String>,
1845    pub custom_repr: Option<String>,
1846    pub custom_rpc_generator: RpcGeneratorFunction,
1847    pub custom_includes: Vec<String>,
1848    pub owned: bool,
1849    pub nostd: bool,
1850    pub hashbrown: bool,
1851    pub gen_info: bool,
1852    pub add_deprecated_fields: bool,
1853}
1854
1855#[derive(Debug, Default, Clone)]
1856pub struct FileDescriptor {
1857    pub import_paths: Vec<PathBuf>,
1858    pub package: String,
1859    pub syntax: Syntax,
1860    pub messages: Vec<Message>,
1861    pub enums: Vec<Enumerator>,
1862    pub module: String,
1863    pub rpc_services: Vec<RpcService>,
1864    pub owned: bool,
1865}
1866
1867impl FileDescriptor {
1868    pub fn run(configs: &[Config]) -> Result<()> {
1869        for config in configs {
1870            Self::write_proto(&config)?
1871        }
1872        Ok(())
1873    }
1874
1875    pub fn write_proto(config: &Config) -> Result<()> {
1876        let mut desc = FileDescriptor::read_proto(&config.in_file, &config.import_search_path)?;
1877        desc.owned = config.owned;
1878
1879        if desc.messages.is_empty() && desc.enums.is_empty() {
1880            // There could had been unsupported structures, so bail early
1881            return Err(Error::EmptyRead);
1882        }
1883
1884        desc.resolve_types()?;
1885        desc.break_cycles(config.error_cycle)?;
1886        desc.sanity_checks()?;
1887        if config.dont_use_cow {
1888            desc.convert_field_types(&FieldType::StringCow, &FieldType::String_);
1889            desc.convert_field_types(&FieldType::BytesCow, &FieldType::Bytes_);
1890        }
1891        desc.set_defaults(config)?;
1892        desc.sanitize_names();
1893
1894        if config.single_module {
1895            desc.package = "".to_string();
1896        }
1897
1898        let (prefix, file_package) = split_package(&desc.package);
1899
1900        let mut file_stem = if file_package.is_empty() {
1901            get_file_stem(&config.out_file)?
1902        } else {
1903            file_package.to_string()
1904        };
1905
1906        if !file_package.is_empty() {
1907            sanitize_keyword(&mut file_stem);
1908        }
1909        let mut out_file = config.out_file.with_file_name(format!("{}.rs", file_stem));
1910
1911        if !prefix.is_empty() {
1912            use std::fs::create_dir_all;
1913            // e.g. package is a.b; we need to create directory 'a' and insert it into the path
1914            let file = PathBuf::from(out_file.file_name().unwrap());
1915            out_file.pop();
1916            for p in prefix.split('.') {
1917                out_file.push(p);
1918
1919                if !out_file.exists() {
1920                    create_dir_all(&out_file)?;
1921                    update_mod_file(&out_file)?;
1922                }
1923            }
1924            out_file.push(file);
1925        }
1926        if config.no_output {
1927            let imported = |b| if b { " imported" } else { "" };
1928            println!("source will be written to {}\n", out_file.display());
1929            for m in &desc.messages {
1930                println!(
1931                    "message {} module {}{}",
1932                    m.name,
1933                    m.module,
1934                    imported(m.imported)
1935                );
1936            }
1937            for e in &desc.enums {
1938                println!(
1939                    "enum {} module {}{}",
1940                    e.name,
1941                    e.module,
1942                    imported(e.imported)
1943                );
1944            }
1945            return Ok(());
1946        }
1947
1948        let name = config.in_file.file_name().and_then(|e| e.to_str()).unwrap();
1949        let mut w = BufWriter::new(File::create(&out_file)?);
1950        desc.write(&mut w, name, config)?;
1951        update_mod_file(&out_file)
1952    }
1953
1954    pub fn convert_field_types(&mut self, from: &FieldType, to: &FieldType) {
1955        // Messages and enums are the only structures with types
1956        for m in &mut self.messages {
1957            m.convert_field_types(from, to);
1958        }
1959    }
1960
1961    /// Opens a proto file, reads it and returns raw parsed data
1962    pub fn read_proto(in_file: &Path, import_search_path: &[PathBuf]) -> Result<FileDescriptor> {
1963        let file = std::fs::read_to_string(in_file)?;
1964        let (rem, mut desc) = file_descriptor(&file).map_err(Error::Nom)?;
1965        let rem = rem.trim();
1966        if !rem.is_empty() {
1967            return Err(Error::TrailingGarbage(rem.chars().take(50).collect()));
1968        }
1969        for mut m in &mut desc.messages {
1970            if m.path.as_os_str().is_empty() {
1971                m.path = in_file.to_path_buf();
1972                if !import_search_path.is_empty() {
1973                    if let Ok(p) = m.path.clone().strip_prefix(&import_search_path[0]) {
1974                        m.import = p.to_path_buf();
1975                    }
1976                }
1977            }
1978        }
1979        // proto files with no packages are given an implicit module,
1980        // since every generated Rust source file represents a module
1981        desc.module = if desc.package.is_empty() {
1982            get_file_stem(in_file)?
1983        } else {
1984            desc.package.clone()
1985        };
1986
1987        desc.fetch_imports(&in_file, import_search_path)?;
1988        Ok(desc)
1989    }
1990
1991    fn sanity_checks(&self) -> Result<()> {
1992        for m in &self.messages {
1993            m.sanity_checks(&self)?;
1994        }
1995        Ok(())
1996    }
1997
1998    /// Get messages and enums from imports
1999    fn fetch_imports(&mut self, in_file: &Path, import_search_path: &[PathBuf]) -> Result<()> {
2000        for m in &mut self.messages {
2001            m.set_package(&self.package, &self.module);
2002        }
2003        for m in &mut self.enums {
2004            m.set_package(&self.package, &self.module);
2005        }
2006
2007        for import in &self.import_paths {
2008            // this is the same logic as the C preprocessor;
2009            // if the include path item is absolute, then append the filename,
2010            // otherwise it is always relative to the file.
2011            let mut matching_file = None;
2012            for path in import_search_path {
2013                let candidate = if path.is_absolute() {
2014                    path.join(&import)
2015                } else {
2016                    in_file
2017                        .parent()
2018                        .map_or_else(|| path.join(&import), |p| p.join(path).join(&import))
2019                };
2020                if candidate.exists() {
2021                    matching_file = Some(candidate);
2022                    break;
2023                }
2024            }
2025            if matching_file.is_none() {
2026                return Err(Error::InvalidImport(format!(
2027                    "file {} not found on import path",
2028                    import.display()
2029                )));
2030            }
2031            let proto_file = matching_file.unwrap();
2032            let mut f = FileDescriptor::read_proto(&proto_file, import_search_path)?;
2033
2034            // if the proto has a packge then the names will be prefixed
2035            let package = f.package.clone();
2036            let module = f.module.clone();
2037            self.messages.extend(f.messages.drain(..).map(|mut m| {
2038                if m.package.is_empty() {
2039                    m.set_package(&package, &module);
2040                }
2041                if m.path.as_os_str().is_empty() {
2042                    m.path = proto_file.clone();
2043                }
2044                if m.import.as_os_str().is_empty() {
2045                    m.import = import.clone();
2046                }
2047                m.set_imported();
2048                m
2049            }));
2050            self.enums.extend(f.enums.drain(..).map(|mut e| {
2051                if e.package.is_empty() {
2052                    e.set_package(&package, &module);
2053                }
2054                if e.path.as_os_str().is_empty() {
2055                    e.path = proto_file.clone();
2056                }
2057                if e.import.as_os_str().is_empty() {
2058                    e.import = import.clone();
2059                }
2060                e.imported = true;
2061                e
2062            }));
2063        }
2064        Ok(())
2065    }
2066
2067    fn set_defaults(&mut self, config: &Config) -> Result<()> {
2068        // set map fields as required (they are equivalent to repeated message)
2069        for m in &mut self.messages {
2070            m.set_map_required();
2071        }
2072        // if proto3, then changes several defaults
2073        if let Syntax::Proto3 = self.syntax {
2074            for m in &mut self.messages {
2075                m.set_repeated_as_packed();
2076            }
2077        }
2078        // this is very inefficient but we don't care ...
2079        //let msgs = self.messages.clone();
2080        let copy = self.clone();
2081        for m in &mut self.messages {
2082            m.sanitize_defaults(&copy, config)?; //&msgs, &self.enums)?; ???
2083        }
2084        // force packed only on primitives
2085        for m in &mut self.messages {
2086            m.unset_packed_non_primitives();
2087        }
2088        Ok(())
2089    }
2090
2091    fn sanitize_names(&mut self) {
2092        for m in &mut self.messages {
2093            m.sanitize_names();
2094        }
2095        for e in &mut self.enums {
2096            e.sanitize_names();
2097        }
2098    }
2099
2100    /// Breaks cycles by adding boxes when necessary
2101    fn break_cycles(&mut self, error_cycle: bool) -> Result<()> {
2102        // get strongly connected components
2103        let sccs = self.sccs();
2104
2105        fn is_cycle(scc: &[MessageIndex], desc: &FileDescriptor) -> bool {
2106            scc.iter()
2107                .map(|m| m.get_message(desc))
2108                .flat_map(|m| m.all_fields())
2109                .filter(|f| !f.boxed)
2110                .filter_map(|f| f.typ.message())
2111                .any(|m| scc.contains(m))
2112        }
2113
2114        // sccs are sub DFS trees so if there is a edge connecting a node to
2115        // another node higher in the scc list, then this is a cycle. (Note that
2116        // we may have several cycles per scc).
2117        //
2118        // Technically we only need to box one edge (optional field) per cycle to
2119        // have Sized structs. Unfortunately, scc root depend on the order we
2120        // traverse the graph so such a field is not guaranteed to always be the same.
2121        //
2122        // For now, we decide (see discussion in #121) to box all optional fields
2123        // within a scc. We favor generated code stability over performance.
2124        for scc in &sccs {
2125            debug!("scc: {:?}", scc);
2126            for (i, v) in scc.iter().enumerate() {
2127                // cycles with v as root
2128                let cycles = v
2129                    .get_message(self)
2130                    .all_fields()
2131                    .filter_map(|f| f.typ.message())
2132                    .filter_map(|m| scc[i..].iter().position(|n| n == m))
2133                    .collect::<Vec<_>>();
2134                for cycle in cycles {
2135                    let cycle = &scc[i..i + cycle + 1];
2136                    debug!("cycle: {:?}", &cycle);
2137                    for v in cycle {
2138                        for f in v
2139                            .get_message_mut(self)
2140                            .all_fields_mut()
2141                            .filter(|f| f.frequency == Frequency::Optional)
2142                            .filter(|f| f.typ.message().map_or(false, |m| cycle.contains(m)))
2143                        {
2144                            f.boxed = true;
2145                        }
2146                    }
2147                    if is_cycle(cycle, self) {
2148                        if error_cycle {
2149                            return Err(Error::Cycle(
2150                                cycle
2151                                    .iter()
2152                                    .map(|m| m.get_message(self).name.clone())
2153                                    .collect(),
2154                            ));
2155                        } else {
2156                            for v in cycle {
2157                                warn!(
2158                                    "Unsound proto file would result in infinite size Messages.\n\
2159                                     Cycle detected in messages {:?}.\n\
2160                                     Modifying required fields into optional fields",
2161                                    cycle
2162                                        .iter()
2163                                        .map(|m| &m.get_message(self).name)
2164                                        .collect::<Vec<_>>()
2165                                );
2166                                for f in v
2167                                    .get_message_mut(self)
2168                                    .all_fields_mut()
2169                                    .filter(|f| f.frequency == Frequency::Required)
2170                                    .filter(|f| {
2171                                        f.typ.message().map_or(false, |m| cycle.contains(m))
2172                                    })
2173                                {
2174                                    f.boxed = true;
2175                                    f.frequency = Frequency::Optional;
2176                                }
2177                            }
2178                        }
2179                    }
2180                }
2181            }
2182        }
2183        Ok(())
2184    }
2185
2186    fn get_full_names(&mut self) -> (HashMap<String, MessageIndex>, HashMap<String, EnumIndex>) {
2187        fn rec_full_names(
2188            m: &mut Message,
2189            index: &mut MessageIndex,
2190            full_msgs: &mut HashMap<String, MessageIndex>,
2191            full_enums: &mut HashMap<String, EnumIndex>,
2192        ) {
2193            m.index = index.clone();
2194            if m.package.is_empty() {
2195                full_msgs.entry(m.name.clone()).or_insert_with(|| index.clone());
2196            } else {
2197                full_msgs.entry(format!("{}.{}", m.package, m.name)).or_insert_with(|| index.clone());
2198            }
2199            for (i, e) in m.enums.iter_mut().enumerate() {
2200                let index = EnumIndex {
2201                    msg_index: index.clone(),
2202                    index: i,
2203                };
2204                e.index = index.clone();
2205                full_enums.entry(format!("{}.{}", e.package, e.name)).or_insert(index);
2206            }
2207            for (i, m) in m.messages.iter_mut().enumerate() {
2208                index.push(i);
2209                rec_full_names(m, index, full_msgs, full_enums);
2210                index.pop();
2211            }
2212        }
2213
2214        let mut full_msgs = HashMap::new();
2215        let mut full_enums = HashMap::new();
2216        let mut index = MessageIndex { indexes: vec![] };
2217        for (i, m) in self.messages.iter_mut().enumerate() {
2218            index.push(i);
2219            rec_full_names(m, &mut index, &mut full_msgs, &mut full_enums);
2220            index.pop();
2221        }
2222        for (i, e) in self.enums.iter_mut().enumerate() {
2223            let index = EnumIndex {
2224                msg_index: index.clone(),
2225                index: i,
2226            };
2227            e.index = index.clone();
2228            if e.package.is_empty() {
2229                full_enums.entry(e.name.clone()).or_insert_with(|| index.clone());
2230            } else {
2231                full_enums.entry(format!("{}.{}", e.package, e.name)).or_insert_with(|| index.clone());
2232            }
2233        }
2234        (full_msgs, full_enums)
2235    }
2236
2237    fn resolve_types(&mut self) -> Result<()> {
2238        let (full_msgs, full_enums) = self.get_full_names();
2239
2240        fn rec_resolve_types(
2241            m: &mut Message,
2242            full_msgs: &HashMap<String, MessageIndex>,
2243            full_enums: &HashMap<String, EnumIndex>,
2244        ) -> Result<()> {
2245            // Interestingly, we can't call all_fields_mut to iterate over the
2246            // fields here: writing out the field traversal as below lets Rust
2247            // split m's mutable borrow, permitting the loop body to use fields
2248            // of `m` other than `fields` and `oneofs`.
2249            'types: for typ in m
2250                .fields
2251                .iter_mut()
2252                .chain(m.oneofs.iter_mut().flat_map(|o| o.fields.iter_mut()))
2253                .map(|f| &mut f.typ)
2254                .flat_map(|typ| match *typ {
2255                    FieldType::Map(ref mut key, ref mut value) => {
2256                        vec![&mut **key, &mut **value].into_iter()
2257                    }
2258                    _ => vec![typ].into_iter(),
2259                })
2260            {
2261                if let FieldType::MessageOrEnum(name) = typ.clone() {
2262                    let test_names: Vec<String> = if name.starts_with('.') {
2263                        vec![name.clone().split_off(1)]
2264                    } else if m.package.is_empty() {
2265                        vec![format!("{}.{}", m.name, name), name.clone()]
2266                    } else {
2267                        let mut v = vec![
2268                            format!("{}.{}.{}", m.package, m.name, name),
2269                            format!("{}.{}", m.package, name),
2270                        ];
2271                        for (index, _) in m.package.match_indices('.').rev() {
2272                            v.push(format!("{}.{}", &m.package[..index], name));
2273                        }
2274                        v.push(name.clone());
2275                        v
2276                    };
2277                    for name in &test_names {
2278                        if let Some(msg) = full_msgs.get(name) {
2279                            *typ = FieldType::Message(msg.clone());
2280                            continue 'types;
2281                        } else if let Some(e) = full_enums.get(name) {
2282                            *typ = FieldType::Enum(e.clone());
2283                            continue 'types;
2284                        }
2285                    }
2286                    return Err(Error::MessageOrEnumNotFound(name));
2287                }
2288            }
2289            for m in m.messages.iter_mut() {
2290                rec_resolve_types(m, full_msgs, full_enums)?;
2291            }
2292            Ok(())
2293        }
2294
2295        for m in self.messages.iter_mut() {
2296            rec_resolve_types(m, &full_msgs, &full_enums)?;
2297        }
2298        Ok(())
2299    }
2300
2301    fn write<W: Write>(&self, w: &mut W, filename: &str, config: &Config) -> Result<()> {
2302        println!(
2303            "Found {} messages, and {} enums",
2304            self.messages.len(),
2305            self.enums.len()
2306        );
2307        if config.headers {
2308            self.write_headers(w, filename, config)?;
2309        }
2310        self.write_package_start(w)?;
2311        self.write_uses(w, config)?;
2312        self.write_imports(w)?;
2313        self.write_enums(w)?;
2314        self.write_messages(w, config)?;
2315        self.write_rpc_services(w, config)?;
2316        self.write_package_end(w)?;
2317        Ok(())
2318    }
2319
2320    fn write_headers<W: Write>(&self, w: &mut W, filename: &str, config: &Config) -> Result<()> {
2321        writeln!(
2322            w,
2323            "// Automatically generated rust module for '{}' file",
2324            filename
2325        )?;
2326        writeln!(w)?;
2327        writeln!(w, "#![allow(non_snake_case)]")?;
2328        writeln!(w, "#![allow(non_upper_case_globals)]")?;
2329        writeln!(w, "#![allow(non_camel_case_types)]")?;
2330        writeln!(w, "#![allow(unused_imports)]")?;
2331        writeln!(w, "#![allow(unknown_lints)]")?;
2332        writeln!(w, "#![allow(clippy::all)]")?;
2333
2334        if config.add_deprecated_fields {
2335            writeln!(w, "#![allow(deprecated)]")?;
2336        }
2337
2338        writeln!(w, "#![cfg_attr(rustfmt, rustfmt_skip)]")?;
2339        writeln!(w)?;
2340        Ok(())
2341    }
2342
2343    fn write_package_start<W: Write>(&self, w: &mut W) -> Result<()> {
2344        writeln!(w)?;
2345        Ok(())
2346    }
2347
2348    fn write_uses<W: Write>(&self, w: &mut W, config: &Config) -> Result<()> {
2349        if self.messages.iter().all(|m| m.is_unit()) {
2350            writeln!(
2351                w,
2352                "use quick_protobuf::{{BytesReader, Result, MessageInfo, MessageRead, MessageWrite}};"
2353            )?;
2354            if self.owned {
2355                writeln!(w, "use core::convert::{{TryFrom, TryInto}};")?;
2356            }
2357            return Ok(());
2358        }
2359
2360        Message::write_common_uses(w, &self.messages, self, config)?;
2361
2362        writeln!(
2363            w,
2364            "use quick_protobuf::{{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result}};"
2365        )?;
2366
2367        if self.owned {
2368            writeln!(
2369                w,
2370                "use core::convert::{{TryFrom, TryInto}};"
2371            )?;
2372        }
2373
2374        writeln!(w, "use quick_protobuf::sizeofs::*;")?;
2375        for include in &config.custom_includes {
2376            writeln!(w, "{}", include)?;
2377        }
2378        Ok(())
2379    }
2380
2381    fn write_imports<W: Write>(&self, w: &mut W) -> Result<()> {
2382        // even if we don't have an explicit package, there is an implicit Rust module
2383        // This `use` allows us to refer to the package root.
2384        // NOTE! I'm suppressing not-needed 'use super::*' errors currently!
2385        let mut depth = self.package.split('.').count();
2386        if depth == 0 {
2387            depth = 1;
2388        }
2389        write!(w, "use ")?;
2390        for _ in 0..depth {
2391            write!(w, "super::")?;
2392        }
2393        writeln!(w, "*;")?;
2394        Ok(())
2395    }
2396
2397    fn write_package_end<W: Write>(&self, w: &mut W) -> Result<()> {
2398        writeln!(w)?;
2399        Ok(())
2400    }
2401
2402    fn write_enums<W: Write>(&self, w: &mut W) -> Result<()> {
2403        for m in self.enums.iter().filter(|e| !e.imported) {
2404            println!("Writing enum {}", m.name);
2405            writeln!(w)?;
2406            m.write_definition(w)?;
2407            writeln!(w)?;
2408            m.write_impl_default(w)?;
2409            writeln!(w)?;
2410            m.write_from_i32(w)?;
2411            writeln!(w)?;
2412            m.write_from_str(w)?;
2413        }
2414        Ok(())
2415    }
2416
2417    fn write_rpc_services<W: Write>(&self, w: &mut W, config: &Config) -> Result<()> {
2418        for m in self.rpc_services.iter() {
2419            println!("Writing Rpc {}", m.service_name);
2420            writeln!(w)?;
2421            m.write_definition(w, config)?;
2422        }
2423        Ok(())
2424    }
2425
2426    fn write_messages<W: Write>(&self, w: &mut W, config: &Config) -> Result<()> {
2427        for m in self.messages.iter().filter(|m| !m.imported) {
2428            m.write(w, &self, config)?;
2429        }
2430        Ok(())
2431    }
2432}
2433
2434/// Calculates the tag value
2435fn tag(number: u32, typ: &FieldType, packed: bool) -> u32 {
2436    number << 3 | typ.wire_type_num(packed)
2437}
2438
2439/// "" is ("",""), "a" is ("","a"), "a.b" is ("a"."b"), and so forth.
2440fn split_package(package: &str) -> (&str, &str) {
2441    if package.is_empty() {
2442        ("", "")
2443    } else if let Some(i) = package.rfind('.') {
2444        (&package[0..i], &package[i + 1..])
2445    } else {
2446        ("", package)
2447    }
2448}
2449
2450const MAGIC_HEADER: &str = "// Automatically generated mod.rs";
2451
2452/// Given a file path, create or update the mod.rs file within its folder
2453fn update_mod_file(path: &Path) -> Result<()> {
2454    let mut file = path.to_path_buf();
2455    use std::fs::OpenOptions;
2456    use std::io::prelude::*;
2457
2458    let name = file.file_stem().unwrap().to_string_lossy().to_string();
2459    file.pop();
2460    file.push("mod.rs");
2461    let matches = "pub mod ";
2462    let mut present = false;
2463    let mut exists = false;
2464    if let Ok(f) = File::open(&file) {
2465        exists = true;
2466        let mut first = true;
2467        for line in BufReader::new(f).lines() {
2468            let line = line?;
2469            if first {
2470                if line.find(MAGIC_HEADER).is_none() {
2471                    // it is NOT one of our generated mod.rs files, so don't modify it!
2472                    present = true;
2473                    break;
2474                }
2475                first = false;
2476            }
2477            if let Some(i) = line.find(matches) {
2478                let rest = &line[i + matches.len()..line.len() - 1];
2479                if rest == name {
2480                    // we already have a reference to this module...
2481                    present = true;
2482                    break;
2483                }
2484            }
2485        }
2486    }
2487    if !present {
2488        let mut f = if exists {
2489            OpenOptions::new().append(true).open(&file)?
2490        } else {
2491            let mut f = File::create(&file)?;
2492            writeln!(f, "{}", MAGIC_HEADER)?;
2493            f
2494        };
2495
2496        writeln!(f, "pub mod {};", name)?;
2497    }
2498    Ok(())
2499}
2500
2501/// get the proper sanitized file stem from an input file path
2502fn get_file_stem(path: &Path) -> Result<String> {
2503    let mut file_stem = path
2504        .file_stem()
2505        .and_then(|f| f.to_str())
2506        .map(|s| s.to_string())
2507        .ok_or_else(|| Error::OutputFile(format!("{}", path.display())))?;
2508
2509    file_stem = file_stem.replace(|c: char| !c.is_alphanumeric(), "_");
2510    // will now be properly alphanumeric, but may be a keyword!
2511    sanitize_keyword(&mut file_stem);
2512    Ok(file_stem)
2513}