weld_codegen/
gen.rs

1//! This module contains [Generator](#Generator) - config-driven code generation,
2//! and [CodeGen](#CodeGen), the trait for language-specific code-driven code generation.
3//!
4//!
5use std::{
6    borrow::Borrow,
7    collections::BTreeMap,
8    io::Write,
9    path::{Path, PathBuf},
10};
11
12use atelier_core::model::{
13    shapes::{AppliedTraits, HasTraits as _},
14    HasIdentity as _, Identifier, Model,
15};
16
17use crate::docgen::DocGen;
18use crate::{
19    codegen_go::GoCodeGen,
20    codegen_py::PythonCodeGen,
21    codegen_rust::RustCodeGen,
22    config::{CodegenConfig, LanguageConfig, OutputFile, OutputLanguage},
23    error::{Error, Result},
24    format::{NullFormatter, SourceFormatter},
25    model::{get_trait, serialization_trait, CommentKind, NumberedMember},
26    render::Renderer,
27    wasmbus_model::{RenameItem, Serialization},
28    writer::Writer,
29    Bytes, JsonValue, ParamMap, TomlValue,
30};
31
32/// Common templates compiled-in
33pub const COMMON_TEMPLATES: &[(&str, &str)] = &[];
34
35/// A Generator is a data-driven wrapper around code generator implementations,
36/// There are two main modes of generation:
37/// handlebars template-driven, and code-driven. For the latter, you implement
38/// a trait CodeGen, which gets callbacks for various parts of the code generation.
39/// One parameter to CodeGen is the output file name, which can be used if code is
40/// to be generated across several files. (for example, if you want
41/// one file to define interfaces and a separate file to define implementation classe).
42/// Handlebars-driven code generation may be more well-suited for files such as Makefiles
43/// and project config files that don't need a huge amount of customization,
44/// but they can also be used to generate 100% of an interface. You decide!
45#[derive(Debug, Default)]
46pub struct Generator {}
47
48impl<'model> Generator {
49    /// Perform code generation on model, iterating through each configured OutputFile.
50    /// Each file to be generated is either based on a handlebars
51    /// template, which is generated with the Renderer class,
52    /// or is generated by an implemented of the CodeGen implementation.
53    /// Function parameters:
54    /// - model - the smithy model
55    /// - config - CodeConfig (usually loaded from a codegen.toml file)
56    ///     - N.B. all relative paths (template dirs and the output_dir parameter) are
57    ///            adjusted to be relative to config.base_dir.
58    /// - templates - list of additional templates to register with the handlebars engine
59    ///          (The templates parameter of config is ignored. to use a list of file
60    ///          paths, call templates_from_dir() to load them and generate this parameter)
61    /// - output_dir - top-level folder containing all output files
62    /// - create - whether the user intends to create a new project (true) or just update (false).
63    ///   The generator does not check for existence of an output file before overwriting. Usually the
64    ///   `create` flag is set to true for the first pass when project files are created, after which
65    ///   time the project file can be manually edited, and only the main codegen output will be updated.
66    ///   The default (create is false/the flag is unspecified)) changes the fewest files
67    pub fn gen(
68        &self,
69        model: Option<&'model Model>,
70        config: CodegenConfig,
71        templates: Vec<(String, String)>,
72        output_dir: &Path,
73        defines: Vec<(String, TomlValue)>,
74    ) -> Result<()> {
75        let mut json_model = match model {
76            Some(model) => atelier_json::model_to_json(model),
77            None => JsonValue::default(),
78        };
79        let output_dir = if output_dir.is_absolute() {
80            output_dir.to_path_buf()
81        } else {
82            config.base_dir.join(output_dir)
83        };
84        // create one renderer so we only need to parse templates once
85        let mut renderer = Renderer::default();
86
87        for (name, template) in COMMON_TEMPLATES.iter() {
88            renderer.add_template((name, template))?;
89        }
90
91        for (language, mut lc) in config.languages.into_iter() {
92            if !config.output_languages.is_empty() && !config.output_languages.contains(&language) {
93                // if user specified list of languages, only generate code for those languages
94                continue;
95            }
96            // add templates from <lang>.templates
97            if let Some(template_dir) = &lc.templates {
98                let template_dir = if template_dir.is_absolute() {
99                    template_dir.clone()
100                } else {
101                    config.base_dir.join(template_dir)
102                };
103                for (name, tmpl) in templates_from_dir(&template_dir)? {
104                    renderer.add_template((&name, &tmpl))?;
105                }
106            }
107            // add templates from cli
108            for (name, template) in templates.iter() {
109                renderer.add_template((name, template))?;
110            }
111            // if language output_dir is relative, append it, otherwise use it
112            let output_dir = if lc.output_dir.is_absolute() {
113                lc.output_dir.clone()
114            } else {
115                output_dir.join(&lc.output_dir)
116            };
117            std::fs::create_dir_all(&output_dir).map_err(|e| {
118                Error::Io(format!(
119                    "creating directory {}: {}",
120                    output_dir.display(),
121                    e
122                ))
123            })?;
124            // add command-line overrides
125            for (k, v) in defines.iter() {
126                lc.parameters.insert(k.to_string(), v.clone());
127            }
128            let base_params: BTreeMap<String, JsonValue> = to_json(&lc.parameters)?;
129            let mut cgen = gen_for_language(&language, model);
130
131            // initialize generator
132            cgen.init(model, &lc, Some(&output_dir), &mut renderer)?;
133
134            // A common param dictionary is shared (read-only) by the renderer and the code generator,
135            // Parameters include the following:
136            //   - "model" - the entire model, generated by parsing one or more smithy files,
137            //               and converted to json-ast format
138            //   - "_file" - the output file (path is relative to the output directory
139            //   - LanguageConfig.parameters , applicable to all files for this output language
140            //   - OutputFile.parameters - file-specific parameters
141            //   The latter two are added in that order, so that a per-file parameter can
142            //   override per-language settings.
143
144            // The parameter dict is cleared after each iteration to avoid one file's override
145            // from leaking into the next file in the iteration. There are two "optimizations"
146            // in the loop below.
147            // - The handlebars renderer only parses templates once, so it is shared across output files,
148            // - The smithy model is parsed and validated once. After using it for one file, we pull it
149            //   out of the params map, save it aside, and then re-insert it the next time.
150
151            // list of files that were created or updated on this run
152            let mut updated_files = Vec::new();
153
154            for file_config in lc.files.iter() {
155                // for conditional files (with the `if_defined` property), check that we have the right conditions
156                if let Some(TomlValue::String(key)) = file_config.params.get("if_defined") {
157                    match lc.parameters.get(key) {
158                        None | Some(TomlValue::Boolean(false)) => {
159                            // not defined, do not generate
160                            continue;
161                        }
162                        Some(_) => {}
163                    }
164                }
165                let mut params = base_params.clone();
166                params.insert("model".to_string(), json_model);
167
168                let file_params: BTreeMap<String, JsonValue> = to_json(&file_config.params)?;
169                params.extend(file_params.into_iter());
170                params.insert(
171                    "_file".to_string(),
172                    JsonValue::String(file_config.path.to_string_lossy().to_string()),
173                );
174
175                let out_path = output_dir.join(&file_config.path);
176                let parent = out_path.parent().unwrap();
177                std::fs::create_dir_all(parent).map_err(|e| {
178                    Error::Io(format!("creating directory {}: {}", parent.display(), e))
179                })?;
180
181                // generate output using either hbs or CodeGen
182                if let Some(hbs) = &file_config.hbs {
183                    let mut out = std::fs::File::create(&out_path).map_err(|e| {
184                        Error::Io(format!("creating file {}: {}", &out_path.display(), e))
185                    })?;
186                    renderer.render(hbs, &params, &mut out)?;
187                    out.flush().map_err(|e| {
188                        crate::Error::Io(format!(
189                            "saving output file {}:{}",
190                            &out_path.display(),
191                            e
192                        ))
193                    })?;
194                } else if let Some(model) = model {
195                    let mut w: Writer = Writer::default();
196                    let bytes = cgen.generate_file(&mut w, model, file_config, &params)?;
197                    std::fs::write(&out_path, &bytes).map_err(|e| {
198                        Error::Io(format!("writing output file {}: {}", out_path.display(), e))
199                    })?;
200                };
201                updated_files.push(out_path);
202                // retrieve json_model for the next iteration
203                json_model = params.remove("model").unwrap();
204            }
205            cgen.format(updated_files, &lc)?;
206        }
207
208        Ok(())
209    }
210}
211
212fn gen_for_language<'model>(
213    language: &OutputLanguage,
214    model: Option<&'model Model>,
215) -> Box<dyn CodeGen + 'model> {
216    match language {
217        OutputLanguage::Rust => Box::new(RustCodeGen::new(model)),
218        //OutputLanguage::AssemblyScript => Box::new(AsmCodeGen::new(model)),
219        OutputLanguage::Python => Box::new(PythonCodeGen::new(model)),
220        OutputLanguage::TinyGo => Box::new(GoCodeGen::new(model, true)),
221        OutputLanguage::Go => Box::new(GoCodeGen::new(model, false)),
222        OutputLanguage::Html => Box::<DocGen>::default(),
223        OutputLanguage::Poly => Box::<PolyCodeGen>::default(),
224        _ => {
225            crate::error::print_warning(&format!("Target language {language} not implemented"));
226            Box::<NoCodeGen>::default()
227        }
228    }
229}
230
231/// A Codegen is used to generate source for a Smithy Model
232/// The generator will invoke these functions (in order)
233/// - init()
234/// - write_source_file_header()
235/// - declare_types()
236/// - write_services()
237/// - finalize()
238///
239pub trait CodeGen {
240    /// Initialize code generator and renderer for language output.j
241    /// This hook is called before any code is generated and can be used to initialize code generator
242    /// and/or perform additional processing before output files are created.
243    #[allow(unused_variables)]
244    fn init(
245        &mut self,
246        model: Option<&Model>,
247        lc: &LanguageConfig,
248        output_dir: Option<&Path>,
249        renderer: &mut Renderer,
250    ) -> std::result::Result<(), Error> {
251        Ok(())
252    }
253
254    /// This entrypoint drives output-file-specific code generation.
255    /// This default implementation invokes `init_file`, `write_source_file_header`, `declare_types`, `write_services`, and `finalize`.
256    /// The return value is Bytes containing the data that should be written to the output file.
257    fn generate_file(
258        &mut self,
259        w: &mut Writer,
260        model: &Model,
261        file_config: &OutputFile,
262        params: &ParamMap,
263    ) -> Result<Bytes> {
264        self.init_file(w, model, file_config, params)?;
265        self.write_source_file_header(w, model, params)?;
266        self.declare_types(w, model, params)?;
267        self.write_services(w, model, params)?;
268        self.finalize(w)
269    }
270
271    /// Perform any initialization required prior to code generation for a file
272    /// `model` may be used to check model metadata
273    /// `id` is a tag from codegen.toml that indicates which source file is to be written
274    /// `namespace` is the namespace in the model to generate
275    #[allow(unused_variables)]
276    fn init_file(
277        &mut self,
278        w: &mut Writer,
279        model: &Model,
280        file_config: &OutputFile,
281        params: &ParamMap,
282    ) -> Result<()> {
283        Ok(())
284    }
285
286    /// generate the source file header
287    #[allow(unused_variables)]
288    fn write_source_file_header(
289        &mut self,
290        w: &mut Writer,
291        model: &Model,
292        params: &ParamMap,
293    ) -> Result<()> {
294        Ok(())
295    }
296
297    /// Write declarations for simple types, maps, and structures
298    #[allow(unused_variables)]
299    fn declare_types(&mut self, w: &mut Writer, model: &Model, params: &ParamMap) -> Result<()> {
300        Ok(())
301    }
302
303    /// Write service declarations and implementation stubs
304    #[allow(unused_variables)]
305    fn write_services(&mut self, w: &mut Writer, model: &Model, params: &ParamMap) -> Result<()> {
306        Ok(())
307    }
308
309    /// Complete generation and return the output bytes
310    fn finalize(&mut self, w: &mut Writer) -> Result<Bytes> {
311        Ok(w.take().freeze())
312    }
313
314    /// Write documentation for item
315    #[allow(unused_variables)]
316    fn write_documentation(&mut self, w: &mut Writer, _id: &Identifier, text: &str) {
317        for line in text.split('\n') {
318            // remove whitespace from end of line
319            let line = line.trim_end_matches(|c| c == '\r' || c == ' ' || c == '\t');
320            self.write_comment(w, CommentKind::Documentation, line);
321        }
322    }
323
324    /// Writes single-line comment beginning with '// '
325    /// Can be overridden if more specific kinds are needed
326    #[allow(unused_variables)]
327    fn write_comment(&mut self, w: &mut Writer, kind: CommentKind, line: &str) {
328        w.write(b"// ");
329        w.write(line);
330        w.write(b"\n");
331    }
332
333    fn write_ident(&self, w: &mut Writer, id: &Identifier) {
334        w.write(&self.to_type_name_case(&id.to_string()));
335    }
336
337    /// append suffix to type name, for example "Game", "Context" -> "GameContext"
338    fn write_ident_with_suffix(
339        &mut self,
340        w: &mut Writer,
341        id: &Identifier,
342        suffix: &str,
343    ) -> Result<()> {
344        self.write_ident(w, id);
345        w.write(suffix); // assume it's already PascalCase
346        Ok(())
347    }
348
349    // Writes info the the current output writer
350    //fn write(&mut self, bytes: impl ToBytes);
351
352    // Returns the current buffer, zeroing out self
353    //fn take(&mut self) -> BytesMut;
354
355    /// Returns current output language
356    fn output_language(&self) -> OutputLanguage;
357
358    fn has_rename_trait(&self, traits: &AppliedTraits) -> Option<String> {
359        if let Ok(Some(items)) = get_trait::<Vec<RenameItem>>(traits, crate::model::rename_trait())
360        {
361            let lang = self.output_language().to_string();
362            return items.iter().find(|i| i.lang == lang).map(|i| i.name.clone());
363        }
364        None
365    }
366
367    /// returns file extension of source files for this language
368    fn get_file_extension(&self) -> &'static str {
369        self.output_language().extension()
370    }
371
372    /// Convert method name to its target-language-idiomatic case style
373    fn to_method_name_case(&self, name: &str) -> String;
374
375    /// Convert method name to its target-language-idiomatic case style
376    /// implementors should override to_method_name_case
377    fn to_method_name(&self, method_id: &Identifier, method_traits: &AppliedTraits) -> String {
378        if let Some(name) = self.has_rename_trait(method_traits) {
379            name
380        } else {
381            self.to_method_name_case(&method_id.to_string())
382        }
383    }
384
385    /// Convert field name to its target-language-idiomatic case style
386    fn to_field_name_case(&self, name: &str) -> String;
387
388    /// Convert field name to its target-language-idiomatic case style
389    /// implementors should override to_field_name_case
390    fn to_field_name(
391        &self,
392        member_id: &Identifier,
393        member_traits: &AppliedTraits,
394    ) -> std::result::Result<String, Error> {
395        if let Some(name) = self.has_rename_trait(member_traits) {
396            Ok(name)
397        } else {
398            Ok(self.to_field_name_case(&member_id.to_string()))
399        }
400    }
401
402    /// Convert type name to its target-language-idiomatic case style
403    fn to_type_name_case(&self, s: &str) -> String;
404
405    fn get_field_name_and_ser_name(&self, field: &NumberedMember) -> Result<(String, String)> {
406        let field_name = self.to_field_name(field.id(), field.traits())?;
407        let ser_name = if let Some(Serialization { name: Some(ser_name) }) =
408            get_trait(field.traits(), serialization_trait())?
409        {
410            ser_name
411        } else {
412            field.id().to_string()
413        };
414        Ok((field_name, ser_name))
415    }
416
417    /// The operation name used in dispatch, from method
418    /// The default implementation is provided and should not be overridden
419    fn op_dispatch_name(&self, id: &Identifier) -> String {
420        crate::strings::to_pascal_case(&id.to_string())
421    }
422
423    /// The full operation name with service prefix
424    /// The default implementation is provided and should not be overridden
425    fn full_dispatch_name(&self, service_id: &Identifier, method_id: &Identifier) -> String {
426        format!(
427            "{}.{}",
428            &self.to_type_name_case(&service_id.to_string()),
429            &self.op_dispatch_name(method_id)
430        )
431    }
432
433    fn source_formatter(&self, formatter: Vec<String>) -> Result<Box<dyn SourceFormatter>>;
434
435    /// After code generation has completed for all files, this method is called once per output language
436    /// to allow code formatters to run. The `files` parameter contains a list of all files written or updated.
437    fn format(
438        &mut self,
439        files: Vec<PathBuf>,
440        lc: &LanguageConfig,
441        //lc_params: &BTreeMap<String, TomlValue>,
442    ) -> Result<()> {
443        // if we just created an interface project, don't run rustfmt yet
444        // because we haven't generated the other rust file yet, so rustfmt will fail.
445        if !lc.parameters.contains_key("create_interface") {
446            // make a list of all output files with ".rs" extension so we can fix formatting with rustfmt
447            // minor nit: we don't check the _config-only flag so there could be some false positives here, but rustfmt is safe to use anyway
448            let formatter = self.source_formatter(lc.formatter.clone())?;
449
450            let extension = self.output_language().extension();
451            let sources = files
452                .into_iter()
453                .filter(|path| match path.extension() {
454                    Some(s) => s.to_string_lossy().as_ref() == extension,
455                    _ => false,
456                })
457                .collect::<Vec<PathBuf>>();
458
459            if !sources.is_empty() {
460                ensure_files_exist(&sources)?;
461
462                let file_names: Vec<std::borrow::Cow<'_, str>> =
463                    sources.iter().map(|p| p.to_string_lossy()).collect();
464                let borrowed = file_names.iter().map(|s| s.borrow()).collect::<Vec<&str>>();
465                formatter.run(&borrowed)?;
466            }
467        }
468        Ok(())
469    }
470}
471
472/// confirm all files are present, otherwise return error
473fn ensure_files_exist(source_files: &[std::path::PathBuf]) -> Result<()> {
474    let missing = source_files
475        .iter()
476        .filter(|p| !p.is_file())
477        .map(|p| p.to_string_lossy().into_owned())
478        .collect::<Vec<String>>();
479    if !missing.is_empty() {
480        return Err(Error::Formatter(format!(
481            "missing source file(s) '{}'",
482            missing.join(",")
483        )));
484    }
485    Ok(())
486}
487
488#[derive(Debug, Default)]
489struct PolyCodeGen {}
490impl CodeGen for PolyCodeGen {
491    fn output_language(&self) -> OutputLanguage {
492        OutputLanguage::Poly
493    }
494    /// generate method name
495    fn to_method_name_case(&self, name: &str) -> String {
496        crate::strings::to_snake_case(name)
497    }
498
499    /// generate field name
500    fn to_field_name_case(&self, name: &str) -> String {
501        crate::strings::to_snake_case(name)
502    }
503
504    /// generate type name
505    fn to_type_name_case(&self, name: &str) -> String {
506        crate::strings::to_pascal_case(name)
507    }
508
509    fn source_formatter(&self, _: Vec<String>) -> Result<Box<dyn SourceFormatter>> {
510        Ok(Box::<NullFormatter>::default())
511    }
512}
513
514#[allow(dead_code)]
515/// helper function for indenting code (used by python codegen)
516pub fn spaces(indent_level: u8) -> &'static str {
517    const SP: &str =
518        "                                                                                         \
519         \
520                                                                                                  ";
521    &SP[0..((indent_level * 4) as usize)]
522}
523
524// convert from TOML map to JSON map so it's usable by handlebars
525//pub fn toml_to_json(map: &BTreeMap<String, TomlValue>) -> Result<ParamMap> {
526//    let s = serde_json::to_string(map)?;
527//    let value: ParamMap = serde_json::from_str(&s)?;
528//    Ok(value)
529//}
530
531/// Converts a type to json
532pub fn to_json<S: serde::Serialize, T: serde::de::DeserializeOwned>(val: S) -> Result<T> {
533    let s = serde_json::to_string(&val)?;
534    Ok(serde_json::from_str(&s)?)
535}
536
537/// Search a folder recursively for files ending with the provided extension
538/// Filenames must be utf-8 characters
539pub fn find_files(dir: &Path, extension: &str) -> Result<Vec<PathBuf>> {
540    if dir.is_dir() {
541        let mut results = Vec::new();
542        for entry in std::fs::read_dir(dir)
543            .map_err(|e| Error::Io(format!("reading directory {}: {}", dir.display(), e)))?
544        {
545            let entry = entry.map_err(|e| crate::Error::Io(format!("scanning folder: {e}")))?;
546            let path = entry.path();
547            if path.is_dir() {
548                results.append(&mut find_files(&path, extension)?);
549            } else {
550                let ext = path
551                    .extension()
552                    .map(|s| s.to_string_lossy().to_string())
553                    .unwrap_or_default();
554                if ext == extension {
555                    results.push(path)
556                }
557            }
558        }
559        Ok(results)
560    } else if dir.is_file()
561        && &dir
562            .extension()
563            .map(|s| s.to_string_lossy().to_string())
564            .unwrap_or_default()
565            == "smithy"
566    {
567        Ok(vec![dir.to_owned()])
568    } else {
569        Err(Error::Other(format!(
570            "'{}' is not a valid folder or '.{}' file",
571            dir.display(),
572            extension
573        )))
574    }
575}
576
577/// Add all templates from the specified folder, using the base file name
578/// as the template name. For example, "header.hbs" will be registered as "header"
579pub fn templates_from_dir(start: &std::path::Path) -> Result<Vec<(String, String)>> {
580    let mut templates = Vec::new();
581
582    for path in crate::gen::find_files(start, "hbs")?.iter() {
583        let stem = path
584            .file_stem()
585            .map(|s| s.to_string_lossy().to_string())
586            .unwrap_or_default();
587        if !stem.is_empty() {
588            let template = std::fs::read_to_string(path)
589                .map_err(|e| Error::Io(format!("reading template {}: {}", path.display(), e)))?;
590            templates.push((stem, template));
591        }
592    }
593    Ok(templates)
594}
595
596#[derive(Default)]
597struct NoCodeGen {}
598impl CodeGen for NoCodeGen {
599    fn output_language(&self) -> OutputLanguage {
600        OutputLanguage::Poly
601    }
602
603    fn get_file_extension(&self) -> &'static str {
604        ""
605    }
606
607    fn to_method_name_case(&self, name: &str) -> String {
608        crate::strings::to_snake_case(name)
609    }
610
611    fn to_field_name_case(&self, name: &str) -> String {
612        crate::strings::to_snake_case(name)
613    }
614
615    fn to_type_name_case(&self, name: &str) -> String {
616        crate::strings::to_pascal_case(name)
617    }
618
619    fn source_formatter(&self, _: Vec<String>) -> Result<Box<dyn SourceFormatter>> {
620        Ok(Box::<NullFormatter>::default())
621    }
622}