dioxus_native_core_macro/
lib.rs

1#![doc = include_str!("../README.md")]
2#![doc(html_logo_url = "https://avatars.githubusercontent.com/u/79236386")]
3#![doc(html_favicon_url = "https://avatars.githubusercontent.com/u/79236386")]
4
5extern crate proc_macro;
6
7use std::collections::HashSet;
8
9use proc_macro::TokenStream;
10use quote::{format_ident, quote};
11use syn::{parse_macro_input, ItemImpl, Type, TypePath, TypeTuple};
12
13/// A helper attribute for deriving `State` for a struct.
14#[proc_macro_attribute]
15pub fn partial_derive_state(_: TokenStream, input: TokenStream) -> TokenStream {
16    let impl_block: syn::ItemImpl = parse_macro_input!(input as syn::ItemImpl);
17
18    let has_create_fn = impl_block
19        .items
20        .iter()
21        .any(|item| matches!(item, syn::ImplItem::Fn(method) if method.sig.ident == "create"));
22
23    let parent_dependencies = impl_block
24        .items
25        .iter()
26        .find_map(|item| {
27            if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
28                (ident == "ParentDependencies").then_some(ty)
29            } else {
30                None
31            }
32        })
33        .expect("ParentDependencies must be defined");
34    let child_dependencies = impl_block
35        .items
36        .iter()
37        .find_map(|item| {
38            if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
39                (ident == "ChildDependencies").then_some(ty)
40            } else {
41                None
42            }
43        })
44        .expect("ChildDependencies must be defined");
45    let node_dependencies = impl_block
46        .items
47        .iter()
48        .find_map(|item| {
49            if let syn::ImplItem::Type(syn::ImplItemType { ident, ty, .. }) = item {
50                (ident == "NodeDependencies").then_some(ty)
51            } else {
52                None
53            }
54        })
55        .expect("NodeDependencies must be defined");
56
57    let this_type = &impl_block.self_ty;
58    let this_type = extract_type_path(this_type)
59        .unwrap_or_else(|| panic!("Self must be a type path, found {}", quote!(#this_type)));
60
61    let mut combined_dependencies = HashSet::new();
62
63    let self_path: TypePath = syn::parse_quote!(Self);
64
65    let parent_dependencies = match extract_tuple(parent_dependencies) {
66        Some(tuple) => {
67            let mut parent_dependencies = Vec::new();
68            for type_ in &tuple.elems {
69                let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
70                    panic!(
71                        "ParentDependencies must be a tuple of type paths, found {}",
72                        quote!(#type_)
73                    )
74                });
75                if type_ == self_path {
76                    type_ = this_type.clone();
77                }
78                combined_dependencies.insert(type_.clone());
79                parent_dependencies.push(type_);
80            }
81            parent_dependencies
82        }
83        _ => panic!(
84            "ParentDependencies must be a tuple, found {}",
85            quote!(#parent_dependencies)
86        ),
87    };
88    let child_dependencies = match extract_tuple(child_dependencies) {
89        Some(tuple) => {
90            let mut child_dependencies = Vec::new();
91            for type_ in &tuple.elems {
92                let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
93                    panic!(
94                        "ChildDependencies must be a tuple of type paths, found {}",
95                        quote!(#type_)
96                    )
97                });
98                if type_ == self_path {
99                    type_ = this_type.clone();
100                }
101                combined_dependencies.insert(type_.clone());
102                child_dependencies.push(type_);
103            }
104            child_dependencies
105        }
106        _ => panic!(
107            "ChildDependencies must be a tuple, found {}",
108            quote!(#child_dependencies)
109        ),
110    };
111    let node_dependencies = match extract_tuple(node_dependencies) {
112        Some(tuple) => {
113            let mut node_dependencies = Vec::new();
114            for type_ in &tuple.elems {
115                let mut type_ = extract_type_path(type_).unwrap_or_else(|| {
116                    panic!(
117                        "NodeDependencies must be a tuple of type paths, found {}",
118                        quote!(#type_)
119                    )
120                });
121                if type_ == self_path {
122                    type_ = this_type.clone();
123                }
124                combined_dependencies.insert(type_.clone());
125                node_dependencies.push(type_);
126            }
127            node_dependencies
128        }
129        _ => panic!(
130            "NodeDependencies must be a tuple, found {}",
131            quote!(#node_dependencies)
132        ),
133    };
134    combined_dependencies.insert(this_type.clone());
135
136    let combined_dependencies: Vec<_> = combined_dependencies.into_iter().collect();
137    let parent_dependancies_idxes: Vec<_> = parent_dependencies
138        .iter()
139        .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
140        .collect();
141    let child_dependencies_idxes: Vec<_> = child_dependencies
142        .iter()
143        .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
144        .collect();
145    let node_dependencies_idxes: Vec<_> = node_dependencies
146        .iter()
147        .filter_map(|ident| combined_dependencies.iter().position(|i| i == ident))
148        .collect();
149    let this_type_idx = combined_dependencies
150        .iter()
151        .enumerate()
152        .find_map(|(i, ident)| (this_type == *ident).then_some(i))
153        .unwrap();
154    let this_view = format_ident!("__data{}", this_type_idx);
155
156    let combined_dependencies_quote = combined_dependencies.iter().map(|ident| {
157        if ident == &this_type {
158            quote! {shipyard::ViewMut<#ident>}
159        } else {
160            quote! {shipyard::View<#ident>}
161        }
162    });
163    let combined_dependencies_quote = quote!((#(#combined_dependencies_quote,)*));
164
165    let ItemImpl {
166        attrs,
167        defaultness,
168        unsafety,
169        impl_token,
170        generics,
171        trait_,
172        self_ty,
173        items,
174        ..
175    } = impl_block;
176    let for_ = trait_.as_ref().map(|t| t.2);
177    let trait_ = trait_.map(|t| t.1);
178
179    let split_views: Vec<_> = (0..combined_dependencies.len())
180        .map(|i| {
181            let ident = format_ident!("__data{}", i);
182            if i == this_type_idx {
183                quote! {mut #ident}
184            } else {
185                quote! {#ident}
186            }
187        })
188        .collect();
189
190    let node_view = node_dependencies_idxes
191        .iter()
192        .map(|i| format_ident!("__data{}", i))
193        .collect::<Vec<_>>();
194    let get_node_view = {
195        if node_dependencies.is_empty() {
196            quote! {
197                let raw_node = ();
198            }
199        } else {
200            let temps = (0..node_dependencies.len())
201                .map(|i| format_ident!("__temp{}", i))
202                .collect::<Vec<_>>();
203            quote! {
204                let raw_node: (#(*const #node_dependencies,)*) = {
205                    let (#(#temps,)*) = (#(&#node_view,)*).get(id).unwrap_or_else(|err| panic!("Failed to get node view {:?}", err));
206                    (#(#temps as *const _,)*)
207                };
208            }
209        }
210    };
211    let deref_node_view = {
212        if node_dependencies.is_empty() {
213            quote! {
214                let node = raw_node;
215            }
216        } else {
217            let indexes = (0..node_dependencies.len()).map(syn::Index::from);
218            quote! {
219                let node = unsafe { (#(dioxus_native_core::prelude::DependancyView::new(&*raw_node.#indexes),)*) };
220            }
221        }
222    };
223
224    let parent_view = parent_dependancies_idxes
225        .iter()
226        .map(|i| format_ident!("__data{}", i))
227        .collect::<Vec<_>>();
228    let get_parent_view = {
229        if parent_dependencies.is_empty() {
230            quote! {
231                let raw_parent = tree.parent_id_advanced(id, Self::TRAVERSE_SHADOW_DOM).map(|_| ());
232            }
233        } else {
234            let temps = (0..parent_dependencies.len())
235                .map(|i| format_ident!("__temp{}", i))
236                .collect::<Vec<_>>();
237            quote! {
238                let raw_parent = tree.parent_id_advanced(id, Self::TRAVERSE_SHADOW_DOM).and_then(|parent_id| {
239                    let raw_parent: Option<(#(*const #parent_dependencies,)*)> = (#(&#parent_view,)*).get(parent_id).ok().map(|c| {
240                        let (#(#temps,)*) = c;
241                        (#(#temps as *const _,)*)
242                    });
243                    raw_parent
244                });
245            }
246        }
247    };
248    let deref_parent_view = {
249        if parent_dependencies.is_empty() {
250            quote! {
251                let parent = raw_parent;
252            }
253        } else {
254            let indexes = (0..parent_dependencies.len()).map(syn::Index::from);
255            quote! {
256                let parent = unsafe { raw_parent.map(|raw_parent| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_parent.#indexes),)*)) };
257            }
258        }
259    };
260
261    let child_view = child_dependencies_idxes
262        .iter()
263        .map(|i| format_ident!("__data{}", i))
264        .collect::<Vec<_>>();
265    let get_child_view = {
266        if child_dependencies.is_empty() {
267            quote! {
268                let raw_children: Vec<_> = tree.children_ids_advanced(id, Self::TRAVERSE_SHADOW_DOM).into_iter().map(|_| ()).collect();
269            }
270        } else {
271            let temps = (0..child_dependencies.len())
272                .map(|i| format_ident!("__temp{}", i))
273                .collect::<Vec<_>>();
274            quote! {
275                let raw_children: Vec<_> = tree.children_ids_advanced(id, Self::TRAVERSE_SHADOW_DOM).into_iter().filter_map(|id| {
276                    let raw_children: Option<(#(*const #child_dependencies,)*)> = (#(&#child_view,)*).get(id).ok().map(|c| {
277                        let (#(#temps,)*) = c;
278                        (#(#temps as *const _,)*)
279                    });
280                    raw_children
281                }).collect();
282            }
283        }
284    };
285    let deref_child_view = {
286        if child_dependencies.is_empty() {
287            quote! {
288                let children = raw_children;
289            }
290        } else {
291            let indexes = (0..child_dependencies.len()).map(syn::Index::from);
292            quote! {
293                let children = unsafe { raw_children.iter().map(|raw_children| (#(dioxus_native_core::prelude::DependancyView::new(&*raw_children.#indexes),)*)).collect::<Vec<_>>() };
294            }
295        }
296    };
297
298    let trait_generics = trait_
299        .as_ref()
300        .unwrap()
301        .segments
302        .last()
303        .unwrap()
304        .arguments
305        .clone();
306
307    // if a create function is defined, we don't generate one
308    // otherwise we generate a default one that uses the update function and the default constructor
309    let create_fn = (!has_create_fn).then(|| {
310        quote! {
311            fn create<'a>(
312                node_view: dioxus_native_core::prelude::NodeView # trait_generics,
313                node: <Self::NodeDependencies as Dependancy>::ElementBorrowed<'a>,
314                parent: Option<<Self::ParentDependencies as Dependancy>::ElementBorrowed<'a>>,
315                children: Vec<<Self::ChildDependencies as Dependancy>::ElementBorrowed<'a>>,
316                context: &dioxus_native_core::prelude::SendAnyMap,
317            ) -> Self {
318                let mut myself = Self::default();
319                myself.update(node_view, node, parent, children, context);
320                myself
321            }
322        }
323    });
324
325    quote!(
326        #(#attrs)*
327        #defaultness #unsafety #impl_token #generics #trait_ #for_ #self_ty {
328            #create_fn
329
330            #(#items)*
331
332            fn workload_system(type_id: std::any::TypeId, dependants: std::sync::Arc<dioxus_native_core::prelude::Dependants>, pass_direction: dioxus_native_core::prelude::PassDirection) -> dioxus_native_core::exports::shipyard::WorkloadSystem {
333                use dioxus_native_core::exports::shipyard::{IntoWorkloadSystem, Get, AddComponent};
334                use dioxus_native_core::tree::TreeRef;
335                use dioxus_native_core::prelude::{NodeType, NodeView};
336
337                let node_mask = Self::NODE_MASK.build();
338
339                (move |data: #combined_dependencies_quote, run_view: dioxus_native_core::prelude::RunPassView #trait_generics| {
340                    let (#(#split_views,)*) = data;
341                    let tree = run_view.tree.clone();
342                    let node_types = run_view.node_type.clone();
343                    dioxus_native_core::prelude::run_pass(type_id, dependants.clone(), pass_direction, run_view, |id, context| {
344                        let node_data: &NodeType<_> = node_types.get(id).unwrap_or_else(|err| panic!("Failed to get node type {:?}", err));
345                        // get all of the states from the tree view
346                        // Safety: No node has itself as a parent or child.
347                        let raw_myself: Option<*mut Self> = (&mut #this_view).get(id).ok().map(|c| c as *mut _);
348                        #get_node_view
349                        #get_parent_view
350                        #get_child_view
351
352                        let myself: Option<&mut Self> = unsafe { raw_myself.map(|val| &mut *val) };
353                        #deref_node_view
354                        #deref_parent_view
355                        #deref_child_view
356
357                        let view = NodeView::new(id, node_data, &node_mask);
358                        if let Some(myself) = myself {
359                            myself
360                                .update(view, node, parent, children, context)
361                        }
362                        else {
363                            (&mut #this_view).add_component_unchecked(
364                                id,
365                                Self::create(view, node, parent, children, context));
366                            true
367                        }
368                    })
369                }).into_workload_system().unwrap()
370            }
371        }
372    )
373    .into()
374}
375
376fn extract_tuple(ty: &Type) -> Option<TypeTuple> {
377    match ty {
378        Type::Tuple(tuple) => Some(tuple.clone()),
379        Type::Group(group) => extract_tuple(&group.elem),
380        _ => None,
381    }
382}
383
384fn extract_type_path(ty: &Type) -> Option<TypePath> {
385    match ty {
386        Type::Path(path) => Some(path.clone()),
387        Type::Group(group) => extract_type_path(&group.elem),
388        _ => None,
389    }
390}