1use crate::{
5 analyzer,
6 indent::{IndentConfig, IndentedWriter},
7 CodeGeneratorConfig,
8};
9use serde_reflection::{ContainerFormat, Format, Named, Registry, VariantFormat};
10use std::{
11 borrow::Cow,
12 collections::{BTreeMap, HashSet},
13 io::{Result, Write},
14 path::PathBuf,
15};
16
17pub struct CodeGenerator<'a> {
19 config: &'a CodeGeneratorConfig,
21 derive_macros: Vec<String>,
23 custom_derive_block: Option<String>,
25 track_visibility: bool,
27}
28
29struct RustEmitter<'a, T> {
31 out: IndentedWriter<T>,
33 generator: &'a CodeGenerator<'a>,
35 known_sizes: Cow<'a, HashSet<&'a str>>,
37 current_namespace: Vec<String>,
39}
40
41impl<'a> CodeGenerator<'a> {
42 pub fn new(config: &'a CodeGeneratorConfig) -> Self {
44 Self {
45 config,
46 derive_macros: vec!["Clone", "Debug", "PartialEq", "PartialOrd"]
47 .into_iter()
48 .map(String::from)
49 .collect(),
50 custom_derive_block: None,
51 track_visibility: true,
52 }
53 }
54
55 pub fn with_derive_macros(mut self, derive_macros: Vec<String>) -> Self {
57 self.derive_macros = derive_macros;
58 self
59 }
60
61 pub fn with_custom_derive_block(mut self, custom_derive_block: Option<String>) -> Self {
64 self.custom_derive_block = custom_derive_block;
65 self
66 }
67
68 pub fn with_track_visibility(mut self, track_visibility: bool) -> Self {
70 self.track_visibility = track_visibility;
71 self
72 }
73
74 pub fn output(
76 &self,
77 out: &mut dyn Write,
78 registry: &Registry,
79 ) -> std::result::Result<(), Box<dyn std::error::Error>> {
80 let external_names = self
81 .config
82 .external_definitions
83 .values()
84 .flatten()
85 .cloned()
86 .collect();
87 let dependencies =
88 analyzer::get_dependency_map_with_external_dependencies(registry, &external_names)?;
89 let entries = analyzer::best_effort_topological_sort(&dependencies);
90
91 let known_sizes = external_names
92 .iter()
93 .map(<String as std::ops::Deref>::deref)
94 .collect::<HashSet<_>>();
95
96 let current_namespace = self
97 .config
98 .module_name
99 .split('.')
100 .map(String::from)
101 .collect();
102 let mut emitter = RustEmitter {
103 out: IndentedWriter::new(out, IndentConfig::Space(4)),
104 generator: self,
105 known_sizes: Cow::Owned(known_sizes),
106 current_namespace,
107 };
108
109 emitter.output_preamble()?;
110 for name in entries {
111 let format = ®istry[name];
112 emitter.output_container(name, format)?;
113 emitter.known_sizes.to_mut().insert(name);
114 }
115 Ok(())
116 }
117
118 pub fn quote_container_definitions(
120 &self,
121 registry: &Registry,
122 ) -> std::result::Result<BTreeMap<String, String>, Box<dyn std::error::Error>> {
123 let dependencies = analyzer::get_dependency_map(registry)?;
124 let entries = analyzer::best_effort_topological_sort(&dependencies);
125
126 let mut result = BTreeMap::new();
127 let mut known_sizes = HashSet::new();
128 let current_namespace = self
129 .config
130 .module_name
131 .split('.')
132 .map(String::from)
133 .collect::<Vec<_>>();
134
135 for name in entries {
136 let mut content = Vec::new();
137 {
138 let mut emitter = RustEmitter {
139 out: IndentedWriter::new(&mut content, IndentConfig::Space(4)),
140 generator: self,
141 known_sizes: Cow::Borrowed(&known_sizes),
142 current_namespace: current_namespace.clone(),
143 };
144 let format = ®istry[name];
145 emitter.output_container(name, format)?;
146 }
147 known_sizes.insert(name);
148 result.insert(
149 name.to_string(),
150 String::from_utf8_lossy(&content).trim().to_string() + "\n",
151 );
152 }
153 Ok(result)
154 }
155}
156
157impl<'a, T> RustEmitter<'a, T>
158where
159 T: std::io::Write,
160{
161 fn output_comment(&mut self, name: &str) -> std::io::Result<()> {
162 let mut path = self.current_namespace.clone();
163 path.push(name.to_string());
164 if let Some(doc) = self.generator.config.comments.get(&path) {
165 let text = textwrap::indent(doc, "/// ").replace("\n\n", "\n///\n");
166 write!(self.out, "\n{}", text)?;
167 }
168 Ok(())
169 }
170
171 fn output_custom_code(&mut self, name: &str) -> std::io::Result<()> {
172 let mut path = self.current_namespace.clone();
173 path.push(name.to_string());
174 if let Some(code) = self.generator.config.custom_code.get(&path) {
175 write!(self.out, "\n{}", code)?;
176 }
177 Ok(())
178 }
179
180 fn output_preamble(&mut self) -> Result<()> {
181 let external_names = self
182 .generator
183 .config
184 .external_definitions
185 .values()
186 .flatten()
187 .cloned()
188 .collect::<HashSet<_>>();
189 writeln!(self.out, "#![allow(unused_imports)]")?;
190 if !external_names.contains("Map") {
191 writeln!(self.out, "use std::collections::BTreeMap as Map;")?;
192 }
193 if self.generator.config.serialization {
194 writeln!(self.out, "use serde::{{Serialize, Deserialize}};")?;
195 }
196 if self.generator.config.serialization && !external_names.contains("Bytes") {
197 writeln!(self.out, "use serde_bytes::ByteBuf as Bytes;")?;
198 }
199 for (module, definitions) in &self.generator.config.external_definitions {
200 if !module.is_empty() {
202 writeln!(
203 self.out,
204 "use {}::{{{}}};",
205 module,
206 definitions.to_vec().join(", "),
207 )?;
208 }
209 }
210 writeln!(self.out)?;
211 if !self.generator.config.serialization && !external_names.contains("Bytes") {
212 writeln!(self.out, "type Bytes = Vec<u8>;\n")?;
214 }
215 Ok(())
216 }
217
218 fn quote_type(format: &Format, known_sizes: Option<&HashSet<&str>>) -> String {
219 use Format::*;
220 match format {
221 TypeName(x) => {
222 if let Some(set) = known_sizes {
223 if !set.contains(x.as_str()) {
224 return format!("Box<{}>", x);
225 }
226 }
227 x.to_string()
228 }
229 Unit => "()".into(),
230 Bool => "bool".into(),
231 I8 => "i8".into(),
232 I16 => "i16".into(),
233 I32 => "i32".into(),
234 I64 => "i64".into(),
235 I128 => "i128".into(),
236 U8 => "u8".into(),
237 U16 => "u16".into(),
238 U32 => "u32".into(),
239 U64 => "u64".into(),
240 U128 => "u128".into(),
241 F32 => "f32".into(),
242 F64 => "f64".into(),
243 Char => "char".into(),
244 Str => "String".into(),
245 Bytes => "Bytes".into(),
246
247 Option(format) => format!("Option<{}>", Self::quote_type(format, known_sizes)),
248 Seq(format) => format!("Vec<{}>", Self::quote_type(format, None)),
249 Map { key, value } => format!(
250 "Map<{}, {}>",
251 Self::quote_type(key, None),
252 Self::quote_type(value, None)
253 ),
254 Tuple(formats) => format!("({})", Self::quote_types(formats, known_sizes)),
255 TupleArray { content, size } => {
256 format!("[{}; {}]", Self::quote_type(content, known_sizes), *size)
257 }
258
259 Variable(_) => panic!("unexpected value"),
260 }
261 }
262
263 fn quote_types(formats: &[Format], known_sizes: Option<&HashSet<&str>>) -> String {
264 formats
265 .iter()
266 .map(|x| Self::quote_type(x, known_sizes))
267 .collect::<Vec<_>>()
268 .join(", ")
269 }
270
271 fn output_fields(&mut self, base: &[&str], fields: &[Named<Format>]) -> Result<()> {
272 let prefix = if base.len() <= 1 && self.generator.track_visibility {
274 "pub "
275 } else {
276 ""
277 };
278 for field in fields {
279 self.output_comment(&field.name)?;
280 writeln!(
281 self.out,
282 "{}{}: {},",
283 prefix,
284 field.name,
285 Self::quote_type(&field.value, Some(&self.known_sizes)),
286 )?;
287 }
288 Ok(())
289 }
290
291 fn output_variant(&mut self, base: &str, name: &str, variant: &VariantFormat) -> Result<()> {
292 self.output_comment(name)?;
293 use VariantFormat::*;
294 match variant {
295 Unit => writeln!(self.out, "{},", name),
296 NewType(format) => writeln!(
297 self.out,
298 "{}({}),",
299 name,
300 Self::quote_type(format, Some(&self.known_sizes))
301 ),
302 Tuple(formats) => writeln!(
303 self.out,
304 "{}({}),",
305 name,
306 Self::quote_types(formats, Some(&self.known_sizes))
307 ),
308 Struct(fields) => {
309 writeln!(self.out, "{} {{", name)?;
310 self.current_namespace.push(name.to_string());
311 self.out.indent();
312 self.output_fields(&[base, name], fields)?;
313 self.out.unindent();
314 self.current_namespace.pop();
315 writeln!(self.out, "}},")
316 }
317 Variable(_) => panic!("incorrect value"),
318 }
319 }
320
321 fn output_variants(
322 &mut self,
323 base: &str,
324 variants: &BTreeMap<u32, Named<VariantFormat>>,
325 ) -> Result<()> {
326 for (expected_index, (index, variant)) in variants.iter().enumerate() {
327 assert_eq!(*index, expected_index as u32);
328 self.output_variant(base, &variant.name, &variant.value)?;
329 }
330 Ok(())
331 }
332
333 fn output_container(&mut self, name: &str, format: &ContainerFormat) -> Result<()> {
334 self.output_comment(name)?;
335 let mut derive_macros = self.generator.derive_macros.clone();
336 if self.generator.config.serialization {
337 derive_macros.push("Serialize".to_string());
338 derive_macros.push("Deserialize".to_string());
339 }
340 let mut prefix = String::new();
341 if !derive_macros.is_empty() {
342 prefix.push_str(&format!("#[derive({})]\n", derive_macros.join(", ")));
343 }
344 if let Some(text) = &self.generator.custom_derive_block {
345 prefix.push_str(text);
346 prefix.push('\n');
347 }
348 if self.generator.track_visibility {
349 prefix.push_str("pub ");
350 }
351
352 use ContainerFormat::*;
353 match format {
354 UnitStruct => writeln!(self.out, "{}struct {};\n", prefix, name)?,
355 NewTypeStruct(format) => writeln!(
356 self.out,
357 "{}struct {}({}{});\n",
358 prefix,
359 name,
360 if self.generator.track_visibility {
361 "pub "
362 } else {
363 ""
364 },
365 Self::quote_type(format, Some(&self.known_sizes))
366 )?,
367 TupleStruct(formats) => writeln!(
368 self.out,
369 "{}struct {}({});\n",
370 prefix,
371 name,
372 Self::quote_types(formats, Some(&self.known_sizes))
373 )?,
374 Struct(fields) => {
375 writeln!(self.out, "{}struct {} {{", prefix, name)?;
376 self.current_namespace.push(name.to_string());
377 self.out.indent();
378 self.output_fields(&[name], fields)?;
379 self.out.unindent();
380 self.current_namespace.pop();
381 writeln!(self.out, "}}\n")?;
382 }
383 Enum(variants) => {
384 writeln!(self.out, "{}enum {} {{", prefix, name)?;
385 self.current_namespace.push(name.to_string());
386 self.out.indent();
387 self.output_variants(name, variants)?;
388 self.out.unindent();
389 self.current_namespace.pop();
390 writeln!(self.out, "}}\n")?;
391 }
392 }
393 self.output_custom_code(name)
394 }
395}
396
397pub struct Installer {
399 install_dir: PathBuf,
400}
401
402impl Installer {
403 pub fn new(install_dir: PathBuf) -> Self {
404 Installer { install_dir }
405 }
406
407 fn runtime_installation_message(name: &str) {
408 eprintln!("Not installing sources for published crate {}", name);
409 }
410}
411
412impl crate::SourceInstaller for Installer {
413 type Error = Box<dyn std::error::Error>;
414
415 fn install_module(
416 &self,
417 config: &CodeGeneratorConfig,
418 registry: &Registry,
419 ) -> std::result::Result<(), Self::Error> {
420 let generator = CodeGenerator::new(config);
421 let (name, version) = {
422 let parts = config.module_name.splitn(2, ':').collect::<Vec<_>>();
423 if parts.len() >= 2 {
424 (parts[0].to_string(), parts[1].to_string())
425 } else {
426 (parts[0].to_string(), "0.1.0".to_string())
427 }
428 };
429 let dir_path = self.install_dir.join(&name);
430 std::fs::create_dir_all(&dir_path)?;
431
432 if config.package_manifest {
433 let mut cargo = std::fs::File::create(dir_path.join("Cargo.toml"))?;
434 write!(
435 cargo,
436 r#"[package]
437name = "{}"
438version = "{}"
439edition = "2018"
440
441[dependencies]
442serde = {{ version = "1.0", features = ["derive"] }}
443serde_bytes = "0.11"
444"#,
445 name, version,
446 )?;
447 }
448
449 std::fs::create_dir_all(dir_path.join("src"))?;
450 let source_path = dir_path.join("src/lib.rs");
451 let mut source = std::fs::File::create(source_path)?;
452 generator.output(&mut source, registry)
453 }
454
455 fn install_serde_runtime(&self) -> std::result::Result<(), Self::Error> {
456 Self::runtime_installation_message("serde");
457 Ok(())
458 }
459
460 fn install_bincode_runtime(&self) -> std::result::Result<(), Self::Error> {
461 Self::runtime_installation_message("bincode");
462 Ok(())
463 }
464
465 fn install_bcs_runtime(&self) -> std::result::Result<(), Self::Error> {
466 Self::runtime_installation_message("bcs");
467 Ok(())
468 }
469}