selfstack 0.3.0

A macro to generate self-referential structs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
// Our macros are just complicated enough to blow past default limits.
#![recursion_limit = "128"]
extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
// quote is used by syn::parse_quote
#[allow(unused_imports)]
use quote::quote;
use quote::ToTokens;
use syn::spanned::Spanned;

/// Replace all lifetimes in ty with the lifetime lt.
fn replace_lifetimes(ty: &mut syn::Type, lt: syn::Lifetime) -> Result<(), syn::Error> {
    match ty {
        syn::Type::Path(ref mut p) => {
            for seg in &mut p.path.segments {
                match &mut seg.arguments {
                    syn::PathArguments::Parenthesized(ref mut args) => {
                        for mut input in &mut args.inputs {
                            replace_lifetimes(&mut input, lt.clone())?;
                        }
                        if let syn::ReturnType::Type(_, ty) = &mut args.output {
                            replace_lifetimes(ty, lt.clone())?;
                        }
                    }
                    syn::PathArguments::AngleBracketed(ref mut args) => {
                        for mut arg in &mut args.args {
                            match &mut arg {
                                syn::GenericArgument::Lifetime(l) => {
                                    *l = lt.clone();
                                }
                                syn::GenericArgument::Type(ref mut t) => {
                                    replace_lifetimes(t, lt.clone())?;
                                }
                                _ => (),
                            }
                        }
                    }
                    syn::PathArguments::None => (),
                }
            }
        }
        syn::Type::Reference(ref mut r) => {
            if let Some(ref mut l) = &mut r.lifetime {
                *l = lt;
            }
        }
        ref x => {
            // TODO: Handle more types.
            let msg = format!("type not supported: {:?}", &x);
            return Err(syn::Error::new(ty.span(), msg));
        }
    };
    Ok(())
}

// proc_macros cannot be statements, but rustc thinks item proc_macros are statements in doctests.
// If we use main(), then it forces it to be an item.
#[allow(clippy::needless_doctest_main)]
#[proc_macro]
/// selfstack produces a stack-like self-referential data structure with a safe interface. This is
/// safe because layers in the stack can only reference layers below them, and lower layers outlive
/// higher layers. This restriction prevents cycles, dangling references, and other unsoundness
/// that would generally be possible with self-reference.
///
/// You must declare a mod in this macro. This is so the macro can make unsafe operations private.
///
/// Any struct inside that mod defines the layers of a stack with its fields. This struct is the
/// Store. It provides storage for all fields ahead of time, but at first is uninitialized.
/// Lifetime names are ignored, but you may use them to document references.
///
/// ```
/// selfstack::selfstack! {
///     mod mystack {
///         pub struct MyStore {
///             layer1: u32,
///             layer2: &'layer1 u32,
///         }
///     }
/// }
/// # fn main() {}
/// ```
///
/// To initialize the first layer, call `set_$field(T)` on the Store with the value of the first
/// field. This will return a SubStruct that will allow you to safely access the subset of
/// initialized layers.
///
/// ```
/// # selfstack::selfstack! {
/// #     mod mystack {
/// #         pub struct MyStore {
/// #             layer1: u32,
/// #             layer2: &'layer1 u32,
/// #         }
/// #     }
/// # }
/// #
/// # fn main() {
/// let mut store = mystack::MyStore::new();
/// let sub_struct = store.set_layer1(42);
/// # }
/// ```
///
///
/// You can initialize further layers with the `build_$field()` and `try_build_$field()` methods.
/// These return SubStructs that will allow access to the next layer. For `build()`, you pass a
/// closure that takes references to the previous fields and returns the value of the next field.
/// `try_build()` is the same except the closure should return a `Result` and it will return a
/// `Result`.
///
/// ```
/// # selfstack::selfstack! {
/// #     mod mystack {
/// #         pub struct MyStore {
/// #             layer1: u32,
/// #             layer2: &'layer1 u32,
/// #         }
/// #     }
/// # }
/// #
/// # fn main() {
/// # let mut store = mystack::MyStore::new();
/// # let sub_struct = store.set_layer1(42);
/// let sub_struct = sub_struct.build_layer2(|layer1: &u32|->&u32 {
///     layer1
/// });
/// # }
/// ```
///
/// You can get a const reference to any layer or a mutable reference to the top-most layer using
/// methods that match ref_$fieldname() or mut_$fieldname(). If you need references to multiple
/// layers simultaneously, you can call the view() method which will return a ViewStruct that
/// contains public fields referencing each layer [^0].
///
/// ```
/// # selfstack::selfstack! {
/// #     mod mystack {
/// #         pub struct MyStore {
/// #             layer1: u32,
/// #             layer2: &'layer1 u32,
/// #         }
/// #     }
/// # }
/// #
/// # fn main() {
/// # let mut store = mystack::MyStore::new();
/// # let sub_struct = store.set_layer1(42);
/// # let mut sub_struct = sub_struct.build_layer2(|layer1: &u32|->&u32 {
/// #     &layer1
/// # });
/// assert_eq!(*sub_struct.ref_layer1(), 42);
/// let view = sub_struct.view();
/// assert_eq!(**view.layer2, 42);
/// assert_eq!(*view.layer2, view.layer1);
/// *view.layer2 = &0; // Top layer is mutable.
/// assert_eq!(**view.layer2, 0);
/// # }
/// ```
///
/// When the SubStruct is dropped, the initialized fields will be dropped in reverse order and the
/// Store can be reused.
///
/// [^0]: This cludge is due to limitations in the borrow checker. Calling a method on the
/// SubStruct borrows the entire SubStruct, and the borrow checker won't allow multiple borrows
/// simultaneously if any are mutable. The borrow checker is able to allow simultaneous borrows to
/// the individual fields of a struct however.
pub fn selfstack(item: TokenStream) -> TokenStream {
    let mut mod_def = syn::parse_macro_input!(item as syn::ItemMod);
    if let Some((_, content)) = &mut mod_def.content {
        let input_content = std::mem::replace(content, Vec::new());
        for item in input_content {
            match item {
                syn::Item::Struct(s) => {
                    if let Err(e) = selfstack_struct(s, content) {
                        return e.to_compile_error().into();
                    }
                }
                syn::Item::Use(u) => {
                    content.push(syn::Item::Use(u));
                }
                _ => {
                    // Allowing other items inside the mod (like fn and impl) could circumvent the
                    // safe interface.
                    return syn::Error::new_spanned(item, "item not supported in a selfstack mod")
                        .to_compile_error()
                        .into();
                }
            }
        }
    }
    let mut out = proc_macro2::TokenStream::new();
    mod_def.to_tokens(&mut out);
    out.into()
}

