holium_backend/
codegen.rs

1//! Codegen has the logic of code generation for our wasm module to run in the Holium protocol.
2
3use crate::ast;
4use crate::Diagnostic;
5use proc_macro2::{Ident, Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn;
8
9/// A trait for converting AST structs into Tokens and adding them to a TokenStream,
10/// or providing a diagnostic if conversion fails.
11pub trait TryToTokens {
12    /// Attempt to convert a `Self` into tokens and add it to the `TokenStream`
13    fn try_to_tokens(&self, into: &mut TokenStream) -> Result<(), Diagnostic>;
14
15    /// Attempt to convert a `Self` into a new `TokenStream`
16    fn try_to_token_stream(&self) -> Result<TokenStream, Diagnostic> {
17        let mut tokens = TokenStream::new();
18        self.try_to_tokens(&mut tokens)?;
19        Ok(tokens)
20    }
21}
22
23impl TryToTokens for ast::Program {
24    // Generate wrappers for all the items that we've found
25    fn try_to_tokens(&self, into: &mut TokenStream) -> Result<(), Diagnostic> {
26        // Handling exported functions
27        let mut errors = Vec::new();
28        for export in self.exports.iter() {
29            if let Err(e) = export.try_to_tokens(into) {
30                errors.push(e);
31            }
32        }
33
34        // Handling tagged structures
35        for s in self.structs.iter() {
36            s.to_tokens(into);
37        }
38
39        Diagnostic::from_vec(errors)?;
40
41        Ok(())
42    }
43}
44
45impl ToTokens for ast::Struct {
46    fn to_tokens(&self, into: &mut TokenStream) {
47        let name = &self.rust_name;
48
49        // Add derive for serialize & deserialize
50        *into = (quote! {
51            #[derive(holium_rs_sdk::internal::serde::Serialize, holium_rs_sdk::internal::serde::Deserialize)]
52            #[serde( crate = "holium_rs_sdk::internal::serde")]
53            #into
54        })
55            .to_token_stream();
56
57        // For each field of our structure add a new children node
58        let mut generate_node_children: Vec<TokenStream> = vec![];
59
60        for field in self.fields.iter() {
61            let field_name = field.name.to_string();
62            let field_type = &field.ty;
63
64            generate_node_children.push(quote! {
65                holium_rs_sdk::internal::key_tree::Node {
66                    value: Some(#field_name),
67                    children: <#field_type>::generate_node().children
68                }
69            });
70        }
71
72        // Generating conversion from data_tree::Node to structure and implement key_tree::GenerateNode
73        // trait
74        (quote! {
75            impl holium_rs_sdk::internal::key_tree::GenerateNode for #name {
76                fn generate_node() -> holium_rs_sdk::internal::key_tree::Node {
77                    holium_rs_sdk::internal::key_tree::Node {
78                        value: None,
79                        children: vec![
80                            #(#generate_node_children),*
81                        ],
82                    }
83                }
84            }
85
86            impl From<holium_rs_sdk::internal::data_tree::Node> for #name {
87                fn from(data_tree: holium_rs_sdk::internal::data_tree::Node) -> Self {
88                    let key_node = <#name>::generate_node();
89                    let cbor = data_tree.assign_keys(&key_node);
90                    let cbor_bytes: Vec<u8> = internal::serde_cbor::to_vec(&cbor).unwrap();
91                    holium_rs_sdk::internal::serde_cbor::from_slice(&cbor_bytes).unwrap()
92                }
93            }
94        })
95        .to_tokens(into);
96    }
97}
98
99impl TryToTokens for ast::Export {
100    fn try_to_tokens(self: &ast::Export, into: &mut TokenStream) -> Result<(), Diagnostic> {
101        let mut input_payload_fields: Vec<TokenStream> = vec![];
102        let mut input_payload_node_children: Vec<TokenStream> = vec![];
103        let mut converted_args: Vec<TokenStream> = vec![];
104
105        let name = &self.rust_name;
106        let receiver = quote! { #name };
107
108        let exported_name = &self.export_name();
109        let holium_func_name = &self.rust_symbol();
110
111        // First, generating inputs elements : input payload struct & function arguments
112        for (i, arg) in self.function.arguments.iter().enumerate() {
113            let field = format!("arg{}", i);
114            let field_ident = Ident::new(&field, Span::call_site());
115            let input_ident = Ident::new(&format!("input"), Span::call_site());
116            let ty = &arg.ty;
117
118            match &*arg.ty {
119                // If argument type is mutable reference
120                syn::Type::Reference(syn::TypeReference {
121                    mutability: Some(_),
122                    elem,
123                    ..
124                }) => {
125                    input_payload_fields.push(quote! {
126                        #field_ident: #elem
127                    });
128                    input_payload_node_children.push(quote! {
129                        holium_rs_sdk::internal::key_tree::Node {
130                            value: Some(#field),
131                            children: <#elem>::generate_node().children
132                        }
133                    });
134                    converted_args.push(quote! {
135                        &mut #input_ident.#field_ident
136                    });
137                }
138                // If argument type is non-mutable reference
139                syn::Type::Reference(syn::TypeReference { elem, .. }) => {
140                    input_payload_fields.push(quote! {
141                        #field_ident: #elem
142                    });
143                    input_payload_node_children.push(quote! {
144                        holium_rs_sdk::internal::key_tree::Node {
145                            value: Some(#field),
146                            children: <#elem>::generate_node().children
147                        }
148                    });
149                    // If argument type is non-mutable reference but a &str no need to add &
150                    if (quote! {#elem}).to_string() == "str" {
151                        converted_args.push(quote! {
152                            #input_ident.#field_ident
153                        });
154                    } else {
155                        converted_args.push(quote! {
156                            &#input_ident.#field_ident
157                        });
158                    }
159                }
160                // For all other types
161                _ => {
162                    input_payload_fields.push(quote! {
163                        #field_ident: #ty
164                    });
165                    input_payload_node_children.push(quote! {
166                        holium_rs_sdk::internal::key_tree::Node {
167                            value: Some(#field),
168                            children: <#ty>::generate_node().children
169                        }
170                    });
171                    converted_args.push(quote! {
172                        #input_ident.#field_ident
173                    });
174                }
175            }
176        }
177
178        (quote! {
179            #[allow(non_snake_case)]
180            #[cfg_attr(
181                all(target_arch = "wasm32"),
182                export_name = #exported_name,
183            )]
184            #[allow(clippy::all)]
185            pub extern "C" fn #holium_func_name(ptr: *mut u8, len: usize) -> holium_rs_sdk::internal::memory::Slice {
186                #[derive(holium_rs_sdk::internal::serde::Serialize, holium_rs_sdk::internal::serde::Deserialize)]
187                #[serde( crate = "holium_rs_sdk::internal::serde")]
188                struct InputPayload {
189                    #(#input_payload_fields),*
190                }
191
192                impl holium_rs_sdk::internal::key_tree::GenerateNode for InputPayload {
193                    fn generate_node() -> holium_rs_sdk::internal::key_tree::Node {
194                        holium_rs_sdk::internal::key_tree::Node {
195                            value: None,
196                            children: vec![
197                                #(#input_payload_node_children),*
198                            ]
199                        }
200                    }
201                }
202
203                impl From<holium_rs_sdk::internal::data_tree::Node> for InputPayload {
204                    fn from(data_tree: holium_rs_sdk::internal::data_tree::Node) -> Self {
205                        let key_node = <InputPayload>::generate_node();
206                        let cbor = data_tree.assign_keys(&key_node);
207                        let cbor_bytes: Vec<u8> = internal::serde_cbor::to_vec(&cbor).unwrap();
208                        holium_rs_sdk::internal::serde_cbor::from_slice(&cbor_bytes).unwrap()
209                    }
210                }
211
212                let payload_u8: &[u8] = unsafe { std::slice::from_raw_parts(ptr, len) };
213                let data_node: holium_rs_sdk::internal::data_tree::Node = holium_rs_sdk::internal::serde_cbor::from_slice(payload_u8).unwrap();
214
215                let input: InputPayload = data_node.into();
216
217                let output = #receiver(#(#converted_args),*);
218
219                let output_cbor = holium_rs_sdk::internal::serde_cbor::value::to_value(vec![output]).unwrap();
220
221                let output_node = holium_rs_sdk::internal::data_tree::Node::new(output_cbor).unwrap();
222                let output_node_u8 = holium_rs_sdk::internal::serde_cbor::to_vec(&output_node).unwrap();
223
224                holium_rs_sdk::internal::memory::Slice {
225                    ptr: output_node_u8.as_ptr() as u32,
226                    len: output_node_u8.len() as u32
227                }
228            }
229        })
230            .to_tokens(into);
231
232        Ok(())
233    }
234}