1use crate::ast::*;
4use heck::{ToSnakeCase, ToUpperCamelCase};
5use proc_macro2::TokenStream;
6use quote::{format_ident, quote};
7
8pub fn generate(protocol: &Protocol) -> String {
10 let mut tokens = TokenStream::new();
11
12 tokens.extend(generate_prelude());
14
15 for constant in &protocol.constants {
17 tokens.extend(generate_constant(constant));
18 }
19
20 for type_def in &protocol.types {
22 tokens.extend(generate_type(type_def));
23 }
24
25 tokens.extend(generate_client_methods(&protocol.procedures));
27
28 let file = syn::parse2(tokens).expect("generated invalid Rust code");
30 prettyplease::unparse(&file)
31}
32
33fn generate_prelude() -> TokenStream {
34 quote! {
38 use serde::{Serialize, Deserialize};
42
43 pub const VIR_UUID_BUFLEN: usize = 16;
45 pub const VIR_UUID_STRING_BUFLEN: usize = 37;
46
47 pub use libvirt_xdr::opaque::FixedOpaque16;
49 }
50}
51fn generate_constant(constant: &Constant) -> TokenStream {
52 let name = format_ident!("{}", constant.name);
53
54 match &constant.value {
58 ConstValue::Int(n) => {
59 quote! {
60 pub const #name: i64 = #n;
61 }
62 }
63 ConstValue::Ident(_) => {
64 TokenStream::new()
66 }
67 }
68}
69
70fn generate_type(type_def: &TypeDef) -> TokenStream {
71 match type_def {
72 TypeDef::Struct(s) => generate_struct(s),
73 TypeDef::Enum(e) => generate_enum(e),
74 TypeDef::Union(u) => generate_union(u),
75 TypeDef::Typedef(t) => generate_typedef(t),
76 }
77}
78
79fn generate_struct(s: &StructDef) -> TokenStream {
80 let name = format_ident!("{}", to_rust_type_name(&s.name));
81
82 let fields: Vec<_> = s
83 .fields
84 .iter()
85 .map(|f| {
86 let field_name = format_ident!("{}", to_rust_field_name(&f.name));
87 let field_type = type_to_tokens(&f.ty);
88 quote! {
89 pub #field_name: #field_type
90 }
91 })
92 .collect();
93
94 quote! {
95 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
96 pub struct #name {
97 #(#fields),*
98 }
99 }
100}
101
102fn generate_enum(e: &EnumDef) -> TokenStream {
103 let name = format_ident!("{}", to_rust_type_name(&e.name));
104
105 let variants: Vec<_> = e
106 .variants
107 .iter()
108 .filter_map(|v| {
109 let variant_name = format_ident!("{}", to_rust_variant_name(&v.name, &e.name));
110
111 match &v.value {
112 Some(ConstValue::Int(n)) => {
113 let n = *n as i32;
114 Some(quote! { #variant_name = #n })
115 }
116 Some(ConstValue::Ident(_)) => {
117 None
119 }
120 None => Some(quote! { #variant_name }),
121 }
122 })
123 .collect();
124
125 quote! {
126 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
127 #[repr(i32)]
128 pub enum #name {
129 #(#variants),*
130 }
131 }
132}
133
134fn generate_union(u: &UnionDef) -> TokenStream {
135 let name = format_ident!("{}", to_rust_type_name(&u.name));
136
137 let variants: Vec<_> = u
138 .cases
139 .iter()
140 .filter_map(|case| {
141 let variant_name = match &case.values.first()? {
142 ConstValue::Int(n) => format_ident!("V{}", *n as u64),
143 ConstValue::Ident(s) => format_ident!("{}", to_rust_variant_name(s, &u.name)),
144 };
145
146 match &case.field {
147 Some(f) => {
148 let field_type = type_to_tokens(&f.ty);
149 Some(quote! { #variant_name(#field_type) })
150 }
151 None => Some(quote! { #variant_name }),
152 }
153 })
154 .collect();
155
156 quote! {
157 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
158 pub enum #name {
159 #(#variants),*
160 }
161 }
162}
163
164fn generate_typedef(t: &TypedefDef) -> TokenStream {
165 let name = format_ident!("{}", to_rust_type_name(&t.name));
166 let target = type_to_tokens(&t.target);
167
168 quote! {
169 pub type #name = #target;
170 }
171}
172
173fn type_to_tokens(ty: &Type) -> TokenStream {
174 match ty {
175 Type::Void => quote! { () },
176 Type::Int => quote! { i32 },
177 Type::UInt => quote! { u32 },
178 Type::Hyper => quote! { i64 },
179 Type::UHyper => quote! { u64 },
180 Type::Float => quote! { f32 },
181 Type::Double => quote! { f64 },
182 Type::Bool => quote! { bool },
183 Type::String { .. } => quote! { String },
184 Type::Opaque { len } => match len {
185 LengthSpec::Fixed(n) => {
186 let n = *n as usize;
187 if n == 16 {
189 quote! { FixedOpaque16 }
190 } else {
191 quote! { [u8; #n] }
192 }
193 }
194 LengthSpec::Variable { .. } => quote! { Vec<u8> },
195 },
196 Type::Array { elem, len } => {
197 let elem_type = type_to_tokens(elem);
198 match len {
199 LengthSpec::Fixed(n) => {
200 let n = *n as usize;
201 quote! { [#elem_type; #n] }
202 }
203 LengthSpec::Variable { .. } => quote! { Vec<#elem_type> },
204 }
205 }
206 Type::Optional(inner) => {
207 let inner_type = type_to_tokens(inner);
208 quote! { Option<#inner_type> }
209 }
210 Type::Named(name) => {
211 let ident = format_ident!("{}", to_rust_type_name(name));
212 quote! { #ident }
213 }
214 }
215}
216
217fn to_rust_type_name(name: &str) -> String {
219 match name {
221 "u8" | "u16" | "u32" | "u64" | "u128" | "usize" |
222 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" |
223 "f32" | "f64" | "bool" | "char" | "str" | "String" => {
224 return name.to_string();
225 }
226 _ => {}
227 }
228
229 let name = name
231 .strip_prefix("remote_")
232 .or_else(|| name.strip_prefix("virNet"))
233 .unwrap_or(name);
234
235 let converted = name.to_upper_camel_case();
236
237 match converted.as_str() {
239 "String" => "RemoteString".to_string(),
240 "Vec" => "RemoteVec".to_string(),
241 "Option" => "RemoteOption".to_string(),
242 "Box" => "RemoteBox".to_string(),
243 "Result" => "RemoteResult".to_string(),
244 _ => converted,
245 }
246}
247
248fn to_rust_field_name(name: &str) -> String {
251 let name = name.to_snake_case();
252
253 match name.as_str() {
255 "type" => "r#type".to_string(),
256 "match" => "r#match".to_string(),
257 "ref" => "r#ref".to_string(),
258 "mod" => "r#mod".to_string(),
259 "fn" => "r#fn".to_string(),
260 "struct" => "r#struct".to_string(),
261 "enum" => "r#enum".to_string(),
262 "trait" => "r#trait".to_string(),
263 "impl" => "r#impl".to_string(),
264 "self" => "r#self".to_string(),
265 "super" => "r#super".to_string(),
266 "crate" => "r#crate".to_string(),
267 "use" => "r#use".to_string(),
268 "pub" => "r#pub".to_string(),
269 "in" => "r#in".to_string(),
270 "where" => "r#where".to_string(),
271 "async" => "r#async".to_string(),
272 "await" => "r#await".to_string(),
273 "dyn" => "r#dyn".to_string(),
274 "loop" => "r#loop".to_string(),
275 "move" => "r#move".to_string(),
276 "return" => "r#return".to_string(),
277 "static" => "r#static".to_string(),
278 "const" => "r#const".to_string(),
279 "unsafe" => "r#unsafe".to_string(),
280 "extern" => "r#extern".to_string(),
281 "let" => "r#let".to_string(),
282 "mut" => "r#mut".to_string(),
283 "if" => "r#if".to_string(),
284 "else" => "r#else".to_string(),
285 "for" => "r#for".to_string(),
286 "while" => "r#while".to_string(),
287 "break" => "r#break".to_string(),
288 "continue" => "r#continue".to_string(),
289 "as" => "r#as".to_string(),
290 "box" => "r#box".to_string(),
291 "priv" => "r#priv".to_string(),
292 "abstract" => "r#abstract".to_string(),
293 "final" => "r#final".to_string(),
294 "override" => "r#override".to_string(),
295 "virtual" => "r#virtual".to_string(),
296 "yield" => "r#yield".to_string(),
297 "become" => "r#become".to_string(),
298 "macro" => "r#macro".to_string(),
299 "typeof" => "r#typeof".to_string(),
300 "try" => "r#try".to_string(),
301 "union" => "r#union".to_string(),
302 _ => name,
303 }
304}
305
306fn to_rust_variant_name(name: &str, enum_name: &str) -> String {
308 let name = name
310 .strip_prefix(&format!("{}_", enum_name.to_uppercase()))
311 .or_else(|| name.strip_prefix("REMOTE_"))
312 .or_else(|| name.strip_prefix("VIR_"))
313 .unwrap_or(name);
314
315 name.to_upper_camel_case()
316}
317
318fn generate_client_methods(procedures: &[Procedure]) -> TokenStream {
320 let methods: Vec<_> = procedures
321 .iter()
322 .map(|proc| generate_client_method(proc))
323 .collect();
324
325 quote! {
326 #[allow(async_fn_in_trait)]
329 pub trait LibvirtRpc {
330 async fn rpc_call(&self, procedure: u32, payload: Vec<u8>) -> Result<Vec<u8>, RpcError>;
332 }
333
334 #[derive(Debug)]
336 pub enum RpcError {
337 Encode(String),
339 Decode(String),
341 Transport(String),
343 Server(Error),
345 }
346
347 impl std::fmt::Display for RpcError {
348 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
349 match self {
350 RpcError::Encode(e) => write!(f, "XDR encode error: {}", e),
351 RpcError::Decode(e) => write!(f, "XDR decode error: {}", e),
352 RpcError::Transport(e) => write!(f, "Transport error: {}", e),
353 RpcError::Server(e) => write!(f, "Server error: {:?}", e),
354 }
355 }
356 }
357
358 impl std::error::Error for RpcError {}
359
360 pub struct GeneratedClient<T: LibvirtRpc> {
362 inner: T,
363 }
364
365 impl<T: LibvirtRpc> GeneratedClient<T> {
366 pub fn new(inner: T) -> Self {
368 Self { inner }
369 }
370
371 pub fn inner(&self) -> &T {
373 &self.inner
374 }
375
376 pub fn inner_mut(&mut self) -> &mut T {
378 &mut self.inner
379 }
380
381 #(#methods)*
382 }
383 }
384}
385
386fn generate_client_method(proc: &Procedure) -> TokenStream {
388 let method_name = proc
390 .name
391 .strip_prefix("REMOTE_PROC_")
392 .unwrap_or(&proc.name)
393 .to_lowercase();
394 let method_ident = format_ident!("{}", method_name);
395
396 let proc_variant = format_ident!(
398 "Proc{}",
399 proc.name
400 .strip_prefix("REMOTE_PROC_")
401 .unwrap_or(&proc.name)
402 .to_upper_camel_case()
403 );
404
405 match (&proc.args, &proc.ret) {
406 (Some(args_name), Some(ret_name)) => {
407 let args_type = format_ident!("{}", to_rust_type_name(args_name));
409 let ret_type = format_ident!("{}", to_rust_type_name(ret_name));
410
411 quote! {
412 pub async fn #method_ident(&self, args: #args_type) -> Result<#ret_type, RpcError> {
414 let payload = libvirt_xdr::to_bytes(&args)
415 .map_err(|e| RpcError::Encode(e.to_string()))?;
416 let response = self.inner.rpc_call(Procedure::#proc_variant as u32, payload).await?;
417 libvirt_xdr::from_bytes(&response)
418 .map_err(|e| RpcError::Decode(e.to_string()))
419 }
420 }
421 }
422 (Some(args_name), None) => {
423 let args_type = format_ident!("{}", to_rust_type_name(args_name));
425
426 quote! {
427 pub async fn #method_ident(&self, args: #args_type) -> Result<(), RpcError> {
429 let payload = libvirt_xdr::to_bytes(&args)
430 .map_err(|e| RpcError::Encode(e.to_string()))?;
431 let _ = self.inner.rpc_call(Procedure::#proc_variant as u32, payload).await?;
432 Ok(())
433 }
434 }
435 }
436 (None, Some(ret_name)) => {
437 let ret_type = format_ident!("{}", to_rust_type_name(ret_name));
439
440 quote! {
441 pub async fn #method_ident(&self) -> Result<#ret_type, RpcError> {
443 let response = self.inner.rpc_call(Procedure::#proc_variant as u32, Vec::new()).await?;
444 libvirt_xdr::from_bytes(&response)
445 .map_err(|e| RpcError::Decode(e.to_string()))
446 }
447 }
448 }
449 (None, None) => {
450 quote! {
452 pub async fn #method_ident(&self) -> Result<(), RpcError> {
454 let _ = self.inner.rpc_call(Procedure::#proc_variant as u32, Vec::new()).await?;
455 Ok(())
456 }
457 }
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_to_rust_type_name() {
468 assert_eq!(to_rust_type_name("remote_domain"), "Domain");
469 assert_eq!(to_rust_type_name("remote_nonnull_domain"), "NonnullDomain");
470 assert_eq!(to_rust_type_name("foo_bar"), "FooBar");
471 }
472
473 #[test]
474 fn test_to_rust_field_name() {
475 assert_eq!(to_rust_field_name("maxMem"), "max_mem");
476 assert_eq!(to_rust_field_name("nrVirtCpu"), "nr_virt_cpu");
477 }
478
479 #[test]
480 fn test_generate_struct() {
481 let s = StructDef {
482 name: "remote_domain".to_string(),
483 fields: vec![
484 Field {
485 name: "name".to_string(),
486 ty: Type::String { max_len: None },
487 },
488 Field {
489 name: "id".to_string(),
490 ty: Type::Int,
491 },
492 ],
493 };
494
495 let code = generate_struct(&s).to_string();
496 assert!(code.contains("struct Domain"));
497 assert!(code.contains("name : String"));
498 assert!(code.contains("id : i32"));
499 }
500
501 #[test]
502 fn test_generate_enum() {
503 let e = EnumDef {
504 name: "remote_domain_state".to_string(),
505 variants: vec![
506 EnumVariant {
507 name: "VIR_DOMAIN_NOSTATE".to_string(),
508 value: Some(ConstValue::Int(0)),
509 },
510 EnumVariant {
511 name: "VIR_DOMAIN_RUNNING".to_string(),
512 value: Some(ConstValue::Int(1)),
513 },
514 ],
515 };
516
517 let code = generate_enum(&e).to_string();
518 assert!(code.contains("enum DomainState"));
519 assert!(code.contains("DomainNostate"));
520 assert!(code.contains("DomainRunning"));
521 }
522}