zlink_codegen/
codegen.rs

1//! Code generation implementation.
2
3use anyhow::Result;
4use heck::{ToPascalCase, ToSnakeCase};
5use std::fmt::Write;
6use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
7
8/// Code generator for Varlink interfaces.
9pub struct CodeGenerator {
10    output: String,
11    indent_level: usize,
12}
13
14impl CodeGenerator {
15    /// Create a new code generator.
16    pub fn new() -> Self {
17        Self {
18            output: String::new(),
19            indent_level: 0,
20        }
21    }
22
23    /// Get the generated output.
24    pub fn output(self) -> String {
25        self.output
26    }
27
28    /// Write module-level header for multiple interfaces.
29    pub fn write_module_header(&mut self) -> Result<()> {
30        writeln!(
31            &mut self.output,
32            "// Generated code from Varlink IDL files."
33        )?;
34        writeln!(&mut self.output)?;
35        writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
36        writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
37        writeln!(&mut self.output)?;
38        Ok(())
39    }
40
41    /// Generate code for an interface.
42    pub fn generate_interface(
43        &mut self,
44        interface: &Interface<'_>,
45        skip_module_header: bool,
46    ) -> Result<()> {
47        if skip_module_header {
48            self.write_interface_comment(interface)?;
49        } else {
50            self.write_header(interface)?;
51            self.writeln("use serde::{Deserialize, Serialize};")?;
52            // Always import ReplyError since we generate a stub error type when there are no errors
53            self.writeln("use zlink::{proxy, ReplyError};")?;
54            self.writeln("")?;
55        }
56
57        // Generate proxy trait using the proxy macro.
58        self.generate_proxy_trait(interface)?;
59        self.writeln("")?;
60
61        // Generate output structs for methods.
62        self.generate_output_structs(interface)?;
63
64        // Generate custom types.
65        for custom_type in interface.custom_types() {
66            self.generate_custom_type(custom_type)?;
67            self.writeln("")?;
68        }
69
70        // Generate errors.
71        if interface.errors().count() > 0 {
72            self.generate_errors(interface)?;
73            self.writeln("")?;
74        }
75
76        Ok(())
77    }
78
79    fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
80        writeln!(
81            &mut self.output,
82            "// Generated code for Varlink interface `{}`.",
83            interface.name()
84        )?;
85        writeln!(&mut self.output)?;
86        Ok(())
87    }
88
89    fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
90        writeln!(
91            &mut self.output,
92            "//! Generated code for Varlink interface `{}`.",
93            interface.name()
94        )?;
95        writeln!(&mut self.output, "//!",)?;
96        writeln!(
97            &mut self.output,
98            "//! This code was generated by `zlink-codegen` from Varlink IDL.",
99        )?;
100        writeln!(
101            &mut self.output,
102            "//! You may prefer to adapt it, instead of using it verbatim.",
103        )?;
104        writeln!(&mut self.output)?;
105
106        // Add interface comments if any.
107        for comment in interface.comments() {
108            writeln!(&mut self.output, "//! {}", comment.text())?;
109        }
110        writeln!(&mut self.output)?;
111
112        Ok(())
113    }
114
115    fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
116        match custom_type {
117            CustomType::Object(obj) => self.generate_custom_object(obj),
118            CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
119        }
120    }
121
122    fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
123        // Add comments.
124        for comment in obj.comments() {
125            self.writeln(&format!("/// {}", comment.text()))?;
126        }
127
128        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
129        self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
130        self.indent();
131
132        for field in obj.fields() {
133            self.generate_field(field)?;
134        }
135
136        self.dedent();
137        self.writeln("}")?;
138
139        Ok(())
140    }
141
142    fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
143        // Add comments.
144        for comment in enum_type.comments() {
145            self.writeln(&format!("/// {}", comment.text()))?;
146        }
147
148        self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
149        self.writeln("#[serde(rename_all = \"snake_case\")]")?;
150        self.writeln(&format!(
151            "pub enum {} {{",
152            enum_type.name().to_pascal_case()
153        ))?;
154        self.indent();
155
156        for variant in enum_type.variants() {
157            // Add variant comments.
158            for comment in variant.comments() {
159                self.writeln(&format!("/// {}", comment.text()))?;
160            }
161
162            // Varlink enum variants don't have explicit values, just names.
163            self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
164        }
165
166        self.dedent();
167        self.writeln("}")?;
168
169        Ok(())
170    }
171
172    fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
173        // Add field comments.
174        for comment in field.comments() {
175            self.writeln(&format!("/// {}", comment.text()))?;
176        }
177
178        let field_name = field.name().to_snake_case();
179        let rust_type = self.type_to_rust(field.ty())?;
180
181        // Check if the field type is optional.
182        let rust_type = if matches!(field.ty(), Type::Optional(_)) {
183            // The type_to_rust will already wrap in Option
184            rust_type
185        } else {
186            rust_type
187        };
188
189        // Handle field name if it's a Rust keyword.
190        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
191            format!("#[serde(rename = \"{}\")]", field.name())
192        } else {
193            String::new()
194        };
195
196        if !field_name_attr.is_empty() {
197            self.writeln(&field_name_attr)?;
198        }
199
200        let safe_field_name = if is_rust_keyword(&field_name) {
201            format!("r#{}", field_name)
202        } else {
203            field_name
204        };
205
206        self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
207
208        Ok(())
209    }
210
211    fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
212        self.writeln("/// Errors that can occur in this interface.")?;
213        self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
214        self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
215        self.writeln(&format!(
216            "pub enum {}Error {{",
217            interface_name_to_rust(interface.name())
218        ))?;
219        self.indent();
220
221        for error in interface.errors() {
222            // Add error comments.
223            for comment in error.comments() {
224                self.writeln(&format!("/// {}", comment.text()))?;
225            }
226
227            let variant_name = error.name().to_pascal_case();
228            if error.fields().count() == 0 {
229                self.writeln(&format!("{},", variant_name))?;
230            } else {
231                self.writeln(&format!("{} {{", variant_name))?;
232                self.indent();
233                for field in error.fields() {
234                    self.generate_error_field(field)?;
235                }
236                self.dedent();
237                self.writeln("},")?;
238            }
239        }
240
241        self.dedent();
242        self.writeln("}")?;
243
244        Ok(())
245    }
246
247    /// Generate output structs for all methods in the `interface`.
248    fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
249        for method in interface.methods() {
250            // Generate output struct for any method with at least one output parameter.
251            // Varlink output parameters are always named, so we need a struct even for single
252            // outputs.
253            if method.outputs().count() > 0 {
254                let struct_name = format!("{}Output", method.name().to_pascal_case());
255
256                // Add method comments if available
257                self.writeln(&format!(
258                    "/// Output parameters for the {} method.",
259                    method.name()
260                ))?;
261
262                // Add lifetime parameter for output structs that need it
263                let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
264
265                self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
266                if needs_lifetime {
267                    self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
268                } else {
269                    self.writeln(&format!("pub struct {} {{", struct_name))?;
270                }
271                self.indent();
272
273                for output in method.outputs() {
274                    let field_name = output.name().to_snake_case();
275                    // Use reference types for output parameters where appropriate
276                    let rust_type = if needs_lifetime {
277                        self.type_to_rust_output(output.ty())?
278                    } else {
279                        self.type_to_rust(output.ty())?
280                    };
281
282                    // Add #[serde(borrow)] for fields that need it
283                    if needs_lifetime && type_needs_borrow(output.ty()) {
284                        self.writeln("#[serde(borrow)]")?;
285                    }
286
287                    if field_name != output.name() {
288                        self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
289                    }
290
291                    let safe_field_name = if is_rust_keyword(&field_name) {
292                        format!("r#{}", field_name)
293                    } else {
294                        field_name
295                    };
296
297                    self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
298                }
299
300                self.dedent();
301                self.writeln("}")?;
302                self.writeln("")?;
303            }
304        }
305
306        Ok(())
307    }
308
309    fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
310        let trait_name = interface_name_to_rust(interface.name());
311
312        // Generate a stub error type if there are no errors in the interface
313        let error_type = if interface.errors().count() > 0 {
314            format!("{}Error", interface_name_to_rust(interface.name()))
315        } else {
316            // Generate a stub error type for interfaces without errors
317            let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
318
319            // Generate the stub error type before the proxy trait
320            self.writeln("/// Stub error type for interface without errors.")?;
321            self.writeln("///")?;
322            self.writeln("/// This is an empty enum that can never be instantiated.")?;
323            self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
324            self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
325            self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
326            self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
327            self.writeln("")?;
328
329            stub_error_name
330        };
331
332        self.writeln("/// Proxy trait for calling methods on the interface.")?;
333        self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
334        self.writeln(&format!("pub trait {} {{", trait_name))?;
335        self.indent();
336
337        for method in interface.methods() {
338            self.generate_proxy_method_signature(method, &error_type)?;
339        }
340
341        self.dedent();
342        self.writeln("}")?;
343
344        Ok(())
345    }
346
347    fn generate_proxy_method_signature(
348        &mut self,
349        method: &Method<'_>,
350        error_type: &str,
351    ) -> Result<()> {
352        // Add method comments.
353        for comment in method.comments() {
354            self.writeln(&format!("/// {}", comment.text()))?;
355        }
356
357        let method_name = method.name().to_snake_case();
358        let safe_method_name = if is_rust_keyword(&method_name) {
359            format!("r#{}", method_name)
360        } else {
361            method_name
362        };
363
364        // Generate method signature.
365        let mut signature = format!("async fn {}(&mut self", safe_method_name);
366
367        // Add input parameters.
368        for param in method.inputs() {
369            let param_name = param.name().to_snake_case();
370            let safe_param_name = if is_rust_keyword(&param_name) {
371                format!("r#{}", param_name)
372            } else {
373                param_name
374            };
375            // Use references for parameters that can be borrowed
376            let rust_type = self.type_to_rust_param(param.ty())?;
377
378            write!(&mut signature, ",")?;
379            // Add parameter with potential rename attribute.
380            if safe_param_name != param.name() {
381                write!(&mut signature, " #[zlink(rename = \"{}\")]", param.name(),)?;
382            }
383
384            write!(&mut signature, " {}: {}", safe_param_name, rust_type)?;
385        }
386
387        signature.push_str(") -> zlink::Result<Result<");
388
389        // Handle output parameters.
390        let output_count = method.outputs().count();
391        if output_count == 0 {
392            signature.push_str("()");
393        } else {
394            // Always use the generated output struct for any outputs.
395            // Varlink output parameters are always named, so we need a struct even for single
396            // outputs.
397            let struct_name = format!("{}Output", method.name().to_pascal_case());
398            // Add lifetime parameter if the struct needs one
399            let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
400            if needs_lifetime {
401                signature.push_str(&format!("{}<'_>", struct_name));
402            } else {
403                signature.push_str(&struct_name);
404            }
405        }
406
407        write!(&mut signature, ", {}>>", error_type)?;
408        signature.push(';');
409
410        self.writeln(&signature)?;
411
412        Ok(())
413    }
414
415    fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
416        // Add field comments.
417        for comment in field.comments() {
418            self.writeln(&format!("/// {}", comment.text()))?;
419        }
420
421        let field_name = field.name().to_snake_case();
422        let rust_type = self.type_to_rust(field.ty())?;
423
424        // Handle field name if it's a Rust keyword.
425        let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
426            format!("#[zlink(rename = \"{}\")]", field.name())
427        } else {
428            String::new()
429        };
430
431        if !field_name_attr.is_empty() {
432            self.writeln(&field_name_attr)?;
433        }
434
435        let safe_field_name = if is_rust_keyword(&field_name) {
436            format!("r#{}", field_name)
437        } else {
438            field_name
439        };
440
441        self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
442
443        Ok(())
444    }
445
446    fn type_to_rust(&self, ty: &Type) -> Result<String> {
447        type_to_rust(ty)
448    }
449
450    fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
451        type_to_rust_param(ty)
452    }
453
454    fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
455        type_to_rust_output(ty)
456    }
457
458    fn writeln(&mut self, s: &str) -> Result<()> {
459        self.write(s)?;
460        writeln!(&mut self.output)?;
461        Ok(())
462    }
463
464    fn write(&mut self, s: &str) -> Result<()> {
465        for _ in 0..self.indent_level {
466            write!(&mut self.output, "    ")?;
467        }
468        write!(&mut self.output, "{}", s)?;
469        Ok(())
470    }
471
472    fn indent(&mut self) {
473        self.indent_level += 1;
474    }
475
476    fn dedent(&mut self) {
477        if self.indent_level > 0 {
478            self.indent_level -= 1;
479        }
480    }
481}
482
483impl Default for CodeGenerator {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489fn type_to_rust(ty: &Type) -> Result<String> {
490    Ok(match ty {
491        Type::Bool => "bool".to_string(),
492        Type::Int => "i64".to_string(),
493        Type::Float => "f64".to_string(),
494        Type::String => "String".to_string(),
495        Type::Object(_fields) => {
496            // Anonymous struct - generate inline.
497            // For now, use serde_json::Value for anonymous objects.
498            // In the future, we could generate anonymous structs.
499            "serde_json::Value".to_string()
500        }
501        Type::Enum(_variants) => {
502            // Anonymous enum - use String for now.
503            "String".to_string()
504        }
505        Type::Array(elem_type) => {
506            let elem_rust = type_to_rust(elem_type.inner())?;
507            format!("Vec<{}>", elem_rust)
508        }
509        Type::Map(value_type) => {
510            let value_rust = type_to_rust(value_type.inner())?;
511            format!("std::collections::HashMap<String, {}>", value_rust)
512        }
513        Type::ForeignObject => "serde_json::Value".to_string(),
514        Type::Optional(inner_type) => {
515            let inner_rust = type_to_rust(inner_type.inner())?;
516            format!("Option<{}>", inner_rust)
517        }
518        Type::Custom(name) => name.to_pascal_case(),
519    })
520}
521
522fn type_to_rust_param(ty: &Type) -> Result<String> {
523    Ok(match ty {
524        Type::Bool => "bool".to_string(),
525        Type::Int => "i64".to_string(),
526        Type::Float => "f64".to_string(),
527        Type::String => "&str".to_string(),
528        Type::Object(_fields) => {
529            // For parameters, use reference to avoid clone
530            "&serde_json::Value".to_string()
531        }
532        Type::Enum(_variants) => {
533            // Anonymous enum - use &str for parameters
534            "&str".to_string()
535        }
536        Type::Array(elem_type) => {
537            // Use slice for array parameters with proper string handling
538            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
539            format!("&[{}]", elem_rust)
540        }
541        Type::Map(value_type) => {
542            // Use reference for map parameters with proper string handling
543            let value_rust = type_to_rust_param_elem(value_type.inner())?;
544            format!("&std::collections::HashMap<&str, {}>", value_rust)
545        }
546        Type::ForeignObject => "&serde_json::Value".to_string(),
547        Type::Optional(inner_type) => {
548            let inner_rust = type_to_rust_param(inner_type.inner())?;
549            // For optional parameters, always wrap in Option
550            format!("Option<{}>", inner_rust)
551        }
552        Type::Custom(name) => format!("&{}", name.to_pascal_case()),
553    })
554}
555
556// Helper function to get the proper type for collection elements in parameters.
557// Ensures strings always use &str instead of String.
558fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
559    Ok(match ty {
560        Type::Bool => "bool".to_string(),
561        Type::Int => "i64".to_string(),
562        Type::Float => "f64".to_string(),
563        Type::String => "&str".to_string(),
564        Type::Object(_fields) => "serde_json::Value".to_string(),
565        Type::Enum(_variants) => "&str".to_string(),
566        Type::Array(elem_type) => {
567            let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
568            format!("Vec<{}>", elem_rust)
569        }
570        Type::Map(value_type) => {
571            let value_rust = type_to_rust_param_elem(value_type.inner())?;
572            format!("std::collections::HashMap<&str, {}>", value_rust)
573        }
574        Type::ForeignObject => "serde_json::Value".to_string(),
575        Type::Optional(inner_type) => {
576            let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
577            format!("Option<{}>", inner_rust)
578        }
579        Type::Custom(name) => name.to_pascal_case(),
580    })
581}
582
583fn type_to_rust_output(ty: &Type) -> Result<String> {
584    Ok(match ty {
585        Type::Bool => "bool".to_string(),
586        Type::Int => "i64".to_string(),
587        Type::Float => "f64".to_string(),
588        Type::String => "&'a str".to_string(),
589        Type::Object(_fields) => {
590            // Use owned type for objects - serde can't deserialize to &Value
591            "serde_json::Value".to_string()
592        }
593        Type::Enum(_variants) => {
594            // Anonymous enum - use &str for outputs
595            "&'a str".to_string()
596        }
597        Type::Array(elem_type) => {
598            // Use Vec for array outputs with owned inner types (except strings stay as &'a str)
599            let elem_rust = match elem_type.inner() {
600                Type::String => "&'a str".to_string(),
601                Type::Enum(_) => "&'a str".to_string(),
602                _ => type_to_rust(elem_type.inner())?,
603            };
604            format!("Vec<{}>", elem_rust)
605        }
606        Type::Map(value_type) => {
607            // Use HashMap for map outputs with borrowed types for efficiency
608            let value_rust = match value_type.inner() {
609                Type::String => "&'a str".to_string(),
610                Type::Enum(_) => "&'a str".to_string(),
611                _ => type_to_rust(value_type.inner())?,
612            };
613            format!("std::collections::HashMap<&'a str, {}>", value_rust)
614        }
615        Type::ForeignObject => "serde_json::Value".to_string(),
616        Type::Optional(inner_type) => {
617            // For optional outputs, recursively apply type_to_rust_output to maintain
618            // correct reference types for strings within collections
619            let inner_rust = type_to_rust_output(inner_type.inner())?;
620            format!("Option<{}>", inner_rust)
621        }
622        Type::Custom(name) => name.to_pascal_case(),
623    })
624}
625
626fn interface_name_to_rust(name: &str) -> String {
627    // Convert interface name like "org.example.Interface" to "Interface".
628    name.split('.').next_back().unwrap_or(name).to_pascal_case()
629}
630
631fn type_needs_lifetime(ty: &Type) -> bool {
632    match ty {
633        Type::String => true,
634        Type::Enum(_) => true, // Anonymous enums use &'a str
635        Type::Array(inner) => type_needs_lifetime(inner.inner()),
636        Type::Map(_) => {
637            // Maps always need lifetime because keys are &'a str
638            true
639        }
640        Type::Optional(inner) => type_needs_lifetime(inner.inner()),
641        _ => false,
642    }
643}
644
645fn type_needs_borrow(ty: &Type) -> bool {
646    match ty {
647        Type::String => true,
648        Type::Enum(_) => true, // Anonymous enums use &'a str
649        Type::Array(inner) => type_needs_borrow(inner.inner()),
650        Type::Map(_) => {
651            // Maps always need borrow because keys are &'a str
652            true
653        }
654        Type::Optional(inner) => type_needs_borrow(inner.inner()),
655        _ => false,
656    }
657}
658
659fn is_rust_keyword(s: &str) -> bool {
660    [
661        "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
662        "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
663        "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
664        "true", "type", "unsafe", "use", "where", "while",
665    ]
666    .contains(&s)
667}