1use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8pub struct CodeGenerator {
10 output: String,
11 indent_level: usize,
12}
13
14impl CodeGenerator {
15 pub fn new() -> Self {
17 Self {
18 output: String::new(),
19 indent_level: 0,
20 }
21 }
22
23 pub fn output(self) -> String {
25 self.output
26 }
27
28 pub fn write_module_header(&mut self) -> Result<()> {
30 writeln!(
31 &mut self.output,
32 "// Generated code from Varlink IDL files."
33 )?;
34 writeln!(&mut self.output)?;
35 writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36 writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37 writeln!(&mut self.output)?;
38 Ok(())
39 }
40
41 pub fn generate_interface(
43 &mut self,
44 interface: &Interface<'_>,
45 skip_module_header: bool,
46 ) -> Result<()> {
47 if skip_module_header {
48 self.write_interface_comment(interface)?;
49 } else {
50 self.write_header(interface)?;
51 self.writeln("use serde::{Deserialize, Serialize};")?;
52 self.writeln("use zlink::{proxy, ReplyError};")?;
54 self.writeln("")?;
55 }
56
57 self.generate_proxy_trait(interface)?;
59 self.writeln("")?;
60
61 self.generate_output_structs(interface)?;
63
64 for custom_type in interface.custom_types() {
66 self.generate_custom_type(custom_type)?;
67 self.writeln("")?;
68 }
69
70 if interface.errors().count() > 0 {
72 self.generate_errors(interface)?;
73 self.writeln("")?;
74 }
75
76 Ok(())
77 }
78
79 fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80 writeln!(
81 &mut self.output,
82 "// Generated code for Varlink interface `{}`.",
83 interface.name()
84 )?;
85 writeln!(&mut self.output)?;
86 Ok(())
87 }
88
89 fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90 writeln!(
91 &mut self.output,
92 "//! Generated code for Varlink interface `{}`.",
93 interface.name()
94 )?;
95 writeln!(&mut self.output, "//!",)?;
96 writeln!(
97 &mut self.output,
98 "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99 )?;
100 writeln!(
101 &mut self.output,
102 "//! You may prefer to adapt it, instead of using it verbatim.",
103 )?;
104 writeln!(&mut self.output)?;
105
106 for comment in interface.comments() {
108 writeln!(&mut self.output, "//! {}", comment.text())?;
109 }
110 writeln!(&mut self.output)?;
111
112 Ok(())
113 }
114
115 fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116 match custom_type {
117 CustomType::Object(obj) => self.generate_custom_object(obj),
118 CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119 }
120 }
121
122 fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123 for comment in obj.comments() {
125 self.writeln(&format!("/// {}", comment.text()))?;
126 }
127
128 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129 self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130 self.indent();
131
132 for field in obj.fields() {
133 self.generate_field(field)?;
134 }
135
136 self.dedent();
137 self.writeln("}")?;
138
139 Ok(())
140 }
141
142 fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143 for comment in enum_type.comments() {
145 self.writeln(&format!("/// {}", comment.text()))?;
146 }
147
148 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149 self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150 self.writeln(&format!(
151 "pub enum {} {{",
152 enum_type.name().to_pascal_case()
153 ))?;
154 self.indent();
155
156 for variant in enum_type.variants() {
157 for comment in variant.comments() {
159 self.writeln(&format!("/// {}", comment.text()))?;
160 }
161
162 self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164 }
165
166 self.dedent();
167 self.writeln("}")?;
168
169 Ok(())
170 }
171
172 fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173 for comment in field.comments() {
175 self.writeln(&format!("/// {}", comment.text()))?;
176 }
177
178 let field_name = field.name().to_snake_case();
179 let rust_type = self.type_to_rust(field.ty())?;
180
181 let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183 rust_type
185 } else {
186 rust_type
187 };
188
189 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191 format!("#[serde(rename = \"{}\")]", field.name())
192 } else {
193 String::new()
194 };
195
196 if !field_name_attr.is_empty() {
197 self.writeln(&field_name_attr)?;
198 }
199
200 let safe_field_name = if is_rust_keyword(&field_name) {
201 format!("r#{}", field_name)
202 } else {
203 field_name
204 };
205
206 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208 Ok(())
209 }
210
211 fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212 self.writeln("/// Errors that can occur in this interface.")?;
213 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215 self.writeln(&format!(
216 "pub enum {}Error {{",
217 interface_name_to_rust(interface.name())
218 ))?;
219 self.indent();
220
221 for error in interface.errors() {
222 for comment in error.comments() {
224 self.writeln(&format!("/// {}", comment.text()))?;
225 }
226
227 let variant_name = error.name().to_pascal_case();
228 if error.fields().count() == 0 {
229 self.writeln(&format!("{},", variant_name))?;
230 } else {
231 self.writeln(&format!("{} {{", variant_name))?;
232 self.indent();
233 for field in error.fields() {
234 self.generate_error_field(field)?;
235 }
236 self.dedent();
237 self.writeln("},")?;
238 }
239 }
240
241 self.dedent();
242 self.writeln("}")?;
243
244 Ok(())
245 }
246
247 fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249 for method in interface.methods() {
250 if method.outputs().count() > 0 {
254 let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256 self.writeln(&format!(
258 "/// Output parameters for the {} method.",
259 method.name()
260 ))?;
261
262 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266 if needs_lifetime {
267 self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268 } else {
269 self.writeln(&format!("pub struct {} {{", struct_name))?;
270 }
271 self.indent();
272
273 for output in method.outputs() {
274 let field_name = output.name().to_snake_case();
275 let rust_type = if needs_lifetime {
277 self.type_to_rust_output(output.ty())?
278 } else {
279 self.type_to_rust(output.ty())?
280 };
281
282 if needs_lifetime && type_needs_borrow(output.ty()) {
284 self.writeln("#[serde(borrow)]")?;
285 }
286
287 if field_name != output.name() {
288 self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289 }
290
291 let safe_field_name = if is_rust_keyword(&field_name) {
292 format!("r#{}", field_name)
293 } else {
294 field_name
295 };
296
297 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298 }
299
300 self.dedent();
301 self.writeln("}")?;
302 self.writeln("")?;
303 }
304 }
305
306 Ok(())
307 }
308
309 fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310 let trait_name = interface_name_to_rust(interface.name());
311
312 let error_type = if interface.errors().count() > 0 {
314 format!("{}Error", interface_name_to_rust(interface.name()))
315 } else {
316 let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319 self.writeln("/// Stub error type for interface without errors.")?;
321 self.writeln("///")?;
322 self.writeln("/// This is an empty enum that can never be instantiated.")?;
323 self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326 self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327 self.writeln("")?;
328
329 stub_error_name
330 };
331
332 self.writeln("/// Proxy trait for calling methods on the interface.")?;
333 self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334 self.writeln(&format!("pub trait {} {{", trait_name))?;
335 self.indent();
336
337 for method in interface.methods() {
338 self.generate_proxy_method_signature(method, &error_type)?;
339 }
340
341 self.dedent();
342 self.writeln("}")?;
343
344 Ok(())
345 }
346
347 fn generate_proxy_method_signature(
348 &mut self,
349 method: &Method<'_>,
350 error_type: &str,
351 ) -> Result<()> {
352 for comment in method.comments() {
354 self.writeln(&format!("/// {}", comment.text()))?;
355 }
356
357 let method_name = method.name().to_snake_case();
358 let safe_method_name = if is_rust_keyword(&method_name) {
359 format!("r#{}", method_name)
360 } else {
361 method_name
362 };
363
364 let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367 for param in method.inputs() {
369 let param_name = param.name().to_snake_case();
370 let safe_param_name = if is_rust_keyword(¶m_name) {
371 format!("r#{}", param_name)
372 } else {
373 param_name
374 };
375 let rust_type = self.type_to_rust_param(param.ty())?;
377
378 write!(&mut signature, ",")?;
379 if safe_param_name != param.name() {
381 write!(&mut signature, " #[zlink(rename = \"{}\")]", param.name(),)?;
382 }
383
384 write!(&mut signature, " {}: {}", safe_param_name, rust_type)?;
385 }
386
387 signature.push_str(") -> zlink::Result<Result<");
388
389 let output_count = method.outputs().count();
391 if output_count == 0 {
392 signature.push_str("()");
393 } else {
394 let struct_name = format!("{}Output", method.name().to_pascal_case());
398 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
400 if needs_lifetime {
401 signature.push_str(&format!("{}<'_>", struct_name));
402 } else {
403 signature.push_str(&struct_name);
404 }
405 }
406
407 write!(&mut signature, ", {}>>", error_type)?;
408 signature.push(';');
409
410 self.writeln(&signature)?;
411
412 Ok(())
413 }
414
415 fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
416 for comment in field.comments() {
418 self.writeln(&format!("/// {}", comment.text()))?;
419 }
420
421 let field_name = field.name().to_snake_case();
422 let rust_type = self.type_to_rust(field.ty())?;
423
424 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
426 format!("#[zlink(rename = \"{}\")]", field.name())
427 } else {
428 String::new()
429 };
430
431 if !field_name_attr.is_empty() {
432 self.writeln(&field_name_attr)?;
433 }
434
435 let safe_field_name = if is_rust_keyword(&field_name) {
436 format!("r#{}", field_name)
437 } else {
438 field_name
439 };
440
441 self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
442
443 Ok(())
444 }
445
446 fn type_to_rust(&self, ty: &Type) -> Result<String> {
447 type_to_rust(ty)
448 }
449
450 fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
451 type_to_rust_param(ty)
452 }
453
454 fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
455 type_to_rust_output(ty)
456 }
457
458 fn writeln(&mut self, s: &str) -> Result<()> {
459 self.write(s)?;
460 writeln!(&mut self.output)?;
461 Ok(())
462 }
463
464 fn write(&mut self, s: &str) -> Result<()> {
465 for _ in 0..self.indent_level {
466 write!(&mut self.output, " ")?;
467 }
468 write!(&mut self.output, "{}", s)?;
469 Ok(())
470 }
471
472 fn indent(&mut self) {
473 self.indent_level += 1;
474 }
475
476 fn dedent(&mut self) {
477 if self.indent_level > 0 {
478 self.indent_level -= 1;
479 }
480 }
481}
482
483impl Default for CodeGenerator {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489fn type_to_rust(ty: &Type) -> Result<String> {
490 Ok(match ty {
491 Type::Bool => "bool".to_string(),
492 Type::Int => "i64".to_string(),
493 Type::Float => "f64".to_string(),
494 Type::String => "String".to_string(),
495 Type::Object(_fields) => {
496 "serde_json::Value".to_string()
500 }
501 Type::Enum(_variants) => {
502 "String".to_string()
504 }
505 Type::Array(elem_type) => {
506 let elem_rust = type_to_rust(elem_type.inner())?;
507 format!("Vec<{}>", elem_rust)
508 }
509 Type::Map(value_type) => {
510 let value_rust = type_to_rust(value_type.inner())?;
511 format!("std::collections::HashMap<String, {}>", value_rust)
512 }
513 Type::ForeignObject => "serde_json::Value".to_string(),
514 Type::Optional(inner_type) => {
515 let inner_rust = type_to_rust(inner_type.inner())?;
516 format!("Option<{}>", inner_rust)
517 }
518 Type::Custom(name) => name.to_pascal_case(),
519 Type::Any => "serde_json::Value".to_string(),
520 })
521}
522
523fn type_to_rust_param(ty: &Type) -> Result<String> {
524 Ok(match ty {
525 Type::Bool => "bool".to_string(),
526 Type::Int => "i64".to_string(),
527 Type::Float => "f64".to_string(),
528 Type::String => "&str".to_string(),
529 Type::Object(_fields) => {
530 "&serde_json::Value".to_string()
532 }
533 Type::Enum(_variants) => {
534 "&str".to_string()
536 }
537 Type::Array(elem_type) => {
538 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
540 format!("&[{}]", elem_rust)
541 }
542 Type::Map(value_type) => {
543 let value_rust = type_to_rust_param_elem(value_type.inner())?;
545 format!("&std::collections::HashMap<&str, {}>", value_rust)
546 }
547 Type::ForeignObject => "&serde_json::Value".to_string(),
548 Type::Optional(inner_type) => {
549 let inner_rust = type_to_rust_param(inner_type.inner())?;
550 format!("Option<{}>", inner_rust)
552 }
553 Type::Custom(name) => format!("&{}", name.to_pascal_case()),
554 Type::Any => "&serde_json::Value".to_string(),
555 })
556}
557
558fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
561 Ok(match ty {
562 Type::Bool => "bool".to_string(),
563 Type::Int => "i64".to_string(),
564 Type::Float => "f64".to_string(),
565 Type::String => "&str".to_string(),
566 Type::Object(_fields) => "serde_json::Value".to_string(),
567 Type::Enum(_variants) => "&str".to_string(),
568 Type::Array(elem_type) => {
569 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
570 format!("Vec<{}>", elem_rust)
571 }
572 Type::Map(value_type) => {
573 let value_rust = type_to_rust_param_elem(value_type.inner())?;
574 format!("std::collections::HashMap<&str, {}>", value_rust)
575 }
576 Type::ForeignObject => "serde_json::Value".to_string(),
577 Type::Any => "serde_json::Value".to_string(),
578 Type::Optional(inner_type) => {
579 let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
580 format!("Option<{}>", inner_rust)
581 }
582 Type::Custom(name) => name.to_pascal_case(),
583 })
584}
585
586fn type_to_rust_output(ty: &Type) -> Result<String> {
587 Ok(match ty {
588 Type::Bool => "bool".to_string(),
589 Type::Int => "i64".to_string(),
590 Type::Float => "f64".to_string(),
591 Type::String => "&'a str".to_string(),
592 Type::Object(_fields) => {
593 "serde_json::Value".to_string()
595 }
596 Type::Enum(_variants) => {
597 "&'a str".to_string()
599 }
600 Type::Array(elem_type) => {
601 let elem_rust = match elem_type.inner() {
603 Type::String => "&'a str".to_string(),
604 Type::Enum(_) => "&'a str".to_string(),
605 _ => type_to_rust(elem_type.inner())?,
606 };
607 format!("Vec<{}>", elem_rust)
608 }
609 Type::Map(value_type) => {
610 let value_rust = match value_type.inner() {
612 Type::String => "&'a str".to_string(),
613 Type::Enum(_) => "&'a str".to_string(),
614 _ => type_to_rust(value_type.inner())?,
615 };
616 format!("std::collections::HashMap<&'a str, {}>", value_rust)
617 }
618 Type::ForeignObject => "serde_json::Value".to_string(),
619 Type::Any => "serde_json::Value".to_string(),
620 Type::Optional(inner_type) => {
621 let inner_rust = type_to_rust_output(inner_type.inner())?;
624 format!("Option<{}>", inner_rust)
625 }
626 Type::Custom(name) => name.to_pascal_case(),
627 })
628}
629
630fn interface_name_to_rust(name: &str) -> String {
631 name.split('.').next_back().unwrap_or(name).to_pascal_case()
633}
634
635fn type_needs_lifetime(ty: &Type) -> bool {
636 match ty {
637 Type::String => true,
638 Type::Enum(_) => true, Type::Array(inner) => type_needs_lifetime(inner.inner()),
640 Type::Map(_) => {
641 true
643 }
644 Type::Optional(inner) => type_needs_lifetime(inner.inner()),
645 _ => false,
646 }
647}
648
649fn type_needs_borrow(ty: &Type) -> bool {
650 match ty {
651 Type::String => true,
652 Type::Enum(_) => true, Type::Array(inner) => type_needs_borrow(inner.inner()),
654 Type::Map(_) => {
655 true
657 }
658 Type::Optional(inner) => type_needs_borrow(inner.inner()),
659 _ => false,
660 }
661}
662
663fn is_rust_keyword(s: &str) -> bool {
664 [
665 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
666 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
667 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
668 "true", "type", "unsafe", "use", "where", "while",
669 ]
670 .contains(&s)
671}