libvirt_codegen/
generator.rs

1//! Rust code generator from XDR AST.
2
3use crate::ast::*;
4use heck::{ToSnakeCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8/// Generate Rust code from a protocol definition.
9pub fn generate(protocol: &Protocol) -> String {
10    let mut tokens = TokenStream::new();
11
12    // Generate prelude
13    tokens.extend(generate_prelude());
14
15    // Generate constants
16    for constant in &protocol.constants {
17        tokens.extend(generate_constant(constant));
18    }
19
20    // Generate types
21    for type_def in &protocol.types {
22        tokens.extend(generate_type(type_def));
23    }
24
25    // Generate RPC client methods
26    tokens.extend(generate_client_methods(&protocol.procedures));
27
28    // Format the output
29    let file = syn::parse2(tokens).expect("generated invalid Rust code");
30    prettyplease::unparse(&file)
31}
32
33fn generate_prelude() -> TokenStream {
34    // Note: This code is included into a submodule via include!(),
35    // so we cannot use inner attributes (like #![allow(...)]).
36    // The parent module should add the necessary attributes.
37    quote! {
38        // Generated code from libvirt protocol definition.
39        // Do not edit manually.
40
41        use serde::{Serialize, Deserialize};
42
43        // Well-known libvirt constants from libvirt.h
44        pub const VIR_UUID_BUFLEN: usize = 16;
45        pub const VIR_UUID_STRING_BUFLEN: usize = 37;
46
47        // Re-export fixed opaque type for UUID
48        pub use libvirt_xdr::opaque::FixedOpaque16;
49    }
50}
51fn generate_constant(constant: &Constant) -> TokenStream {
52    let name = format_ident!("{}", constant.name);
53
54    // Only generate constants with literal integer values.
55    // Skip constants that reference external symbols (like VIR_* from libvirt.h)
56    // since we don't have their definitions.
57    match &constant.value {
58        ConstValue::Int(n) => {
59            quote! {
60                pub const #name: i64 = #n;
61            }
62        }
63        ConstValue::Ident(_) => {
64            // Skip - references external constant we don't have
65            TokenStream::new()
66        }
67    }
68}
69
70fn generate_type(type_def: &TypeDef) -> TokenStream {
71    match type_def {
72        TypeDef::Struct(s) => generate_struct(s),
73        TypeDef::Enum(e) => generate_enum(e),
74        TypeDef::Union(u) => generate_union(u),
75        TypeDef::Typedef(t) => generate_typedef(t),
76    }
77}
78
79fn generate_struct(s: &StructDef) -> TokenStream {
80    let name = format_ident!("{}", to_rust_type_name(&s.name));
81
82    let fields: Vec<_> = s
83        .fields
84        .iter()
85        .map(|f| {
86            let field_name = format_ident!("{}", to_rust_field_name(&f.name));
87            let field_type = type_to_tokens(&f.ty);
88            quote! {
89                pub #field_name: #field_type
90            }
91        })
92        .collect();
93
94    quote! {
95        #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
96        pub struct #name {
97            #(#fields),*
98        }
99    }
100}
101
102fn generate_enum(e: &EnumDef) -> TokenStream {
103    let name = format_ident!("{}", to_rust_type_name(&e.name));
104
105    let variants: Vec<_> = e
106        .variants
107        .iter()
108        .filter_map(|v| {
109            let variant_name = format_ident!("{}", to_rust_variant_name(&v.name, &e.name));
110
111            match &v.value {
112                Some(ConstValue::Int(n)) => {
113                    let n = *n as i32;
114                    Some(quote! { #variant_name = #n })
115                }
116                Some(ConstValue::Ident(_)) => {
117                    // Skip variants that reference other constants
118                    None
119                }
120                None => Some(quote! { #variant_name }),
121            }
122        })
123        .collect();
124
125    quote! {
126        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
127        #[repr(i32)]
128        pub enum #name {
129            #(#variants),*
130        }
131    }
132}
133
134fn generate_union(u: &UnionDef) -> TokenStream {
135    let name = format_ident!("{}", to_rust_type_name(&u.name));
136
137    let variants: Vec<_> = u
138        .cases
139        .iter()
140        .filter_map(|case| {
141            let variant_name = match &case.values.first()? {
142                ConstValue::Int(n) => format_ident!("V{}", *n as u64),
143                ConstValue::Ident(s) => format_ident!("{}", to_rust_variant_name(s, &u.name)),
144            };
145
146            match &case.field {
147                Some(f) => {
148                    let field_type = type_to_tokens(&f.ty);
149                    Some(quote! { #variant_name(#field_type) })
150                }
151                None => Some(quote! { #variant_name }),
152            }
153        })
154        .collect();
155
156    quote! {
157        #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
158        pub enum #name {
159            #(#variants),*
160        }
161    }
162}
163
164fn generate_typedef(t: &TypedefDef) -> TokenStream {
165    let name = format_ident!("{}", to_rust_type_name(&t.name));
166    let target = type_to_tokens(&t.target);
167
168    quote! {
169        pub type #name = #target;
170    }
171}
172
173fn type_to_tokens(ty: &Type) -> TokenStream {
174    match ty {
175        Type::Void => quote! { () },
176        Type::Int => quote! { i32 },
177        Type::UInt => quote! { u32 },
178        Type::Hyper => quote! { i64 },
179        Type::UHyper => quote! { u64 },
180        Type::Float => quote! { f32 },
181        Type::Double => quote! { f64 },
182        Type::Bool => quote! { bool },
183        Type::String { .. } => quote! { String },
184        Type::Opaque { len } => match len {
185            LengthSpec::Fixed(n) => {
186                let n = *n as usize;
187                // Use FixedOpaque16 for 16-byte opaque (UUID) to handle XDR correctly
188                if n == 16 {
189                    quote! { FixedOpaque16 }
190                } else {
191                    quote! { [u8; #n] }
192                }
193            }
194            LengthSpec::Variable { .. } => quote! { Vec<u8> },
195        },
196        Type::Array { elem, len } => {
197            let elem_type = type_to_tokens(elem);
198            match len {
199                LengthSpec::Fixed(n) => {
200                    let n = *n as usize;
201                    quote! { [#elem_type; #n] }
202                }
203                LengthSpec::Variable { .. } => quote! { Vec<#elem_type> },
204            }
205        }
206        Type::Optional(inner) => {
207            let inner_type = type_to_tokens(inner);
208            quote! { Option<#inner_type> }
209        }
210        Type::Named(name) => {
211            let ident = format_ident!("{}", to_rust_type_name(name));
212            quote! { #ident }
213        }
214    }
215}
216
217/// Convert XDR type name to Rust type name (PascalCase).
218fn to_rust_type_name(name: &str) -> String {
219    // Preserve Rust primitive types as-is
220    match name {
221        "u8" | "u16" | "u32" | "u64" | "u128" | "usize" |
222        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" |
223        "f32" | "f64" | "bool" | "char" | "str" | "String" => {
224            return name.to_string();
225        }
226        _ => {}
227    }
228
229    // Remove common prefixes
230    let name = name
231        .strip_prefix("remote_")
232        .or_else(|| name.strip_prefix("virNet"))
233        .unwrap_or(name);
234
235    let converted = name.to_upper_camel_case();
236
237    // Avoid collision with Rust standard types
238    match converted.as_str() {
239        "String" => "RemoteString".to_string(),
240        "Vec" => "RemoteVec".to_string(),
241        "Option" => "RemoteOption".to_string(),
242        "Box" => "RemoteBox".to_string(),
243        "Result" => "RemoteResult".to_string(),
244        _ => converted,
245    }
246}
247
248/// Convert XDR field name to Rust field name (snake_case).
249/// Handles Rust keywords by appending underscore.
250fn to_rust_field_name(name: &str) -> String {
251    let name = name.to_snake_case();
252
253    // Handle Rust keywords
254    match name.as_str() {
255        "type" => "r#type".to_string(),
256        "match" => "r#match".to_string(),
257        "ref" => "r#ref".to_string(),
258        "mod" => "r#mod".to_string(),
259        "fn" => "r#fn".to_string(),
260        "struct" => "r#struct".to_string(),
261        "enum" => "r#enum".to_string(),
262        "trait" => "r#trait".to_string(),
263        "impl" => "r#impl".to_string(),
264        "self" => "r#self".to_string(),
265        "super" => "r#super".to_string(),
266        "crate" => "r#crate".to_string(),
267        "use" => "r#use".to_string(),
268        "pub" => "r#pub".to_string(),
269        "in" => "r#in".to_string(),
270        "where" => "r#where".to_string(),
271        "async" => "r#async".to_string(),
272        "await" => "r#await".to_string(),
273        "dyn" => "r#dyn".to_string(),
274        "loop" => "r#loop".to_string(),
275        "move" => "r#move".to_string(),
276        "return" => "r#return".to_string(),
277        "static" => "r#static".to_string(),
278        "const" => "r#const".to_string(),
279        "unsafe" => "r#unsafe".to_string(),
280        "extern" => "r#extern".to_string(),
281        "let" => "r#let".to_string(),
282        "mut" => "r#mut".to_string(),
283        "if" => "r#if".to_string(),
284        "else" => "r#else".to_string(),
285        "for" => "r#for".to_string(),
286        "while" => "r#while".to_string(),
287        "break" => "r#break".to_string(),
288        "continue" => "r#continue".to_string(),
289        "as" => "r#as".to_string(),
290        "box" => "r#box".to_string(),
291        "priv" => "r#priv".to_string(),
292        "abstract" => "r#abstract".to_string(),
293        "final" => "r#final".to_string(),
294        "override" => "r#override".to_string(),
295        "virtual" => "r#virtual".to_string(),
296        "yield" => "r#yield".to_string(),
297        "become" => "r#become".to_string(),
298        "macro" => "r#macro".to_string(),
299        "typeof" => "r#typeof".to_string(),
300        "try" => "r#try".to_string(),
301        "union" => "r#union".to_string(),
302        _ => name,
303    }
304}
305
306/// Convert XDR enum variant name to Rust variant name.
307fn to_rust_variant_name(name: &str, enum_name: &str) -> String {
308    // Try to strip the enum name prefix
309    let name = name
310        .strip_prefix(&format!("{}_", enum_name.to_uppercase()))
311        .or_else(|| name.strip_prefix("REMOTE_"))
312        .or_else(|| name.strip_prefix("VIR_"))
313        .unwrap_or(name);
314
315    name.to_upper_camel_case()
316}
317
318/// Generate RPC client methods from procedure definitions.
319fn generate_client_methods(procedures: &[Procedure]) -> TokenStream {
320    let methods: Vec<_> = procedures
321        .iter()
322        .map(|proc| generate_client_method(proc))
323        .collect();
324
325    quote! {
326        /// Trait for making RPC calls to libvirt daemon.
327        /// This trait is implemented by the Connection type.
328        #[allow(async_fn_in_trait)]
329        pub trait LibvirtRpc {
330            /// Make an RPC call with the given procedure number and payload.
331            async fn rpc_call(&self, procedure: u32, payload: Vec<u8>) -> Result<Vec<u8>, RpcError>;
332        }
333
334        /// Error type for RPC operations.
335        #[derive(Debug)]
336        pub enum RpcError {
337            /// XDR encoding error
338            Encode(String),
339            /// XDR decoding error
340            Decode(String),
341            /// Transport/connection error
342            Transport(String),
343            /// Server returned an error
344            Server(Error),
345        }
346
347        impl std::fmt::Display for RpcError {
348            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349                match self {
350                    RpcError::Encode(e) => write!(f, "XDR encode error: {}", e),
351                    RpcError::Decode(e) => write!(f, "XDR decode error: {}", e),
352                    RpcError::Transport(e) => write!(f, "Transport error: {}", e),
353                    RpcError::Server(e) => write!(f, "Server error: {:?}", e),
354                }
355            }
356        }
357
358        impl std::error::Error for RpcError {}
359
360        /// Generated RPC client methods for libvirt protocol.
361        pub struct GeneratedClient<T: LibvirtRpc> {
362            inner: T,
363        }
364
365        impl<T: LibvirtRpc> GeneratedClient<T> {
366            /// Create a new GeneratedClient wrapping an RPC transport.
367            pub fn new(inner: T) -> Self {
368                Self { inner }
369            }
370
371            /// Get a reference to the inner transport.
372            pub fn inner(&self) -> &T {
373                &self.inner
374            }
375
376            /// Get a mutable reference to the inner transport.
377            pub fn inner_mut(&mut self) -> &mut T {
378                &mut self.inner
379            }
380
381            #(#methods)*
382        }
383    }
384}
385
386/// Generate a single RPC method for a procedure.
387fn generate_client_method(proc: &Procedure) -> TokenStream {
388    // Convert REMOTE_PROC_CONNECT_LIST_DOMAINS to connect_list_domains
389    let method_name = proc
390        .name
391        .strip_prefix("REMOTE_PROC_")
392        .unwrap_or(&proc.name)
393        .to_lowercase();
394    let method_ident = format_ident!("{}", method_name);
395
396    // Convert to Procedure enum variant name: ProcConnectListDomains
397    let proc_variant = format_ident!(
398        "Proc{}",
399        proc.name
400            .strip_prefix("REMOTE_PROC_")
401            .unwrap_or(&proc.name)
402            .to_upper_camel_case()
403    );
404
405    match (&proc.args, &proc.ret) {
406        (Some(args_name), Some(ret_name)) => {
407            // Has both args and return
408            let args_type = format_ident!("{}", to_rust_type_name(args_name));
409            let ret_type = format_ident!("{}", to_rust_type_name(ret_name));
410
411            quote! {
412                /// RPC method for procedure #method_name.
413                pub async fn #method_ident(&self, args: #args_type) -> Result<#ret_type, RpcError> {
414                    let payload = libvirt_xdr::to_bytes(&args)
415                        .map_err(|e| RpcError::Encode(e.to_string()))?;
416                    let response = self.inner.rpc_call(Procedure::#proc_variant as u32, payload).await?;
417                    libvirt_xdr::from_bytes(&response)
418                        .map_err(|e| RpcError::Decode(e.to_string()))
419                }
420            }
421        }
422        (Some(args_name), None) => {
423            // Has args but no return
424            let args_type = format_ident!("{}", to_rust_type_name(args_name));
425
426            quote! {
427                /// RPC method for procedure #method_name.
428                pub async fn #method_ident(&self, args: #args_type) -> Result<(), RpcError> {
429                    let payload = libvirt_xdr::to_bytes(&args)
430                        .map_err(|e| RpcError::Encode(e.to_string()))?;
431                    let _ = self.inner.rpc_call(Procedure::#proc_variant as u32, payload).await?;
432                    Ok(())
433                }
434            }
435        }
436        (None, Some(ret_name)) => {
437            // No args but has return
438            let ret_type = format_ident!("{}", to_rust_type_name(ret_name));
439
440            quote! {
441                /// RPC method for procedure #method_name.
442                pub async fn #method_ident(&self) -> Result<#ret_type, RpcError> {
443                    let response = self.inner.rpc_call(Procedure::#proc_variant as u32, Vec::new()).await?;
444                    libvirt_xdr::from_bytes(&response)
445                        .map_err(|e| RpcError::Decode(e.to_string()))
446                }
447            }
448        }
449        (None, None) => {
450            // No args and no return
451            quote! {
452                /// RPC method for procedure #method_name.
453                pub async fn #method_ident(&self) -> Result<(), RpcError> {
454                    let _ = self.inner.rpc_call(Procedure::#proc_variant as u32, Vec::new()).await?;
455                    Ok(())
456                }
457            }
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_to_rust_type_name() {
468        assert_eq!(to_rust_type_name("remote_domain"), "Domain");
469        assert_eq!(to_rust_type_name("remote_nonnull_domain"), "NonnullDomain");
470        assert_eq!(to_rust_type_name("foo_bar"), "FooBar");
471    }
472
473    #[test]
474    fn test_to_rust_field_name() {
475        assert_eq!(to_rust_field_name("maxMem"), "max_mem");
476        assert_eq!(to_rust_field_name("nrVirtCpu"), "nr_virt_cpu");
477    }
478
479    #[test]
480    fn test_generate_struct() {
481        let s = StructDef {
482            name: "remote_domain".to_string(),
483            fields: vec![
484                Field {
485                    name: "name".to_string(),
486                    ty: Type::String { max_len: None },
487                },
488                Field {
489                    name: "id".to_string(),
490                    ty: Type::Int,
491                },
492            ],
493        };
494
495        let code = generate_struct(&s).to_string();
496        assert!(code.contains("struct Domain"));
497        assert!(code.contains("name : String"));
498        assert!(code.contains("id : i32"));
499    }
500
501    #[test]
502    fn test_generate_enum() {
503        let e = EnumDef {
504            name: "remote_domain_state".to_string(),
505            variants: vec![
506                EnumVariant {
507                    name: "VIR_DOMAIN_NOSTATE".to_string(),
508                    value: Some(ConstValue::Int(0)),
509                },
510                EnumVariant {
511                    name: "VIR_DOMAIN_RUNNING".to_string(),
512                    value: Some(ConstValue::Int(1)),
513                },
514            ],
515        };
516
517        let code = generate_enum(&e).to_string();
518        assert!(code.contains("enum DomainState"));
519        assert!(code.contains("DomainNostate"));
520        assert!(code.contains("DomainRunning"));
521    }
522}