1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5 self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6 JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7 to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12 collections::HashMap,
13 fs::{self, File},
14 io::Write,
15 path::Path,
16};
17
18pub type CodeGenerator<'a> = Handlebars<'a>;
20
21pub trait HandlebarsAMQPExtension {
23 fn register_amqp_helpers(self) -> Self;
25 fn simple_codegen(
28 out_dir: &str,
29 target: &str,
30 template_name: &str,
31 template: &str,
32 var_name: &str,
33 ) {
34 Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35 }
36 fn simple_codegen_with_data(
40 out_dir: &str,
41 target: &str,
42 template_name: &str,
43 template: &str,
44 var_name: &str,
45 data: Option<Value>,
46 );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50 fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51 self.register_escape_fn(handlebars::no_escape);
52 self.register_helper("camel", Box::new(CamelHelper));
53 self.register_helper("snake", Box::new(SnakeHelper));
54 self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55 self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56 self.register_helper("include_more", Box::new(IncludeMoreHelper));
57 self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58 self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59 self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60 self.register_helper("each_argument", Box::new(EachArgumentHelper));
61 self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62 self
63 }
64
65 fn simple_codegen_with_data(
66 out_dir: &str,
67 target: &str,
68 template_name: &str,
69 template: &str,
70 var_name: &str,
71 metadata: Option<Value>,
72 ) {
73 let dest_path = Path::new(out_dir).join(format!("{target}.rs"));
74 let mut f = File::create(&dest_path)
75 .unwrap_or_else(|err| panic!("Failed to create {dest_path:?}: {err}"));
76 let specs = AMQProtocolDefinition::load(metadata);
77 let mut codegen = CodeGenerator::default().register_amqp_helpers();
78 let mut data = HashMap::new();
79
80 codegen.set_strict_mode(true);
81 codegen
82 .register_template_string(template_name, template)
83 .unwrap_or_else(|e| panic!("Failed to register {template_name} template: {e}"));
84 data.insert(
85 var_name.to_string(),
86 serde_json::to_value(specs)
87 .unwrap_or_else(|e| panic!("Failed to serialize specs: {e}")),
88 );
89
90 writeln!(
91 f,
92 "{}",
93 codegen
94 .render(template_name, &data)
95 .unwrap_or_else(|err| panic!("Failed to render {template_name} template: {err}"))
96 )
97 .unwrap_or_else(|e| panic!("Failed to generate {target}.rs: {e}"));
98 }
99}
100
101#[derive(Debug)]
103pub struct CamelHelper;
104impl HelperDef for CamelHelper {
105 fn call<'reg: 'rc, 'rc>(
106 &self,
107 h: &Helper<'rc>,
108 _: &'reg Handlebars<'_>,
109 _: &'rc Context,
110 _: &mut RenderContext<'reg, 'rc>,
111 out: &mut dyn Output,
112 ) -> HelperResult {
113 let value = h
114 .param(0)
115 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
116 let param = value.value().as_str().ok_or_else(|| {
117 RenderErrorReason::ParamTypeMismatchForName(
118 "camel",
119 "string".to_string(),
120 "string".to_string(),
121 )
122 })?;
123 out.write(&camel_case(param))?;
124 Ok(())
125 }
126}
127
128#[derive(Debug)]
130pub struct SnakeHelper;
131impl HelperDef for SnakeHelper {
132 fn call<'reg: 'rc, 'rc>(
133 &self,
134 h: &Helper<'rc>,
135 _: &'reg Handlebars<'_>,
136 _: &'rc Context,
137 _: &mut RenderContext<'reg, 'rc>,
138 out: &mut dyn Output,
139 ) -> HelperResult {
140 let value = h
141 .param(0)
142 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
143 let raw = h
144 .param(1)
145 .and_then(|raw| raw.value().as_bool())
146 .unwrap_or(true);
147 let param = value.value().as_str().ok_or_else(|| {
148 RenderErrorReason::ParamTypeMismatchForName(
149 "snake",
150 "string".to_string(),
151 "string".to_string(),
152 )
153 })?;
154 out.write(&snake_case(param, raw))?;
155 Ok(())
156 }
157}
158
159#[derive(Debug)]
161pub struct SnakeTypeHelper;
162impl HelperDef for SnakeTypeHelper {
163 fn call<'reg: 'rc, 'rc>(
164 &self,
165 h: &Helper<'rc>,
166 _: &'reg Handlebars<'_>,
167 _: &'rc Context,
168 _: &mut RenderContext<'reg, 'rc>,
169 out: &mut dyn Output,
170 ) -> HelperResult {
171 let value = h
172 .param(0)
173 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
174 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
175 RenderErrorReason::ParamTypeMismatchForName(
176 "snake_type",
177 "AMQPType".to_string(),
178 "string".to_string(),
179 )
180 })?;
181 out.write(&snake_case(¶m.to_string(), true))?;
182 Ok(())
183 }
184}
185
186#[derive(Debug)]
188pub struct SanitizeNameHelper;
189impl HelperDef for SanitizeNameHelper {
190 fn call<'reg: 'rc, 'rc>(
191 &self,
192 h: &Helper<'rc>,
193 _: &'reg Handlebars<'_>,
194 _: &'rc Context,
195 _: &mut RenderContext<'reg, 'rc>,
196 out: &mut dyn Output,
197 ) -> HelperResult {
198 let value = h
199 .param(0)
200 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
201 let param = value.value().as_str().ok_or_else(|| {
202 RenderErrorReason::ParamTypeMismatchForName(
203 "sanitize_name",
204 "string".to_string(),
205 "string".to_string(),
206 )
207 })?;
208 out.write(¶m.replace('-', "_"))?;
209 Ok(())
210 }
211}
212
213#[derive(Debug)]
215pub struct IncludeMoreHelper;
216impl HelperDef for IncludeMoreHelper {
217 fn call<'reg: 'rc, 'rc>(
218 &self,
219 h: &Helper<'rc>,
220 _: &'reg Handlebars<'_>,
221 _: &'rc Context,
222 _: &mut RenderContext<'reg, 'rc>,
223 out: &mut dyn Output,
224 ) -> HelperResult {
225 let amqp_class = h
226 .param(0)
227 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
228 let amqp_method = h
229 .param(1)
230 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
231 let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
232 RenderErrorReason::ParamTypeMismatchForName(
233 "include_more",
234 "string".to_string(),
235 "class".to_string(),
236 )
237 })?;
238 let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
239 RenderErrorReason::ParamTypeMismatchForName(
240 "include_more",
241 "string".to_string(),
242 "method".to_string(),
243 )
244 })?;
245 if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
246 let include = Path::new(&cargo_manifest_dir)
247 .join("templates")
248 .join("includes")
249 .join(amqp_class)
250 .join(format!("{amqp_method}.rs"));
251 if let Ok(include) = fs::read_to_string(include) {
252 out.write(&include)?;
253 }
254 }
255 Ok(())
256 }
257}
258
259#[derive(Debug)]
261pub struct PassByRefHelper;
262impl HelperDef for PassByRefHelper {
263 fn call_inner<'reg: 'rc, 'rc>(
264 &self,
265 h: &Helper<'rc>,
266 _: &'reg Handlebars<'_>,
267 _: &'rc Context,
268 _: &mut RenderContext<'reg, 'rc>,
269 ) -> Result<ScopedJson<'rc>, RenderError> {
270 let value = h
271 .param(0)
272 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
273 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
274 RenderErrorReason::ParamTypeMismatchForName(
275 "pass_by_ref",
276 "AMQPType".to_string(),
277 "string".to_string(),
278 )
279 })?;
280 let pass_by_ref = matches!(
281 param,
282 AMQPType::ShortString
283 | AMQPType::LongString
284 | AMQPType::FieldArray
285 | AMQPType::FieldTable
286 | AMQPType::ByteArray
287 );
288 Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
289 }
290}
291
292#[derive(Debug)]
294pub struct UseStrRefHelper;
295impl HelperDef for UseStrRefHelper {
296 fn call_inner<'reg: 'rc, 'rc>(
297 &self,
298 h: &Helper<'rc>,
299 _: &'reg Handlebars<'_>,
300 _: &'rc Context,
301 _: &mut RenderContext<'reg, 'rc>,
302 ) -> Result<ScopedJson<'rc>, RenderError> {
303 let value = h
304 .param(0)
305 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
306 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
307 let use_str_ref = matches!(
308 param,
309 Some(AMQPType::ShortString) | Some(AMQPType::LongString)
310 );
311 Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
312 }
313}
314
315#[derive(Debug)]
317pub struct UseBytesRefHelper;
318impl HelperDef for UseBytesRefHelper {
319 fn call_inner<'reg: 'rc, 'rc>(
320 &self,
321 h: &Helper<'rc>,
322 _: &'reg Handlebars<'_>,
323 _: &'rc Context,
324 _: &mut RenderContext<'reg, 'rc>,
325 ) -> Result<ScopedJson<'rc>, RenderError> {
326 let value = h
327 .param(0)
328 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
329 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
330 let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
331 Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
332 }
333}
334
335#[derive(Debug)]
337pub struct EachArgumentHelper;
338impl HelperDef for EachArgumentHelper {
339 fn call<'reg: 'rc, 'rc>(
340 &self,
341 h: &Helper<'rc>,
342 r: &'reg Handlebars<'_>,
343 ctx: &'rc Context,
344 rc: &mut RenderContext<'reg, 'rc>,
345 out: &mut dyn Output,
346 ) -> HelperResult {
347 let value = h
348 .param(0)
349 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
350
351 if let Some(t) = h.template() {
352 let mut block_context = BlockContext::new();
353 if let Some(path) = value.context_path() {
354 *block_context.base_path_mut() = path.to_vec();
355 }
356 rc.push_block(block_context);
357 let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
358 .map_err(|_| {
359 RenderErrorReason::ParamTypeMismatchForName(
360 "each_argument",
361 "Vec<AMQPArgument>".to_string(),
362 "arguments".to_string(),
363 )
364 })?;
365 let len = arguments.len();
366 let array_path = value.context_path();
367 for (index, argument) in arguments.iter().enumerate() {
368 if let Some(ref mut block) = rc.block_mut() {
369 let (path, is_value) = match *argument {
370 AMQPArgument::Value(_) => ("Value".to_owned(), true),
371 AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
372 };
373 block.set_local_var("index", to_json(index));
374 block.set_local_var("last", to_json(index == len - 1));
375 block.set_local_var("argument_is_value", to_json(is_value));
376 if let Some(p) = array_path {
377 if index == 0 {
378 let mut path = Vec::with_capacity(p.len() + 1);
379 path.extend_from_slice(p);
380 path.push(index.to_string());
381 *block.base_path_mut() = path;
382 } else if let Some(ptr) = block.base_path_mut().last_mut() {
383 *ptr = index.to_string();
384 }
385 }
386 if let Some(block_param) = h.block_param() {
387 let mut params = BlockParams::new();
388 params.add_path(block_param, vec![path])?;
389 block.set_block_params(params);
390 }
391 }
392 t.render(r, ctx, rc, out)?;
393 }
394 rc.pop_block();
395 }
396 Ok(())
397 }
398}
399
400#[derive(Debug)]
402pub struct AMQPValueRefHelper;
403impl HelperDef for AMQPValueRefHelper {
404 fn call_inner<'reg: 'rc, 'rc>(
405 &self,
406 h: &Helper<'rc>,
407 _: &'reg Handlebars<'_>,
408 _: &'rc Context,
409 _: &mut RenderContext<'reg, 'rc>,
410 ) -> Result<ScopedJson<'rc>, RenderError> {
411 let arg = h
412 .param(0)
413 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
414 let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
415 RenderErrorReason::ParamTypeMismatchForName(
416 "amqp_value",
417 "AMQPValue".to_string(),
418 "value".to_string(),
419 )
420 })?;
421 let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
422 Ok(ScopedJson::Derived(value))
423 }
424}
425
426fn json_value(val: AMQPValue) -> serde_json::Result<Value> {
427 match val {
428 AMQPValue::Boolean(v) => serde_json::to_value(v),
429 AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
430 AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
431 AMQPValue::ShortInt(v) => serde_json::to_value(v),
432 AMQPValue::ShortUInt(v) => serde_json::to_value(v),
433 AMQPValue::LongInt(v) => serde_json::to_value(v),
434 AMQPValue::LongUInt(v) => serde_json::to_value(v),
435 AMQPValue::LongLongInt(v) => serde_json::to_value(v),
436 AMQPValue::Float(v) => serde_json::to_value(v),
437 AMQPValue::Double(v) => serde_json::to_value(v),
438 AMQPValue::DecimalValue(v) => serde_json::to_value(v),
439 AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{v}\"")),
440 AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{v}\"")),
441 AMQPValue::FieldArray(v) => serde_json::to_value(v),
442 AMQPValue::Timestamp(v) => serde_json::to_value(v),
443 AMQPValue::FieldTable(v) => serde_json::to_value(v),
444 AMQPValue::ByteArray(v) => serde_json::to_value(v),
445 AMQPValue::Void => Ok(JsonValue::Null),
446 }
447}
448
449#[cfg(test)]
450mod test {
451 use super::*;
452
453 use std::collections::BTreeMap;
454
455 const TEMPLATE: &str = r#"
456{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
457{{protocol.copyright}}
458port {{protocol.port}}
459{{#each protocol.domains ~}}
460{{@key}}: {{this}}
461{{/each ~}}
462{{#each protocol.constants as |constant| ~}}
463{{constant.name}} = {{constant.value}}
464{{/each ~}}
465{{#each protocol.classes as |class| ~}}
466{{class.id}} - {{class.name}}
467{{#each class.properties as |property| ~}}
468{{property.name}}: {{property.type}}
469{{/each ~}}
470{{#each class.methods as |method| ~}}
471{{method.id}} - {{method.name}}
472synchronous: {{method.synchronous}}
473{{#each_argument method.arguments as |argument| ~}}
474{{#if @argument_is_value ~}}
475{{argument.name}}({{argument.domain}}): {{argument.type}}
476{{else}}
477{{#each argument.flags as |flag| ~}}
478{{flag.name}}: {{flag.default_value}}
479{{/each ~}}
480{{/if ~}}
481{{/each_argument ~}}
482{{/each ~}}
483{{/each ~}}
484"#;
485
486 fn specs() -> AMQProtocolDefinition {
487 let mut domains = BTreeMap::default();
488 domains.insert("domain1".to_string(), AMQPType::LongString);
489 AMQProtocolDefinition {
490 name: "AMQP".to_string(),
491 major_version: 0,
492 minor_version: 9,
493 revision: 1,
494 port: 5672,
495 copyright: "Copyright 1\nCopyright 2".to_string(),
496 domains,
497 constants: vec![AMQPConstant {
498 name: "constant1".to_string(),
499 amqp_type: AMQPType::ShortUInt,
500 value: 128,
501 }],
502 soft_errors: Vec::default(),
503 hard_errors: Vec::default(),
504 classes: vec![AMQPClass {
505 id: 42,
506 methods: vec![AMQPMethod {
507 id: 64,
508 arguments: vec![
509 AMQPArgument::Value(AMQPValueArgument {
510 amqp_type: AMQPType::LongString,
511 name: "argument1".to_string(),
512 default_value: Some(AMQPValue::LongString("value1".into())),
513 domain: Some("domain1".to_string()),
514 force_default: false,
515 }),
516 AMQPArgument::Flags(AMQPFlagsArgument {
517 ignore_flags: false,
518 flags: vec![
519 AMQPFlagArgument {
520 name: "flag1".to_string(),
521 default_value: true,
522 force_default: false,
523 },
524 AMQPFlagArgument {
525 name: "flag2".to_string(),
526 default_value: false,
527 force_default: false,
528 },
529 ],
530 }),
531 ],
532 name: "method1".to_string(),
533 synchronous: true,
534 content: false,
535 metadata: Value::default(),
536 is_reply: false,
537 ignore_args: false,
538 c2s: true,
539 s2c: true,
540 }],
541 name: "class1".to_string(),
542 properties: vec![AMQPProperty {
543 amqp_type: AMQPType::LongString,
544 name: "property1".to_string(),
545 }],
546 metadata: Value::default(),
547 }],
548 }
549 }
550
551 #[test]
552 fn main_template() {
553 let mut data = HashMap::new();
554 let mut codegen = CodeGenerator::default().register_amqp_helpers();
555 data.insert("protocol".to_string(), specs());
556 assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
557 assert_eq!(
558 codegen.render("main", &data).unwrap(),
559 r#"
560AMQP - 0.9.1
561Copyright 1
562Copyright 2
563port 5672
564domain1: LongString
565constant1= 128
56642- class1
567property1: LongString
56864- method1
569synchronous: true
570argument1(domain1): LongString
571flag1: true
572flag2: false
573"#
574 );
575 }
576}