golem_wasm_rpc_derive/
lib.rs

1// Copyright 2024-2025 Golem Cloud
2//
3// Licensed under the Golem Source License v1.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://license.golem.cloud/LICENSE
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use heck::*;
16use proc_macro::TokenStream;
17use proc_macro2::{Ident, Span};
18use quote::quote;
19use syn::parse::{Parse, ParseStream};
20use syn::{Attribute, Data, DeriveInput, Fields, LitStr, Type, Variant};
21
22#[proc_macro_derive(IntoValue, attributes(wit_transparent, unit_case, wit_field))]
23pub fn derive_into_value(input: TokenStream) -> TokenStream {
24    let ast: DeriveInput = syn::parse(input).expect("derive input");
25    let ident = &ast.ident;
26    let wit_transparent = ast
27        .attrs
28        .iter()
29        .any(|attr| attr.path().is_ident("wit_transparent"));
30    let lit_name = LitStr::new(&ident.to_string(), Span::call_site());
31
32    let (into_value, get_type) = match ast.data {
33        Data::Struct(data) => {
34            let newtype_result = if data.fields.len() == 1 {
35                let field = data.fields.iter().next().unwrap().clone();
36                if field.ident.is_none() || wit_transparent {
37                    // single field without an identifier, we consider this a newtype
38                    let field_type = field.ty;
39
40                    let into_value = match field.ident {
41                        None => quote! {
42                            self.0.into_value()
43                        },
44                        Some(field_name) => quote! {
45                            self.#field_name.into_value()
46                        },
47                    };
48                    let get_type = quote! {
49                        <#field_type as golem_wasm_rpc::IntoValue>::get_type()
50                    };
51
52                    Some((into_value, get_type))
53                } else {
54                    None
55                }
56            } else {
57                None
58            };
59
60            match newtype_result {
61                Some(newtype_result) => newtype_result,
62                None => record_or_tuple(&lit_name, &data.fields),
63            }
64        }
65        Data::Enum(data) => {
66            let is_simple_enum = data
67                .variants
68                .iter()
69                .all(|variant| variant.fields.is_empty());
70
71            if is_simple_enum {
72                let case_branches = data
73                    .variants
74                    .iter()
75                    .enumerate()
76                    .map(|(idx, variant)| {
77                        let case_ident = &variant.ident;
78                        let idx = idx as u32;
79                        quote! {
80                            #ident::#case_ident => golem_wasm_rpc::Value::Enum(#idx)
81                        }
82                    })
83                    .collect::<Vec<_>>();
84                let case_labels = data
85                    .variants
86                    .iter()
87                    .map(|variant| variant.ident.to_string().to_kebab_case())
88                    .collect::<Vec<_>>();
89
90                let into_value = quote! {
91                    match self {
92                        #(#case_branches),*
93                    }
94                };
95
96                let get_type = quote! {
97                    golem_wasm_ast::analysis::analysed_type::r#enum(
98                        &[#(#case_labels),*]
99                    ).named(#lit_name)
100                };
101
102                (into_value, get_type)
103            } else {
104                let case_branches = data
105                    .variants
106                    .iter()
107                    .enumerate()
108                    .map(|(idx, variant)| {
109                        let case_ident = &variant.ident;
110                        let idx = idx as u32;
111
112                        let wit_fields = variant.fields
113                            .iter()
114                            .map(|field| {
115                                field
116                                    .attrs
117                                    .iter()
118                                    .find(|attr| attr.path().is_ident("wit_field"))
119                                    .map(parse_wit_field_attribute)
120                                    .unwrap_or_default()
121                            })
122                            .collect::<Vec<_>>();
123
124                        if variant.fields.is_empty() {
125                            quote! {
126                                #ident::#case_ident => golem_wasm_rpc::Value::Variant {
127                                    case_idx: #idx,
128                                    case_value: None
129                                }
130                            }
131                        } else if has_single_anonymous_field(&variant.fields) {
132                            // separate inner type
133                            if is_unit_case(variant) {
134                                quote! {
135                                    #ident::#case_ident(inner) => golem_wasm_rpc::Value::Variant {
136                                        case_idx: #idx,
137                                        case_value: None
138                                    }
139                                }
140                            } else {
141                                let wit_field = wit_fields.first().unwrap();
142                                let into_value = apply_conversions(wit_field, quote! { inner });
143                                quote! {
144                                    #ident::#case_ident(inner) => golem_wasm_rpc::Value::Variant {
145                                        case_idx: #idx,
146                                        case_value: Some(Box::new(#into_value))
147                                    }
148                                }
149                            }
150                        } else if has_only_named_fields(&variant.fields) {
151                            // record case
152                            let field_names = variant
153                                .fields
154                                .iter()
155                                .map(|field| {
156                                    let field = field.ident.as_ref().unwrap();
157                                    quote! { #field }
158                                })
159                                .collect::<Vec<_>>();
160
161                            let field_values = variant.fields.iter().map(|field| {
162                                let field = field.ident.as_ref().unwrap();
163                                quote! {
164                                    #field.into_value()
165                                }
166                            });
167
168                            if is_unit_case(variant) {
169                                quote! {
170                                    #ident::#case_ident { #(#field_names),* } => golem_wasm_rpc::Value::Variant {
171                                        case_idx: #idx,
172                                        case_value: None
173                                    }
174                                }
175                            } else {
176                                quote! {
177                                    #ident::#case_ident { #(#field_names),* } =>
178                                        golem_wasm_rpc::Value::Variant {
179                                            case_idx: #idx,
180                                            case_value: Some(Box::new(golem_wasm_rpc::Value::Record(
181                                                vec![#(#field_values),*]
182                                            )))
183                                        }
184                                }
185                            }
186                        } else {
187                            // tuple case
188                            let field_names = variant
189                                .fields
190                                .iter()
191                                .enumerate()
192                                .map(|(idx, _field)| {
193                                    Ident::new(&format!("f{idx}"), Span::call_site())
194                                })
195                                .collect::<Vec<_>>();
196
197                            let field_values = field_names.iter().map(|field| {
198                                quote! {
199                                    #field.into_value()
200                                }
201                            });
202
203                            if is_unit_case(variant) {
204                                quote! {
205                                    #ident::#case_ident(#(#field_names),*) => golem_wasm_rpc::Value::Variant {
206                                        case_idx: #idx,
207                                        case_value: None
208                                    }
209                                }
210                            } else {
211                                quote! {
212                                    #ident::#case_ident(#(#field_names),*) =>
213                                        golem_wasm_rpc::Value::Variant {
214                                            case_idx: #idx,
215                                            case_value: Some(Box::new(golem_wasm_rpc::Value::Tuple(
216                                                vec![#(#field_values),*]
217                                            )))
218                                        }
219                                }
220                            }
221                        }
222                    })
223                    .collect::<Vec<_>>();
224
225                let case_defs = data.variants.iter()
226                    .map(|variant| {
227                        let wit_fields = variant.fields
228                            .iter()
229                            .map(|field| {
230                                field
231                                    .attrs
232                                    .iter()
233                                    .find(|attr| attr.path().is_ident("wit_field"))
234                                    .map(parse_wit_field_attribute)
235                                    .unwrap_or_default()
236                            })
237                            .collect::<Vec<_>>();
238
239                        let case_name = variant.ident.to_string().to_kebab_case();
240                        if is_unit_case(variant) {
241                            quote! {
242                                golem_wasm_ast::analysis::analysed_type::unit_case(#case_name)
243                            }
244                        } else if has_single_anonymous_field(&variant.fields) {
245                            let single_field = variant.fields.iter().next().unwrap();
246                            let typ = &single_field.ty;
247                            let wit_field = wit_fields.first().unwrap();
248                            let typ = get_field_type(typ, wit_field);
249
250                            quote! {
251                                golem_wasm_ast::analysis::analysed_type::case(#case_name, <#typ as golem_wasm_rpc::IntoValue>::get_type())
252                            }
253                        } else {
254                            let (_, inner_get_type) = record_or_tuple(&LitStr::new(&case_name, Span::call_site()), &variant.fields);
255
256                            quote! {
257                                golem_wasm_ast::analysis::analysed_type::case(#case_name, #inner_get_type)
258                            }
259                        }
260                    })
261                    .collect::<Vec<_>>();
262
263                let into_value = quote! {
264                    match self {
265                        #(#case_branches),*
266                    }
267                };
268                let get_type = quote! {
269                    golem_wasm_ast::analysis::analysed_type::variant(
270                        vec![#(#case_defs),*]
271                    ).named(#lit_name)
272                };
273
274                (into_value, get_type)
275            }
276        }
277        Data::Union(_data) => {
278            panic!("Cannot derive IntoValue for unions")
279        }
280    };
281
282    let result = quote! {
283        impl golem_wasm_rpc::IntoValue for #ident {
284            fn into_value(self) -> golem_wasm_rpc::Value {
285                #into_value
286            }
287
288            fn get_type() -> golem_wasm_ast::analysis::AnalysedType {
289                #get_type
290            }
291        }
292    };
293
294    result.into()
295}
296
297fn record_or_tuple(
298    lit_name: &LitStr,
299    fields: &Fields,
300) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
301    let all_fields_has_names = fields.iter().all(|field| field.ident.is_some());
302
303    if all_fields_has_names {
304        let wit_fields = fields
305            .iter()
306            .map(|field| {
307                field
308                    .attrs
309                    .iter()
310                    .find(|attr| attr.path().is_ident("wit_field"))
311                    .map(parse_wit_field_attribute)
312                    .unwrap_or_default()
313            })
314            .collect::<Vec<_>>();
315
316        let field_values = fields
317            .iter()
318            .zip(&wit_fields)
319            .filter_map(|(field, wit_field)| {
320                if wit_field.skip {
321                    None
322                } else {
323                    let field_name = field.ident.as_ref().unwrap();
324                    Some(apply_conversions(wit_field, quote! { self.#field_name }))
325                }
326            })
327            .collect::<Vec<_>>();
328
329        let field_defs = fields
330            .iter()
331            .zip(wit_fields)
332            .filter_map(|(field, wit_field)| {
333                if wit_field.skip {
334                    None
335                } else {
336                    let field_name = wit_field
337                        .rename
338                        .as_ref()
339                        .map(|lit| lit.value())
340                        .unwrap_or_else(|| {
341                            field.ident.as_ref().unwrap().to_string().to_kebab_case()
342                        });
343                    let field_type = get_field_type(&field.ty, &wit_field);
344                    Some(quote! {
345                        golem_wasm_ast::analysis::analysed_type::field(
346                            #field_name,
347                            <#field_type as golem_wasm_rpc::IntoValue>::get_type()
348                        )
349                    })
350                }
351            })
352            .collect::<Vec<_>>();
353
354        let into_value = quote! {
355            golem_wasm_rpc::Value::Record(vec![
356                #(#field_values),*
357            ])
358        };
359        let get_type = quote! {
360            golem_wasm_ast::analysis::analysed_type::record(vec![
361                #(#field_defs),*
362            ]).named(#lit_name)
363        };
364
365        (into_value, get_type)
366    } else {
367        let tuple_field_values = fields
368            .iter()
369            .map(|field| {
370                let field_name = field.ident.as_ref().unwrap();
371                quote! { self.#field_name.into_value() }
372            })
373            .collect::<Vec<_>>();
374
375        let tuple_field_types = fields
376            .iter()
377            .map(|field| {
378                let field_type = &field.ty;
379                quote! {
380                    <#field_type as golem_wasm_rpc::IntoValue>::get_type()
381                }
382            })
383            .collect::<Vec<_>>();
384
385        let into_value = quote! {
386            golem_wasm_rpc::Value::Tuple(vec![
387                #(#tuple_field_values),*
388            ])
389        };
390        let get_type = quote! {
391            golem_wasm_ast::analysis::analysed_type::tuple(vec![
392                #(#tuple_field_types),*
393            ]).named(#lit_name)
394        };
395
396        (into_value, get_type)
397    }
398}
399
400fn get_field_type(ty: &Type, wit_field: &WitField) -> proc_macro2::TokenStream {
401    match (
402        &wit_field.convert,
403        &wit_field.convert_vec,
404        &wit_field.convert_option,
405    ) {
406        (Some(convert_to), None, None) => quote! { #convert_to },
407        (None, Some(convert_to), None) => quote! { Vec<#convert_to> },
408        (None, None, Some(convert_to)) => quote! { Option<#convert_to> },
409        _ => {
410            quote! { #ty }
411        }
412    }
413}
414
415fn has_single_anonymous_field(fields: &Fields) -> bool {
416    fields.len() == 1 && fields.iter().next().unwrap().ident.is_none()
417}
418
419fn has_only_named_fields(fields: &Fields) -> bool {
420    fields.iter().all(|field| field.ident.is_some())
421}
422
423fn is_unit_case(variant: &Variant) -> bool {
424    variant.fields.is_empty()
425        || variant
426            .attrs
427            .iter()
428            .any(|attr| attr.path().is_ident("unit_case"))
429}
430
431fn apply_conversions(
432    wit_field: &WitField,
433    field_access: proc_macro2::TokenStream,
434) -> proc_macro2::TokenStream {
435    match (
436        &wit_field.convert,
437        &wit_field.convert_vec,
438        &wit_field.convert_option,
439    ) {
440        (Some(convert_to), None, None) => {
441            quote! { Into::<#convert_to>::into(#field_access).into_value() }
442        }
443        (None, Some(convert_to), None) => {
444            quote! { #field_access.into_iter().map(Into::<#convert_to>::into).collect::<Vec<_>>().into_value() }
445        }
446        (None, None, Some(convert_to)) => {
447            quote! { #field_access.map(Into::<#convert_to>::into).into_value() }
448        }
449        _ => quote! { #field_access.into_value() },
450    }
451}
452
453#[derive(Default)]
454struct WitField {
455    skip: bool,
456    rename: Option<LitStr>,
457    convert: Option<Type>,
458    convert_vec: Option<Type>,
459    convert_option: Option<Type>,
460}
461
462fn parse_wit_field_attribute(attr: &Attribute) -> WitField {
463    attr.parse_args_with(WitField::parse)
464        .expect("failed to parse wit_field attribute")
465}
466
467impl Parse for WitField {
468    fn parse(input: ParseStream) -> syn::Result<Self> {
469        let mut skip = false;
470        let mut rename = None;
471        let mut convert = None;
472        let mut convert_vec = None;
473        let mut convert_option = None;
474
475        while !input.is_empty() {
476            let ident: Ident = input.parse()?;
477            if ident == "skip" {
478                skip = true;
479            } else if ident == "rename" {
480                input.parse::<syn::Token![=]>()?;
481                rename = Some(input.parse()?);
482            } else if ident == "convert" {
483                input.parse::<syn::Token![=]>()?;
484                convert = Some(input.parse()?);
485            } else if ident == "convert_vec" {
486                input.parse::<syn::Token![=]>()?;
487                convert_vec = Some(input.parse()?);
488            } else if ident == "convert_option" {
489                input.parse::<syn::Token![=]>()?;
490                convert_option = Some(input.parse()?);
491            } else {
492                return Err(syn::Error::new(ident.span(), "unexpected attribute"));
493            }
494        }
495
496        Ok(WitField {
497            skip,
498            rename,
499            convert,
500            convert_vec,
501            convert_option,
502        })
503    }
504}