ovsdb_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields};
6
7/// Attribute macro for OVSDB table structs
8///
9/// This macro automatically adds `_uuid` and `_version` fields to your struct
10/// and generates the necessary implementations for it to work with OVSDB.
11///
12/// # Example
13///
14/// ```rust
15/// use ovsdb_derive::ovsdb_object;
16/// use std::collections::HashMap;
17///
18/// #[ovsdb_object]
19/// pub struct NbGlobal {
20///     pub name: Option<String>,
21///     pub nb_cfg: Option<i64>,
22///     pub external_ids: Option<HashMap<String, String>>,
23/// }
24/// ```
25#[proc_macro_attribute]
26pub fn ovsdb_object(_attr: TokenStream, item: TokenStream) -> TokenStream {
27    // Parse the struct definition
28    let mut input = parse_macro_input!(item as DeriveInput);
29
30    // Add _uuid and _version fields if they don't exist
31    if let Data::Struct(ref mut data_struct) = input.data {
32        if let Fields::Named(ref mut fields) = data_struct.fields {
33            // Check if _uuid and _version already exist
34            let has_uuid = fields
35                .named
36                .iter()
37                .any(|f| f.ident.as_ref().is_some_and(|i| i == "_uuid"));
38            let has_version = fields
39                .named
40                .iter()
41                .any(|f| f.ident.as_ref().is_some_and(|i| i == "_version"));
42
43            // Add fields if they don't exist
44            if !has_uuid {
45                // Add _uuid field
46                fields.named.push(parse_quote! {
47                    pub _uuid: Option<uuid::Uuid>
48                });
49            }
50            if !has_version {
51                // Add _version field
52                fields.named.push(parse_quote! {
53                    pub _version: Option<uuid::Uuid>
54                });
55            }
56        }
57    }
58
59    // Get the name of the struct
60    let struct_name = &input.ident;
61
62    // Extract field names and types, excluding _uuid and _version
63    let mut field_names = Vec::new();
64    let mut field_types = Vec::new();
65
66    if let Data::Struct(ref data_struct) = input.data {
67        if let Fields::Named(ref fields) = data_struct.fields {
68            for field in &fields.named {
69                if let Some(ident) = &field.ident {
70                    if ident == "_uuid" || ident == "_version" {
71                        continue;
72                    }
73                    field_names.push(ident);
74                    field_types.push(&field.ty);
75                }
76            }
77        }
78    }
79
80    // Generate implementations
81    let implementation = quote! {
82        // Re-export the input struct with the added fields
83        #input
84
85        // Automatically import necessary items from ovsdb-schema
86        use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt};
87
88        impl #struct_name {
89            /// Create a new instance with default values
90            pub fn new() -> Self {
91                Self {
92                    #(
93                        #field_names: Default::default(),
94                    )*
95                    _uuid: None,
96                    _version: None,
97                }
98            }
99
100            /// Convert to a HashMap for OVSDB serialization
101            pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> {
102                let mut map = std::collections::HashMap::new();
103
104                #(
105                    // Skip None values
106                    let field_value = &self.#field_names;
107                    if let Some(value) = field_value.to_ovsdb_json() {
108                        map.insert(stringify!(#field_names).to_string(), value);
109                    }
110                )*
111
112                map
113            }
114
115            /// Create from a HashMap received from OVSDB
116            pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> {
117                let mut result = Self::new();
118
119                // Extract UUID if present
120                if let Some(uuid_val) = map.get("_uuid") {
121                    if let Some(uuid) = extract_uuid(uuid_val) {
122                        result._uuid = Some(uuid);
123                    }
124                }
125
126                // Extract version if present
127                if let Some(version_val) = map.get("_version") {
128                    if let Some(version) = extract_uuid(version_val) {
129                        result._version = Some(version);
130                    }
131                }
132
133                // Extract other fields
134                #(
135                    if let Some(value) = map.get(stringify!(#field_names)) {
136                        result.#field_names = <#field_types>::from_ovsdb_json(value)
137                            .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?;
138                    }
139                )*
140
141                Ok(result)
142            }
143        }
144
145        impl Default for #struct_name {
146            fn default() -> Self {
147                Self::new()
148            }
149        }
150
151        impl serde::Serialize for #struct_name {
152            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153            where
154                S: serde::Serializer
155            {
156                self.to_map().serialize(serializer)
157            }
158        }
159
160        impl<'de> serde::Deserialize<'de> for #struct_name {
161            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
162            where
163                D: serde::Deserializer<'de>
164            {
165                let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
166                Self::from_map(&map).map_err(serde::de::Error::custom)
167            }
168        }
169    };
170
171    // Return the modified struct and implementations
172    TokenStream::from(implementation)
173}
174
175/// Derive macro for OVSDB table structs (requires manual _uuid and _version fields)
176///
177/// This macro generates the necessary implementations for a struct to work with OVSDB.
178/// The struct must have `_uuid` and `_version` fields of type `Option<uuid::Uuid>`.
179///
180/// # Example
181///
182/// ```rust
183/// use ovsdb_derive::OVSDB;
184/// use std::collections::HashMap;
185/// use uuid::Uuid;
186///
187/// #[derive(Debug, Clone, PartialEq, OVSDB)]
188/// pub struct NbGlobal {
189///     pub name: Option<String>,
190///     pub nb_cfg: Option<i64>,
191///     pub external_ids: Option<HashMap<String, String>>,
192///     
193///     // Required fields
194///     pub _uuid: Option<Uuid>,
195///     pub _version: Option<Uuid>,
196/// }
197/// ```
198#[proc_macro_derive(OVSDB)]
199pub fn ovsdb_derive(input: TokenStream) -> TokenStream {
200    // Parse the input tokens into a syntax tree
201    let input = parse_macro_input!(input as DeriveInput);
202
203    // Get the name of the struct
204    let struct_name = &input.ident;
205
206    // Check if the input is a struct
207    let fields = match &input.data {
208        Data::Struct(data_struct) => match &data_struct.fields {
209            Fields::Named(fields_named) => &fields_named.named,
210            _ => panic!("OVSDB can only be derived for structs with named fields"),
211        },
212        _ => panic!("OVSDB can only be derived for structs"),
213    };
214
215    // Extract field names and types, excluding _uuid and _version
216    let mut field_names = Vec::new();
217    let mut field_types = Vec::new();
218
219    for field in fields {
220        if let Some(ident) = &field.ident {
221            if ident == "_uuid" || ident == "_version" {
222                continue;
223            }
224            field_names.push(ident);
225            field_types.push(&field.ty);
226        }
227    }
228
229    // Generate code for the implementation
230    let expanded = quote! {
231        // Automatically import necessary items from ovsdb-schema
232        use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt};
233
234        impl #struct_name {
235            /// Create a new instance with default values
236            pub fn new() -> Self {
237                Self {
238                    #(
239                        #field_names: Default::default(),
240                    )*
241                    _uuid: None,
242                    _version: None,
243                }
244            }
245
246            /// Convert to a HashMap for OVSDB serialization
247            pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> {
248                let mut map = std::collections::HashMap::new();
249
250                #(
251                    // Skip None values
252                    let field_value = &self.#field_names;
253                    if let Some(value) = field_value.to_ovsdb_json() {
254                        map.insert(stringify!(#field_names).to_string(), value);
255                    }
256                )*
257
258                map
259            }
260
261            /// Create from a HashMap received from OVSDB
262            pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> {
263                let mut result = Self::new();
264
265                // Extract UUID if present
266                if let Some(uuid_val) = map.get("_uuid") {
267                    if let Some(uuid) = extract_uuid(uuid_val) {
268                        result._uuid = Some(uuid);
269                    }
270                }
271
272                // Extract version if present
273                if let Some(version_val) = map.get("_version") {
274                    if let Some(version) = extract_uuid(version_val) {
275                        result._version = Some(version);
276                    }
277                }
278
279                // Extract other fields
280                #(
281                    if let Some(value) = map.get(stringify!(#field_names)) {
282                        result.#field_names = <#field_types>::from_ovsdb_json(value)
283                            .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?;
284                    }
285                )*
286
287                Ok(result)
288            }
289        }
290
291        impl Default for #struct_name {
292            fn default() -> Self {
293                Self::new()
294            }
295        }
296
297        impl serde::Serialize for #struct_name {
298            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299            where
300                S: serde::Serializer
301            {
302                self.to_map().serialize(serializer)
303            }
304        }
305
306        impl<'de> serde::Deserialize<'de> for #struct_name {
307            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308            where
309                D: serde::Deserializer<'de>
310            {
311                let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
312                Self::from_map(&map).map_err(serde::de::Error::custom)
313            }
314        }
315    };
316
317    // Return the generated code
318    TokenStream::from(expanded)
319}