1use crate::{specs::*, util::*};
2
3use amq_protocol_types::{AMQPType, AMQPValue};
4use handlebars::{
5 self, BlockContext, BlockParams, Context, Handlebars, Helper, HelperDef, HelperResult,
6 JsonValue, Output, RenderContext, RenderError, RenderErrorReason, Renderable, ScopedJson,
7 to_json,
8};
9use serde_json::{self, Value};
10
11use std::{
12 collections::HashMap,
13 fs::{self, File},
14 io::Write,
15 path::Path,
16};
17
18pub type CodeGenerator<'a> = Handlebars<'a>;
20
21pub trait HandlebarsAMQPExtension {
23 fn register_amqp_helpers(self) -> Self;
25 fn simple_codegen(
28 out_dir: &str,
29 target: &str,
30 template_name: &str,
31 template: &str,
32 var_name: &str,
33 ) {
34 Self::simple_codegen_with_data(out_dir, target, template_name, template, var_name, None);
35 }
36 fn simple_codegen_with_data(
40 out_dir: &str,
41 target: &str,
42 template_name: &str,
43 template: &str,
44 var_name: &str,
45 data: Option<Value>,
46 );
47}
48
49impl<'a> HandlebarsAMQPExtension for CodeGenerator<'a> {
50 fn register_amqp_helpers(mut self) -> CodeGenerator<'a> {
51 self.register_escape_fn(handlebars::no_escape);
52 self.register_helper("camel", Box::new(CamelHelper));
53 self.register_helper("snake", Box::new(SnakeHelper));
54 self.register_helper("snake_type", Box::new(SnakeTypeHelper));
55 self.register_helper("sanitize_name", Box::new(SanitizeNameHelper));
56 self.register_helper("include_more", Box::new(IncludeMoreHelper));
57 self.register_helper("pass_by_ref", Box::new(PassByRefHelper));
58 self.register_helper("use_str_ref", Box::new(UseStrRefHelper));
59 self.register_helper("use_bytes_ref", Box::new(UseBytesRefHelper));
60 self.register_helper("each_argument", Box::new(EachArgumentHelper));
61 self.register_helper("amqp_value_ref", Box::new(AMQPValueRefHelper));
62 self
63 }
64
65 fn simple_codegen_with_data(
66 out_dir: &str,
67 target: &str,
68 template_name: &str,
69 template: &str,
70 var_name: &str,
71 metadata: Option<Value>,
72 ) {
73 let dest_path = Path::new(out_dir).join(format!("{target}.rs"));
74 let mut f = File::create(&dest_path)
75 .unwrap_or_else(|err| panic!("Failed to create {dest_path:?}: {err}"));
76 let specs = AMQProtocolDefinition::load(metadata);
77 let mut codegen = CodeGenerator::default().register_amqp_helpers();
78 let mut data = HashMap::new();
79
80 codegen.set_strict_mode(true);
81 codegen
82 .register_template_string(template_name, template)
83 .unwrap_or_else(|e| panic!("Failed to register {template_name} template: {e}"));
84 data.insert(
85 var_name.to_string(),
86 serde_json::to_value(specs)
87 .unwrap_or_else(|e| panic!("Failed to serialize specs: {e}")),
88 );
89
90 writeln!(
91 f,
92 "{}",
93 codegen
94 .render(template_name, &data)
95 .unwrap_or_else(|err| panic!("Failed to render {template_name} template: {err}"))
96 )
97 .unwrap_or_else(|e| panic!("Failed to generate {target}.rs: {e}"));
98 }
99}
100
101pub struct CamelHelper;
103impl HelperDef for CamelHelper {
104 fn call<'reg: 'rc, 'rc>(
105 &self,
106 h: &Helper<'rc>,
107 _: &'reg Handlebars<'_>,
108 _: &'rc Context,
109 _: &mut RenderContext<'reg, 'rc>,
110 out: &mut dyn Output,
111 ) -> HelperResult {
112 let value = h
113 .param(0)
114 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("camel", 0))?;
115 let param = value.value().as_str().ok_or_else(|| {
116 RenderErrorReason::ParamTypeMismatchForName(
117 "camel",
118 "string".to_string(),
119 "string".to_string(),
120 )
121 })?;
122 out.write(&camel_case(param))?;
123 Ok(())
124 }
125}
126
127pub struct SnakeHelper;
129impl HelperDef for SnakeHelper {
130 fn call<'reg: 'rc, 'rc>(
131 &self,
132 h: &Helper<'rc>,
133 _: &'reg Handlebars<'_>,
134 _: &'rc Context,
135 _: &mut RenderContext<'reg, 'rc>,
136 out: &mut dyn Output,
137 ) -> HelperResult {
138 let value = h
139 .param(0)
140 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake", 0))?;
141 let raw = h
142 .param(1)
143 .and_then(|raw| raw.value().as_bool())
144 .unwrap_or(true);
145 let param = value.value().as_str().ok_or_else(|| {
146 RenderErrorReason::ParamTypeMismatchForName(
147 "snake",
148 "string".to_string(),
149 "string".to_string(),
150 )
151 })?;
152 out.write(&snake_case(param, raw))?;
153 Ok(())
154 }
155}
156
157pub struct SnakeTypeHelper;
159impl HelperDef for SnakeTypeHelper {
160 fn call<'reg: 'rc, 'rc>(
161 &self,
162 h: &Helper<'rc>,
163 _: &'reg Handlebars<'_>,
164 _: &'rc Context,
165 _: &mut RenderContext<'reg, 'rc>,
166 out: &mut dyn Output,
167 ) -> HelperResult {
168 let value = h
169 .param(0)
170 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("snake_type", 0))?;
171 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
172 RenderErrorReason::ParamTypeMismatchForName(
173 "snake_type",
174 "AMQPType".to_string(),
175 "string".to_string(),
176 )
177 })?;
178 out.write(&snake_case(¶m.to_string(), true))?;
179 Ok(())
180 }
181}
182
183pub struct SanitizeNameHelper;
185impl HelperDef for SanitizeNameHelper {
186 fn call<'reg: 'rc, 'rc>(
187 &self,
188 h: &Helper<'rc>,
189 _: &'reg Handlebars<'_>,
190 _: &'rc Context,
191 _: &mut RenderContext<'reg, 'rc>,
192 out: &mut dyn Output,
193 ) -> HelperResult {
194 let value = h
195 .param(0)
196 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("sanitize_name", 0))?;
197 let param = value.value().as_str().ok_or_else(|| {
198 RenderErrorReason::ParamTypeMismatchForName(
199 "sanitize_name",
200 "string".to_string(),
201 "string".to_string(),
202 )
203 })?;
204 out.write(¶m.replace('-', "_"))?;
205 Ok(())
206 }
207}
208
209pub struct IncludeMoreHelper;
211impl HelperDef for IncludeMoreHelper {
212 fn call<'reg: 'rc, 'rc>(
213 &self,
214 h: &Helper<'rc>,
215 _: &'reg Handlebars<'_>,
216 _: &'rc Context,
217 _: &mut RenderContext<'reg, 'rc>,
218 out: &mut dyn Output,
219 ) -> HelperResult {
220 let amqp_class = h
221 .param(0)
222 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 0))?;
223 let amqp_method = h
224 .param(1)
225 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("include_more", 1))?;
226 let amqp_class = amqp_class.value().as_str().ok_or_else(|| {
227 RenderErrorReason::ParamTypeMismatchForName(
228 "include_more",
229 "string".to_string(),
230 "class".to_string(),
231 )
232 })?;
233 let amqp_method = amqp_method.value().as_str().ok_or_else(|| {
234 RenderErrorReason::ParamTypeMismatchForName(
235 "include_more",
236 "string".to_string(),
237 "method".to_string(),
238 )
239 })?;
240 if let Ok(cargo_manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
241 let include = Path::new(&cargo_manifest_dir)
242 .join("templates")
243 .join("includes")
244 .join(amqp_class)
245 .join(format!("{amqp_method}.rs"));
246 if let Ok(include) = fs::read_to_string(include) {
247 out.write(&include)?;
248 }
249 }
250 Ok(())
251 }
252}
253
254pub struct PassByRefHelper;
256impl HelperDef for PassByRefHelper {
257 fn call_inner<'reg: 'rc, 'rc>(
258 &self,
259 h: &Helper<'rc>,
260 _: &'reg Handlebars<'_>,
261 _: &'rc Context,
262 _: &mut RenderContext<'reg, 'rc>,
263 ) -> Result<ScopedJson<'rc>, RenderError> {
264 let value = h
265 .param(0)
266 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("pass_by_ref", 0))?;
267 let param: AMQPType = serde_json::from_value(value.value().clone()).map_err(|_| {
268 RenderErrorReason::ParamTypeMismatchForName(
269 "pass_by_ref",
270 "AMQPType".to_string(),
271 "string".to_string(),
272 )
273 })?;
274 let pass_by_ref = matches!(
275 param,
276 AMQPType::ShortString
277 | AMQPType::LongString
278 | AMQPType::FieldArray
279 | AMQPType::FieldTable
280 | AMQPType::ByteArray
281 );
282 Ok(ScopedJson::Derived(JsonValue::from(pass_by_ref)))
283 }
284}
285
286pub struct UseStrRefHelper;
288impl HelperDef for UseStrRefHelper {
289 fn call_inner<'reg: 'rc, 'rc>(
290 &self,
291 h: &Helper<'rc>,
292 _: &'reg Handlebars<'_>,
293 _: &'rc Context,
294 _: &mut RenderContext<'reg, 'rc>,
295 ) -> Result<ScopedJson<'rc>, RenderError> {
296 let value = h
297 .param(0)
298 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_str_ref", 0))?;
299 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
300 let use_str_ref = matches!(
301 param,
302 Some(AMQPType::ShortString) | Some(AMQPType::LongString)
303 );
304 Ok(ScopedJson::Derived(JsonValue::from(use_str_ref)))
305 }
306}
307
308pub struct UseBytesRefHelper;
310impl HelperDef for UseBytesRefHelper {
311 fn call_inner<'reg: 'rc, 'rc>(
312 &self,
313 h: &Helper<'rc>,
314 _: &'reg Handlebars<'_>,
315 _: &'rc Context,
316 _: &mut RenderContext<'reg, 'rc>,
317 ) -> Result<ScopedJson<'rc>, RenderError> {
318 let value = h
319 .param(0)
320 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("use_bytes_ref", 0))?;
321 let param = serde_json::from_value::<AMQPType>(value.value().clone()).ok();
322 let use_bytes_ref = matches!(param, Some(AMQPType::LongString));
323 Ok(ScopedJson::Derived(JsonValue::from(use_bytes_ref)))
324 }
325}
326
327pub struct EachArgumentHelper;
329impl HelperDef for EachArgumentHelper {
330 fn call<'reg: 'rc, 'rc>(
331 &self,
332 h: &Helper<'rc>,
333 r: &'reg Handlebars<'_>,
334 ctx: &'rc Context,
335 rc: &mut RenderContext<'reg, 'rc>,
336 out: &mut dyn Output,
337 ) -> HelperResult {
338 let value = h
339 .param(0)
340 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("each_argument", 0))?;
341
342 if let Some(t) = h.template() {
343 let mut block_context = BlockContext::new();
344 if let Some(path) = value.context_path() {
345 *block_context.base_path_mut() = path.to_vec();
346 }
347 rc.push_block(block_context);
348 let arguments: Vec<AMQPArgument> = serde_json::from_value(value.value().clone())
349 .map_err(|_| {
350 RenderErrorReason::ParamTypeMismatchForName(
351 "each_argument",
352 "Vec<AMQPArgument>".to_string(),
353 "arguments".to_string(),
354 )
355 })?;
356 let len = arguments.len();
357 let array_path = value.context_path();
358 for (index, argument) in arguments.iter().enumerate() {
359 if let Some(ref mut block) = rc.block_mut() {
360 let (path, is_value) = match *argument {
361 AMQPArgument::Value(_) => ("Value".to_owned(), true),
362 AMQPArgument::Flags(_) => ("Flags".to_owned(), false),
363 };
364 block.set_local_var("index", to_json(index));
365 block.set_local_var("last", to_json(index == len - 1));
366 block.set_local_var("argument_is_value", to_json(is_value));
367 if let Some(p) = array_path {
368 if index == 0 {
369 let mut path = Vec::with_capacity(p.len() + 1);
370 path.extend_from_slice(p);
371 path.push(index.to_string());
372 *block.base_path_mut() = path;
373 } else if let Some(ptr) = block.base_path_mut().last_mut() {
374 *ptr = index.to_string();
375 }
376 }
377 if let Some(block_param) = h.block_param() {
378 let mut params = BlockParams::new();
379 params.add_path(block_param, vec![path])?;
380 block.set_block_params(params);
381 }
382 }
383 t.render(r, ctx, rc, out)?;
384 }
385 rc.pop_block();
386 }
387 Ok(())
388 }
389}
390
391pub struct AMQPValueRefHelper;
393impl HelperDef for AMQPValueRefHelper {
394 fn call_inner<'reg: 'rc, 'rc>(
395 &self,
396 h: &Helper<'rc>,
397 _: &'reg Handlebars<'_>,
398 _: &'rc Context,
399 _: &mut RenderContext<'reg, 'rc>,
400 ) -> Result<ScopedJson<'rc>, RenderError> {
401 let arg = h
402 .param(0)
403 .ok_or_else(|| RenderErrorReason::ParamNotFoundForIndex("amqp_value", 0))?;
404 let param = serde_json::from_value(arg.value().clone()).map_err(|_| {
405 RenderErrorReason::ParamTypeMismatchForName(
406 "amqp_value",
407 "AMQPValue".to_string(),
408 "value".to_string(),
409 )
410 })?;
411 let value = json_value(param).map_err(RenderErrorReason::SerdeError)?;
412 Ok(ScopedJson::Derived(value))
413 }
414}
415
416fn json_value(val: AMQPValue) -> serde_json::Result<serde_json::Value> {
417 match val {
418 AMQPValue::Boolean(v) => serde_json::to_value(v),
419 AMQPValue::ShortShortInt(v) => serde_json::to_value(v),
420 AMQPValue::ShortShortUInt(v) => serde_json::to_value(v),
421 AMQPValue::ShortInt(v) => serde_json::to_value(v),
422 AMQPValue::ShortUInt(v) => serde_json::to_value(v),
423 AMQPValue::LongInt(v) => serde_json::to_value(v),
424 AMQPValue::LongUInt(v) => serde_json::to_value(v),
425 AMQPValue::LongLongInt(v) => serde_json::to_value(v),
426 AMQPValue::Float(v) => serde_json::to_value(v),
427 AMQPValue::Double(v) => serde_json::to_value(v),
428 AMQPValue::DecimalValue(v) => serde_json::to_value(v),
429 AMQPValue::ShortString(v) => serde_json::to_value(format!("\"{v}\"")),
430 AMQPValue::LongString(v) => serde_json::to_value(format!("b\"{v}\"")),
431 AMQPValue::FieldArray(v) => serde_json::to_value(v),
432 AMQPValue::Timestamp(v) => serde_json::to_value(v),
433 AMQPValue::FieldTable(v) => serde_json::to_value(v),
434 AMQPValue::ByteArray(v) => serde_json::to_value(v),
435 AMQPValue::Void => Ok(JsonValue::Null),
436 }
437}
438
439#[cfg(test)]
440mod test {
441 use super::*;
442
443 use std::collections::BTreeMap;
444
445 pub const TEMPLATE: &str = r#"
446{{protocol.name}} - {{protocol.major_version}}.{{protocol.minor_version}}.{{protocol.revision}}
447{{protocol.copyright}}
448port {{protocol.port}}
449{{#each protocol.domains ~}}
450{{@key}}: {{this}}
451{{/each ~}}
452{{#each protocol.constants as |constant| ~}}
453{{constant.name}} = {{constant.value}}
454{{/each ~}}
455{{#each protocol.classes as |class| ~}}
456{{class.id}} - {{class.name}}
457{{#each class.properties as |property| ~}}
458{{property.name}}: {{property.type}}
459{{/each ~}}
460{{#each class.methods as |method| ~}}
461{{method.id}} - {{method.name}}
462synchronous: {{method.synchronous}}
463{{#each_argument method.arguments as |argument| ~}}
464{{#if @argument_is_value ~}}
465{{argument.name}}({{argument.domain}}): {{argument.type}}
466{{else}}
467{{#each argument.flags as |flag| ~}}
468{{flag.name}}: {{flag.default_value}}
469{{/each ~}}
470{{/if ~}}
471{{/each_argument ~}}
472{{/each ~}}
473{{/each ~}}
474"#;
475
476 fn specs() -> AMQProtocolDefinition {
477 let mut domains = BTreeMap::default();
478 domains.insert("domain1".to_string(), AMQPType::LongString);
479 AMQProtocolDefinition {
480 name: "AMQP".to_string(),
481 major_version: 0,
482 minor_version: 9,
483 revision: 1,
484 port: 5672,
485 copyright: "Copyright 1\nCopyright 2".to_string(),
486 domains,
487 constants: vec![AMQPConstant {
488 name: "constant1".to_string(),
489 amqp_type: AMQPType::ShortUInt,
490 value: 128,
491 }],
492 soft_errors: Vec::default(),
493 hard_errors: Vec::default(),
494 classes: vec![AMQPClass {
495 id: 42,
496 methods: vec![AMQPMethod {
497 id: 64,
498 arguments: vec![
499 AMQPArgument::Value(AMQPValueArgument {
500 amqp_type: AMQPType::LongString,
501 name: "argument1".to_string(),
502 default_value: Some(AMQPValue::LongString("value1".into())),
503 domain: Some("domain1".to_string()),
504 force_default: false,
505 }),
506 AMQPArgument::Flags(AMQPFlagsArgument {
507 ignore_flags: false,
508 flags: vec![
509 AMQPFlagArgument {
510 name: "flag1".to_string(),
511 default_value: true,
512 force_default: false,
513 },
514 AMQPFlagArgument {
515 name: "flag2".to_string(),
516 default_value: false,
517 force_default: false,
518 },
519 ],
520 }),
521 ],
522 name: "method1".to_string(),
523 synchronous: true,
524 content: false,
525 metadata: Value::default(),
526 is_reply: false,
527 ignore_args: false,
528 c2s: true,
529 s2c: true,
530 }],
531 name: "class1".to_string(),
532 properties: vec![AMQPProperty {
533 amqp_type: AMQPType::LongString,
534 name: "property1".to_string(),
535 }],
536 metadata: Value::default(),
537 }],
538 }
539 }
540
541 #[test]
542 fn main_template() {
543 let mut data = HashMap::new();
544 let mut codegen = CodeGenerator::default().register_amqp_helpers();
545 data.insert("protocol".to_string(), specs());
546 assert!(codegen.register_template_string("main", TEMPLATE).is_ok());
547 assert_eq!(
548 codegen.render("main", &data).unwrap(),
549 r#"
550AMQP - 0.9.1
551Copyright 1
552Copyright 2
553port 5672
554domain1: LongString
555constant1= 128
55642- class1
557property1: LongString
55864- method1
559synchronous: true
560argument1(domain1): LongString
561flag1: true
562flag2: false
563"#
564 );
565 }
566}