1use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8pub struct CodeGenerator {
10 output: String,
11 indent_level: usize,
12}
13
14impl CodeGenerator {
15 pub fn new() -> Self {
17 Self {
18 output: String::new(),
19 indent_level: 0,
20 }
21 }
22
23 pub fn output(self) -> String {
25 self.output
26 }
27
28 pub fn write_module_header(&mut self) -> Result<()> {
30 writeln!(
31 &mut self.output,
32 "// Generated code from Varlink IDL files."
33 )?;
34 writeln!(&mut self.output)?;
35 writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36 writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37 writeln!(&mut self.output)?;
38 Ok(())
39 }
40
41 pub fn generate_interface(
43 &mut self,
44 interface: &Interface<'_>,
45 skip_module_header: bool,
46 ) -> Result<()> {
47 if skip_module_header {
48 self.write_interface_comment(interface)?;
49 } else {
50 self.write_header(interface)?;
51 self.writeln("use serde::{Deserialize, Serialize};")?;
52 self.writeln("use zlink::{proxy, ReplyError};")?;
54 self.writeln("")?;
55 }
56
57 self.generate_proxy_trait(interface)?;
59 self.writeln("")?;
60
61 self.generate_output_structs(interface)?;
63
64 for custom_type in interface.custom_types() {
66 self.generate_custom_type(custom_type)?;
67 self.writeln("")?;
68 }
69
70 if interface.errors().count() > 0 {
72 self.generate_errors(interface)?;
73 self.writeln("")?;
74 }
75
76 Ok(())
77 }
78
79 fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80 writeln!(
81 &mut self.output,
82 "// Generated code for Varlink interface `{}`.",
83 interface.name()
84 )?;
85 writeln!(&mut self.output)?;
86 Ok(())
87 }
88
89 fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90 writeln!(
91 &mut self.output,
92 "//! Generated code for Varlink interface `{}`.",
93 interface.name()
94 )?;
95 writeln!(&mut self.output, "//!",)?;
96 writeln!(
97 &mut self.output,
98 "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99 )?;
100 writeln!(
101 &mut self.output,
102 "//! You may prefer to adapt it, instead of using it verbatim.",
103 )?;
104 writeln!(&mut self.output)?;
105
106 for comment in interface.comments() {
108 writeln!(&mut self.output, "//! {}", comment.text())?;
109 }
110 writeln!(&mut self.output)?;
111
112 Ok(())
113 }
114
115 fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116 match custom_type {
117 CustomType::Object(obj) => self.generate_custom_object(obj),
118 CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119 }
120 }
121
122 fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123 for comment in obj.comments() {
125 self.writeln(&format!("/// {}", comment.text()))?;
126 }
127
128 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129 self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130 self.indent();
131
132 for field in obj.fields() {
133 self.generate_field(field)?;
134 }
135
136 self.dedent();
137 self.writeln("}")?;
138
139 Ok(())
140 }
141
142 fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143 for comment in enum_type.comments() {
145 self.writeln(&format!("/// {}", comment.text()))?;
146 }
147
148 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149 self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150 self.writeln(&format!(
151 "pub enum {} {{",
152 enum_type.name().to_pascal_case()
153 ))?;
154 self.indent();
155
156 for variant in enum_type.variants() {
157 for comment in variant.comments() {
159 self.writeln(&format!("/// {}", comment.text()))?;
160 }
161
162 self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164 }
165
166 self.dedent();
167 self.writeln("}")?;
168
169 Ok(())
170 }
171
172 fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173 for comment in field.comments() {
175 self.writeln(&format!("/// {}", comment.text()))?;
176 }
177
178 let field_name = field.name().to_snake_case();
179 let rust_type = self.type_to_rust(field.ty())?;
180
181 let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183 rust_type
185 } else {
186 rust_type
187 };
188
189 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191 format!("#[serde(rename = \"{}\")]", field.name())
192 } else {
193 String::new()
194 };
195
196 if !field_name_attr.is_empty() {
197 self.writeln(&field_name_attr)?;
198 }
199
200 let safe_field_name = if is_rust_keyword(&field_name) {
201 format!("r#{}", field_name)
202 } else {
203 field_name
204 };
205
206 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208 Ok(())
209 }
210
211 fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212 self.writeln("/// Errors that can occur in this interface.")?;
213 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215 self.writeln(&format!(
216 "pub enum {}Error {{",
217 interface_name_to_rust(interface.name())
218 ))?;
219 self.indent();
220
221 for error in interface.errors() {
222 for comment in error.comments() {
224 self.writeln(&format!("/// {}", comment.text()))?;
225 }
226
227 let variant_name = error.name().to_pascal_case();
228 if error.fields().count() == 0 {
229 self.writeln(&format!("{},", variant_name))?;
230 } else {
231 self.writeln(&format!("{} {{", variant_name))?;
232 self.indent();
233 for field in error.fields() {
234 self.generate_error_field(field)?;
235 }
236 self.dedent();
237 self.writeln("},")?;
238 }
239 }
240
241 self.dedent();
242 self.writeln("}")?;
243
244 Ok(())
245 }
246
247 fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249 for method in interface.methods() {
250 if method.outputs().count() > 0 {
254 let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256 self.writeln(&format!(
258 "/// Output parameters for the {} method.",
259 method.name()
260 ))?;
261
262 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265 self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266 if needs_lifetime {
267 self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268 } else {
269 self.writeln(&format!("pub struct {} {{", struct_name))?;
270 }
271 self.indent();
272
273 for output in method.outputs() {
274 let field_name = output.name().to_snake_case();
275 let rust_type = if needs_lifetime {
277 self.type_to_rust_output(output.ty())?
278 } else {
279 self.type_to_rust(output.ty())?
280 };
281
282 if needs_lifetime && type_needs_borrow(output.ty()) {
284 self.writeln("#[serde(borrow)]")?;
285 }
286
287 if field_name != output.name() {
288 self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289 }
290
291 let safe_field_name = if is_rust_keyword(&field_name) {
292 format!("r#{}", field_name)
293 } else {
294 field_name
295 };
296
297 self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298 }
299
300 self.dedent();
301 self.writeln("}")?;
302 self.writeln("")?;
303 }
304 }
305
306 Ok(())
307 }
308
309 fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310 let trait_name = interface_name_to_rust(interface.name());
311
312 let error_type = if interface.errors().count() > 0 {
314 format!("{}Error", interface_name_to_rust(interface.name()))
315 } else {
316 let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319 self.writeln("/// Stub error type for interface without errors.")?;
321 self.writeln("///")?;
322 self.writeln("/// This is an empty enum that can never be instantiated.")?;
323 self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324 self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325 self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326 self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327 self.writeln("")?;
328
329 stub_error_name
330 };
331
332 self.writeln("/// Proxy trait for calling methods on the interface.")?;
333 self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334 self.writeln(&format!("pub trait {} {{", trait_name))?;
335 self.indent();
336
337 for method in interface.methods() {
338 self.generate_proxy_method_signature(method, &error_type)?;
339 }
340
341 self.dedent();
342 self.writeln("}")?;
343
344 Ok(())
345 }
346
347 fn generate_proxy_method_signature(
348 &mut self,
349 method: &Method<'_>,
350 error_type: &str,
351 ) -> Result<()> {
352 for comment in method.comments() {
354 self.writeln(&format!("/// {}", comment.text()))?;
355 }
356
357 let method_name = method.name().to_snake_case();
358 let safe_method_name = if is_rust_keyword(&method_name) {
359 format!("r#{}", method_name)
360 } else {
361 method_name
362 };
363
364 let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367 for param in method.inputs() {
369 let param_name = param.name().to_snake_case();
370 let safe_param_name = if is_rust_keyword(¶m_name) {
371 format!("r#{}", param_name)
372 } else {
373 param_name
374 };
375 let rust_type = self.type_to_rust_param(param.ty())?;
377
378 write!(&mut signature, ",")?;
379 if safe_param_name != param.name() {
381 write!(&mut signature, " #[zlink(rename = \"{}\")]", param.name(),)?;
382 }
383
384 write!(&mut signature, " {}: {}", safe_param_name, rust_type)?;
385 }
386
387 signature.push_str(") -> zlink::Result<Result<");
388
389 let output_count = method.outputs().count();
391 if output_count == 0 {
392 signature.push_str("()");
393 } else {
394 let struct_name = format!("{}Output", method.name().to_pascal_case());
398 let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
400 if needs_lifetime {
401 signature.push_str(&format!("{}<'_>", struct_name));
402 } else {
403 signature.push_str(&struct_name);
404 }
405 }
406
407 write!(&mut signature, ", {}>>", error_type)?;
408 signature.push(';');
409
410 self.writeln(&signature)?;
411
412 Ok(())
413 }
414
415 fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
416 for comment in field.comments() {
418 self.writeln(&format!("/// {}", comment.text()))?;
419 }
420
421 let field_name = field.name().to_snake_case();
422 let rust_type = self.type_to_rust(field.ty())?;
423
424 let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
426 format!("#[zlink(rename = \"{}\")]", field.name())
427 } else {
428 String::new()
429 };
430
431 if !field_name_attr.is_empty() {
432 self.writeln(&field_name_attr)?;
433 }
434
435 let safe_field_name = if is_rust_keyword(&field_name) {
436 format!("r#{}", field_name)
437 } else {
438 field_name
439 };
440
441 self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
442
443 Ok(())
444 }
445
446 fn type_to_rust(&self, ty: &Type) -> Result<String> {
447 type_to_rust(ty)
448 }
449
450 fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
451 type_to_rust_param(ty)
452 }
453
454 fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
455 type_to_rust_output(ty)
456 }
457
458 fn writeln(&mut self, s: &str) -> Result<()> {
459 self.write(s)?;
460 writeln!(&mut self.output)?;
461 Ok(())
462 }
463
464 fn write(&mut self, s: &str) -> Result<()> {
465 for _ in 0..self.indent_level {
466 write!(&mut self.output, " ")?;
467 }
468 write!(&mut self.output, "{}", s)?;
469 Ok(())
470 }
471
472 fn indent(&mut self) {
473 self.indent_level += 1;
474 }
475
476 fn dedent(&mut self) {
477 if self.indent_level > 0 {
478 self.indent_level -= 1;
479 }
480 }
481}
482
483impl Default for CodeGenerator {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489fn type_to_rust(ty: &Type) -> Result<String> {
490 Ok(match ty {
491 Type::Bool => "bool".to_string(),
492 Type::Int => "i64".to_string(),
493 Type::Float => "f64".to_string(),
494 Type::String => "String".to_string(),
495 Type::Object(_fields) => {
496 "serde_json::Value".to_string()
500 }
501 Type::Enum(_variants) => {
502 "String".to_string()
504 }
505 Type::Array(elem_type) => {
506 let elem_rust = type_to_rust(elem_type.inner())?;
507 format!("Vec<{}>", elem_rust)
508 }
509 Type::Map(value_type) => {
510 let value_rust = type_to_rust(value_type.inner())?;
511 format!("std::collections::HashMap<String, {}>", value_rust)
512 }
513 Type::ForeignObject => "serde_json::Value".to_string(),
514 Type::Optional(inner_type) => {
515 let inner_rust = type_to_rust(inner_type.inner())?;
516 format!("Option<{}>", inner_rust)
517 }
518 Type::Custom(name) => name.to_pascal_case(),
519 })
520}
521
522fn type_to_rust_param(ty: &Type) -> Result<String> {
523 Ok(match ty {
524 Type::Bool => "bool".to_string(),
525 Type::Int => "i64".to_string(),
526 Type::Float => "f64".to_string(),
527 Type::String => "&str".to_string(),
528 Type::Object(_fields) => {
529 "&serde_json::Value".to_string()
531 }
532 Type::Enum(_variants) => {
533 "&str".to_string()
535 }
536 Type::Array(elem_type) => {
537 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
539 format!("&[{}]", elem_rust)
540 }
541 Type::Map(value_type) => {
542 let value_rust = type_to_rust_param_elem(value_type.inner())?;
544 format!("&std::collections::HashMap<&str, {}>", value_rust)
545 }
546 Type::ForeignObject => "&serde_json::Value".to_string(),
547 Type::Optional(inner_type) => {
548 let inner_rust = type_to_rust_param(inner_type.inner())?;
549 format!("Option<{}>", inner_rust)
551 }
552 Type::Custom(name) => format!("&{}", name.to_pascal_case()),
553 })
554}
555
556fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
559 Ok(match ty {
560 Type::Bool => "bool".to_string(),
561 Type::Int => "i64".to_string(),
562 Type::Float => "f64".to_string(),
563 Type::String => "&str".to_string(),
564 Type::Object(_fields) => "serde_json::Value".to_string(),
565 Type::Enum(_variants) => "&str".to_string(),
566 Type::Array(elem_type) => {
567 let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
568 format!("Vec<{}>", elem_rust)
569 }
570 Type::Map(value_type) => {
571 let value_rust = type_to_rust_param_elem(value_type.inner())?;
572 format!("std::collections::HashMap<&str, {}>", value_rust)
573 }
574 Type::ForeignObject => "serde_json::Value".to_string(),
575 Type::Optional(inner_type) => {
576 let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
577 format!("Option<{}>", inner_rust)
578 }
579 Type::Custom(name) => name.to_pascal_case(),
580 })
581}
582
583fn type_to_rust_output(ty: &Type) -> Result<String> {
584 Ok(match ty {
585 Type::Bool => "bool".to_string(),
586 Type::Int => "i64".to_string(),
587 Type::Float => "f64".to_string(),
588 Type::String => "&'a str".to_string(),
589 Type::Object(_fields) => {
590 "serde_json::Value".to_string()
592 }
593 Type::Enum(_variants) => {
594 "&'a str".to_string()
596 }
597 Type::Array(elem_type) => {
598 let elem_rust = match elem_type.inner() {
600 Type::String => "&'a str".to_string(),
601 Type::Enum(_) => "&'a str".to_string(),
602 _ => type_to_rust(elem_type.inner())?,
603 };
604 format!("Vec<{}>", elem_rust)
605 }
606 Type::Map(value_type) => {
607 let value_rust = match value_type.inner() {
609 Type::String => "&'a str".to_string(),
610 Type::Enum(_) => "&'a str".to_string(),
611 _ => type_to_rust(value_type.inner())?,
612 };
613 format!("std::collections::HashMap<&'a str, {}>", value_rust)
614 }
615 Type::ForeignObject => "serde_json::Value".to_string(),
616 Type::Optional(inner_type) => {
617 let inner_rust = type_to_rust_output(inner_type.inner())?;
620 format!("Option<{}>", inner_rust)
621 }
622 Type::Custom(name) => name.to_pascal_case(),
623 })
624}
625
626fn interface_name_to_rust(name: &str) -> String {
627 name.split('.').next_back().unwrap_or(name).to_pascal_case()
629}
630
631fn type_needs_lifetime(ty: &Type) -> bool {
632 match ty {
633 Type::String => true,
634 Type::Enum(_) => true, Type::Array(inner) => type_needs_lifetime(inner.inner()),
636 Type::Map(_) => {
637 true
639 }
640 Type::Optional(inner) => type_needs_lifetime(inner.inner()),
641 _ => false,
642 }
643}
644
645fn type_needs_borrow(ty: &Type) -> bool {
646 match ty {
647 Type::String => true,
648 Type::Enum(_) => true, Type::Array(inner) => type_needs_borrow(inner.inner()),
650 Type::Map(_) => {
651 true
653 }
654 Type::Optional(inner) => type_needs_borrow(inner.inner()),
655 _ => false,
656 }
657}
658
659fn is_rust_keyword(s: &str) -> bool {
660 [
661 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
662 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
663 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
664 "true", "type", "unsafe", "use", "where", "while",
665 ]
666 .contains(&s)
667}