amq_protocol_codegen/
templating.rs

1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5    self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6    JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7    to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12    collections::HashMap,
13    fs::{self, File},
14    io::Write,
15    path::Path,
16};
17
18/// Type alias to avoid making our users explicitly depend on an extra dependency
19pub type CodeGenerator<'a> = Handlebars<'a>;
20
21/// Our extension for better integration with Handlebars
22pub trait HandlebarsAMQPExtension {
23    /// Register the various standard helpers we'll need for AMQP codegen
24    fn register_amqp_helpers(self) -> Self;
25    /// Generate code using the standard representation of specs and the given template, using the
26    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html).
27    fn simple_codegen(
28        out_dir: &str,
29        target: &str,
30        template_name: &str,
31        template: &str,
32        var_name: &str,
33    ) {
34        Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35    }
36    /// Generate code using the standard representation of specs and the given template, using the
37    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html),
38    /// and also passing data to the template.
39    fn simple_codegen_with_data(
40        out_dir: &str,
41        target: &str,
42        template_name: &str,
43        template: &str,
44        var_name: &str,
45        data: Option<Value>,
46    );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50    fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51        self.register_escape_fn(handlebars::no_escape);
52        self.register_helper("camel", Box::new(CamelHelper));
53        self.register_helper("snake", Box::new(SnakeHelper));
54        self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55        self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56        self.register_helper("include_more", Box::new(IncludeMoreHelper));
57        self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58        self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59        self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60        self.register_helper("each_argument", Box::new(EachArgumentHelper));
61        self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62        self
63    }
64
65    fn simple_codegen_with_data(
66        out_dir: &str,
67        target: &str,
68        template_name: &str,
69        template: &str,
70        var_name: &str,
71        metadata: Option<Value>,
72    ) {
73        let dest_path = Path::new(out_dir).join(format!("{target}.rs"));
74        let mut f = File::create(&dest_path)
75            .unwrap_or_else(|err| panic!("Failed to create {dest_path:?}: {err}"));
76        let specs = AMQProtocolDefinition::load(metadata);
77        let mut codegen = CodeGenerator::default().register_amqp_helpers();
78        let mut data = HashMap::new();
79
80        codegen.set_strict_mode(true);
81        codegen
82            .register_template_string(template_name, template)
83            .unwrap_or_else(|e| panic!("Failed to register {template_name} template: {e}"));
84        data.insert(
85            var_name.to_string(),
86            serde_json::to_value(specs)
87                .unwrap_or_else(|e| panic!("Failed to serialize specs: {e}")),
88        );
89
90        writeln!(
91            f,
92            "{}",
93            codegen
94                .render(template_name, &data)
95                .unwrap_or_else(|err| panic!("Failed to render {template_name} template: {err}"))
96        )
97        .unwrap_or_else(|e| panic!("Failed to generate {target}.rs: {e}"));
98    }
99}
100
101/// Helper for converting text to camel case
102pub struct CamelHelper;
103impl HelperDef for CamelHelper {
104    fn call<'reg: 'rc, 'rc>(
105        &self,
106        h: &Helper<'rc>,
107        _: &'reg Handlebars<'_>,
108        _: &'rc Context,
109        _: &mut RenderContext<'reg, 'rc>,
110        out: &mut dyn Output,
111    ) -> HelperResult {
112        let value = h
113            .param(0)
114            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
115        let param = value.value().as_str().ok_or_else(|| {
116            RenderErrorReason::ParamTypeMismatchForName(
117                "camel",
118                "string".to_string(),
119                "string".to_string(),
120            )
121        })?;
122        out.write(&camel_case(param))?;
123        Ok(())
124    }
125}
126
127/// Helper for converting text to snake case
128pub struct SnakeHelper;
129impl HelperDef for SnakeHelper {
130    fn call<'reg: 'rc, 'rc>(
131        &self,
132        h: &Helper<'rc>,
133        _: &'reg Handlebars<'_>,
134        _: &'rc Context,
135        _: &mut RenderContext<'reg, 'rc>,
136        out: &mut dyn Output,
137    ) -> HelperResult {
138        let value = h
139            .param(0)
140            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
141        let raw = h
142            .param(1)
143            .and_then(|raw| raw.value().as_bool())
144            .unwrap_or(true);
145        let param = value.value().as_str().ok_or_else(|| {
146            RenderErrorReason::ParamTypeMismatchForName(
147                "snake",
148                "string".to_string(),
149                "string".to_string(),
150            )
151        })?;
152        out.write(&snake_case(param, raw))?;
153        Ok(())
154    }
155}
156
157/// Helper for getting the type name converted to snake case
158pub struct SnakeTypeHelper;
159impl HelperDef for SnakeTypeHelper {
160    fn call<'reg: 'rc, 'rc>(
161        &self,
162        h: &Helper<'rc>,
163        _: &'reg Handlebars<'_>,
164        _: &'rc Context,
165        _: &mut RenderContext<'reg, 'rc>,
166        out: &mut dyn Output,
167    ) -> HelperResult {
168        let value = h
169            .param(0)
170            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
171        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
172            RenderErrorReason::ParamTypeMismatchForName(
173                "snake_type",
174                "AMQPType".to_string(),
175                "string".to_string(),
176            )
177        })?;
178        out.write(&snake_case(&param.to_string(), true))?;
179        Ok(())
180    }
181}
182
183/// Helper to sanitize name so the it becomes a valid identifier
184pub struct SanitizeNameHelper;
185impl HelperDef for SanitizeNameHelper {
186    fn call<'reg: 'rc, 'rc>(
187        &self,
188        h: &Helper<'rc>,
189        _: &'reg Handlebars<'_>,
190        _: &'rc Context,
191        _: &mut RenderContext<'reg, 'rc>,
192        out: &mut dyn Output,
193    ) -> HelperResult {
194        let value = h
195            .param(0)
196            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
197        let param = value.value().as_str().ok_or_else(|| {
198            RenderErrorReason::ParamTypeMismatchForName(
199                "sanitize_name",
200                "string".to_string(),
201                "string".to_string(),
202            )
203        })?;
204        out.write(&param.replace('-', "_"))?;
205        Ok(())
206    }
207}
208
209/// Helper to include additional code such as rustdoc
210pub struct IncludeMoreHelper;
211impl HelperDef for IncludeMoreHelper {
212    fn call<'reg: 'rc, 'rc>(
213        &self,
214        h: &Helper<'rc>,
215        _: &'reg Handlebars<'_>,
216        _: &'rc Context,
217        _: &mut RenderContext<'reg, 'rc>,
218        out: &mut dyn Output,
219    ) -> HelperResult {
220        let amqp_class = h
221            .param(0)
222            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
223        let amqp_method = h
224            .param(1)
225            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
226        let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
227            RenderErrorReason::ParamTypeMismatchForName(
228                "include_more",
229                "string".to_string(),
230                "class".to_string(),
231            )
232        })?;
233        let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
234            RenderErrorReason::ParamTypeMismatchForName(
235                "include_more",
236                "string".to_string(),
237                "method".to_string(),
238            )
239        })?;
240        if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
241            let include = Path::new(&cargo_manifest_dir)
242                .join("templates")
243                .join("includes")
244                .join(amqp_class)
245                .join(format!("{amqp_method}.rs"));
246            if let Ok(include) = fs::read_to_string(include) {
247                out.write(&include)?;
248            }
249        }
250        Ok(())
251    }
252}
253
254/// Helper to check whether a param is passed by ref or not
255pub struct PassByRefHelper;
256impl HelperDef for PassByRefHelper {
257    fn call_inner<'reg: 'rc, 'rc>(
258        &self,
259        h: &Helper<'rc>,
260        _: &'reg Handlebars<'_>,
261        _: &'rc Context,
262        _: &mut RenderContext<'reg, 'rc>,
263    ) -> Result<ScopedJson<'rc>, RenderError> {
264        let value = h
265            .param(0)
266            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
267        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
268            RenderErrorReason::ParamTypeMismatchForName(
269                "pass_by_ref",
270                "AMQPType".to_string(),
271                "string".to_string(),
272            )
273        })?;
274        let pass_by_ref = matches!(
275            param,
276            AMQPType::ShortString
277                | AMQPType::LongString
278                | AMQPType::FieldArray
279                | AMQPType::FieldTable
280                | AMQPType::ByteArray
281        );
282        Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
283    }
284}
285
286/// Helper to check whether a param is passed using an &str or its real type
287pub struct UseStrRefHelper;
288impl HelperDef for UseStrRefHelper {
289    fn call_inner<'reg: 'rc, 'rc>(
290        &self,
291        h: &Helper<'rc>,
292        _: &'reg Handlebars<'_>,
293        _: &'rc Context,
294        _: &mut RenderContext<'reg, 'rc>,
295    ) -> Result<ScopedJson<'rc>, RenderError> {
296        let value = h
297            .param(0)
298            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
299        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
300        let use_str_ref = matches!(
301            param,
302            Some(AMQPType::ShortString) | Some(AMQPType::LongString)
303        );
304        Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
305    }
306}
307
308/// Helper to check whether a param is passed using an &[u8] or its real type
309pub struct UseBytesRefHelper;
310impl HelperDef for UseBytesRefHelper {
311    fn call_inner<'reg: 'rc, 'rc>(
312        &self,
313        h: &Helper<'rc>,
314        _: &'reg Handlebars<'_>,
315        _: &'rc Context,
316        _: &mut RenderContext<'reg, 'rc>,
317    ) -> Result<ScopedJson<'rc>, RenderError> {
318        let value = h
319            .param(0)
320            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
321        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
322        let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
323        Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
324    }
325}
326
327/// Helper to walk through a Vec of [AMQPArgument](../specs.AMQPArgument.html).
328pub struct EachArgumentHelper;
329impl HelperDef for EachArgumentHelper {
330    fn call<'reg: 'rc, 'rc>(
331        &self,
332        h: &Helper<'rc>,
333        r: &'reg Handlebars<'_>,
334        ctx: &'rc Context,
335        rc: &mut RenderContext<'reg, 'rc>,
336        out: &mut dyn Output,
337    ) -> HelperResult {
338        let value = h
339            .param(0)
340            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
341
342        if let Some(t) = h.template() {
343            let mut block_context = BlockContext::new();
344            if let Some(path) = value.context_path() {
345                *block_context.base_path_mut() = path.to_vec();
346            }
347            rc.push_block(block_context);
348            let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
349                .map_err(|_| {
350                    RenderErrorReason::ParamTypeMismatchForName(
351                        "each_argument",
352                        "Vec<AMQPArgument>".to_string(),
353                        "arguments".to_string(),
354                    )
355                })?;
356            let len = arguments.len();
357            let array_path = value.context_path();
358            for (index, argument) in arguments.iter().enumerate() {
359                if let Some(ref mut block) = rc.block_mut() {
360                    let (path, is_value) = match *argument {
361                        AMQPArgument::Value(_) => ("Value".to_owned(), true),
362                        AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
363                    };
364                    block.set_local_var("index", to_json(index));
365                    block.set_local_var("last", to_json(index == len - 1));
366                    block.set_local_var("argument_is_value", to_json(is_value));
367                    if let Some(p) = array_path {
368                        if index == 0 {
369                            let mut path = Vec::with_capacity(p.len() + 1);
370                            path.extend_from_slice(p);
371                            path.push(index.to_string());
372                            *block.base_path_mut() = path;
373                        } else if let Some(ptr) = block.base_path_mut().last_mut() {
374                            *ptr = index.to_string();
375                        }
376                    }
377                    if let Some(block_param) = h.block_param() {
378                        let mut params = BlockParams::new();
379                        params.add_path(block_param, vec![path])?;
380                        block.set_block_params(params);
381                    }
382                }
383                t.render(r, ctx, rc, out)?;
384            }
385            rc.pop_block();
386        }
387        Ok(())
388    }
389}
390
391/// Helper for "unwrapping" an amqp_value
392pub struct AMQPValueRefHelper;
393impl HelperDef for AMQPValueRefHelper {
394    fn call_inner<'reg: 'rc, 'rc>(
395        &self,
396        h: &Helper<'rc>,
397        _: &'reg Handlebars<'_>,
398        _: &'rc Context,
399        _: &mut RenderContext<'reg, 'rc>,
400    ) -> Result<ScopedJson<'rc>, RenderError> {
401        let arg = h
402            .param(0)
403            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
404        let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
405            RenderErrorReason::ParamTypeMismatchForName(
406                "amqp_value",
407                "AMQPValue".to_string(),
408                "value".to_string(),
409            )
410        })?;
411        let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
412        Ok(ScopedJson::Derived(value))
413    }
414}
415
416fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
417    match val {
418        AMQPValue::Boolean(v) => serde_json::to_value(v),
419        AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
420        AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
421        AMQPValue::ShortInt(v) => serde_json::to_value(v),
422        AMQPValue::ShortUInt(v) => serde_json::to_value(v),
423        AMQPValue::LongInt(v) => serde_json::to_value(v),
424        AMQPValue::LongUInt(v) => serde_json::to_value(v),
425        AMQPValue::LongLongInt(v) => serde_json::to_value(v),
426        AMQPValue::Float(v) => serde_json::to_value(v),
427        AMQPValue::Double(v) => serde_json::to_value(v),
428        AMQPValue::DecimalValue(v) => serde_json::to_value(v),
429        AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{v}\"")),
430        AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{v}\"")),
431        AMQPValue::FieldArray(v) => serde_json::to_value(v),
432        AMQPValue::Timestamp(v) => serde_json::to_value(v),
433        AMQPValue::FieldTable(v) => serde_json::to_value(v),
434        AMQPValue::ByteArray(v) => serde_json::to_value(v),
435        AMQPValue::Void => Ok(JsonValue::Null),
436    }
437}
438
439#[cfg(test)]
440mod test {
441    use super::*;
442
443    use std::collections::BTreeMap;
444
445    pub const TEMPLATE: &str = r#"
446{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
447{{protocol.copyright}}
448port {{protocol.port}}
449{{#each protocol.domains ~}}
450{{@key}}: {{this}}
451{{/each ~}}
452{{#each protocol.constants as |constant| ~}}
453{{constant.name}} = {{constant.value}}
454{{/each ~}}
455{{#each protocol.classes as |class| ~}}
456{{class.id}} - {{class.name}}
457{{#each class.properties as |property| ~}}
458{{property.name}}: {{property.type}}
459{{/each ~}}
460{{#each class.methods as |method| ~}}
461{{method.id}} - {{method.name}}
462synchronous: {{method.synchronous}}
463{{#each_argument method.arguments as |argument| ~}}
464{{#if @argument_is_value ~}}
465{{argument.name}}({{argument.domain}}): {{argument.type}}
466{{else}}
467{{#each argument.flags as |flag| ~}}
468{{flag.name}}: {{flag.default_value}}
469{{/each ~}}
470{{/if ~}}
471{{/each_argument ~}}
472{{/each ~}}
473{{/each ~}}
474"#;
475
476    fn specs() -> AMQProtocolDefinition {
477        let mut domains = BTreeMap::default();
478        domains.insert("domain1".to_string(), AMQPType::LongString);
479        AMQProtocolDefinition {
480            name: "AMQP".to_string(),
481            major_version: 0,
482            minor_version: 9,
483            revision: 1,
484            port: 5672,
485            copyright: "Copyright 1\nCopyright 2".to_string(),
486            domains,
487            constants: vec![AMQPConstant {
488                name: "constant1".to_string(),
489                amqp_type: AMQPType::ShortUInt,
490                value: 128,
491            }],
492            soft_errors: Vec::default(),
493            hard_errors: Vec::default(),
494            classes: vec![AMQPClass {
495                id: 42,
496                methods: vec![AMQPMethod {
497                    id: 64,
498                    arguments: vec![
499                        AMQPArgument::Value(AMQPValueArgument {
500                            amqp_type: AMQPType::LongString,
501                            name: "argument1".to_string(),
502                            default_value: Some(AMQPValue::LongString("value1".into())),
503                            domain: Some("domain1".to_string()),
504                            force_default: false,
505                        }),
506                        AMQPArgument::Flags(AMQPFlagsArgument {
507                            ignore_flags: false,
508                            flags: vec![
509                                AMQPFlagArgument {
510                                    name: "flag1".to_string(),
511                                    default_value: true,
512                                    force_default: false,
513                                },
514                                AMQPFlagArgument {
515                                    name: "flag2".to_string(),
516                                    default_value: false,
517                                    force_default: false,
518                                },
519                            ],
520                        }),
521                    ],
522                    name: "method1".to_string(),
523                    synchronous: true,
524                    content: false,
525                    metadata: Value::default(),
526                    is_reply: false,
527                    ignore_args: false,
528                    c2s: true,
529                    s2c: true,
530                }],
531                name: "class1".to_string(),
532                properties: vec![AMQPProperty {
533                    amqp_type: AMQPType::LongString,
534                    name: "property1".to_string(),
535                }],
536                metadata: Value::default(),
537            }],
538        }
539    }
540
541    #[test]
542    fn main_template() {
543        let mut data = HashMap::new();
544        let mut codegen = CodeGenerator::default().register_amqp_helpers();
545        data.insert("protocol".to_string(), specs());
546        assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
547        assert_eq!(
548            codegen.render("main", &data).unwrap(),
549            r#"
550AMQP - 0.9.1
551Copyright 1
552Copyright 2
553port 5672
554domain1: LongString
555constant1= 128
55642- class1
557property1: LongString
55864- method1
559synchronous: true
560argument1(domain1): LongString
561flag1: true
562flag2: false
563"#
564        );
565    }
566}