1use crate::config::CodeGeneratorConfig;
10use serde_generate::indent::{IndentConfig, IndentedWriter};
11use serde_reflection::{ContainerFormat, Format, Named, VariantFormat};
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet, HashSet};
14use std::io::{Result, Write};
15
16pub type QualifiedName = (Option<String>, String);
18pub type Registry = BTreeMap<QualifiedName, ContainerFormat>;
19
20pub struct CodeGenerator<'a> {
22 config: &'a CodeGeneratorConfig,
24 derive_macros: Vec<String>,
26 custom_derive_block: Option<String>,
28 track_visibility: bool,
30}
31
32struct RustEmitter<'a, T> {
34 out: IndentedWriter<T>,
36 generator: &'a CodeGenerator<'a>,
38 known_sizes: Cow<'a, HashSet<&'a str>>,
40 current_namespace: Vec<String>,
42}
43
44impl<'a> CodeGenerator<'a> {
45 pub fn new(config: &'a CodeGeneratorConfig) -> Self {
47 Self {
48 config,
49 derive_macros: vec!["Clone", "Debug", "PartialEq", "PartialOrd"]
50 .into_iter()
51 .map(String::from)
52 .collect(),
53 custom_derive_block: None,
54 track_visibility: true,
55 }
56 }
57
58 pub fn with_derive_macros(mut self, derive_macros: Vec<String>) -> Self {
60 self.derive_macros = derive_macros;
61 self
62 }
63
64 pub fn with_custom_derive_block(mut self, custom_derive_block: Option<String>) -> Self {
67 self.custom_derive_block = custom_derive_block;
68 self
69 }
70
71 pub fn with_track_visibility(mut self, track_visibility: bool) -> Self {
73 self.track_visibility = track_visibility;
74 self
75 }
76
77 pub fn output(
79 &self,
80 out: &mut dyn Write,
81 registry: &Registry,
82 ) -> std::result::Result<(), Box<dyn std::error::Error>> {
83 let external_names: BTreeSet<String> = self
84 .config
85 .external_definitions
86 .values()
87 .cloned()
88 .flatten()
89 .collect();
90
91 let known_sizes = external_names
92 .iter()
93 .map(<String as std::ops::Deref>::deref)
94 .collect::<HashSet<_>>();
95
96 let current_namespace = self
97 .config
98 .module_name
99 .split('.')
100 .map(String::from)
101 .collect();
102 let mut emitter = RustEmitter {
103 out: IndentedWriter::new(out, IndentConfig::Space(4)),
104 generator: self,
105 known_sizes: Cow::Owned(known_sizes),
106 current_namespace,
107 };
108
109 emitter.output_preamble()?;
110 for ((ns, name), format) in registry {
111 emitter.output_container(ns, name, format)?;
112 emitter.known_sizes.to_mut().insert(name);
113 }
114 Ok(())
115 }
116}
117
118impl<'a, T> RustEmitter<'a, T>
119where
120 T: std::io::Write,
121{
122 fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
123 let mut path = self.current_namespace.clone();
124 path.push(name.to_string());
125 if let Some(doc) = self.generator.config.comments.get(&path) {
126 let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
127 write!(self.out, "\n{}", text)?;
128 }
129 Ok(())
130 }
131
132 fn output_preamble(&mut self) -> Result<()> {
133 let external_names = self
134 .generator
135 .config
136 .external_definitions
137 .values()
138 .cloned()
139 .flatten()
140 .collect::<HashSet<_>>();
141 writeln!(self.out, "#![allow(unused_imports, non_snake_case, non_camel_case_types, non_upper_case_globals)]")?;
142 if !external_names.contains("Map") {
143 writeln!(self.out, "use std::collections::BTreeMap as Map;")?;
144 }
145 writeln!(self.out, "use serde::{{Serialize, Deserialize}};")?;
146 if !external_names.contains("Bytes") {
147 writeln!(self.out, "use serde_bytes::ByteBuf as Bytes;")?;
148 }
149 for (module, definitions) in &self.generator.config.external_definitions {
150 if !module.is_empty() {
152 writeln!(
153 self.out,
154 "use {}::{{{}}};",
155 module,
156 definitions.to_vec().join(", "),
157 )?;
158 }
159 }
160 writeln!(self.out)?;
161 Ok(())
162 }
163
164 fn output_field_annotation(&mut self, format: &Format) -> std::io::Result<()> {
165 use Format::*;
166 match format {
167 Str => writeln!(
168 self.out,
169 "#[serde(skip_serializing_if = \"String::is_empty\")]"
170 )?,
171 Option(_) => writeln!(
172 self.out,
173 "#[serde(skip_serializing_if = \"Option::is_none\")]"
174 )?,
175 Seq(_) => writeln!(
176 self.out,
177 "#[serde(skip_serializing_if = \"Vec::is_empty\")]"
178 )?,
179 _ => (),
180 }
181
182 Ok(())
183 }
184
185 fn quote_type(format: &Format, known_sizes: Option<&HashSet<&str>>) -> String {
186 use Format::*;
187 match format {
188 TypeName(x) => {
189 if let Some(set) = known_sizes {
190 if !set.contains(x.as_str()) && !x.as_str().starts_with("Vec") {
191 return format!("Box<{}>", x);
192 }
193 }
194 x.to_string()
195 }
196 Unit => "()".into(),
197 Bool => "bool".into(),
198 I8 => "i8".into(),
199 I16 => "i16".into(),
200 I32 => "i32".into(),
201 I64 => "i64".into(),
202 I128 => "i128".into(),
203 U8 => "u8".into(),
204 U16 => "u16".into(),
205 U32 => "u32".into(),
206 U64 => "u64".into(),
207 U128 => "u128".into(),
208 F32 => "f32".into(),
209 F64 => "f64".into(),
210 Char => "char".into(),
211 Str => "String".into(),
212 Bytes => "Bytes".into(),
213
214 Option(format) => format!("Option<{}>", Self::quote_type(format, known_sizes)),
215 Seq(format) => format!("Vec<{}>", Self::quote_type(format, None)),
216 Map { key, value } => format!(
217 "Map<{}, {}>",
218 Self::quote_type(key, None),
219 Self::quote_type(value, None)
220 ),
221 Tuple(formats) => format!("({})", Self::quote_types(formats, known_sizes)),
222 TupleArray { content, size } => {
223 format!("[{}; {}]", Self::quote_type(content, known_sizes), *size)
224 }
225
226 Variable(_) => panic!("unexpected value"),
227 }
228 }
229
230 fn quote_types(formats: &[Format], known_sizes: Option<&HashSet<&str>>) -> String {
231 formats
232 .iter()
233 .map(|x| Self::quote_type(x, known_sizes))
234 .collect::<Vec<_>>()
235 .join(", ")
236 }
237
238 fn output_fields(&mut self, base: &[&str], fields: &[Named<Format>]) -> Result<()> {
239 let prefix = if base.len() <= 1 && self.generator.track_visibility {
241 "pub "
242 } else {
243 ""
244 };
245 for field in fields {
246 self.output_comment(&field.name)?;
247 self.output_field_annotation(&field.value)?;
248 writeln!(
249 self.out,
250 "{}{}: {},",
251 prefix,
252 field.name,
253 Self::quote_type(&field.value, Some(&self.known_sizes)),
254 )?;
255 }
256 Ok(())
257 }
258
259 fn output_variant(&mut self, base: &str, name: &str, variant: &VariantFormat) -> Result<()> {
260 self.output_comment(name)?;
261 use VariantFormat::*;
262 match variant {
263 Unit => writeln!(self.out, "{},", name),
264 NewType(format) => writeln!(
265 self.out,
266 "{}({}),",
267 name,
268 Self::quote_type(format, Some(&self.known_sizes))
269 ),
270 Tuple(formats) => writeln!(
271 self.out,
272 "{}({}),",
273 name,
274 Self::quote_types(formats, Some(&self.known_sizes))
275 ),
276 Struct(fields) => {
277 writeln!(self.out, "{} {{", name)?;
278 self.current_namespace.push(name.to_string());
279 self.out.indent();
280 self.output_fields(&[base, name], fields)?;
281 self.out.unindent();
282 self.current_namespace.pop();
283 writeln!(self.out, "}},")
284 }
285 Variable(_) => panic!("incorrect value"),
286 }
287 }
288
289 fn output_variants(
290 &mut self,
291 base: &str,
292 variants: &BTreeMap<u32, Named<VariantFormat>>,
293 ) -> Result<()> {
294 for (expected_index, (index, variant)) in variants.iter().enumerate() {
295 assert_eq!(*index, expected_index as u32);
296 self.output_variant(base, &variant.name, &variant.value)?;
297 }
298 Ok(())
299 }
300
301 fn output_container(
302 &mut self,
303 namespace: &Option<String>,
304 name: &str,
305 format: &ContainerFormat,
306 ) -> Result<()> {
307 self.output_comment(name)?;
308 let mut derive_macros = self.generator.derive_macros.clone();
309 derive_macros.push("Serialize".to_string());
310 derive_macros.push("Deserialize".to_string());
311 let mut prefix = String::new();
312 if !derive_macros.is_empty() {
313 prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
314 }
315 if let Some(text) = &self.generator.custom_derive_block {
316 prefix.push_str(text);
317 prefix.push('\n');
318 }
319
320 use ContainerFormat::*;
321 match format {
322 UnitStruct => writeln!(self.out, "{}struct {};\n", prefix, name),
323 NewTypeStruct(format) => writeln!(
324 self.out,
325 "{}struct {}({}{});\n",
326 prefix,
327 name,
328 if self.generator.track_visibility {
329 "pub "
330 } else {
331 ""
332 },
333 Self::quote_type(format, Some(&self.known_sizes))
334 ),
335 TupleStruct(formats) => writeln!(
336 self.out,
337 "{}struct {}({});\n",
338 prefix,
339 name,
340 Self::quote_types(formats, Some(&self.known_sizes))
341 ),
342 Struct(fields) => {
343 let mut struct_name = name.to_string();
344 prefix.clear();
345 derive_macros.push("Default".to_string());
346 prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
347
348 if let Some(ns) = namespace {
349 prefix.push_str(&format!("#[serde(rename = \"{}\")]\n", name));
350 struct_name = format!("{}_{}", ns, name)
351 }
352
353 if self.generator.track_visibility {
354 prefix.push_str("pub ");
355 }
356
357 writeln!(self.out, "{}struct {} {{", prefix, struct_name)?;
358 self.current_namespace.push(name.to_string());
359 self.out.indent();
360 self.output_fields(&[name], fields)?;
361 self.out.unindent();
362 self.current_namespace.pop();
363 writeln!(self.out, "}}\n")
364 }
365 Enum(variants) => {
366 if self.generator.track_visibility {
367 prefix.push_str("pub ");
368 }
369
370 writeln!(self.out, "{}enum {} {{", prefix, name)?;
371 self.current_namespace.push(name.to_string());
372 self.out.indent();
373 self.output_variants(name, variants)?;
374 self.out.unindent();
375 self.current_namespace.pop();
376 writeln!(self.out, "}}\n")
377 }
378 }
379 }
380}