pbjson_build/generator/
enumeration.rs

1//! This module contains the code to generate Serialize and Deserialize
2//! implementations for enumeration type
3//!
4//! An enumeration should be decode-able from the full string variant name
5//! or its integer tag number, and should encode to the string representation
6
7use super::{
8    write_deserialize_end, write_deserialize_start, write_serialize_end, write_serialize_start,
9    Indent,
10};
11use crate::descriptor::{EnumDescriptor, TypePath};
12use crate::generator::write_fields_array;
13use crate::resolver::Resolver;
14use std::collections::HashSet;
15use std::io::{Result, Write};
16
17pub fn generate_enum<W: Write>(
18    resolver: &Resolver<'_>,
19    path: &TypePath,
20    descriptor: &EnumDescriptor,
21    writer: &mut W,
22    use_integers_for_enums: bool,
23) -> Result<()> {
24    let rust_type = resolver.rust_type(path);
25
26    let mut seen_numbers = HashSet::new();
27    let variants: Vec<_> = descriptor
28        .values
29        .iter()
30        // Skip duplicates if we've seen the number before
31        // Protobuf's `allow_alias` option permits duplicates if set
32        .filter(|variant| seen_numbers.insert(variant.number()))
33        .map(|variant| {
34            let variant_name = variant.name.clone().unwrap();
35            let variant_number = variant.number();
36            let rust_variant = resolver.rust_variant(path, &variant_name);
37            (variant_name, variant_number, rust_variant)
38        })
39        .collect();
40
41    // Generate Serialize
42    write_serialize_start(0, &rust_type, writer)?;
43    if use_integers_for_enums {
44        writeln!(writer, "{}let variant = match self {{", Indent(2))?;
45        for (_, variant_number, rust_variant) in &variants {
46            writeln!(
47                writer,
48                "{}Self::{} => {},",
49                Indent(3),
50                rust_variant,
51                variant_number
52            )?;
53        }
54        writeln!(writer, "{}}};", Indent(2))?;
55
56        writeln!(writer, "{}serializer.serialize_i32(variant)", Indent(2))?;
57    } else {
58        writeln!(writer, "{}let variant = match self {{", Indent(2))?;
59        for (variant_name, _, rust_variant) in &variants {
60            writeln!(
61                writer,
62                "{}Self::{} => \"{}\",",
63                Indent(3),
64                rust_variant,
65                variant_name
66            )?;
67        }
68        writeln!(writer, "{}}};", Indent(2))?;
69
70        writeln!(writer, "{}serializer.serialize_str(variant)", Indent(2))?;
71    }
72    write_serialize_end(0, writer)?;
73
74    // Generate Deserialize
75    write_deserialize_start(0, &rust_type, writer)?;
76    write_fields_array(writer, 2, variants.iter().map(|(name, _, _)| name.as_str()))?;
77    write_visitor(writer, 2, &rust_type, &variants)?;
78
79    // Use deserialize_any to allow users to provide integers or strings
80    writeln!(
81        writer,
82        "{}deserializer.deserialize_any(GeneratedVisitor)",
83        Indent(2)
84    )?;
85
86    write_deserialize_end(0, writer)?;
87    Ok(())
88}
89
90fn write_visitor<W: Write>(
91    writer: &mut W,
92    indent: usize,
93    rust_type: &str,
94    variants: &[(String, i32, String)],
95) -> Result<()> {
96    // Protobuf supports deserialization of enumerations both from string and integer values
97    writeln!(
98        writer,
99        r#"{indent}struct GeneratedVisitor;
100
101{indent}impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {{
102{indent}    type Value = {rust_type};
103
104{indent}    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{
105{indent}        write!(formatter, "expected one of: {{:?}}", &FIELDS)
106{indent}    }}
107
108{indent}    fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E>
109{indent}    where
110{indent}        E: serde::de::Error,
111{indent}    {{
112{indent}        i32::try_from(v)
113{indent}            .ok()
114{indent}            .and_then(|x| x.try_into().ok())
115{indent}            .ok_or_else(|| {{
116{indent}                serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
117{indent}            }})
118{indent}    }}
119
120{indent}    fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, E>
121{indent}    where
122{indent}        E: serde::de::Error,
123{indent}    {{
124{indent}        i32::try_from(v)
125{indent}            .ok()
126{indent}            .and_then(|x| x.try_into().ok())
127{indent}            .ok_or_else(|| {{
128{indent}                serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
129{indent}            }})
130{indent}    }}
131
132{indent}    fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
133{indent}    where
134{indent}        E: serde::de::Error,
135{indent}    {{"#,
136        indent = Indent(indent),
137        rust_type = rust_type,
138    )?;
139
140    writeln!(writer, "{}match value {{", Indent(indent + 2))?;
141    for (variant_name, _, rust_variant) in variants {
142        writeln!(
143            writer,
144            "{}\"{}\" => Ok({}::{}),",
145            Indent(indent + 3),
146            variant_name,
147            rust_type,
148            rust_variant
149        )?;
150    }
151
152    writeln!(
153        writer,
154        "{indent}_ => Err(serde::de::Error::unknown_variant(value, FIELDS)),",
155        indent = Indent(indent + 3)
156    )?;
157    writeln!(writer, "{}}}", Indent(indent + 2))?;
158    writeln!(writer, "{}}}", Indent(indent + 1))?;
159    writeln!(writer, "{}}}", Indent(indent))
160}