amq_protocol_codegen/
templating.rs

1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5    self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6    JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7    to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12    collections::HashMap,
13    fs::{self, File},
14    io::Write,
15    path::Path,
16};
17
18/// Type alias to avoid making our users explicitly depend on an extra dependency
19pub type CodeGenerator<'a> = Handlebars<'a>;
20
21/// Our extension for better integration with Handlebars
22pub trait HandlebarsAMQPExtension {
23    /// Register the various standard helpers we'll need for AMQP codegen
24    fn register_amqp_helpers(self) -> Self;
25    /// Generate code using the standard representation of specs and the given template, using the
26    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html).
27    fn simple_codegen(
28        out_dir: &str,
29        target: &str,
30        template_name: &str,
31        template: &str,
32        var_name: &str,
33    ) {
34        Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35    }
36    /// Generate code using the standard representation of specs and the given template, using the
37    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html),
38    /// and also passing data to the template.
39    fn simple_codegen_with_data(
40        out_dir: &str,
41        target: &str,
42        template_name: &str,
43        template: &str,
44        var_name: &str,
45        data: Option<Value>,
46    );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50    fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51        self.register_escape_fn(handlebars::no_escape);
52        self.register_helper("camel", Box::new(CamelHelper));
53        self.register_helper("snake", Box::new(SnakeHelper));
54        self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55        self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56        self.register_helper("include_more", Box::new(IncludeMoreHelper));
57        self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58        self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59        self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60        self.register_helper("each_argument", Box::new(EachArgumentHelper));
61        self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62        self
63    }
64
65    fn simple_codegen_with_data(
66        out_dir: &str,
67        target: &str,
68        template_name: &str,
69        template: &str,
70        var_name: &str,
71        metadata: Option<Value>,
72    ) {
73        let dest_path = Path::new(out_dir).join(format!("{}.rs", target));
74        let mut f = File::create(&dest_path)
75            .unwrap_or_else(|err| panic!("Failed to create {:?}: {}", dest_path, err));
76        let specs = AMQProtocolDefinition::load(metadata);
77        let mut codegen = CodeGenerator::default().register_amqp_helpers();
78        let mut data = HashMap::new();
79
80        codegen.set_strict_mode(true);
81        codegen
82            .register_template_string(template_name, template)
83            .unwrap_or_else(|e| panic!("Failed to register {} template: {}", template_name, e));
84        data.insert(
85            var_name.to_string(),
86            serde_json::to_value(specs)
87                .unwrap_or_else(|e| panic!("Failed to serialize specs: {}", e)),
88        );
89
90        writeln!(
91            f,
92            "{}",
93            codegen
94                .render(template_name, &data)
95                .unwrap_or_else(|err| panic!(
96                    "Failed to render {} template: {}",
97                    template_name, err
98                ))
99        )
100        .unwrap_or_else(|e| panic!("Failed to generate {}.rs: {}", target, e));
101    }
102}
103
104/// Helper for converting text to camel case
105pub struct CamelHelper;
106impl HelperDef for CamelHelper {
107    fn call<'reg: 'rc, 'rc>(
108        &self,
109        h: &Helper<'rc>,
110        _: &'reg Handlebars<'_>,
111        _: &'rc Context,
112        _: &mut RenderContext<'reg, 'rc>,
113        out: &mut dyn Output,
114    ) -> HelperResult {
115        let value = h
116            .param(0)
117            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
118        let param = value.value().as_str().ok_or_else(|| {
119            RenderErrorReason::ParamTypeMismatchForName(
120                "camel",
121                "string".to_string(),
122                "string".to_string(),
123            )
124        })?;
125        out.write(&camel_case(param))?;
126        Ok(())
127    }
128}
129
130/// Helper for converting text to snake case
131pub struct SnakeHelper;
132impl HelperDef for SnakeHelper {
133    fn call<'reg: 'rc, 'rc>(
134        &self,
135        h: &Helper<'rc>,
136        _: &'reg Handlebars<'_>,
137        _: &'rc Context,
138        _: &mut RenderContext<'reg, 'rc>,
139        out: &mut dyn Output,
140    ) -> HelperResult {
141        let value = h
142            .param(0)
143            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
144        let raw = h
145            .param(1)
146            .and_then(|raw| raw.value().as_bool())
147            .unwrap_or(true);
148        let param = value.value().as_str().ok_or_else(|| {
149            RenderErrorReason::ParamTypeMismatchForName(
150                "snake",
151                "string".to_string(),
152                "string".to_string(),
153            )
154        })?;
155        out.write(&snake_case(param, raw))?;
156        Ok(())
157    }
158}
159
160/// Helper for getting the type name converted to snake case
161pub struct SnakeTypeHelper;
162impl HelperDef for SnakeTypeHelper {
163    fn call<'reg: 'rc, 'rc>(
164        &self,
165        h: &Helper<'rc>,
166        _: &'reg Handlebars<'_>,
167        _: &'rc Context,
168        _: &mut RenderContext<'reg, 'rc>,
169        out: &mut dyn Output,
170    ) -> HelperResult {
171        let value = h
172            .param(0)
173            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
174        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
175            RenderErrorReason::ParamTypeMismatchForName(
176                "snake_type",
177                "AMQPType".to_string(),
178                "string".to_string(),
179            )
180        })?;
181        out.write(&snake_case(&param.to_string(), true))?;
182        Ok(())
183    }
184}
185
186/// Helper to sanitize name so the it becomes a valid identifier
187pub struct SanitizeNameHelper;
188impl HelperDef for SanitizeNameHelper {
189    fn call<'reg: 'rc, 'rc>(
190        &self,
191        h: &Helper<'rc>,
192        _: &'reg Handlebars<'_>,
193        _: &'rc Context,
194        _: &mut RenderContext<'reg, 'rc>,
195        out: &mut dyn Output,
196    ) -> HelperResult {
197        let value = h
198            .param(0)
199            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
200        let param = value.value().as_str().ok_or_else(|| {
201            RenderErrorReason::ParamTypeMismatchForName(
202                "sanitize_name",
203                "string".to_string(),
204                "string".to_string(),
205            )
206        })?;
207        out.write(&param.replace('-', "_"))?;
208        Ok(())
209    }
210}
211
212/// Helper to include additional code such as rustdoc
213pub struct IncludeMoreHelper;
214impl HelperDef for IncludeMoreHelper {
215    fn call<'reg: 'rc, 'rc>(
216        &self,
217        h: &Helper<'rc>,
218        _: &'reg Handlebars<'_>,
219        _: &'rc Context,
220        _: &mut RenderContext<'reg, 'rc>,
221        out: &mut dyn Output,
222    ) -> HelperResult {
223        let amqp_class = h
224            .param(0)
225            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
226        let amqp_method = h
227            .param(1)
228            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
229        let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
230            RenderErrorReason::ParamTypeMismatchForName(
231                "include_more",
232                "string".to_string(),
233                "class".to_string(),
234            )
235        })?;
236        let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
237            RenderErrorReason::ParamTypeMismatchForName(
238                "include_more",
239                "string".to_string(),
240                "method".to_string(),
241            )
242        })?;
243        if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
244            let include = Path::new(&cargo_manifest_dir)
245                .join("templates")
246                .join("includes")
247                .join(amqp_class)
248                .join(format!("{}.rs", amqp_method));
249            if let Ok(include) = fs::read_to_string(include) {
250                out.write(&include)?;
251            }
252        }
253        Ok(())
254    }
255}
256
257/// Helper to check whether a param is passed by ref or not
258pub struct PassByRefHelper;
259impl HelperDef for PassByRefHelper {
260    fn call_inner<'reg: 'rc, 'rc>(
261        &self,
262        h: &Helper<'rc>,
263        _: &'reg Handlebars<'_>,
264        _: &'rc Context,
265        _: &mut RenderContext<'reg, 'rc>,
266    ) -> Result<ScopedJson<'rc>, RenderError> {
267        let value = h
268            .param(0)
269            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
270        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
271            RenderErrorReason::ParamTypeMismatchForName(
272                "pass_by_ref",
273                "AMQPType".to_string(),
274                "string".to_string(),
275            )
276        })?;
277        let pass_by_ref = matches!(
278            param,
279            AMQPType::ShortString
280                | AMQPType::LongString
281                | AMQPType::FieldArray
282                | AMQPType::FieldTable
283                | AMQPType::ByteArray
284        );
285        Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
286    }
287}
288
289/// Helper to check whether a param is passed using an &str or its real type
290pub struct UseStrRefHelper;
291impl HelperDef for UseStrRefHelper {
292    fn call_inner<'reg: 'rc, 'rc>(
293        &self,
294        h: &Helper<'rc>,
295        _: &'reg Handlebars<'_>,
296        _: &'rc Context,
297        _: &mut RenderContext<'reg, 'rc>,
298    ) -> Result<ScopedJson<'rc>, RenderError> {
299        let value = h
300            .param(0)
301            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
302        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
303        let use_str_ref = matches!(
304            param,
305            Some(AMQPType::ShortString) | Some(AMQPType::LongString)
306        );
307        Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
308    }
309}
310
311/// Helper to check whether a param is passed using an &[u8] or its real type
312pub struct UseBytesRefHelper;
313impl HelperDef for UseBytesRefHelper {
314    fn call_inner<'reg: 'rc, 'rc>(
315        &self,
316        h: &Helper<'rc>,
317        _: &'reg Handlebars<'_>,
318        _: &'rc Context,
319        _: &mut RenderContext<'reg, 'rc>,
320    ) -> Result<ScopedJson<'rc>, RenderError> {
321        let value = h
322            .param(0)
323            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
324        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
325        let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
326        Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
327    }
328}
329
330/// Helper to walk through a Vec of [AMQPArgument](../specs.AMQPArgument.html).
331pub struct EachArgumentHelper;
332impl HelperDef for EachArgumentHelper {
333    fn call<'reg: 'rc, 'rc>(
334        &self,
335        h: &Helper<'rc>,
336        r: &'reg Handlebars<'_>,
337        ctx: &'rc Context,
338        rc: &mut RenderContext<'reg, 'rc>,
339        out: &mut dyn Output,
340    ) -> HelperResult {
341        let value = h
342            .param(0)
343            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
344
345        if let Some(t) = h.template() {
346            let mut block_context = BlockContext::new();
347            if let Some(path) = value.context_path() {
348                *block_context.base_path_mut() = path.to_vec();
349            }
350            rc.push_block(block_context);
351            let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
352                .map_err(|_| {
353                    RenderErrorReason::ParamTypeMismatchForName(
354                        "each_argument",
355                        "Vec<AMQPArgument>".to_string(),
356                        "arguments".to_string(),
357                    )
358                })?;
359            let len = arguments.len();
360            let array_path = value.context_path();
361            for (index, argument) in arguments.iter().enumerate() {
362                if let Some(ref mut block) = rc.block_mut() {
363                    let (path, is_value) = match *argument {
364                        AMQPArgument::Value(_) => ("Value".to_owned(), true),
365                        AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
366                    };
367                    block.set_local_var("index", to_json(index));
368                    block.set_local_var("last", to_json(index == len - 1));
369                    block.set_local_var("argument_is_value", to_json(is_value));
370                    if let Some(p) = array_path {
371                        if index == 0 {
372                            let mut path = Vec::with_capacity(p.len() + 1);
373                            path.extend_from_slice(p);
374                            path.push(index.to_string());
375                            *block.base_path_mut() = path;
376                        } else if let Some(ptr) = block.base_path_mut().last_mut() {
377                            *ptr = index.to_string();
378                        }
379                    }
380                    if let Some(block_param) = h.block_param() {
381                        let mut params = BlockParams::new();
382                        params.add_path(block_param, vec![path])?;
383                        block.set_block_params(params);
384                    }
385                }
386                t.render(r, ctx, rc, out)?;
387            }
388            rc.pop_block();
389        }
390        Ok(())
391    }
392}
393
394/// Helper for "unwrapping" an amqp_value
395pub struct AMQPValueRefHelper;
396impl HelperDef for AMQPValueRefHelper {
397    fn call_inner<'reg: 'rc, 'rc>(
398        &self,
399        h: &Helper<'rc>,
400        _: &'reg Handlebars<'_>,
401        _: &'rc Context,
402        _: &mut RenderContext<'reg, 'rc>,
403    ) -> Result<ScopedJson<'rc>, RenderError> {
404        let arg = h
405            .param(0)
406            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
407        let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
408            RenderErrorReason::ParamTypeMismatchForName(
409                "amqp_value",
410                "AMQPValue".to_string(),
411                "value".to_string(),
412            )
413        })?;
414        let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
415        Ok(ScopedJson::Derived(value))
416    }
417}
418
419fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
420    match val {
421        AMQPValue::Boolean(v) => serde_json::to_value(v),
422        AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
423        AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
424        AMQPValue::ShortInt(v) => serde_json::to_value(v),
425        AMQPValue::ShortUInt(v) => serde_json::to_value(v),
426        AMQPValue::LongInt(v) => serde_json::to_value(v),
427        AMQPValue::LongUInt(v) => serde_json::to_value(v),
428        AMQPValue::LongLongInt(v) => serde_json::to_value(v),
429        AMQPValue::Float(v) => serde_json::to_value(v),
430        AMQPValue::Double(v) => serde_json::to_value(v),
431        AMQPValue::DecimalValue(v) => serde_json::to_value(v),
432        AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{}\"", v)),
433        AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{}\"", v)),
434        AMQPValue::FieldArray(v) => serde_json::to_value(v),
435        AMQPValue::Timestamp(v) => serde_json::to_value(v),
436        AMQPValue::FieldTable(v) => serde_json::to_value(v),
437        AMQPValue::ByteArray(v) => serde_json::to_value(v),
438        AMQPValue::Void => Ok(JsonValue::Null),
439    }
440}
441
442#[cfg(test)]
443mod test {
444    use super::*;
445
446    use std::collections::BTreeMap;
447
448    pub const TEMPLATE: &str = r#"
449{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
450{{protocol.copyright}}
451port {{protocol.port}}
452{{#each protocol.domains ~}}
453{{@key}}: {{this}}
454{{/each ~}}
455{{#each protocol.constants as |constant| ~}}
456{{constant.name}} = {{constant.value}}
457{{/each ~}}
458{{#each protocol.classes as |class| ~}}
459{{class.id}} - {{class.name}}
460{{#each class.properties as |property| ~}}
461{{property.name}}: {{property.type}}
462{{/each ~}}
463{{#each class.methods as |method| ~}}
464{{method.id}} - {{method.name}}
465synchronous: {{method.synchronous}}
466{{#each_argument method.arguments as |argument| ~}}
467{{#if @argument_is_value ~}}
468{{argument.name}}({{argument.domain}}): {{argument.type}}
469{{else}}
470{{#each argument.flags as |flag| ~}}
471{{flag.name}}: {{flag.default_value}}
472{{/each ~}}
473{{/if ~}}
474{{/each_argument ~}}
475{{/each ~}}
476{{/each ~}}
477"#;
478
479    fn specs() -> AMQProtocolDefinition {
480        let mut domains = BTreeMap::default();
481        domains.insert("domain1".to_string(), AMQPType::LongString);
482        AMQProtocolDefinition {
483            name: "AMQP".to_string(),
484            major_version: 0,
485            minor_version: 9,
486            revision: 1,
487            port: 5672,
488            copyright: "Copyright 1\nCopyright 2".to_string(),
489            domains,
490            constants: vec![AMQPConstant {
491                name: "constant1".to_string(),
492                amqp_type: AMQPType::ShortUInt,
493                value: 128,
494            }],
495            soft_errors: Vec::default(),
496            hard_errors: Vec::default(),
497            classes: vec![AMQPClass {
498                id: 42,
499                methods: vec![AMQPMethod {
500                    id: 64,
501                    arguments: vec![
502                        AMQPArgument::Value(AMQPValueArgument {
503                            amqp_type: AMQPType::LongString,
504                            name: "argument1".to_string(),
505                            default_value: Some(AMQPValue::LongString("value1".into())),
506                            domain: Some("domain1".to_string()),
507                            force_default: false,
508                        }),
509                        AMQPArgument::Flags(AMQPFlagsArgument {
510                            ignore_flags: false,
511                            flags: vec![
512                                AMQPFlagArgument {
513                                    name: "flag1".to_string(),
514                                    default_value: true,
515                                    force_default: false,
516                                },
517                                AMQPFlagArgument {
518                                    name: "flag2".to_string(),
519                                    default_value: false,
520                                    force_default: false,
521                                },
522                            ],
523                        }),
524                    ],
525                    name: "method1".to_string(),
526                    synchronous: true,
527                    content: false,
528                    metadata: Value::default(),
529                    is_reply: false,
530                    ignore_args: false,
531                    c2s: true,
532                    s2c: true,
533                }],
534                name: "class1".to_string(),
535                properties: vec![AMQPProperty {
536                    amqp_type: AMQPType::LongString,
537                    name: "property1".to_string(),
538                }],
539                metadata: Value::default(),
540            }],
541        }
542    }
543
544    #[test]
545    fn main_template() {
546        let mut data = HashMap::new();
547        let mut codegen = CodeGenerator::default().register_amqp_helpers();
548        data.insert("protocol".to_string(), specs());
549        assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
550        assert_eq!(
551            codegen.render("main", &data).unwrap(),
552            r#"
553AMQP - 0.9.1
554Copyright 1
555Copyright 2
556port 5672
557domain1: LongString
558constant1= 128
55942- class1
560property1: LongString
56164- method1
562synchronous: true
563argument1(domain1): LongString
564flag1: true
565flag2: false
566"#
567        );
568    }
569}