Skip to main content

zlink_codegen/
codegen.rs

1//! Code generation implementation.
2
3use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8/// Code generator for Varlink interfaces.
9pub struct CodeGenerator {
10    output: String,
11    indent_level: usize,
12}
13
14impl CodeGenerator {
15    /// Create a new code generator.
16    pub fn new() -> Self {
17        Self {
18            output: String::new(),
19            indent_level: 0,
20        }
21    }
22
23    /// Get the generated output.
24    pub fn output(self) -> String {
25        self.output
26    }
27
28    /// Write module-level header for multiple interfaces.
29    pub fn write_module_header(&mut self) -> Result<()> {
30        writeln!(
31            &mut self.output,
32            "// Generated code from Varlink IDL files."
33        )?;
34        writeln!(&mut self.output)?;
35        writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36        writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37        writeln!(&mut self.output)?;
38        Ok(())
39    }
40
41    /// Generate code for an interface.
42    pub fn generate_interface(
43        &mut self,
44        interface: &Interface<'_>,
45        skip_module_header: bool,
46    ) -> Result<()> {
47        if skip_module_header {
48            self.write_interface_comment(interface)?;
49        } else {
50            self.write_header(interface)?;
51            self.writeln("use serde::{Deserialize, Serialize};")?;
52            // Always import ReplyError since we generate a stub error type when there are no errors
53            self.writeln("use zlink::{proxy, ReplyError};")?;
54            self.writeln("")?;
55        }
56
57        // Generate proxy trait using the proxy macro.
58        self.generate_proxy_trait(interface)?;
59        self.writeln("")?;
60
61        // Generate output structs for methods.
62        self.generate_output_structs(interface)?;
63
64        // Generate custom types.
65        for custom_type in interface.custom_types() {
66            self.generate_custom_type(custom_type)?;
67            self.writeln("")?;
68        }
69
70        // Generate errors.
71        if interface.errors().count() > 0 {
72            self.generate_errors(interface)?;
73            self.writeln("")?;
74        }
75
76        Ok(())
77    }
78
79    fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80        writeln!(
81            &mut self.output,
82            "// Generated code for Varlink interface `{}`.",
83            interface.name()
84        )?;
85        writeln!(&mut self.output)?;
86        Ok(())
87    }
88
89    fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90        writeln!(
91            &mut self.output,
92            "//! Generated code for Varlink interface `{}`.",
93            interface.name()
94        )?;
95        writeln!(&mut self.output, "//!",)?;
96        writeln!(
97            &mut self.output,
98            "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99        )?;
100        writeln!(
101            &mut self.output,
102            "//! You may prefer to adapt it, instead of using it verbatim.",
103        )?;
104        writeln!(&mut self.output)?;
105
106        // Add interface comments if any.
107        for comment in interface.comments() {
108            writeln!(&mut self.output, "//! {}", comment.text())?;
109        }
110        writeln!(&mut self.output)?;
111
112        Ok(())
113    }
114
115    fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116        match custom_type {
117            CustomType::Object(obj) => self.generate_custom_object(obj),
118            CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119        }
120    }
121
122    fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123        // Add comments.
124        for comment in obj.comments() {
125            self.writeln(&format!("/// {}", comment.text()))?;
126        }
127
128        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129        self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130        self.indent();
131
132        for field in obj.fields() {
133            self.generate_field(field)?;
134        }
135
136        self.dedent();
137        self.writeln("}")?;
138
139        Ok(())
140    }
141
142    fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143        // Add comments.
144        for comment in enum_type.comments() {
145            self.writeln(&format!("/// {}", comment.text()))?;
146        }
147
148        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149        self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150        self.writeln(&format!(
151            "pub enum {} {{",
152            enum_type.name().to_pascal_case()
153        ))?;
154        self.indent();
155
156        for variant in enum_type.variants() {
157            // Add variant comments.
158            for comment in variant.comments() {
159                self.writeln(&format!("/// {}", comment.text()))?;
160            }
161
162            // Varlink enum variants don't have explicit values, just names.
163            self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164        }
165
166        self.dedent();
167        self.writeln("}")?;
168
169        Ok(())
170    }
171
172    fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173        // Add field comments.
174        for comment in field.comments() {
175            self.writeln(&format!("/// {}", comment.text()))?;
176        }
177
178        let field_name = field.name().to_snake_case();
179        let rust_type = self.type_to_rust(field.ty())?;
180
181        // Check if the field type is optional.
182        let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183            // The type_to_rust will already wrap in Option
184            rust_type
185        } else {
186            rust_type
187        };
188
189        // Handle field name if it's a Rust keyword.
190        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191            format!("#[serde(rename = \"{}\")]", field.name())
192        } else {
193            String::new()
194        };
195
196        if !field_name_attr.is_empty() {
197            self.writeln(&field_name_attr)?;
198        }
199
200        let safe_field_name = if is_rust_keyword(&field_name) {
201            format!("r#{}", field_name)
202        } else {
203            field_name
204        };
205
206        self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208        Ok(())
209    }
210
211    fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212        self.writeln("/// Errors that can occur in this interface.")?;
213        self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214        self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215        self.writeln(&format!(
216            "pub enum {}Error {{",
217            interface_name_to_rust(interface.name())
218        ))?;
219        self.indent();
220
221        for error in interface.errors() {
222            // Add error comments.
223            for comment in error.comments() {
224                self.writeln(&format!("/// {}", comment.text()))?;
225            }
226
227            let variant_name = error.name().to_pascal_case();
228            if error.fields().count() == 0 {
229                self.writeln(&format!("{},", variant_name))?;
230            } else {
231                self.writeln(&format!("{} {{", variant_name))?;
232                self.indent();
233                for field in error.fields() {
234                    self.generate_error_field(field)?;
235                }
236                self.dedent();
237                self.writeln("},")?;
238            }
239        }
240
241        self.dedent();
242        self.writeln("}")?;
243
244        Ok(())
245    }
246
247    /// Generate output structs for all methods in the `interface`.
248    fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249        for method in interface.methods() {
250            // Generate output struct for any method with at least one output parameter.
251            // Varlink output parameters are always named, so we need a struct even for single
252            // outputs.
253            if method.outputs().count() > 0 {
254                let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256                // Add method comments if available
257                self.writeln(&format!(
258                    "/// Output parameters for the {} method.",
259                    method.name()
260                ))?;
261
262                // Add lifetime parameter for output structs that need it
263                let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265                self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266                if needs_lifetime {
267                    self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268                } else {
269                    self.writeln(&format!("pub struct {} {{", struct_name))?;
270                }
271                self.indent();
272
273                for output in method.outputs() {
274                    let field_name = output.name().to_snake_case();
275                    // Use reference types for output parameters where appropriate
276                    let rust_type = if needs_lifetime {
277                        self.type_to_rust_output(output.ty())?
278                    } else {
279                        self.type_to_rust(output.ty())?
280                    };
281
282                    // Add #[serde(borrow)] for fields that need it
283                    if needs_lifetime && type_needs_borrow(output.ty()) {
284                        self.writeln("#[serde(borrow)]")?;
285                    }
286
287                    if field_name != output.name() {
288                        self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289                    }
290
291                    let safe_field_name = if is_rust_keyword(&field_name) {
292                        format!("r#{}", field_name)
293                    } else {
294                        field_name
295                    };
296
297                    self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298                }
299
300                self.dedent();
301                self.writeln("}")?;
302                self.writeln("")?;
303            }
304        }
305
306        Ok(())
307    }
308
309    fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310        let trait_name = interface_name_to_rust(interface.name());
311
312        // Generate a stub error type if there are no errors in the interface
313        let error_type = if interface.errors().count() > 0 {
314            format!("{}Error", interface_name_to_rust(interface.name()))
315        } else {
316            // Generate a stub error type for interfaces without errors
317            let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319            // Generate the stub error type before the proxy trait
320            self.writeln("/// Stub error type for interface without errors.")?;
321            self.writeln("///")?;
322            self.writeln("/// This is an empty enum that can never be instantiated.")?;
323            self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324            self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325            self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326            self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327            self.writeln("")?;
328
329            stub_error_name
330        };
331
332        self.writeln("/// Proxy trait for calling methods on the interface.")?;
333        self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334        self.writeln(&format!("pub trait {} {{", trait_name))?;
335        self.indent();
336
337        for method in interface.methods() {
338            self.generate_proxy_method_signature(method, &error_type)?;
339        }
340
341        self.dedent();
342        self.writeln("}")?;
343
344        Ok(())
345    }
346
347    fn generate_proxy_method_signature(
348        &mut self,
349        method: &Method<'_>,
350        error_type: &str,
351    ) -> Result<()> {
352        // Add method comments.
353        for comment in method.comments() {
354            self.writeln(&format!("/// {}", comment.text()))?;
355        }
356
357        let method_name = method.name().to_snake_case();
358        let safe_method_name = if is_rust_keyword(&method_name) {
359            format!("r#{}", method_name)
360        } else {
361            method_name
362        };
363
364        // Generate method signature.
365        let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367        // Add input parameters.
368        for param in method.inputs() {
369            let param_name = param.name().to_snake_case();
370            let safe_param_name = if is_rust_keyword(&param_name) {
371                format!("r#{}", param_name)
372            } else {
373                param_name
374            };
375            // Use references for parameters that can be borrowed
376            let rust_type = self.type_to_rust_param(param.ty())?;
377
378            write!(&mut signature, ",")?;
379            // Add parameter with potential rename attribute.
380            if safe_param_name != param.name() {
381                write!(&mut signature, " #[zlink(rename = \"{}\")]", param.name(),)?;
382            }
383
384            write!(&mut signature, " {}: {}", safe_param_name, rust_type)?;
385        }
386
387        signature.push_str(") -> zlink::Result<Result<");
388
389        // Handle output parameters.
390        let output_count = method.outputs().count();
391        if output_count == 0 {
392            signature.push_str("()");
393        } else {
394            // Always use the generated output struct for any outputs.
395            // Varlink output parameters are always named, so we need a struct even for single
396            // outputs.
397            let struct_name = format!("{}Output", method.name().to_pascal_case());
398            // Add lifetime parameter if the struct needs one
399            let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
400            if needs_lifetime {
401                signature.push_str(&format!("{}<'_>", struct_name));
402            } else {
403                signature.push_str(&struct_name);
404            }
405        }
406
407        write!(&mut signature, ", {}>>", error_type)?;
408        signature.push(';');
409
410        self.writeln(&signature)?;
411
412        Ok(())
413    }
414
415    fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
416        // Add field comments.
417        for comment in field.comments() {
418            self.writeln(&format!("/// {}", comment.text()))?;
419        }
420
421        let field_name = field.name().to_snake_case();
422        let rust_type = self.type_to_rust(field.ty())?;
423
424        // Handle field name if it's a Rust keyword.
425        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
426            format!("#[zlink(rename = \"{}\")]", field.name())
427        } else {
428            String::new()
429        };
430
431        if !field_name_attr.is_empty() {
432            self.writeln(&field_name_attr)?;
433        }
434
435        let safe_field_name = if is_rust_keyword(&field_name) {
436            format!("r#{}", field_name)
437        } else {
438            field_name
439        };
440
441        self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
442
443        Ok(())
444    }
445
446    fn type_to_rust(&self, ty: &Type) -> Result<String> {
447        type_to_rust(ty)
448    }
449
450    fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
451        type_to_rust_param(ty)
452    }
453
454    fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
455        type_to_rust_output(ty)
456    }
457
458    fn writeln(&mut self, s: &str) -> Result<()> {
459        self.write(s)?;
460        writeln!(&mut self.output)?;
461        Ok(())
462    }
463
464    fn write(&mut self, s: &str) -> Result<()> {
465        for _ in 0..self.indent_level {
466            write!(&mut self.output, "    ")?;
467        }
468        write!(&mut self.output, "{}", s)?;
469        Ok(())
470    }
471
472    fn indent(&mut self) {
473        self.indent_level += 1;
474    }
475
476    fn dedent(&mut self) {
477        if self.indent_level > 0 {
478            self.indent_level -= 1;
479        }
480    }
481}
482
483impl Default for CodeGenerator {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489fn type_to_rust(ty: &Type) -> Result<String> {
490    Ok(match ty {
491        Type::Bool => "bool".to_string(),
492        Type::Int => "i64".to_string(),
493        Type::Float => "f64".to_string(),
494        Type::String => "String".to_string(),
495        Type::Object(_fields) => {
496            // Anonymous struct - generate inline.
497            // For now, use serde_json::Value for anonymous objects.
498            // In the future, we could generate anonymous structs.
499            "serde_json::Value".to_string()
500        }
501        Type::Enum(_variants) => {
502            // Anonymous enum - use String for now.
503            "String".to_string()
504        }
505        Type::Array(elem_type) => {
506            let elem_rust = type_to_rust(elem_type.inner())?;
507            format!("Vec<{}>", elem_rust)
508        }
509        Type::Map(value_type) => {
510            let value_rust = type_to_rust(value_type.inner())?;
511            format!("std::collections::HashMap<String, {}>", value_rust)
512        }
513        Type::ForeignObject => "serde_json::Value".to_string(),
514        Type::Optional(inner_type) => {
515            let inner_rust = type_to_rust(inner_type.inner())?;
516            format!("Option<{}>", inner_rust)
517        }
518        Type::Custom(name) => name.to_pascal_case(),
519        Type::Any => "serde_json::Value".to_string(),
520    })
521}
522
523fn type_to_rust_param(ty: &Type) -> Result<String> {
524    Ok(match ty {
525        Type::Bool => "bool".to_string(),
526        Type::Int => "i64".to_string(),
527        Type::Float => "f64".to_string(),
528        Type::String => "&str".to_string(),
529        Type::Object(_fields) => {
530            // For parameters, use reference to avoid clone
531            "&serde_json::Value".to_string()
532        }
533        Type::Enum(_variants) => {
534            // Anonymous enum - use &str for parameters
535            "&str".to_string()
536        }
537        Type::Array(elem_type) => {
538            // Use slice for array parameters with proper string handling
539            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
540            format!("&[{}]", elem_rust)
541        }
542        Type::Map(value_type) => {
543            // Use reference for map parameters with proper string handling
544            let value_rust = type_to_rust_param_elem(value_type.inner())?;
545            format!("&std::collections::HashMap<&str, {}>", value_rust)
546        }
547        Type::ForeignObject => "&serde_json::Value".to_string(),
548        Type::Optional(inner_type) => {
549            let inner_rust = type_to_rust_param(inner_type.inner())?;
550            // For optional parameters, always wrap in Option
551            format!("Option<{}>", inner_rust)
552        }
553        Type::Custom(name) => format!("&{}", name.to_pascal_case()),
554        Type::Any => "&serde_json::Value".to_string(),
555    })
556}
557
558// Helper function to get the proper type for collection elements in parameters.
559// Ensures strings always use &str instead of String.
560fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
561    Ok(match ty {
562        Type::Bool => "bool".to_string(),
563        Type::Int => "i64".to_string(),
564        Type::Float => "f64".to_string(),
565        Type::String => "&str".to_string(),
566        Type::Object(_fields) => "serde_json::Value".to_string(),
567        Type::Enum(_variants) => "&str".to_string(),
568        Type::Array(elem_type) => {
569            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
570            format!("Vec<{}>", elem_rust)
571        }
572        Type::Map(value_type) => {
573            let value_rust = type_to_rust_param_elem(value_type.inner())?;
574            format!("std::collections::HashMap<&str, {}>", value_rust)
575        }
576        Type::ForeignObject => "serde_json::Value".to_string(),
577        Type::Any => "serde_json::Value".to_string(),
578        Type::Optional(inner_type) => {
579            let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
580            format!("Option<{}>", inner_rust)
581        }
582        Type::Custom(name) => name.to_pascal_case(),
583    })
584}
585
586fn type_to_rust_output(ty: &Type) -> Result<String> {
587    Ok(match ty {
588        Type::Bool => "bool".to_string(),
589        Type::Int => "i64".to_string(),
590        Type::Float => "f64".to_string(),
591        Type::String => "&'a str".to_string(),
592        Type::Object(_fields) => {
593            // Use owned type for objects - serde can't deserialize to &Value
594            "serde_json::Value".to_string()
595        }
596        Type::Enum(_variants) => {
597            // Anonymous enum - use &str for outputs
598            "&'a str".to_string()
599        }
600        Type::Array(elem_type) => {
601            // Use Vec for array outputs with owned inner types (except strings stay as &'a str)
602            let elem_rust = match elem_type.inner() {
603                Type::String => "&'a str".to_string(),
604                Type::Enum(_) => "&'a str".to_string(),
605                _ => type_to_rust(elem_type.inner())?,
606            };
607            format!("Vec<{}>", elem_rust)
608        }
609        Type::Map(value_type) => {
610            // Use HashMap for map outputs with borrowed types for efficiency
611            let value_rust = match value_type.inner() {
612                Type::String => "&'a str".to_string(),
613                Type::Enum(_) => "&'a str".to_string(),
614                _ => type_to_rust(value_type.inner())?,
615            };
616            format!("std::collections::HashMap<&'a str, {}>", value_rust)
617        }
618        Type::ForeignObject => "serde_json::Value".to_string(),
619        Type::Any => "serde_json::Value".to_string(),
620        Type::Optional(inner_type) => {
621            // For optional outputs, recursively apply type_to_rust_output to maintain
622            // correct reference types for strings within collections
623            let inner_rust = type_to_rust_output(inner_type.inner())?;
624            format!("Option<{}>", inner_rust)
625        }
626        Type::Custom(name) => name.to_pascal_case(),
627    })
628}
629
630fn interface_name_to_rust(name: &str) -> String {
631    // Convert interface name like "org.example.Interface" to "Interface".
632    name.split('.').next_back().unwrap_or(name).to_pascal_case()
633}
634
635fn type_needs_lifetime(ty: &Type) -> bool {
636    match ty {
637        Type::String => true,
638        Type::Enum(_) => true, // Anonymous enums use &'a str
639        Type::Array(inner) => type_needs_lifetime(inner.inner()),
640        Type::Map(_) => {
641            // Maps always need lifetime because keys are &'a str
642            true
643        }
644        Type::Optional(inner) => type_needs_lifetime(inner.inner()),
645        _ => false,
646    }
647}
648
649fn type_needs_borrow(ty: &Type) -> bool {
650    match ty {
651        Type::String => true,
652        Type::Enum(_) => true, // Anonymous enums use &'a str
653        Type::Array(inner) => type_needs_borrow(inner.inner()),
654        Type::Map(_) => {
655            // Maps always need borrow because keys are &'a str
656            true
657        }
658        Type::Optional(inner) => type_needs_borrow(inner.inner()),
659        _ => false,
660    }
661}
662
663fn is_rust_keyword(s: &str) -> bool {
664    [
665        "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
666        "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
667        "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
668        "true", "type", "unsafe", "use", "where", "while",
669    ]
670    .contains(&s)
671}