amq_protocol_codegen/
templating.rs

1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5    self, to_json, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6    JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7};
8use serde_json::{self, Value};
9
10use std::{
11    collections::HashMap,
12    fs::{self, File},
13    io::Write,
14    path::Path,
15};
16
17/// Type alias to avoid making our users explicitly depend on an extra dependency
18pub type CodeGenerator<'a> = Handlebars<'a>;
19
20/// Our extension for better integration with Handlebars
21pub trait HandlebarsAMQPExtension {
22    /// Register the various standard helpers we'll need for AMQP codegen
23    fn register_amqp_helpers(self) -> Self;
24    /// Generate code using the standard representation of specs and the given template, using the
25    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html).
26    fn simple_codegen(
27        out_dir: &str,
28        target: &str,
29        template_name: &str,
30        template: &str,
31        var_name: &str,
32    ) {
33        Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
34    }
35    /// Generate code using the standard representation of specs and the given template, using the
36    /// given name for the variable holding the [protocol definition](../specs.AMQProtocolDefinition.html),
37    /// and also passing data to the template.
38    fn simple_codegen_with_data(
39        out_dir: &str,
40        target: &str,
41        template_name: &str,
42        template: &str,
43        var_name: &str,
44        data: Option<Value>,
45    );
46}
47
48impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
49    fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
50        self.register_escape_fn(handlebars::no_escape);
51        self.register_helper("camel", Box::new(CamelHelper));
52        self.register_helper("snake", Box::new(SnakeHelper));
53        self.register_helper("snake_type", Box::new(SnakeTypeHelper));
54        self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
55        self.register_helper("include_more", Box::new(IncludeMoreHelper));
56        self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
57        self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
58        self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
59        self.register_helper("each_argument", Box::new(EachArgumentHelper));
60        self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
61        self
62    }
63
64    fn simple_codegen_with_data(
65        out_dir: &str,
66        target: &str,
67        template_name: &str,
68        template: &str,
69        var_name: &str,
70        metadata: Option<Value>,
71    ) {
72        let dest_path = Path::new(out_dir).join(format!("{}.rs", target));
73        let mut f = File::create(&dest_path)
74            .unwrap_or_else(|err| panic!("Failed to create {:?}: {}", dest_path, err));
75        let specs = AMQProtocolDefinition::load(metadata);
76        let mut codegen = CodeGenerator::default().register_amqp_helpers();
77        let mut data = HashMap::new();
78
79        codegen.set_strict_mode(true);
80        codegen
81            .register_template_string(template_name, template)
82            .unwrap_or_else(|e| panic!("Failed to register {} template: {}", template_name, e));
83        data.insert(
84            var_name.to_string(),
85            serde_json::to_value(specs)
86                .unwrap_or_else(|e| panic!("Failed to serialize specs: {}", e)),
87        );
88
89        writeln!(
90            f,
91            "{}",
92            codegen
93                .render(template_name, &data)
94                .unwrap_or_else(|err| panic!(
95                    "Failed to render {} template: {}",
96                    template_name, err
97                ))
98        )
99        .unwrap_or_else(|e| panic!("Failed to generate {}.rs: {}", target, e));
100    }
101}
102
103/// Helper for converting text to camel case
104pub struct CamelHelper;
105impl HelperDef for CamelHelper {
106    fn call<'reg: 'rc, 'rc>(
107        &self,
108        h: &Helper<'rc>,
109        _: &'reg Handlebars<'_>,
110        _: &'rc Context,
111        _: &mut RenderContext<'reg, 'rc>,
112        out: &mut dyn Output,
113    ) -> HelperResult {
114        let value = h
115            .param(0)
116            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
117        let param = value.value().as_str().ok_or_else(|| {
118            RenderErrorReason::ParamTypeMismatchForName(
119                "camel",
120                "string".to_string(),
121                "string".to_string(),
122            )
123        })?;
124        out.write(&camel_case(param))?;
125        Ok(())
126    }
127}
128
129/// Helper for converting text to snake case
130pub struct SnakeHelper;
131impl HelperDef for SnakeHelper {
132    fn call<'reg: 'rc, 'rc>(
133        &self,
134        h: &Helper<'rc>,
135        _: &'reg Handlebars<'_>,
136        _: &'rc Context,
137        _: &mut RenderContext<'reg, 'rc>,
138        out: &mut dyn Output,
139    ) -> HelperResult {
140        let value = h
141            .param(0)
142            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
143        let raw = h
144            .param(1)
145            .and_then(|raw| raw.value().as_bool())
146            .unwrap_or(true);
147        let param = value.value().as_str().ok_or_else(|| {
148            RenderErrorReason::ParamTypeMismatchForName(
149                "snake",
150                "string".to_string(),
151                "string".to_string(),
152            )
153        })?;
154        out.write(&snake_case(param, raw))?;
155        Ok(())
156    }
157}
158
159/// Helper for getting the type name converted to snake case
160pub struct SnakeTypeHelper;
161impl HelperDef for SnakeTypeHelper {
162    fn call<'reg: 'rc, 'rc>(
163        &self,
164        h: &Helper<'rc>,
165        _: &'reg Handlebars<'_>,
166        _: &'rc Context,
167        _: &mut RenderContext<'reg, 'rc>,
168        out: &mut dyn Output,
169    ) -> HelperResult {
170        let value = h
171            .param(0)
172            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
173        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
174            RenderErrorReason::ParamTypeMismatchForName(
175                "snake_type",
176                "AMQPType".to_string(),
177                "string".to_string(),
178            )
179        })?;
180        out.write(&snake_case(&param.to_string(), true))?;
181        Ok(())
182    }
183}
184
185/// Helper to sanitize name so the it becomes a valid identifier
186pub struct SanitizeNameHelper;
187impl HelperDef for SanitizeNameHelper {
188    fn call<'reg: 'rc, 'rc>(
189        &self,
190        h: &Helper<'rc>,
191        _: &'reg Handlebars<'_>,
192        _: &'rc Context,
193        _: &mut RenderContext<'reg, 'rc>,
194        out: &mut dyn Output,
195    ) -> HelperResult {
196        let value = h
197            .param(0)
198            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
199        let param = value.value().as_str().ok_or_else(|| {
200            RenderErrorReason::ParamTypeMismatchForName(
201                "sanitize_name",
202                "string".to_string(),
203                "string".to_string(),
204            )
205        })?;
206        out.write(&param.replace('-', "_"))?;
207        Ok(())
208    }
209}
210
211/// Helper to include additional code such as rustdoc
212pub struct IncludeMoreHelper;
213impl HelperDef for IncludeMoreHelper {
214    fn call<'reg: 'rc, 'rc>(
215        &self,
216        h: &Helper<'rc>,
217        _: &'reg Handlebars<'_>,
218        _: &'rc Context,
219        _: &mut RenderContext<'reg, 'rc>,
220        out: &mut dyn Output,
221    ) -> HelperResult {
222        let amqp_class = h
223            .param(0)
224            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
225        let amqp_method = h
226            .param(1)
227            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
228        let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
229            RenderErrorReason::ParamTypeMismatchForName(
230                "include_more",
231                "string".to_string(),
232                "class".to_string(),
233            )
234        })?;
235        let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
236            RenderErrorReason::ParamTypeMismatchForName(
237                "include_more",
238                "string".to_string(),
239                "method".to_string(),
240            )
241        })?;
242        if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
243            let include = Path::new(&cargo_manifest_dir)
244                .join("templates")
245                .join("includes")
246                .join(amqp_class)
247                .join(format!("{}.rs", amqp_method));
248            if let Ok(include) = fs::read_to_string(include) {
249                out.write(&include)?;
250            }
251        }
252        Ok(())
253    }
254}
255
256/// Helper to check whether a param is passed by ref or not
257pub struct PassByRefHelper;
258impl HelperDef for PassByRefHelper {
259    fn call_inner<'reg: 'rc, 'rc>(
260        &self,
261        h: &Helper<'rc>,
262        _: &'reg Handlebars<'_>,
263        _: &'rc Context,
264        _: &mut RenderContext<'reg, 'rc>,
265    ) -> Result<ScopedJson<'rc>, RenderError> {
266        let value = h
267            .param(0)
268            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
269        let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
270            RenderErrorReason::ParamTypeMismatchForName(
271                "pass_by_ref",
272                "AMQPType".to_string(),
273                "string".to_string(),
274            )
275        })?;
276        let pass_by_ref = matches!(
277            param,
278            AMQPType::ShortString
279                | AMQPType::LongString
280                | AMQPType::FieldArray
281                | AMQPType::FieldTable
282                | AMQPType::ByteArray
283        );
284        Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
285    }
286}
287
288/// Helper to check whether a param is passed using an &str or its real type
289pub struct UseStrRefHelper;
290impl HelperDef for UseStrRefHelper {
291    fn call_inner<'reg: 'rc, 'rc>(
292        &self,
293        h: &Helper<'rc>,
294        _: &'reg Handlebars<'_>,
295        _: &'rc Context,
296        _: &mut RenderContext<'reg, 'rc>,
297    ) -> Result<ScopedJson<'rc>, RenderError> {
298        let value = h
299            .param(0)
300            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
301        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
302        let use_str_ref = matches!(
303            param,
304            Some(AMQPType::ShortString) | Some(AMQPType::LongString)
305        );
306        Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
307    }
308}
309
310/// Helper to check whether a param is passed using an &[u8] or its real type
311pub struct UseBytesRefHelper;
312impl HelperDef for UseBytesRefHelper {
313    fn call_inner<'reg: 'rc, 'rc>(
314        &self,
315        h: &Helper<'rc>,
316        _: &'reg Handlebars<'_>,
317        _: &'rc Context,
318        _: &mut RenderContext<'reg, 'rc>,
319    ) -> Result<ScopedJson<'rc>, RenderError> {
320        let value = h
321            .param(0)
322            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
323        let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
324        let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
325        Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
326    }
327}
328
329/// Helper to walk through a Vec of [AMQPArgument](../specs.AMQPArgument.html).
330pub struct EachArgumentHelper;
331impl HelperDef for EachArgumentHelper {
332    fn call<'reg: 'rc, 'rc>(
333        &self,
334        h: &Helper<'rc>,
335        r: &'reg Handlebars<'_>,
336        ctx: &'rc Context,
337        rc: &mut RenderContext<'reg, 'rc>,
338        out: &mut dyn Output,
339    ) -> HelperResult {
340        let value = h
341            .param(0)
342            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
343
344        if let Some(t) = h.template() {
345            let mut block_context = BlockContext::new();
346            if let Some(path) = value.context_path() {
347                *block_context.base_path_mut() = path.to_vec();
348            }
349            rc.push_block(block_context);
350            let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
351                .map_err(|_| {
352                    RenderErrorReason::ParamTypeMismatchForName(
353                        "each_argument",
354                        "Vec<AMQPArgument>".to_string(),
355                        "arguments".to_string(),
356                    )
357                })?;
358            let len = arguments.len();
359            let array_path = value.context_path();
360            for (index, argument) in arguments.iter().enumerate() {
361                if let Some(ref mut block) = rc.block_mut() {
362                    let (path, is_value) = match *argument {
363                        AMQPArgument::Value(_) => ("Value".to_owned(), true),
364                        AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
365                    };
366                    block.set_local_var("index", to_json(index));
367                    block.set_local_var("last", to_json(index == len - 1));
368                    block.set_local_var("argument_is_value", to_json(is_value));
369                    if let Some(p) = array_path {
370                        if index == 0 {
371                            let mut path = Vec::with_capacity(p.len() + 1);
372                            path.extend_from_slice(p);
373                            path.push(index.to_string());
374                            *block.base_path_mut() = path;
375                        } else if let Some(ptr) = block.base_path_mut().last_mut() {
376                            *ptr = index.to_string();
377                        }
378                    }
379                    if let Some(block_param) = h.block_param() {
380                        let mut params = BlockParams::new();
381                        params.add_path(block_param, vec![path])?;
382                        block.set_block_params(params);
383                    }
384                }
385                t.render(r, ctx, rc, out)?;
386            }
387            rc.pop_block();
388        }
389        Ok(())
390    }
391}
392
393/// Helper for "unwrapping" an amqp_value
394pub struct AMQPValueRefHelper;
395impl HelperDef for AMQPValueRefHelper {
396    fn call_inner<'reg: 'rc, 'rc>(
397        &self,
398        h: &Helper<'rc>,
399        _: &'reg Handlebars<'_>,
400        _: &'rc Context,
401        _: &mut RenderContext<'reg, 'rc>,
402    ) -> Result<ScopedJson<'rc>, RenderError> {
403        let arg = h
404            .param(0)
405            .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
406        let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
407            RenderErrorReason::ParamTypeMismatchForName(
408                "amqp_value",
409                "AMQPValue".to_string(),
410                "value".to_string(),
411            )
412        })?;
413        let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
414        Ok(ScopedJson::Derived(value))
415    }
416}
417
418fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
419    match val {
420        AMQPValue::Boolean(v) => serde_json::to_value(v),
421        AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
422        AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
423        AMQPValue::ShortInt(v) => serde_json::to_value(v),
424        AMQPValue::ShortUInt(v) => serde_json::to_value(v),
425        AMQPValue::LongInt(v) => serde_json::to_value(v),
426        AMQPValue::LongUInt(v) => serde_json::to_value(v),
427        AMQPValue::LongLongInt(v) => serde_json::to_value(v),
428        AMQPValue::Float(v) => serde_json::to_value(v),
429        AMQPValue::Double(v) => serde_json::to_value(v),
430        AMQPValue::DecimalValue(v) => serde_json::to_value(v),
431        AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{}\"", v)),
432        AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{}\"", v)),
433        AMQPValue::FieldArray(v) => serde_json::to_value(v),
434        AMQPValue::Timestamp(v) => serde_json::to_value(v),
435        AMQPValue::FieldTable(v) => serde_json::to_value(v),
436        AMQPValue::ByteArray(v) => serde_json::to_value(v),
437        AMQPValue::Void => Ok(JsonValue::Null),
438    }
439}
440
441#[cfg(test)]
442mod test {
443    use super::*;
444
445    use std::collections::BTreeMap;
446
447    pub const TEMPLATE: &str = r#"
448{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
449{{protocol.copyright}}
450port {{protocol.port}}
451{{#each protocol.domains ~}}
452{{@key}}: {{this}}
453{{/each ~}}
454{{#each protocol.constants as |constant| ~}}
455{{constant.name}} = {{constant.value}}
456{{/each ~}}
457{{#each protocol.classes as |class| ~}}
458{{class.id}} - {{class.name}}
459{{#each class.properties as |property| ~}}
460{{property.name}}: {{property.type}}
461{{/each ~}}
462{{#each class.methods as |method| ~}}
463{{method.id}} - {{method.name}}
464synchronous: {{method.synchronous}}
465{{#each_argument method.arguments as |argument| ~}}
466{{#if @argument_is_value ~}}
467{{argument.name}}({{argument.domain}}): {{argument.type}}
468{{else}}
469{{#each argument.flags as |flag| ~}}
470{{flag.name}}: {{flag.default_value}}
471{{/each ~}}
472{{/if ~}}
473{{/each_argument ~}}
474{{/each ~}}
475{{/each ~}}
476"#;
477
478    fn specs() -> AMQProtocolDefinition {
479        let mut domains = BTreeMap::default();
480        domains.insert("domain1".to_string(), AMQPType::LongString);
481        AMQProtocolDefinition {
482            name: "AMQP".to_string(),
483            major_version: 0,
484            minor_version: 9,
485            revision: 1,
486            port: 5672,
487            copyright: "Copyright 1\nCopyright 2".to_string(),
488            domains,
489            constants: vec![AMQPConstant {
490                name: "constant1".to_string(),
491                amqp_type: AMQPType::ShortUInt,
492                value: 128,
493            }],
494            soft_errors: Vec::default(),
495            hard_errors: Vec::default(),
496            classes: vec![AMQPClass {
497                id: 42,
498                methods: vec![AMQPMethod {
499                    id: 64,
500                    arguments: vec![
501                        AMQPArgument::Value(AMQPValueArgument {
502                            amqp_type: AMQPType::LongString,
503                            name: "argument1".to_string(),
504                            default_value: Some(AMQPValue::LongString("value1".into())),
505                            domain: Some("domain1".to_string()),
506                            force_default: false,
507                        }),
508                        AMQPArgument::Flags(AMQPFlagsArgument {
509                            ignore_flags: false,
510                            flags: vec![
511                                AMQPFlagArgument {
512                                    name: "flag1".to_string(),
513                                    default_value: true,
514                                    force_default: false,
515                                },
516                                AMQPFlagArgument {
517                                    name: "flag2".to_string(),
518                                    default_value: false,
519                                    force_default: false,
520                                },
521                            ],
522                        }),
523                    ],
524                    name: "method1".to_string(),
525                    synchronous: true,
526                    content: false,
527                    metadata: Value::default(),
528                    is_reply: false,
529                    ignore_args: false,
530                    c2s: true,
531                    s2c: true,
532                }],
533                name: "class1".to_string(),
534                properties: vec![AMQPProperty {
535                    amqp_type: AMQPType::LongString,
536                    name: "property1".to_string(),
537                }],
538                metadata: Value::default(),
539            }],
540        }
541    }
542
543    #[test]
544    fn main_template() {
545        let mut data = HashMap::new();
546        let mut codegen = CodeGenerator::default().register_amqp_helpers();
547        data.insert("protocol".to_string(), specs());
548        assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
549        assert_eq!(
550            codegen.render("main", &data).unwrap(),
551            r#"
552AMQP - 0.9.1
553Copyright 1
554Copyright 2
555port 5672
556domain1: LongString
557constant1= 128
55842- class1
559property1: LongString
56064- method1
561synchronous: true
562argument1(domain1): LongString
563flag1: true
564flag2: false
565"#
566        );
567    }
568}