Skip to main content

amq_protocol_codegen/
templating.rs

1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5    self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6    JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7    to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12    collections::HashMap,
13    fs::{self, File},
14    io::Write,
15    path::Path,
16};
17
18/// Type alias to avoid making our users explicitly depend on an extra dependency
19pub type CodeGenerator<'a> = Handlebars<'a>;
20
21/// Our extension for better integration with Handlebars
22pub trait HandlebarsAMQPExtension {
23    /// Register the various standard helpers we'll need for AMQP codegen
24    fn register_amqp_helpers(self) -> Self;
25    /// Generate code using the standard representation of specs and the given template, using the
26    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html).
27    fn simple_codegen(
28        out_dir: &str,
29        target: &str,
30        template_name: &str,
31        template: &str,
32        var_name: &str,
33    ) {
34        Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35    }
36    /// Generate code using the standard representation of specs and the given template, using the
37    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html),
38    /// and also passing data to the template.
39    fn simple_codegen_with_data(
40        out_dir: &str,
41        target: &str,
42        template_name: &str,
43        template: &str,
44        var_name: &str,
45        data: Option<Value>,
46    );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50    fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51        self.register_escape_fn(handlebars::no_escape);
52        self.register_helper("camel", Box::new(CamelHelper));
53        self.register_helper("snake", Box::new(SnakeHelper));
54        self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55        self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56        self.register_helper("include_more", Box::new(IncludeMoreHelper));
57        self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58        self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59        self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60        self.register_helper("each_argument", Box::new(EachArgumentHelper));
61        self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62        self
63    }
64
65    fn simple_codegen_with_data(
66        out_dir: &str,
67        target: &str,
68        template_name: &str,
69        template: &str,
70        var_name: &str,
71        metadata: Option<Value>,
72    ) {
73        let dest_path = Path::new(out_dir).join(format!("{target}.rs"));
74        let mut f = File::create(&dest_path)
75            .unwrap_or_else(|err| panic!("Failed to create {dest_path:?}: {err}"));
76        let specs = AMQProtocolDefinition::load(metadata);
77        let mut codegen = CodeGenerator::default().register_amqp_helpers();
78        let mut data = HashMap::new();
79
80        codegen.set_strict_mode(true);
81        codegen
82            .register_template_string(template_name, template)
83            .unwrap_or_else(|e| panic!("Failed to register {template_name} template: {e}"));
84        data.insert(
85            var_name.to_string(),
86            serde_json::to_value(specs)
87                .unwrap_or_else(|e| panic!("Failed to serialize specs: {e}")),
88        );
89
90        writeln!(
91            f,
92            "{}",
93            codegen
94                .render(template_name, &data)
95                .unwrap_or_else(|err| panic!("Failed to render {template_name} template: {err}"))
96        )
97        .unwrap_or_else(|e| panic!("Failed to generate {target}.rs: {e}"));
98    }
99}
100
101/// Helper for converting text to camel case
102#[derive(Debug)]
103pub struct CamelHelper;
104impl HelperDef for CamelHelper {
105    fn call<'reg: 'rc, 'rc>(
106        &self,
107        h: &Helper<'rc>,
108        _: &'reg Handlebars<'_>,
109        _: &'rc Context,
110        _: &mut RenderContext<'reg, 'rc>,
111        out: &mut dyn Output,
112    ) -> HelperResult {
113        let value = h
114            .param(0)
115            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
116        let param = value.value().as_str().ok_or_else(|| {
117            RenderErrorReason::ParamTypeMismatchForName(
118                "camel",
119                "string".to_string(),
120                "string".to_string(),
121            )
122        })?;
123        out.write(&camel_case(param))?;
124        Ok(())
125    }
126}
127
128/// Helper for converting text to snake case
129#[derive(Debug)]
130pub struct SnakeHelper;
131impl HelperDef for SnakeHelper {
132    fn call<'reg: 'rc, 'rc>(
133        &self,
134        h: &Helper<'rc>,
135        _: &'reg Handlebars<'_>,
136        _: &'rc Context,
137        _: &mut RenderContext<'reg, 'rc>,
138        out: &mut dyn Output,
139    ) -> HelperResult {
140        let value = h
141            .param(0)
142            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
143        let raw = h
144            .param(1)
145            .and_then(|raw| raw.value().as_bool())
146            .unwrap_or(true);
147        let param = value.value().as_str().ok_or_else(|| {
148            RenderErrorReason::ParamTypeMismatchForName(
149                "snake",
150                "string".to_string(),
151                "string".to_string(),
152            )
153        })?;
154        out.write(&snake_case(param, raw))?;
155        Ok(())
156    }
157}
158
159/// Helper for getting the type name converted to snake case
160#[derive(Debug)]
161pub struct SnakeTypeHelper;
162impl HelperDef for SnakeTypeHelper {
163    fn call<'reg: 'rc, 'rc>(
164        &self,
165        h: &Helper<'rc>,
166        _: &'reg Handlebars<'_>,
167        _: &'rc Context,
168        _: &mut RenderContext<'reg, 'rc>,
169        out: &mut dyn Output,
170    ) -> HelperResult {
171        let value = h
172            .param(0)
173            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
174        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
175            RenderErrorReason::ParamTypeMismatchForName(
176                "snake_type",
177                "AMQPType".to_string(),
178                "string".to_string(),
179            )
180        })?;
181        out.write(&snake_case(&param.to_string(), true))?;
182        Ok(())
183    }
184}
185
186/// Helper to sanitize name so the it becomes a valid identifier
187#[derive(Debug)]
188pub struct SanitizeNameHelper;
189impl HelperDef for SanitizeNameHelper {
190    fn call<'reg: 'rc, 'rc>(
191        &self,
192        h: &Helper<'rc>,
193        _: &'reg Handlebars<'_>,
194        _: &'rc Context,
195        _: &mut RenderContext<'reg, 'rc>,
196        out: &mut dyn Output,
197    ) -> HelperResult {
198        let value = h
199            .param(0)
200            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
201        let param = value.value().as_str().ok_or_else(|| {
202            RenderErrorReason::ParamTypeMismatchForName(
203                "sanitize_name",
204                "string".to_string(),
205                "string".to_string(),
206            )
207        })?;
208        out.write(&param.replace('-', "_"))?;
209        Ok(())
210    }
211}
212
213/// Helper to include additional code such as rustdoc
214#[derive(Debug)]
215pub struct IncludeMoreHelper;
216impl HelperDef for IncludeMoreHelper {
217    fn call<'reg: 'rc, 'rc>(
218        &self,
219        h: &Helper<'rc>,
220        _: &'reg Handlebars<'_>,
221        _: &'rc Context,
222        _: &mut RenderContext<'reg, 'rc>,
223        out: &mut dyn Output,
224    ) -> HelperResult {
225        let amqp_class = h
226            .param(0)
227            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
228        let amqp_method = h
229            .param(1)
230            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
231        let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
232            RenderErrorReason::ParamTypeMismatchForName(
233                "include_more",
234                "string".to_string(),
235                "class".to_string(),
236            )
237        })?;
238        let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
239            RenderErrorReason::ParamTypeMismatchForName(
240                "include_more",
241                "string".to_string(),
242                "method".to_string(),
243            )
244        })?;
245        if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
246            let include = Path::new(&cargo_manifest_dir)
247                .join("templates")
248                .join("includes")
249                .join(amqp_class)
250                .join(format!("{amqp_method}.rs"));
251            if let Ok(include) = fs::read_to_string(include) {
252                out.write(&include)?;
253            }
254        }
255        Ok(())
256    }
257}
258
259/// Helper to check whether a param is passed by ref or not
260#[derive(Debug)]
261pub struct PassByRefHelper;
262impl HelperDef for PassByRefHelper {
263    fn call_inner<'reg: 'rc, 'rc>(
264        &self,
265        h: &Helper<'rc>,
266        _: &'reg Handlebars<'_>,
267        _: &'rc Context,
268        _: &mut RenderContext<'reg, 'rc>,
269    ) -> Result<ScopedJson<'rc>, RenderError> {
270        let value = h
271            .param(0)
272            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
273        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
274            RenderErrorReason::ParamTypeMismatchForName(
275                "pass_by_ref",
276                "AMQPType".to_string(),
277                "string".to_string(),
278            )
279        })?;
280        let pass_by_ref = matches!(
281            param,
282            AMQPType::ShortString
283                | AMQPType::LongString
284                | AMQPType::FieldArray
285                | AMQPType::FieldTable
286                | AMQPType::ByteArray
287        );
288        Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
289    }
290}
291
292/// Helper to check whether a param is passed using an &str or its real type
293#[derive(Debug)]
294pub struct UseStrRefHelper;
295impl HelperDef for UseStrRefHelper {
296    fn call_inner<'reg: 'rc, 'rc>(
297        &self,
298        h: &Helper<'rc>,
299        _: &'reg Handlebars<'_>,
300        _: &'rc Context,
301        _: &mut RenderContext<'reg, 'rc>,
302    ) -> Result<ScopedJson<'rc>, RenderError> {
303        let value = h
304            .param(0)
305            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
306        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
307        let use_str_ref = matches!(
308            param,
309            Some(AMQPType::ShortString) | Some(AMQPType::LongString)
310        );
311        Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
312    }
313}
314
315/// Helper to check whether a param is passed using an &[u8] or its real type
316#[derive(Debug)]
317pub struct UseBytesRefHelper;
318impl HelperDef for UseBytesRefHelper {
319    fn call_inner<'reg: 'rc, 'rc>(
320        &self,
321        h: &Helper<'rc>,
322        _: &'reg Handlebars<'_>,
323        _: &'rc Context,
324        _: &mut RenderContext<'reg, 'rc>,
325    ) -> Result<ScopedJson<'rc>, RenderError> {
326        let value = h
327            .param(0)
328            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
329        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
330        let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
331        Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
332    }
333}
334
335/// Helper to walk through a Vec of [AMQPArgument](../specs.AMQPArgument.html).
336#[derive(Debug)]
337pub struct EachArgumentHelper;
338impl HelperDef for EachArgumentHelper {
339    fn call<'reg: 'rc, 'rc>(
340        &self,
341        h: &Helper<'rc>,
342        r: &'reg Handlebars<'_>,
343        ctx: &'rc Context,
344        rc: &mut RenderContext<'reg, 'rc>,
345        out: &mut dyn Output,
346    ) -> HelperResult {
347        let value = h
348            .param(0)
349            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
350
351        if let Some(t) = h.template() {
352            let mut block_context = BlockContext::new();
353            if let Some(path) = value.context_path() {
354                *block_context.base_path_mut() = path.to_vec();
355            }
356            rc.push_block(block_context);
357            let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
358                .map_err(|_| {
359                    RenderErrorReason::ParamTypeMismatchForName(
360                        "each_argument",
361                        "Vec<AMQPArgument>".to_string(),
362                        "arguments".to_string(),
363                    )
364                })?;
365            let len = arguments.len();
366            let array_path = value.context_path();
367            for (index, argument) in arguments.iter().enumerate() {
368                if let Some(ref mut block) = rc.block_mut() {
369                    let (path, is_value) = match *argument {
370                        AMQPArgument::Value(_) => ("Value".to_owned(), true),
371                        AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
372                    };
373                    block.set_local_var("index", to_json(index));
374                    block.set_local_var("last", to_json(index == len - 1));
375                    block.set_local_var("argument_is_value", to_json(is_value));
376                    if let Some(p) = array_path {
377                        if index == 0 {
378                            let mut path = Vec::with_capacity(p.len() + 1);
379                            path.extend_from_slice(p);
380                            path.push(index.to_string());
381                            *block.base_path_mut() = path;
382                        } else if let Some(ptr) = block.base_path_mut().last_mut() {
383                            *ptr = index.to_string();
384                        }
385                    }
386                    if let Some(block_param) = h.block_param() {
387                        let mut params = BlockParams::new();
388                        params.add_path(block_param, vec![path])?;
389                        block.set_block_params(params);
390                    }
391                }
392                t.render(r, ctx, rc, out)?;
393            }
394            rc.pop_block();
395        }
396        Ok(())
397    }
398}
399
400/// Helper for "unwrapping" an amqp_value
401#[derive(Debug)]
402pub struct AMQPValueRefHelper;
403impl HelperDef for AMQPValueRefHelper {
404    fn call_inner<'reg: 'rc, 'rc>(
405        &self,
406        h: &Helper<'rc>,
407        _: &'reg Handlebars<'_>,
408        _: &'rc Context,
409        _: &mut RenderContext<'reg, 'rc>,
410    ) -> Result<ScopedJson<'rc>, RenderError> {
411        let arg = h
412            .param(0)
413            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
414        let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
415            RenderErrorReason::ParamTypeMismatchForName(
416                "amqp_value",
417                "AMQPValue".to_string(),
418                "value".to_string(),
419            )
420        })?;
421        let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
422        Ok(ScopedJson::Derived(value))
423    }
424}
425
426fn json_value(val: AMQPValue) -> serde_json::Result<Value> {
427    match val {
428        AMQPValue::Boolean(v) => serde_json::to_value(v),
429        AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
430        AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
431        AMQPValue::ShortInt(v) => serde_json::to_value(v),
432        AMQPValue::ShortUInt(v) => serde_json::to_value(v),
433        AMQPValue::LongInt(v) => serde_json::to_value(v),
434        AMQPValue::LongUInt(v) => serde_json::to_value(v),
435        AMQPValue::LongLongInt(v) => serde_json::to_value(v),
436        AMQPValue::Float(v) => serde_json::to_value(v),
437        AMQPValue::Double(v) => serde_json::to_value(v),
438        AMQPValue::DecimalValue(v) => serde_json::to_value(v),
439        AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{v}\"")),
440        AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{v}\"")),
441        AMQPValue::FieldArray(v) => serde_json::to_value(v),
442        AMQPValue::Timestamp(v) => serde_json::to_value(v),
443        AMQPValue::FieldTable(v) => serde_json::to_value(v),
444        AMQPValue::ByteArray(v) => serde_json::to_value(v),
445        AMQPValue::Void => Ok(JsonValue::Null),
446    }
447}
448
449#[cfg(test)]
450mod test {
451    use super::*;
452
453    use std::collections::BTreeMap;
454
455    const TEMPLATE: &str = r#"
456{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
457{{protocol.copyright}}
458port {{protocol.port}}
459{{#each protocol.domains ~}}
460{{@key}}: {{this}}
461{{/each ~}}
462{{#each protocol.constants as |constant| ~}}
463{{constant.name}} = {{constant.value}}
464{{/each ~}}
465{{#each protocol.classes as |class| ~}}
466{{class.id}} - {{class.name}}
467{{#each class.properties as |property| ~}}
468{{property.name}}: {{property.type}}
469{{/each ~}}
470{{#each class.methods as |method| ~}}
471{{method.id}} - {{method.name}}
472synchronous: {{method.synchronous}}
473{{#each_argument method.arguments as |argument| ~}}
474{{#if @argument_is_value ~}}
475{{argument.name}}({{argument.domain}}): {{argument.type}}
476{{else}}
477{{#each argument.flags as |flag| ~}}
478{{flag.name}}: {{flag.default_value}}
479{{/each ~}}
480{{/if ~}}
481{{/each_argument ~}}
482{{/each ~}}
483{{/each ~}}
484"#;
485
486    fn specs() -> AMQProtocolDefinition {
487        let mut domains = BTreeMap::default();
488        domains.insert("domain1".to_string(), AMQPType::LongString);
489        AMQProtocolDefinition {
490            name: "AMQP".to_string(),
491            major_version: 0,
492            minor_version: 9,
493            revision: 1,
494            port: 5672,
495            copyright: "Copyright 1\nCopyright 2".to_string(),
496            domains,
497            constants: vec![AMQPConstant {
498                name: "constant1".to_string(),
499                amqp_type: AMQPType::ShortUInt,
500                value: 128,
501            }],
502            soft_errors: Vec::default(),
503            hard_errors: Vec::default(),
504            classes: vec![AMQPClass {
505                id: 42,
506                methods: vec![AMQPMethod {
507                    id: 64,
508                    arguments: vec![
509                        AMQPArgument::Value(AMQPValueArgument {
510                            amqp_type: AMQPType::LongString,
511                            name: "argument1".to_string(),
512                            default_value: Some(AMQPValue::LongString("value1".into())),
513                            domain: Some("domain1".to_string()),
514                            force_default: false,
515                        }),
516                        AMQPArgument::Flags(AMQPFlagsArgument {
517                            ignore_flags: false,
518                            flags: vec![
519                                AMQPFlagArgument {
520                                    name: "flag1".to_string(),
521                                    default_value: true,
522                                    force_default: false,
523                                },
524                                AMQPFlagArgument {
525                                    name: "flag2".to_string(),
526                                    default_value: false,
527                                    force_default: false,
528                                },
529                            ],
530                        }),
531                    ],
532                    name: "method1".to_string(),
533                    synchronous: true,
534                    content: false,
535                    metadata: Value::default(),
536                    is_reply: false,
537                    ignore_args: false,
538                    c2s: true,
539                    s2c: true,
540                }],
541                name: "class1".to_string(),
542                properties: vec![AMQPProperty {
543                    amqp_type: AMQPType::LongString,
544                    name: "property1".to_string(),
545                }],
546                metadata: Value::default(),
547            }],
548        }
549    }
550
551    #[test]
552    fn main_template() {
553        let mut data = HashMap::new();
554        let mut codegen = CodeGenerator::default().register_amqp_helpers();
555        data.insert("protocol".to_string(), specs());
556        assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
557        assert_eq!(
558            codegen.render("main", &data).unwrap(),
559            r#"
560AMQP - 0.9.1
561Copyright 1
562Copyright 2
563port 5672
564domain1: LongString
565constant1= 128
56642- class1
567property1: LongString
56864- method1
569synchronous: true
570argument1(domain1): LongString
571flag1: true
572flag2: false
573"#
574        );
575    }
576}