1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5 self, to_json, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6 JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7};
8use serde_json::{self, Value};
9
10use std::{
11 collections::HashMap,
12 fs::{self, File},
13 io::Write,
14 path::Path,
15};
16
17pub type CodeGenerator<'a> = Handlebars<'a>;
19
20pub trait HandlebarsAMQPExtension {
22 fn register_amqp_helpers(self) -> Self;
24 fn simple_codegen(
27 out_dir: &str,
28 target: &str,
29 template_name: &str,
30 template: &str,
31 var_name: &str,
32 ) {
33 Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
34 }
35 fn simple_codegen_with_data(
39 out_dir: &str,
40 target: &str,
41 template_name: &str,
42 template: &str,
43 var_name: &str,
44 data: Option<Value>,
45 );
46}
47
48impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
49 fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
50 self.register_escape_fn(handlebars::no_escape);
51 self.register_helper("camel", Box::new(CamelHelper));
52 self.register_helper("snake", Box::new(SnakeHelper));
53 self.register_helper("snake_type", Box::new(SnakeTypeHelper));
54 self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
55 self.register_helper("include_more", Box::new(IncludeMoreHelper));
56 self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
57 self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
58 self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
59 self.register_helper("each_argument", Box::new(EachArgumentHelper));
60 self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
61 self
62 }
63
64 fn simple_codegen_with_data(
65 out_dir: &str,
66 target: &str,
67 template_name: &str,
68 template: &str,
69 var_name: &str,
70 metadata: Option<Value>,
71 ) {
72 let dest_path = Path::new(out_dir).join(format!("{}.rs", target));
73 let mut f = File::create(&dest_path)
74 .unwrap_or_else(|err| panic!("Failed to create {:?}: {}", dest_path, err));
75 let specs = AMQProtocolDefinition::load(metadata);
76 let mut codegen = CodeGenerator::default().register_amqp_helpers();
77 let mut data = HashMap::new();
78
79 codegen.set_strict_mode(true);
80 codegen
81 .register_template_string(template_name, template)
82 .unwrap_or_else(|e| panic!("Failed to register {} template: {}", template_name, e));
83 data.insert(
84 var_name.to_string(),
85 serde_json::to_value(specs)
86 .unwrap_or_else(|e| panic!("Failed to serialize specs: {}", e)),
87 );
88
89 writeln!(
90 f,
91 "{}",
92 codegen
93 .render(template_name, &data)
94 .unwrap_or_else(|err| panic!(
95 "Failed to render {} template: {}",
96 template_name, err
97 ))
98 )
99 .unwrap_or_else(|e| panic!("Failed to generate {}.rs: {}", target, e));
100 }
101}
102
103pub struct CamelHelper;
105impl HelperDef for CamelHelper {
106 fn call<'reg: 'rc, 'rc>(
107 &self,
108 h: &Helper<'rc>,
109 _: &'reg Handlebars<'_>,
110 _: &'rc Context,
111 _: &mut RenderContext<'reg, 'rc>,
112 out: &mut dyn Output,
113 ) -> HelperResult {
114 let value = h
115 .param(0)
116 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
117 let param = value.value().as_str().ok_or_else(|| {
118 RenderErrorReason::ParamTypeMismatchForName(
119 "camel",
120 "string".to_string(),
121 "string".to_string(),
122 )
123 })?;
124 out.write(&camel_case(param))?;
125 Ok(())
126 }
127}
128
129pub struct SnakeHelper;
131impl HelperDef for SnakeHelper {
132 fn call<'reg: 'rc, 'rc>(
133 &self,
134 h: &Helper<'rc>,
135 _: &'reg Handlebars<'_>,
136 _: &'rc Context,
137 _: &mut RenderContext<'reg, 'rc>,
138 out: &mut dyn Output,
139 ) -> HelperResult {
140 let value = h
141 .param(0)
142 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
143 let raw = h
144 .param(1)
145 .and_then(|raw| raw.value().as_bool())
146 .unwrap_or(true);
147 let param = value.value().as_str().ok_or_else(|| {
148 RenderErrorReason::ParamTypeMismatchForName(
149 "snake",
150 "string".to_string(),
151 "string".to_string(),
152 )
153 })?;
154 out.write(&snake_case(param, raw))?;
155 Ok(())
156 }
157}
158
159pub struct SnakeTypeHelper;
161impl HelperDef for SnakeTypeHelper {
162 fn call<'reg: 'rc, 'rc>(
163 &self,
164 h: &Helper<'rc>,
165 _: &'reg Handlebars<'_>,
166 _: &'rc Context,
167 _: &mut RenderContext<'reg, 'rc>,
168 out: &mut dyn Output,
169 ) -> HelperResult {
170 let value = h
171 .param(0)
172 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
173 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
174 RenderErrorReason::ParamTypeMismatchForName(
175 "snake_type",
176 "AMQPType".to_string(),
177 "string".to_string(),
178 )
179 })?;
180 out.write(&snake_case(¶m.to_string(), true))?;
181 Ok(())
182 }
183}
184
185pub struct SanitizeNameHelper;
187impl HelperDef for SanitizeNameHelper {
188 fn call<'reg: 'rc, 'rc>(
189 &self,
190 h: &Helper<'rc>,
191 _: &'reg Handlebars<'_>,
192 _: &'rc Context,
193 _: &mut RenderContext<'reg, 'rc>,
194 out: &mut dyn Output,
195 ) -> HelperResult {
196 let value = h
197 .param(0)
198 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
199 let param = value.value().as_str().ok_or_else(|| {
200 RenderErrorReason::ParamTypeMismatchForName(
201 "sanitize_name",
202 "string".to_string(),
203 "string".to_string(),
204 )
205 })?;
206 out.write(¶m.replace('-', "_"))?;
207 Ok(())
208 }
209}
210
211pub struct IncludeMoreHelper;
213impl HelperDef for IncludeMoreHelper {
214 fn call<'reg: 'rc, 'rc>(
215 &self,
216 h: &Helper<'rc>,
217 _: &'reg Handlebars<'_>,
218 _: &'rc Context,
219 _: &mut RenderContext<'reg, 'rc>,
220 out: &mut dyn Output,
221 ) -> HelperResult {
222 let amqp_class = h
223 .param(0)
224 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
225 let amqp_method = h
226 .param(1)
227 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
228 let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
229 RenderErrorReason::ParamTypeMismatchForName(
230 "include_more",
231 "string".to_string(),
232 "class".to_string(),
233 )
234 })?;
235 let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
236 RenderErrorReason::ParamTypeMismatchForName(
237 "include_more",
238 "string".to_string(),
239 "method".to_string(),
240 )
241 })?;
242 if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
243 let include = Path::new(&cargo_manifest_dir)
244 .join("templates")
245 .join("includes")
246 .join(amqp_class)
247 .join(format!("{}.rs", amqp_method));
248 if let Ok(include) = fs::read_to_string(include) {
249 out.write(&include)?;
250 }
251 }
252 Ok(())
253 }
254}
255
256pub struct PassByRefHelper;
258impl HelperDef for PassByRefHelper {
259 fn call_inner<'reg: 'rc, 'rc>(
260 &self,
261 h: &Helper<'rc>,
262 _: &'reg Handlebars<'_>,
263 _: &'rc Context,
264 _: &mut RenderContext<'reg, 'rc>,
265 ) -> Result<ScopedJson<'rc>, RenderError> {
266 let value = h
267 .param(0)
268 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
269 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
270 RenderErrorReason::ParamTypeMismatchForName(
271 "pass_by_ref",
272 "AMQPType".to_string(),
273 "string".to_string(),
274 )
275 })?;
276 let pass_by_ref = matches!(
277 param,
278 AMQPType::ShortString
279 | AMQPType::LongString
280 | AMQPType::FieldArray
281 | AMQPType::FieldTable
282 | AMQPType::ByteArray
283 );
284 Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
285 }
286}
287
288pub struct UseStrRefHelper;
290impl HelperDef for UseStrRefHelper {
291 fn call_inner<'reg: 'rc, 'rc>(
292 &self,
293 h: &Helper<'rc>,
294 _: &'reg Handlebars<'_>,
295 _: &'rc Context,
296 _: &mut RenderContext<'reg, 'rc>,
297 ) -> Result<ScopedJson<'rc>, RenderError> {
298 let value = h
299 .param(0)
300 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
301 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
302 let use_str_ref = matches!(
303 param,
304 Some(AMQPType::ShortString) | Some(AMQPType::LongString)
305 );
306 Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
307 }
308}
309
310pub struct UseBytesRefHelper;
312impl HelperDef for UseBytesRefHelper {
313 fn call_inner<'reg: 'rc, 'rc>(
314 &self,
315 h: &Helper<'rc>,
316 _: &'reg Handlebars<'_>,
317 _: &'rc Context,
318 _: &mut RenderContext<'reg, 'rc>,
319 ) -> Result<ScopedJson<'rc>, RenderError> {
320 let value = h
321 .param(0)
322 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
323 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
324 let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
325 Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
326 }
327}
328
329pub struct EachArgumentHelper;
331impl HelperDef for EachArgumentHelper {
332 fn call<'reg: 'rc, 'rc>(
333 &self,
334 h: &Helper<'rc>,
335 r: &'reg Handlebars<'_>,
336 ctx: &'rc Context,
337 rc: &mut RenderContext<'reg, 'rc>,
338 out: &mut dyn Output,
339 ) -> HelperResult {
340 let value = h
341 .param(0)
342 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
343
344 if let Some(t) = h.template() {
345 let mut block_context = BlockContext::new();
346 if let Some(path) = value.context_path() {
347 *block_context.base_path_mut() = path.to_vec();
348 }
349 rc.push_block(block_context);
350 let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
351 .map_err(|_| {
352 RenderErrorReason::ParamTypeMismatchForName(
353 "each_argument",
354 "Vec<AMQPArgument>".to_string(),
355 "arguments".to_string(),
356 )
357 })?;
358 let len = arguments.len();
359 let array_path = value.context_path();
360 for (index, argument) in arguments.iter().enumerate() {
361 if let Some(ref mut block) = rc.block_mut() {
362 let (path, is_value) = match *argument {
363 AMQPArgument::Value(_) => ("Value".to_owned(), true),
364 AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
365 };
366 block.set_local_var("index", to_json(index));
367 block.set_local_var("last", to_json(index == len - 1));
368 block.set_local_var("argument_is_value", to_json(is_value));
369 if let Some(p) = array_path {
370 if index == 0 {
371 let mut path = Vec::with_capacity(p.len() + 1);
372 path.extend_from_slice(p);
373 path.push(index.to_string());
374 *block.base_path_mut() = path;
375 } else if let Some(ptr) = block.base_path_mut().last_mut() {
376 *ptr = index.to_string();
377 }
378 }
379 if let Some(block_param) = h.block_param() {
380 let mut params = BlockParams::new();
381 params.add_path(block_param, vec![path])?;
382 block.set_block_params(params);
383 }
384 }
385 t.render(r, ctx, rc, out)?;
386 }
387 rc.pop_block();
388 }
389 Ok(())
390 }
391}
392
393pub struct AMQPValueRefHelper;
395impl HelperDef for AMQPValueRefHelper {
396 fn call_inner<'reg: 'rc, 'rc>(
397 &self,
398 h: &Helper<'rc>,
399 _: &'reg Handlebars<'_>,
400 _: &'rc Context,
401 _: &mut RenderContext<'reg, 'rc>,
402 ) -> Result<ScopedJson<'rc>, RenderError> {
403 let arg = h
404 .param(0)
405 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
406 let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
407 RenderErrorReason::ParamTypeMismatchForName(
408 "amqp_value",
409 "AMQPValue".to_string(),
410 "value".to_string(),
411 )
412 })?;
413 let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
414 Ok(ScopedJson::Derived(value))
415 }
416}
417
418fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
419 match val {
420 AMQPValue::Boolean(v) => serde_json::to_value(v),
421 AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
422 AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
423 AMQPValue::ShortInt(v) => serde_json::to_value(v),
424 AMQPValue::ShortUInt(v) => serde_json::to_value(v),
425 AMQPValue::LongInt(v) => serde_json::to_value(v),
426 AMQPValue::LongUInt(v) => serde_json::to_value(v),
427 AMQPValue::LongLongInt(v) => serde_json::to_value(v),
428 AMQPValue::Float(v) => serde_json::to_value(v),
429 AMQPValue::Double(v) => serde_json::to_value(v),
430 AMQPValue::DecimalValue(v) => serde_json::to_value(v),
431 AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{}\"", v)),
432 AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{}\"", v)),
433 AMQPValue::FieldArray(v) => serde_json::to_value(v),
434 AMQPValue::Timestamp(v) => serde_json::to_value(v),
435 AMQPValue::FieldTable(v) => serde_json::to_value(v),
436 AMQPValue::ByteArray(v) => serde_json::to_value(v),
437 AMQPValue::Void => Ok(JsonValue::Null),
438 }
439}
440
441#[cfg(test)]
442mod test {
443 use super::*;
444
445 use std::collections::BTreeMap;
446
447 pub const TEMPLATE: &str = r#"
448{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
449{{protocol.copyright}}
450port {{protocol.port}}
451{{#each protocol.domains ~}}
452{{@key}}: {{this}}
453{{/each ~}}
454{{#each protocol.constants as |constant| ~}}
455{{constant.name}} = {{constant.value}}
456{{/each ~}}
457{{#each protocol.classes as |class| ~}}
458{{class.id}} - {{class.name}}
459{{#each class.properties as |property| ~}}
460{{property.name}}: {{property.type}}
461{{/each ~}}
462{{#each class.methods as |method| ~}}
463{{method.id}} - {{method.name}}
464synchronous: {{method.synchronous}}
465{{#each_argument method.arguments as |argument| ~}}
466{{#if @argument_is_value ~}}
467{{argument.name}}({{argument.domain}}): {{argument.type}}
468{{else}}
469{{#each argument.flags as |flag| ~}}
470{{flag.name}}: {{flag.default_value}}
471{{/each ~}}
472{{/if ~}}
473{{/each_argument ~}}
474{{/each ~}}
475{{/each ~}}
476"#;
477
478 fn specs() -> AMQProtocolDefinition {
479 let mut domains = BTreeMap::default();
480 domains.insert("domain1".to_string(), AMQPType::LongString);
481 AMQProtocolDefinition {
482 name: "AMQP".to_string(),
483 major_version: 0,
484 minor_version: 9,
485 revision: 1,
486 port: 5672,
487 copyright: "Copyright 1\nCopyright 2".to_string(),
488 domains,
489 constants: vec![AMQPConstant {
490 name: "constant1".to_string(),
491 amqp_type: AMQPType::ShortUInt,
492 value: 128,
493 }],
494 soft_errors: Vec::default(),
495 hard_errors: Vec::default(),
496 classes: vec![AMQPClass {
497 id: 42,
498 methods: vec![AMQPMethod {
499 id: 64,
500 arguments: vec![
501 AMQPArgument::Value(AMQPValueArgument {
502 amqp_type: AMQPType::LongString,
503 name: "argument1".to_string(),
504 default_value: Some(AMQPValue::LongString("value1".into())),
505 domain: Some("domain1".to_string()),
506 force_default: false,
507 }),
508 AMQPArgument::Flags(AMQPFlagsArgument {
509 ignore_flags: false,
510 flags: vec![
511 AMQPFlagArgument {
512 name: "flag1".to_string(),
513 default_value: true,
514 force_default: false,
515 },
516 AMQPFlagArgument {
517 name: "flag2".to_string(),
518 default_value: false,
519 force_default: false,
520 },
521 ],
522 }),
523 ],
524 name: "method1".to_string(),
525 synchronous: true,
526 content: false,
527 metadata: Value::default(),
528 is_reply: false,
529 ignore_args: false,
530 c2s: true,
531 s2c: true,
532 }],
533 name: "class1".to_string(),
534 properties: vec![AMQPProperty {
535 amqp_type: AMQPType::LongString,
536 name: "property1".to_string(),
537 }],
538 metadata: Value::default(),
539 }],
540 }
541 }
542
543 #[test]
544 fn main_template() {
545 let mut data = HashMap::new();
546 let mut codegen = CodeGenerator::default().register_amqp_helpers();
547 data.insert("protocol".to_string(), specs());
548 assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
549 assert_eq!(
550 codegen.render("main", &data).unwrap(),
551 r#"
552AMQP - 0.9.1
553Copyright 1
554Copyright 2
555port 5672
556domain1: LongString
557constant1= 128
55842- class1
559property1: LongString
56064- method1
561synchronous: true
562argument1(domain1): LongString
563flag1: true
564flag2: false
565"#
566 );
567 }
568}