Skip to main content

linked_macros_impl/
linked_object.rs

1// Copyright (c) Microsoft Corporation.
2// Copyright (c) Folo authors.
3
4use proc_macro2::TokenStream;
5use quote::quote;
6use syn::spanned::Spanned;
7use syn::{Fields, FieldsNamed, Item, ItemStruct, parse_quote};
8
9use crate::syn_helpers::token_stream_and_error;
10
11#[must_use]
12pub fn entrypoint(_attr: &TokenStream, input: &TokenStream) -> TokenStream {
13    let item_ast = syn::parse2::<Item>(input.clone());
14
15    let result = match item_ast {
16        Ok(Item::Struct(item)) => core(item),
17        Ok(x) => Err(syn::Error::new(
18            x.span(),
19            "the `linked::object` attribute must be applied to a struct",
20        )),
21        Err(e) => Err(e),
22    };
23
24    match result {
25        Ok(r) => r,
26        Err(e) => token_stream_and_error(input, &e),
27    }
28}
29
30fn core(mut item: ItemStruct) -> Result<TokenStream, syn::Error> {
31    let (impl_generics, type_generics, where_clause) = &item.generics.split_for_impl();
32    let name = &item.ident;
33
34    let Fields::Named(FieldsNamed { named: fields, .. }) = &mut item.fields else {
35        return Err(syn::Error::new(
36            item.span(),
37            "the `linked::object` attribute must be applied to a struct with named fields",
38        ));
39    };
40
41    // We add a field to store the Link<Self>, which is later referenced by other macros.
42    fields
43        .push(parse_quote!(#[doc(hidden)] __private_linked_link: ::linked::__private::Link<Self>));
44
45    let extended = quote! {
46        #item
47
48        impl #impl_generics ::linked::Object for #name #type_generics #where_clause {
49            fn family(&self) -> ::linked::Family<Self> {
50                self.__private_linked_link.family()
51            }
52        }
53
54        impl #impl_generics Clone for #name #type_generics #where_clause {
55            fn clone(&self) -> Self {
56                ::linked::__private::clone(self)
57            }
58        }
59
60        impl #impl_generics ::std::convert::From<::linked::Family<#name #type_generics>> for #name #type_generics #where_clause {
61            fn from(family: ::linked::Family<#name #type_generics>) -> Self {
62                family.__private_into()
63            }
64        }
65    };
66
67    Ok(extended)
68}
69
70#[cfg(test)]
71#[cfg_attr(coverage_nightly, coverage(off))]
72mod tests {
73    use quote::quote;
74
75    use super::*;
76    use crate::syn_helpers::contains_compile_error;
77
78    #[test]
79    fn smoke_test() {
80        let input = quote! {
81            struct Foo {
82            }
83        };
84
85        let result = entrypoint(&TokenStream::new(), &input);
86
87        let expected = quote! {
88            struct Foo {
89                #[doc(hidden)]
90                __private_linked_link: ::linked::__private::Link<Self>
91            }
92
93            impl ::linked::Object for Foo {
94                fn family(&self) -> ::linked::Family<Self> {
95                    self.__private_linked_link.family()
96                }
97            }
98
99            impl Clone for Foo {
100                fn clone(&self) -> Self {
101                    ::linked::__private::clone(self)
102                }
103            }
104
105            impl ::std::convert::From<::linked::Family<Foo>> for Foo {
106                fn from(family: ::linked::Family<Foo>) -> Self {
107                    family.__private_into()
108                }
109            }
110        };
111
112        assert_eq!(result.to_string(), expected.to_string());
113    }
114
115    #[test]
116    fn smoke_test_with_generics() {
117        let input = quote! {
118            struct Foo<'y, T: Clone, X>
119            where
120                X: Debug
121            {
122                something: X,
123                something_else: &'y Y,
124            }
125        };
126
127        let result = entrypoint(&TokenStream::new(), &input);
128
129        let expected = quote! {
130            struct Foo<'y, T: Clone, X>
131            where
132                X: Debug
133            {
134                something: X,
135                something_else: &'y Y,
136                #[doc(hidden)]
137                __private_linked_link: ::linked::__private::Link<Self>
138            }
139
140            impl<'y, T: Clone, X> ::linked::Object for Foo<'y, T, X>
141            where
142                X: Debug
143            {
144                fn family(&self) -> ::linked::Family<Self> {
145                    self.__private_linked_link.family()
146                }
147            }
148
149            impl<'y, T: Clone, X> Clone for Foo<'y, T, X>
150            where
151            X: Debug
152            {
153                fn clone(&self) -> Self {
154                    ::linked::__private::clone(self)
155                }
156            }
157
158            impl<'y, T: Clone, X> ::std::convert::From<::linked::Family<Foo<'y, T, X> >> for Foo<'y, T, X>
159            where
160                X: Debug
161            {
162                fn from(family: ::linked::Family<Foo<'y, T, X> >) -> Self {
163                    family.__private_into()
164                }
165            }
166        };
167
168        assert_eq!(result.to_string(), expected.to_string());
169    }
170
171    #[test]
172    fn with_unnamed_fields_fails() {
173        let input = quote! {
174            struct Foo(usize, String);
175        };
176
177        let result = entrypoint(&TokenStream::new(), &input);
178        assert!(contains_compile_error(&result));
179    }
180
181    #[test]
182    fn with_enum_fails() {
183        let input = quote! {
184            enum Direction { Up, Down }
185        };
186
187        let result = entrypoint(&TokenStream::new(), &input);
188        assert!(contains_compile_error(&result));
189    }
190
191    #[test]
192    fn with_invalid_syntax_fails() {
193        // This input cannot be parsed by syn because it is not valid Rust syntax.
194        let input = quote! {
195            struct { missing name }
196        };
197
198        let result = entrypoint(&TokenStream::new(), &input);
199        assert!(contains_compile_error(&result));
200    }
201}