/// This is invoked for each input struct and should push all output Items to out.
fn selfstack_struct(
    mut struct_def: syn::ItemStruct,
    out: &mut Vec<syn::Item>,
) -> Result<(), syn::Error> {
    let sname = &struct_def.ident;
    let mut new_params = syn::punctuated::Punctuated::new();
    let vis = &struct_def.vis;
    for p in &struct_def.generics.params {
        match p {
            syn::GenericParam::Lifetime(_) => (),
            _ => new_params.push(p.clone()),
        }
    }
    struct_def.generics.params = new_params;
    let call_site = proc_macro2::Span::call_site();
    // The _Ptrs struct will hold raw pointers to each field created up front. If we were to create
    // references from the store directly as needed, that would invalidate existing references
    // according to stacked borrows.
    let store_ptrs_ident = syn::Ident::new(&format!("{}_Ptrs", sname), call_site);
    let mut init_field_values =
        syn::punctuated::Punctuated::<syn::FieldValue, syn::Token![,]>::new();
    let mut raw_ptr_field_values =
        syn::punctuated::Punctuated::<syn::FieldValue, syn::Token![,]>::new();
    let mut raw_ptr_fields: syn::FieldsNamed = syn::parse_quote!({});
    let mut impls: Vec<syn::ItemImpl> = vec![];
    let mut structs: Vec<syn::ItemStruct> = vec![];
    let mut field_refs: syn::punctuated::Punctuated<syn::Type, syn::Token![,]> =
        syn::punctuated::Punctuated::new();
    let mut store_refs: syn::punctuated::Punctuated<syn::Expr, syn::Token![,]> =
        syn::punctuated::Punctuated::new();
    let mut drop_stmts: syn::Block = syn::parse_quote!({});
    let mut field_getters: Vec<syn::ImplItem> = vec![];
    let mut view_fields: syn::FieldsNamed = syn::parse_quote!({
        _use_lt_a: ::std::marker::PhantomData<&'b &'a ()>,
    });
    let mut view_field_refs = syn::punctuated::Punctuated::<syn::FieldValue, syn::Token![,]>::new();
    view_field_refs.push(syn::parse_quote!(_use_lt_a: ::std::marker::PhantomData));

    let store_impl = syn::parse_quote! {
        impl<'a> #sname {
        }
    };
    impls.push(store_impl);

    let struct_fields = match &mut struct_def.fields {
        syn::Fields::Named(ref mut fns) => fns,
        // We could possibly use numbered fields, but that currently seems too complicated for the
        // benefit.
        _ => {
            return Err(syn::Error::new(
                struct_def.span(),
                "struct must have named fields",
            ));
        }
    };
    for field in struct_fields.named.iter_mut() {
        let field_ident = field.ident.as_ref().unwrap();
        let build_ident = syn::Ident::new(&format!("build_{}", field_ident), call_site);
        let set_ident = syn::Ident::new(&format!("set_{}", field_ident), call_site);
        let try_build_ident = syn::Ident::new(&format!("try_build_{}", field_ident), call_site);
        let substruct_ident =
            syn::Ident::new(&format!("{}_{}", struct_def.ident, field_ident), call_site);
        let viewstruct_ident = syn::Ident::new(
            &format!("{}_View_{}", struct_def.ident, field_ident),
            call_site,
        );
        let mut_ident = syn::Ident::new(&format!("mut_{}", field_ident), call_site);
        let ref_ident = syn::Ident::new(&format!("ref_{}", field_ident), call_site);

        match &field.vis {
            syn::Visibility::Inherited => (),
            x => {
                return Err(syn::Error::new(
                    x.span(),
                    "fields of a selfstack must be private",
                ))
            }
        }
        let orig_ty = field.ty.clone();
        // We allow bogus lifetimes in order to express self-reference, so we have to erase the
        // lifetimes in the struct definition by using 'static.
        replace_lifetimes(&mut field.ty, syn::parse_quote!('static))?;
        let ty_lt_static = field.ty.clone();
        {
            let fty = &field.ty;
            field.ty = syn::parse_quote!(::std::mem::MaybeUninit<#fty>);
        }
        // Prepare some types with single lifetimes so that generated functions can just use 'a,
        // 'b, and '_.
        let mut ty_lt_a = orig_ty.clone();
        replace_lifetimes(&mut ty_lt_a, syn::parse_quote!('a))?;
        let mut ty_lt_b = orig_ty.clone();
        replace_lifetimes(&mut ty_lt_b, syn::parse_quote!('b))?;
        let mut ty_lt__ = orig_ty.clone();
        replace_lifetimes(&mut ty_lt__, syn::parse_quote!('_))?;
        // The first layer has no previous layers to reference, so use set instead of
        // build.
        let is_set = impls.len() == 1;
        let build: syn::ImplItem = if is_set {
            syn::parse_quote! {
                #[inline]
                #vis fn #set_ident(&'a mut self, #field_ident: #ty_lt_a) -> #substruct_ident<'a> {
                    let ptrs = self.ptrs();
                    let #field_ident = unsafe{::std::mem::transmute::<#ty_lt__, #ty_lt_static>(#field_ident)};
                    unsafe{::std::ptr::write(ptrs.#field_ident, #field_ident)};
                    #substruct_ident{
                        _store: ::std::marker::PhantomData,
                        ptrs,
                    }
                }
            }
        } else {
            syn::parse_quote! {
                #[inline]
                #vis fn #build_ident<F>(mut self, initf: F) -> #substruct_ident<'a>
                    where F: FnOnce(#field_refs) -> #ty_lt_b
                {
                    let ptrs = self.ptrs;
                    ::std::mem::forget(self);
                    let #field_ident = {
                        let #field_ident = initf(#store_refs);
                        unsafe{::std::mem::transmute::<#ty_lt__, #ty_lt_static>(#field_ident)}
                    };
                    unsafe{::std::ptr::write(ptrs.#field_ident, #field_ident)};
                    #substruct_ident{
                        _store: ::std::marker::PhantomData,
                        ptrs,
                    }
                }
            }
        };
        impls.last_mut().unwrap().items.push(build);
        if !is_set {
            let trybuild = syn::parse_quote! {
                #[inline]
                #vis fn #try_build_ident<F, E>(mut self, initf: F) -> Result<#substruct_ident<'a>, E>
                    where F: FnOnce(#field_refs) -> Result<#ty_lt_b, E>
                {
                    let ptrs = self.ptrs;
                    ::std::mem::forget(self);
                    let #field_ident = {
                        let #field_ident = initf(#store_refs)?;
                        unsafe{::std::mem::transmute::<#ty_lt__, #ty_lt_static>(#field_ident)}
                    };
                    unsafe{::std::ptr::write(ptrs.#field_ident, #field_ident)};
                    Ok(#substruct_ident{
                        _store: ::std::marker::PhantomData,
                        ptrs,
                    })
                }
            };
            impls.last_mut().unwrap().items.push(trybuild);
        }
        let substruct_def = syn::parse_quote! {
            #vis struct #substruct_ident<'a> {
                _store: ::std::marker::PhantomData<&'a mut #sname>,
                ptrs: #store_ptrs_ident,
            }
        };
        structs.push(substruct_def);
        drop_stmts.stmts.insert(
            0,
            syn::parse_quote! {
                unsafe{::std::ptr::drop_in_place(self.ptrs.#field_ident)};
            },
        );
        let dropimpl = syn::parse_quote! {
            impl<'a> Drop for #substruct_ident<'a> {
                fn drop(&mut self) {
                    #drop_stmts
                }
            }
        };
        impls.push(dropimpl);
        let subimpl = syn::parse_quote! {
            impl<'a: 'b, 'b> #substruct_ident<'a> {
            }
        };
        impls.push(subimpl);
        field_refs.push(syn::parse_quote!(&'a #ty_lt_a));
        store_refs.push(syn::parse_quote!(
                        unsafe{::std::mem::transmute::<&'_ #ty_lt_a, &'a #ty_lt_a>(&*(ptrs.#field_ident as *const _))}));
        field_getters.push(syn::parse_quote! {
                    #[inline]
                    #vis fn #ref_ident(&'a self) -> &#ty_lt_a {
                        unsafe{::std::mem::transmute::<&'_ #ty_lt_static, &'a #ty_lt_a>(&*(self.ptrs.#field_ident as *const _))}
                    }
                });
        for getter in &field_getters {
            impls.last_mut().unwrap().items.push(getter.clone());
        }
        // The mut getter is a little more complicated. Interior lifetimes in mut references aren't
        // covariant, unlike const references. In this case, the interior 'static lifetime isn't
        // automatically subtyped to 'a. The danger is that if T contains a mutable reference, we
        // could set that reference to something that lives as long as 'a, but supertype thinks it
        // needs to live as long as 'static. Normally that would be correct, but in this case, the
        // 'static lifetime is a lie to make the store struct compilable, and the field will be
        // dropped after 'a.
        let mut_getter = syn::parse_quote! {
            #[inline]
            #vis fn #mut_ident(&'b mut self) -> &'b mut #ty_lt_a {
                unsafe{::std::mem::transmute::<&'b mut #ty_lt_static, &'b mut #ty_lt_a>(
                        &mut *self.ptrs.#field_ident)}
            }
        };
        impls.last_mut().unwrap().items.push(mut_getter);
        if !view_field_refs.empty_or_trailing() {
            view_field_refs.push_punct(syn::Token![,](call_site));
        }
        view_field_refs.push(syn::parse_quote!(
                        #field_ident: unsafe{::std::mem::transmute::<
                            &'b mut #ty_lt_static, &'b mut #ty_lt_a>(
                                &mut *self.ptrs.#field_ident)}));
        let view_struct_expr: syn::Expr = syn::parse_quote! {
                #viewstruct_ident{
                    #view_field_refs
                }
        };
        let view_getter = syn::parse_quote! {
            #[inline]
            #vis fn view(&'b mut self) -> #viewstruct_ident<'a, 'b> {
                return #view_struct_expr;
            }
        };
        impls.last_mut().unwrap().items.push(view_getter);
        view_field_refs.pop();
        view_field_refs
            .push(syn::parse_quote!(#field_ident: unsafe{&*(self.ptrs.#field_ident as *const _)}));
        // syn::parse_quote! doesn't handle struct field definitions. Instead, we use a dummy
        // struct and extract the first field definition from the parsed value.
        fn first_field(s: syn::ItemStruct) -> syn::Field {
            s.fields.iter().next().unwrap().clone()
        }
        macro_rules! parse_field {
            ($($t:tt)*) => {
                first_field(syn::parse_quote!(struct dummy { $($t)* }))
            };
        };
        let mut_view_field = parse_field!(#vis #field_ident: &'b mut #ty_lt_a);
        view_fields.named.push(mut_view_field);
        structs.push(syn::parse_quote! {
            #vis struct #viewstruct_ident<'a: 'b, 'b>
                #view_fields
        });
        view_fields.named.pop();
        let const_view_field = parse_field!(#vis #field_ident: &'b #ty_lt_a);
        view_fields.named.push(const_view_field);
        init_field_values.push(syn::parse_quote!(
                #field_ident: ::std::mem::MaybeUninit::uninit()
        ));
        raw_ptr_field_values.push(syn::parse_quote!(
                #field_ident: self.#field_ident.as_mut_ptr()
        ));
        let raw_ptr_field = parse_field!(#field_ident: *mut #ty_lt_static);
        raw_ptr_fields.named.push(raw_ptr_field);
    }

    let store_impl = impls.first_mut().unwrap();
    store_impl.items.push(syn::parse_quote! {
        #[inline]
        #vis fn new() -> Self {
             #sname { #init_field_values }
        }
    });
    store_impl.items.push(syn::parse_quote! {
        #[inline]
        fn ptrs(&mut self) -> #store_ptrs_ident {
            #store_ptrs_ident {
                #raw_ptr_field_values
            }
        }
    });
    let store_ptrs_struct: syn::ItemStruct = syn::parse_quote! {
        #[derive(Copy,Clone)]
        struct #store_ptrs_ident
            #raw_ptr_fields
    };
    structs.push(store_ptrs_struct);

    out.push(syn::Item::Struct(struct_def));
    for s in structs {
        out.push(syn::Item::Struct(s));
    }
    for i in impls {
        out.push(syn::Item::Impl(i));
    }
    Ok(())
}