serde_generate/
rust.rs

1// Copyright (c) Facebook, Inc. and its affiliates
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    analyzer,
6    indent::{IndentConfig, IndentedWriter},
7    CodeGeneratorConfig,
8};
9use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
10use std::{
11    borrow::Cow,
12    collections::{BTreeMap, HashSet},
13    io::{Result, Write},
14    path::PathBuf,
15};
16
17/// Main configuration object for code-generation in Rust.
18pub struct CodeGenerator<'a> {
19    /// Language-independent configuration.
20    config: &'a CodeGeneratorConfig,
21    /// Which derive macros should be added (independently from serialization).
22    derive_macros: Vec<String>,
23    /// Additional block of text added before each new container definition.
24    custom_derive_block: Option<String>,
25    /// Whether definitions and fields should be marked as `pub`.
26    track_visibility: bool,
27}
28
29/// Shared state for the code generation of a Rust source file.
30struct RustEmitter<'a, T> {
31    /// Writer.
32    out: IndentedWriter<T>,
33    /// Generator.
34    generator: &'a CodeGenerator<'a>,
35    /// Track which definitions have a known size. (Used to add `Box` types.)
36    known_sizes: Cow<'a, HashSet<&'a str>>,
37    /// Current namespace (e.g. vec!["my_package", "my_module", "MyClass"])
38    current_namespace: Vec<String>,
39}
40
41impl<'a> CodeGenerator<'a> {
42    /// Create a Rust code generator for the given config.
43    pub fn new(config: &'a CodeGeneratorConfig) -> Self {
44        Self {
45            config,
46            derive_macros: vec!["Clone", "Debug", "PartialEq", "PartialOrd"]
47                .into_iter()
48                .map(String::from)
49                .collect(),
50            custom_derive_block: None,
51            track_visibility: true,
52        }
53    }
54
55    /// Which derive macros should be added (independently from serialization).
56    pub fn with_derive_macros(mut self, derive_macros: Vec<String>) -> Self {
57        self.derive_macros = derive_macros;
58        self
59    }
60
61    /// Additional block of text added after `derive_macros` (if any), before each new
62    /// container definition.
63    pub fn with_custom_derive_block(mut self, custom_derive_block: Option<String>) -> Self {
64        self.custom_derive_block = custom_derive_block;
65        self
66    }
67
68    /// Whether definitions and fields should be marked as `pub`.
69    pub fn with_track_visibility(mut self, track_visibility: bool) -> Self {
70        self.track_visibility = track_visibility;
71        self
72    }
73
74    /// Write container definitions in Rust.
75    pub fn output(
76        &self,
77        out: &mut dyn Write,
78        registry: &Registry,
79    ) -> std::result::Result<(), Box<dyn std::error::Error>> {
80        let external_names = self
81            .config
82            .external_definitions
83            .values()
84            .flatten()
85            .cloned()
86            .collect();
87        let dependencies =
88            analyzer::get_dependency_map_with_external_dependencies(registry, &external_names)?;
89        let entries = analyzer::best_effort_topological_sort(&dependencies);
90
91        let known_sizes = external_names
92            .iter()
93            .map(<String as std::ops::Deref>::deref)
94            .collect::<HashSet<_>>();
95
96        let current_namespace = self
97            .config
98            .module_name
99            .split('.')
100            .map(String::from)
101            .collect();
102        let mut emitter = RustEmitter {
103            out: IndentedWriter::new(out, IndentConfig::Space(4)),
104            generator: self,
105            known_sizes: Cow::Owned(known_sizes),
106            current_namespace,
107        };
108
109        emitter.output_preamble()?;
110        for name in entries {
111            let format = &registry[name];
112            emitter.output_container(name, format)?;
113            emitter.known_sizes.to_mut().insert(name);
114        }
115        Ok(())
116    }
117
118    /// For each container, generate a Rust definition.
119    pub fn quote_container_definitions(
120        &self,
121        registry: &Registry,
122    ) -> std::result::Result<BTreeMap<String, String>, Box<dyn std::error::Error>> {
123        let dependencies = analyzer::get_dependency_map(registry)?;
124        let entries = analyzer::best_effort_topological_sort(&dependencies);
125
126        let mut result = BTreeMap::new();
127        let mut known_sizes = HashSet::new();
128        let current_namespace = self
129            .config
130            .module_name
131            .split('.')
132            .map(String::from)
133            .collect::<Vec<_>>();
134
135        for name in entries {
136            let mut content = Vec::new();
137            {
138                let mut emitter = RustEmitter {
139                    out: IndentedWriter::new(&mut content, IndentConfig::Space(4)),
140                    generator: self,
141                    known_sizes: Cow::Borrowed(&known_sizes),
142                    current_namespace: current_namespace.clone(),
143                };
144                let format = &registry[name];
145                emitter.output_container(name, format)?;
146            }
147            known_sizes.insert(name);
148            result.insert(
149                name.to_string(),
150                String::from_utf8_lossy(&content).trim().to_string() + "\n",
151            );
152        }
153        Ok(result)
154    }
155}
156
157impl<'a, T> RustEmitter<'a, T>
158where
159    T: std::io::Write,
160{
161    fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
162        let mut path = self.current_namespace.clone();
163        path.push(name.to_string());
164        if let Some(doc) = self.generator.config.comments.get(&path) {
165            let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
166            write!(self.out, "\n{}", text)?;
167        }
168        Ok(())
169    }
170
171    fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
172        let mut path = self.current_namespace.clone();
173        path.push(name.to_string());
174        if let Some(code) = self.generator.config.custom_code.get(&path) {
175            write!(self.out, "\n{}", code)?;
176        }
177        Ok(())
178    }
179
180    fn output_preamble(&mut self) -> Result<()> {
181        let external_names = self
182            .generator
183            .config
184            .external_definitions
185            .values()
186            .flatten()
187            .cloned()
188            .collect::<HashSet<_>>();
189        writeln!(self.out, "#![allow(unused_imports)]")?;
190        if !external_names.contains("Map") {
191            writeln!(self.out, "use std::collections::BTreeMap as Map;")?;
192        }
193        if self.generator.config.serialization {
194            writeln!(self.out, "use serde::{{Serialize, Deserialize}};")?;
195        }
196        if self.generator.config.serialization && !external_names.contains("Bytes") {
197            writeln!(self.out, "use serde_bytes::ByteBuf as Bytes;")?;
198        }
199        for (module, definitions) in &self.generator.config.external_definitions {
200            // Skip the empty module name.
201            if !module.is_empty() {
202                writeln!(
203                    self.out,
204                    "use {}::{{{}}};",
205                    module,
206                    definitions.to_vec().join(", "),
207                )?;
208            }
209        }
210        writeln!(self.out)?;
211        if !self.generator.config.serialization && !external_names.contains("Bytes") {
212            // If we are not going to use Serde derive macros, use plain vectors.
213            writeln!(self.out, "type Bytes = Vec<u8>;\n")?;
214        }
215        Ok(())
216    }
217
218    fn quote_type(format: &Format, known_sizes: Option<&HashSet<&str>>) -> String {
219        use Format::*;
220        match format {
221            TypeName(x) => {
222                if let Some(set) = known_sizes {
223                    if !set.contains(x.as_str()) {
224                        return format!("Box<{}>", x);
225                    }
226                }
227                x.to_string()
228            }
229            Unit => "()".into(),
230            Bool => "bool".into(),
231            I8 => "i8".into(),
232            I16 => "i16".into(),
233            I32 => "i32".into(),
234            I64 => "i64".into(),
235            I128 => "i128".into(),
236            U8 => "u8".into(),
237            U16 => "u16".into(),
238            U32 => "u32".into(),
239            U64 => "u64".into(),
240            U128 => "u128".into(),
241            F32 => "f32".into(),
242            F64 => "f64".into(),
243            Char => "char".into(),
244            Str => "String".into(),
245            Bytes => "Bytes".into(),
246
247            Option(format) => format!("Option<{}>", Self::quote_type(format, known_sizes)),
248            Seq(format) => format!("Vec<{}>", Self::quote_type(format, None)),
249            Map { key, value } => format!(
250                "Map<{}, {}>",
251                Self::quote_type(key, None),
252                Self::quote_type(value, None)
253            ),
254            Tuple(formats) => format!("({})", Self::quote_types(formats, known_sizes)),
255            TupleArray { content, size } => {
256                format!("[{}; {}]", Self::quote_type(content, known_sizes), *size)
257            }
258
259            Variable(_) => panic!("unexpected value"),
260        }
261    }
262
263    fn quote_types(formats: &[Format], known_sizes: Option<&HashSet<&str>>) -> String {
264        formats
265            .iter()
266            .map(|x| Self::quote_type(x, known_sizes))
267            .collect::<Vec<_>>()
268            .join(", ")
269    }
270
271    fn output_fields(&mut self, base: &[&str], fields: &[Named<Format>]) -> Result<()> {
272        // Do not add 'pub' within variants.
273        let prefix = if base.len() <= 1 && self.generator.track_visibility {
274            "pub "
275        } else {
276            ""
277        };
278        for field in fields {
279            self.output_comment(&field.name)?;
280            writeln!(
281                self.out,
282                "{}{}: {},",
283                prefix,
284                field.name,
285                Self::quote_type(&field.value, Some(&self.known_sizes)),
286            )?;
287        }
288        Ok(())
289    }
290
291    fn output_variant(&mut self, base: &str, name: &str, variant: &VariantFormat) -> Result<()> {
292        self.output_comment(name)?;
293        use VariantFormat::*;
294        match variant {
295            Unit => writeln!(self.out, "{},", name),
296            NewType(format) => writeln!(
297                self.out,
298                "{}({}),",
299                name,
300                Self::quote_type(format, Some(&self.known_sizes))
301            ),
302            Tuple(formats) => writeln!(
303                self.out,
304                "{}({}),",
305                name,
306                Self::quote_types(formats, Some(&self.known_sizes))
307            ),
308            Struct(fields) => {
309                writeln!(self.out, "{} {{", name)?;
310                self.current_namespace.push(name.to_string());
311                self.out.indent();
312                self.output_fields(&[base, name], fields)?;
313                self.out.unindent();
314                self.current_namespace.pop();
315                writeln!(self.out, "}},")
316            }
317            Variable(_) => panic!("incorrect value"),
318        }
319    }
320
321    fn output_variants(
322        &mut self,
323        base: &str,
324        variants: &BTreeMap<u32, Named<VariantFormat>>,
325    ) -> Result<()> {
326        for (expected_index, (index, variant)) in variants.iter().enumerate() {
327            assert_eq!(*index, expected_index as u32);
328            self.output_variant(base, &variant.name, &variant.value)?;
329        }
330        Ok(())
331    }
332
333    fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
334        self.output_comment(name)?;
335        let mut derive_macros = self.generator.derive_macros.clone();
336        if self.generator.config.serialization {
337            derive_macros.push("Serialize".to_string());
338            derive_macros.push("Deserialize".to_string());
339        }
340        let mut prefix = String::new();
341        if !derive_macros.is_empty() {
342            prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
343        }
344        if let Some(text) = &self.generator.custom_derive_block {
345            prefix.push_str(text);
346            prefix.push('\n');
347        }
348        if self.generator.track_visibility {
349            prefix.push_str("pub ");
350        }
351
352        use ContainerFormat::*;
353        match format {
354            UnitStruct => writeln!(self.out, "{}struct {};\n", prefix, name)?,
355            NewTypeStruct(format) => writeln!(
356                self.out,
357                "{}struct {}({}{});\n",
358                prefix,
359                name,
360                if self.generator.track_visibility {
361                    "pub "
362                } else {
363                    ""
364                },
365                Self::quote_type(format, Some(&self.known_sizes))
366            )?,
367            TupleStruct(formats) => writeln!(
368                self.out,
369                "{}struct {}({});\n",
370                prefix,
371                name,
372                Self::quote_types(formats, Some(&self.known_sizes))
373            )?,
374            Struct(fields) => {
375                writeln!(self.out, "{}struct {} {{", prefix, name)?;
376                self.current_namespace.push(name.to_string());
377                self.out.indent();
378                self.output_fields(&[name], fields)?;
379                self.out.unindent();
380                self.current_namespace.pop();
381                writeln!(self.out, "}}\n")?;
382            }
383            Enum(variants) => {
384                writeln!(self.out, "{}enum {} {{", prefix, name)?;
385                self.current_namespace.push(name.to_string());
386                self.out.indent();
387                self.output_variants(name, variants)?;
388                self.out.unindent();
389                self.current_namespace.pop();
390                writeln!(self.out, "}}\n")?;
391            }
392        }
393        self.output_custom_code(name)
394    }
395}
396
397/// Installer for generated source files in Rust.
398pub struct Installer {
399    install_dir: PathBuf,
400}
401
402impl Installer {
403    pub fn new(install_dir: PathBuf) -> Self {
404        Installer { install_dir }
405    }
406
407    fn runtime_installation_message(name: &str) {
408        eprintln!("Not installing sources for published crate {}", name);
409    }
410}
411
412impl crate::SourceInstaller for Installer {
413    type Error = Box<dyn std::error::Error>;
414
415    fn install_module(
416        &self,
417        config: &CodeGeneratorConfig,
418        registry: &Registry,
419    ) -> std::result::Result<(), Self::Error> {
420        let generator = CodeGenerator::new(config);
421        let (name, version) = {
422            let parts = config.module_name.splitn(2, ':').collect::<Vec<_>>();
423            if parts.len() >= 2 {
424                (parts[0].to_string(), parts[1].to_string())
425            } else {
426                (parts[0].to_string(), "0.1.0".to_string())
427            }
428        };
429        let dir_path = self.install_dir.join(&name);
430        std::fs::create_dir_all(&dir_path)?;
431
432        if config.package_manifest {
433            let mut cargo = std::fs::File::create(dir_path.join("Cargo.toml"))?;
434            write!(
435                cargo,
436                r#"[package]
437name = "{}"
438version = "{}"
439edition = "2018"
440
441[dependencies]
442serde = {{ version = "1.0", features = ["derive"] }}
443serde_bytes = "0.11"
444"#,
445                name, version,
446            )?;
447        }
448
449        std::fs::create_dir_all(dir_path.join("src"))?;
450        let source_path = dir_path.join("src/lib.rs");
451        let mut source = std::fs::File::create(source_path)?;
452        generator.output(&mut source, registry)
453    }
454
455    fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
456        Self::runtime_installation_message("serde");
457        Ok(())
458    }
459
460    fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
461        Self::runtime_installation_message("bincode");
462        Ok(())
463    }
464
465    fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
466        Self::runtime_installation_message("bcs");
467        Ok(())
468    }
469}