1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5 self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6 JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7 to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12 collections::HashMap,
13 fs::{self, File},
14 io::Write,
15 path::Path,
16};
17
18pub type CodeGenerator<'a> = Handlebars<'a>;
20
21pub trait HandlebarsAMQPExtension {
23 fn register_amqp_helpers(self) -> Self;
25 fn simple_codegen(
28 out_dir: &str,
29 target: &str,
30 template_name: &str,
31 template: &str,
32 var_name: &str,
33 ) {
34 Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35 }
36 fn simple_codegen_with_data(
40 out_dir: &str,
41 target: &str,
42 template_name: &str,
43 template: &str,
44 var_name: &str,
45 data: Option<Value>,
46 );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50 fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51 self.register_escape_fn(handlebars::no_escape);
52 self.register_helper("camel", Box::new(CamelHelper));
53 self.register_helper("snake", Box::new(SnakeHelper));
54 self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55 self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56 self.register_helper("include_more", Box::new(IncludeMoreHelper));
57 self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58 self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59 self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60 self.register_helper("each_argument", Box::new(EachArgumentHelper));
61 self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62 self
63 }
64
65 fn simple_codegen_with_data(
66 out_dir: &str,
67 target: &str,
68 template_name: &str,
69 template: &str,
70 var_name: &str,
71 metadata: Option<Value>,
72 ) {
73 let dest_path = Path::new(out_dir).join(format!("{}.rs", target));
74 let mut f = File::create(&dest_path)
75 .unwrap_or_else(|err| panic!("Failed to create {:?}: {}", dest_path, err));
76 let specs = AMQProtocolDefinition::load(metadata);
77 let mut codegen = CodeGenerator::default().register_amqp_helpers();
78 let mut data = HashMap::new();
79
80 codegen.set_strict_mode(true);
81 codegen
82 .register_template_string(template_name, template)
83 .unwrap_or_else(|e| panic!("Failed to register {} template: {}", template_name, e));
84 data.insert(
85 var_name.to_string(),
86 serde_json::to_value(specs)
87 .unwrap_or_else(|e| panic!("Failed to serialize specs: {}", e)),
88 );
89
90 writeln!(
91 f,
92 "{}",
93 codegen
94 .render(template_name, &data)
95 .unwrap_or_else(|err| panic!(
96 "Failed to render {} template: {}",
97 template_name, err
98 ))
99 )
100 .unwrap_or_else(|e| panic!("Failed to generate {}.rs: {}", target, e));
101 }
102}
103
104pub struct CamelHelper;
106impl HelperDef for CamelHelper {
107 fn call<'reg: 'rc, 'rc>(
108 &self,
109 h: &Helper<'rc>,
110 _: &'reg Handlebars<'_>,
111 _: &'rc Context,
112 _: &mut RenderContext<'reg, 'rc>,
113 out: &mut dyn Output,
114 ) -> HelperResult {
115 let value = h
116 .param(0)
117 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
118 let param = value.value().as_str().ok_or_else(|| {
119 RenderErrorReason::ParamTypeMismatchForName(
120 "camel",
121 "string".to_string(),
122 "string".to_string(),
123 )
124 })?;
125 out.write(&camel_case(param))?;
126 Ok(())
127 }
128}
129
130pub struct SnakeHelper;
132impl HelperDef for SnakeHelper {
133 fn call<'reg: 'rc, 'rc>(
134 &self,
135 h: &Helper<'rc>,
136 _: &'reg Handlebars<'_>,
137 _: &'rc Context,
138 _: &mut RenderContext<'reg, 'rc>,
139 out: &mut dyn Output,
140 ) -> HelperResult {
141 let value = h
142 .param(0)
143 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
144 let raw = h
145 .param(1)
146 .and_then(|raw| raw.value().as_bool())
147 .unwrap_or(true);
148 let param = value.value().as_str().ok_or_else(|| {
149 RenderErrorReason::ParamTypeMismatchForName(
150 "snake",
151 "string".to_string(),
152 "string".to_string(),
153 )
154 })?;
155 out.write(&snake_case(param, raw))?;
156 Ok(())
157 }
158}
159
160pub struct SnakeTypeHelper;
162impl HelperDef for SnakeTypeHelper {
163 fn call<'reg: 'rc, 'rc>(
164 &self,
165 h: &Helper<'rc>,
166 _: &'reg Handlebars<'_>,
167 _: &'rc Context,
168 _: &mut RenderContext<'reg, 'rc>,
169 out: &mut dyn Output,
170 ) -> HelperResult {
171 let value = h
172 .param(0)
173 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
174 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
175 RenderErrorReason::ParamTypeMismatchForName(
176 "snake_type",
177 "AMQPType".to_string(),
178 "string".to_string(),
179 )
180 })?;
181 out.write(&snake_case(¶m.to_string(), true))?;
182 Ok(())
183 }
184}
185
186pub struct SanitizeNameHelper;
188impl HelperDef for SanitizeNameHelper {
189 fn call<'reg: 'rc, 'rc>(
190 &self,
191 h: &Helper<'rc>,
192 _: &'reg Handlebars<'_>,
193 _: &'rc Context,
194 _: &mut RenderContext<'reg, 'rc>,
195 out: &mut dyn Output,
196 ) -> HelperResult {
197 let value = h
198 .param(0)
199 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
200 let param = value.value().as_str().ok_or_else(|| {
201 RenderErrorReason::ParamTypeMismatchForName(
202 "sanitize_name",
203 "string".to_string(),
204 "string".to_string(),
205 )
206 })?;
207 out.write(¶m.replace('-', "_"))?;
208 Ok(())
209 }
210}
211
212pub struct IncludeMoreHelper;
214impl HelperDef for IncludeMoreHelper {
215 fn call<'reg: 'rc, 'rc>(
216 &self,
217 h: &Helper<'rc>,
218 _: &'reg Handlebars<'_>,
219 _: &'rc Context,
220 _: &mut RenderContext<'reg, 'rc>,
221 out: &mut dyn Output,
222 ) -> HelperResult {
223 let amqp_class = h
224 .param(0)
225 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
226 let amqp_method = h
227 .param(1)
228 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
229 let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
230 RenderErrorReason::ParamTypeMismatchForName(
231 "include_more",
232 "string".to_string(),
233 "class".to_string(),
234 )
235 })?;
236 let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
237 RenderErrorReason::ParamTypeMismatchForName(
238 "include_more",
239 "string".to_string(),
240 "method".to_string(),
241 )
242 })?;
243 if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
244 let include = Path::new(&cargo_manifest_dir)
245 .join("templates")
246 .join("includes")
247 .join(amqp_class)
248 .join(format!("{}.rs", amqp_method));
249 if let Ok(include) = fs::read_to_string(include) {
250 out.write(&include)?;
251 }
252 }
253 Ok(())
254 }
255}
256
257pub struct PassByRefHelper;
259impl HelperDef for PassByRefHelper {
260 fn call_inner<'reg: 'rc, 'rc>(
261 &self,
262 h: &Helper<'rc>,
263 _: &'reg Handlebars<'_>,
264 _: &'rc Context,
265 _: &mut RenderContext<'reg, 'rc>,
266 ) -> Result<ScopedJson<'rc>, RenderError> {
267 let value = h
268 .param(0)
269 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
270 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
271 RenderErrorReason::ParamTypeMismatchForName(
272 "pass_by_ref",
273 "AMQPType".to_string(),
274 "string".to_string(),
275 )
276 })?;
277 let pass_by_ref = matches!(
278 param,
279 AMQPType::ShortString
280 | AMQPType::LongString
281 | AMQPType::FieldArray
282 | AMQPType::FieldTable
283 | AMQPType::ByteArray
284 );
285 Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
286 }
287}
288
289pub struct UseStrRefHelper;
291impl HelperDef for UseStrRefHelper {
292 fn call_inner<'reg: 'rc, 'rc>(
293 &self,
294 h: &Helper<'rc>,
295 _: &'reg Handlebars<'_>,
296 _: &'rc Context,
297 _: &mut RenderContext<'reg, 'rc>,
298 ) -> Result<ScopedJson<'rc>, RenderError> {
299 let value = h
300 .param(0)
301 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
302 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
303 let use_str_ref = matches!(
304 param,
305 Some(AMQPType::ShortString) | Some(AMQPType::LongString)
306 );
307 Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
308 }
309}
310
311pub struct UseBytesRefHelper;
313impl HelperDef for UseBytesRefHelper {
314 fn call_inner<'reg: 'rc, 'rc>(
315 &self,
316 h: &Helper<'rc>,
317 _: &'reg Handlebars<'_>,
318 _: &'rc Context,
319 _: &mut RenderContext<'reg, 'rc>,
320 ) -> Result<ScopedJson<'rc>, RenderError> {
321 let value = h
322 .param(0)
323 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
324 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
325 let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
326 Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
327 }
328}
329
330pub struct EachArgumentHelper;
332impl HelperDef for EachArgumentHelper {
333 fn call<'reg: 'rc, 'rc>(
334 &self,
335 h: &Helper<'rc>,
336 r: &'reg Handlebars<'_>,
337 ctx: &'rc Context,
338 rc: &mut RenderContext<'reg, 'rc>,
339 out: &mut dyn Output,
340 ) -> HelperResult {
341 let value = h
342 .param(0)
343 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
344
345 if let Some(t) = h.template() {
346 let mut block_context = BlockContext::new();
347 if let Some(path) = value.context_path() {
348 *block_context.base_path_mut() = path.to_vec();
349 }
350 rc.push_block(block_context);
351 let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
352 .map_err(|_| {
353 RenderErrorReason::ParamTypeMismatchForName(
354 "each_argument",
355 "Vec<AMQPArgument>".to_string(),
356 "arguments".to_string(),
357 )
358 })?;
359 let len = arguments.len();
360 let array_path = value.context_path();
361 for (index, argument) in arguments.iter().enumerate() {
362 if let Some(ref mut block) = rc.block_mut() {
363 let (path, is_value) = match *argument {
364 AMQPArgument::Value(_) => ("Value".to_owned(), true),
365 AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
366 };
367 block.set_local_var("index", to_json(index));
368 block.set_local_var("last", to_json(index == len - 1));
369 block.set_local_var("argument_is_value", to_json(is_value));
370 if let Some(p) = array_path {
371 if index == 0 {
372 let mut path = Vec::with_capacity(p.len() + 1);
373 path.extend_from_slice(p);
374 path.push(index.to_string());
375 *block.base_path_mut() = path;
376 } else if let Some(ptr) = block.base_path_mut().last_mut() {
377 *ptr = index.to_string();
378 }
379 }
380 if let Some(block_param) = h.block_param() {
381 let mut params = BlockParams::new();
382 params.add_path(block_param, vec![path])?;
383 block.set_block_params(params);
384 }
385 }
386 t.render(r, ctx, rc, out)?;
387 }
388 rc.pop_block();
389 }
390 Ok(())
391 }
392}
393
394pub struct AMQPValueRefHelper;
396impl HelperDef for AMQPValueRefHelper {
397 fn call_inner<'reg: 'rc, 'rc>(
398 &self,
399 h: &Helper<'rc>,
400 _: &'reg Handlebars<'_>,
401 _: &'rc Context,
402 _: &mut RenderContext<'reg, 'rc>,
403 ) -> Result<ScopedJson<'rc>, RenderError> {
404 let arg = h
405 .param(0)
406 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
407 let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
408 RenderErrorReason::ParamTypeMismatchForName(
409 "amqp_value",
410 "AMQPValue".to_string(),
411 "value".to_string(),
412 )
413 })?;
414 let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
415 Ok(ScopedJson::Derived(value))
416 }
417}
418
419fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
420 match val {
421 AMQPValue::Boolean(v) => serde_json::to_value(v),
422 AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
423 AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
424 AMQPValue::ShortInt(v) => serde_json::to_value(v),
425 AMQPValue::ShortUInt(v) => serde_json::to_value(v),
426 AMQPValue::LongInt(v) => serde_json::to_value(v),
427 AMQPValue::LongUInt(v) => serde_json::to_value(v),
428 AMQPValue::LongLongInt(v) => serde_json::to_value(v),
429 AMQPValue::Float(v) => serde_json::to_value(v),
430 AMQPValue::Double(v) => serde_json::to_value(v),
431 AMQPValue::DecimalValue(v) => serde_json::to_value(v),
432 AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{}\"", v)),
433 AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{}\"", v)),
434 AMQPValue::FieldArray(v) => serde_json::to_value(v),
435 AMQPValue::Timestamp(v) => serde_json::to_value(v),
436 AMQPValue::FieldTable(v) => serde_json::to_value(v),
437 AMQPValue::ByteArray(v) => serde_json::to_value(v),
438 AMQPValue::Void => Ok(JsonValue::Null),
439 }
440}
441
442#[cfg(test)]
443mod test {
444 use super::*;
445
446 use std::collections::BTreeMap;
447
448 pub const TEMPLATE: &str = r#"
449{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
450{{protocol.copyright}}
451port {{protocol.port}}
452{{#each protocol.domains ~}}
453{{@key}}: {{this}}
454{{/each ~}}
455{{#each protocol.constants as |constant| ~}}
456{{constant.name}} = {{constant.value}}
457{{/each ~}}
458{{#each protocol.classes as |class| ~}}
459{{class.id}} - {{class.name}}
460{{#each class.properties as |property| ~}}
461{{property.name}}: {{property.type}}
462{{/each ~}}
463{{#each class.methods as |method| ~}}
464{{method.id}} - {{method.name}}
465synchronous: {{method.synchronous}}
466{{#each_argument method.arguments as |argument| ~}}
467{{#if @argument_is_value ~}}
468{{argument.name}}({{argument.domain}}): {{argument.type}}
469{{else}}
470{{#each argument.flags as |flag| ~}}
471{{flag.name}}: {{flag.default_value}}
472{{/each ~}}
473{{/if ~}}
474{{/each_argument ~}}
475{{/each ~}}
476{{/each ~}}
477"#;
478
479 fn specs() -> AMQProtocolDefinition {
480 let mut domains = BTreeMap::default();
481 domains.insert("domain1".to_string(), AMQPType::LongString);
482 AMQProtocolDefinition {
483 name: "AMQP".to_string(),
484 major_version: 0,
485 minor_version: 9,
486 revision: 1,
487 port: 5672,
488 copyright: "Copyright 1\nCopyright 2".to_string(),
489 domains,
490 constants: vec![AMQPConstant {
491 name: "constant1".to_string(),
492 amqp_type: AMQPType::ShortUInt,
493 value: 128,
494 }],
495 soft_errors: Vec::default(),
496 hard_errors: Vec::default(),
497 classes: vec![AMQPClass {
498 id: 42,
499 methods: vec![AMQPMethod {
500 id: 64,
501 arguments: vec![
502 AMQPArgument::Value(AMQPValueArgument {
503 amqp_type: AMQPType::LongString,
504 name: "argument1".to_string(),
505 default_value: Some(AMQPValue::LongString("value1".into())),
506 domain: Some("domain1".to_string()),
507 force_default: false,
508 }),
509 AMQPArgument::Flags(AMQPFlagsArgument {
510 ignore_flags: false,
511 flags: vec![
512 AMQPFlagArgument {
513 name: "flag1".to_string(),
514 default_value: true,
515 force_default: false,
516 },
517 AMQPFlagArgument {
518 name: "flag2".to_string(),
519 default_value: false,
520 force_default: false,
521 },
522 ],
523 }),
524 ],
525 name: "method1".to_string(),
526 synchronous: true,
527 content: false,
528 metadata: Value::default(),
529 is_reply: false,
530 ignore_args: false,
531 c2s: true,
532 s2c: true,
533 }],
534 name: "class1".to_string(),
535 properties: vec![AMQPProperty {
536 amqp_type: AMQPType::LongString,
537 name: "property1".to_string(),
538 }],
539 metadata: Value::default(),
540 }],
541 }
542 }
543
544 #[test]
545 fn main_template() {
546 let mut data = HashMap::new();
547 let mut codegen = CodeGenerator::default().register_amqp_helpers();
548 data.insert("protocol".to_string(), specs());
549 assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
550 assert_eq!(
551 codegen.render("main", &data).unwrap(),
552 r#"
553AMQP - 0.9.1
554Copyright 1
555Copyright 2
556port 5672
557domain1: LongString
558constant1= 128
55942- class1
560property1: LongString
56164- method1
562synchronous: true
563argument1(domain1): LongString
564flag1: true
565flag2: false
566"#
567 );
568 }
569}