flat_drop/
lib.rs

1//! In this crate, we define the [FlatDrop] type.
2//! `FlatDrop<K>` behaves just like a `K`, but with a custom `Drop` implementation
3//! that avoids blowing up the stack when dropping large objects.
4//! Instead of recursively dropping subobjects, we perform a depth-first search
5//! and iteratively drop subobjects.
6//!
7//! To use this crate, you can replace recursive [Box]es and [Arc]s in your types
8//! with `FlatDrop<Box<T>>` or `FlatDrop<Arc<T>>`. You'll need to implement the
9//! [Recursive] trait for your type, which performs one step of the iterative
10//! dropping procedure.
11//!
12//! This crate uses `unsafe` internally, but the external API is safe.
13//!
14//! # Example
15//!
16//! ```
17//! use flat_drop::{FlatDrop, Recursive};
18//!
19//! /// Peano natural numbers.
20//! enum Natural {
21//!     Zero,
22//!     Succ(FlatDrop<Box<Natural>>),
23//! }
24//!
25//! impl Recursive for Natural {
26//!     type Container = Box<Natural>;
27//!
28//!     fn destruct(self) -> impl Iterator<Item = Self::Container> {
29//!         match self {
30//!             Natural::Zero => None,
31//!             Natural::Succ(pred) => Some(pred.into_inner()),
32//!         }
33//!         .into_iter()
34//!     }
35//! }
36//!
37//! impl Natural {
38//!     pub fn from_usize(value: usize) -> Self {
39//!         (0..value).fold(Self::Zero, |nat, _| {
40//!             Self::Succ(FlatDrop::new(Box::new(nat)))
41//!         })
42//!     }
43//! }
44//!
45//! // Create a new thread with a 4kb stack and allocate a number far bigger than 4 * 1024.
46//! const STACK_SIZE: usize = 4 * 1024;
47//!
48//! fn task() {
49//!     let nat = Natural::from_usize(STACK_SIZE * 100);
50//!     drop(std::hint::black_box(nat));
51//! }
52//!
53//! std::thread::Builder::new()
54//!     .stack_size(STACK_SIZE)
55//!     .spawn(task)
56//!     .unwrap()
57//!     .join()
58//!     .unwrap();
59//! ```
60
61use std::{
62    fmt::Display,
63    mem::ManuallyDrop,
64    ops::{Deref, DerefMut},
65    rc::Rc,
66    sync::Arc,
67};
68
69/// The [Recursive::destruct] function decomposes an object into some component parts.
70/// Usually, [Recursive::Container] is something like `Box<Self>` or `Arc<Self>`.
71pub trait Recursive {
72    type Container;
73
74    fn destruct(self) -> impl Iterator<Item = Self::Container>;
75}
76
77/// A trait for a smart pointer that contains (at most) a single value.
78pub trait IntoOptionInner {
79    type Inner;
80
81    /// A (potentially) fallible operation to convert the container into its internal value.
82    /// This should never drop any data.
83    ///
84    /// If `Self == Box`, this will always return `Some(*self)`.
85    /// If `Self == Arc`, this is `Arc::into_inner`.
86    fn into_option_inner(self) -> Option<Self::Inner>;
87}
88
89/// If `K` is a container of a recursive type, such as `Box<T>` where `T: Recursive`,
90/// `FlatDrop<K>` behaves just like `K`, but with a custom `Drop` implementation.
91/// In this implementation, we gather the recursive parts of the object iteratively
92/// and drop them without recursion, avoiding stack overflows when dropping
93/// large recursive objects.
94///
95/// # Safety
96///
97/// We keep the invariant that the inner object is always initialised, but will
98/// be dropped (exactly once) in the `drop` implementation.
99#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
100#[repr(transparent)]
101pub struct FlatDrop<K>(ManuallyDrop<K>)
102where
103    K: IntoOptionInner,
104    K::Inner: Recursive<Container = K>;
105
106impl<K> Drop for FlatDrop<K>
107where
108    K: IntoOptionInner,
109    K::Inner: Recursive<Container = K>,
110{
111    fn drop(&mut self) {
112        // Move out of the inner `ManuallyDrop`.
113        // Safety: the inner value has not yet been dropped, and will not be used again.
114        let value = unsafe { ManuallyDrop::take(&mut self.0) };
115
116        // Construct a sequence of containers to drop.
117        let mut to_drop = vec![value];
118
119        // Iteratively decompose each container from this list.
120        // This avoids creating excessive stack frames when destroying large objects.
121        while let Some(container) = to_drop.pop() {
122            if let Some(value) = container.into_option_inner() {
123                to_drop.extend(value.destruct());
124            }
125        }
126
127        // The drop glue will be a no-op since the field is `ManuallyDrop`.
128    }
129}
130
131// Now that we've defined the core parts of the library, we'll make some API.
132
133impl<T> IntoOptionInner for Box<T> {
134    type Inner = T;
135
136    fn into_option_inner(self) -> Option<Self::Inner> {
137        Some(*self)
138    }
139}
140
141impl<T> IntoOptionInner for Rc<T> {
142    type Inner = T;
143
144    fn into_option_inner(self) -> Option<Self::Inner> {
145        Rc::into_inner(self)
146    }
147}
148
149impl<T> IntoOptionInner for Arc<T> {
150    type Inner = T;
151
152    fn into_option_inner(self) -> Option<Self::Inner> {
153        Arc::into_inner(self)
154    }
155}
156
157pub type FlatBox<T> = FlatDrop<Box<T>>;
158pub type FlatRc<T> = FlatDrop<Rc<T>>;
159pub type FlatArc<T> = FlatDrop<Arc<T>>;
160
161impl<K> FlatDrop<K>
162where
163    K: IntoOptionInner,
164    K::Inner: Recursive<Container = K>,
165{
166    pub const fn new(container: K) -> Self {
167        Self(ManuallyDrop::new(container))
168    }
169
170    pub fn into_inner(mut self) -> K {
171        // Safety: This value is always initialised.
172        // Once we take it, we need to be careful to not call `drop` on `self`.
173        let value = unsafe { ManuallyDrop::take(&mut self.0) };
174        // This doesn't leak, because `self` is contained purely on the stack.
175        std::mem::forget(self);
176        value
177    }
178}
179
180impl<T> FlatBox<T>
181where
182    T: Recursive<Container = Box<T>>,
183{
184    pub fn unbox(self) -> T {
185        *self.into_inner()
186    }
187}
188
189impl<K, T> AsRef<T> for FlatDrop<K>
190where
191    T: ?Sized,
192    K: IntoOptionInner,
193    K::Inner: Recursive<Container = K>,
194    K: AsRef<T>,
195{
196    fn as_ref(&self) -> &T {
197        (**self).as_ref()
198    }
199}
200
201impl<K, T> AsMut<T> for FlatDrop<K>
202where
203    T: ?Sized,
204    K: IntoOptionInner,
205    K::Inner: Recursive<Container = K>,
206    K: AsMut<T>,
207{
208    fn as_mut(&mut self) -> &mut T {
209        (**self).as_mut()
210    }
211}
212
213impl<K> Deref for FlatDrop<K>
214where
215    K: IntoOptionInner,
216    K::Inner: Recursive<Container = K>,
217{
218    type Target = K;
219
220    fn deref(&self) -> &K {
221        self.0.deref()
222    }
223}
224
225impl<K> DerefMut for FlatDrop<K>
226where
227    K: IntoOptionInner,
228    K::Inner: Recursive<Container = K>,
229{
230    fn deref_mut(&mut self) -> &mut K {
231        self.0.deref_mut()
232    }
233}
234
235impl<K> Display for FlatDrop<K>
236where
237    K: IntoOptionInner + Display,
238    K::Inner: Recursive<Container = K>,
239{
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        <K as Display>::fmt(self, f)
242    }
243}
244
245impl<K> From<K> for FlatDrop<K>
246where
247    K: IntoOptionInner,
248    K::Inner: Recursive<Container = K>,
249{
250    fn from(value: K) -> Self {
251        Self::new(value)
252    }
253}
254
255impl<T> FlatBox<T>
256where
257    T: Recursive<Container = Box<T>>,
258{
259    pub fn new_boxed(value: T) -> Self {
260        Self::new(Box::new(value))
261    }
262}
263
264impl<T> From<T> for FlatBox<T>
265where
266    T: Recursive<Container = Box<T>>,
267{
268    fn from(value: T) -> Self {
269        Self::new_boxed(value)
270    }
271}
272
273impl<T> FlatRc<T>
274where
275    T: Recursive<Container = Rc<T>>,
276{
277    pub fn new_rc(value: T) -> Self {
278        Self::new(Rc::new(value))
279    }
280}
281
282impl<T> From<T> for FlatRc<T>
283where
284    T: Recursive<Container = Rc<T>>,
285{
286    fn from(value: T) -> Self {
287        Self::new_rc(value)
288    }
289}
290
291impl<T> FlatArc<T>
292where
293    T: Recursive<Container = Arc<T>>,
294{
295    pub fn new_arc(value: T) -> Self {
296        Self::new(Arc::new(value))
297    }
298}
299
300impl<T> From<T> for FlatArc<T>
301where
302    T: Recursive<Container = Arc<T>>,
303{
304    fn from(value: T) -> Self {
305        Self::new_arc(value)
306    }
307}
308
309#[cfg(feature = "serde")]
310impl<K> serde::Serialize for FlatDrop<K>
311where
312    K: IntoOptionInner,
313    K::Inner: Recursive<Container = K>,
314    K: serde::Serialize,
315{
316    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317    where
318        S: serde::Serializer,
319    {
320        <K as serde::Serialize>::serialize(self, serializer)
321    }
322}
323
324#[cfg(feature = "serde")]
325impl<'de, K> serde::Deserialize<'de> for FlatDrop<K>
326where
327    K: IntoOptionInner,
328    K::Inner: Recursive<Container = K>,
329    K: serde::Deserialize<'de>,
330{
331    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
332    where
333        D: serde::Deserializer<'de>,
334    {
335        <K as serde::Deserialize>::deserialize(deserializer).map(Self::new)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use crate::{FlatDrop, Recursive};
342
343    /// Peano natural numbers.
344    enum Natural {
345        Zero,
346        Succ(FlatDrop<Box<Natural>>),
347    }
348
349    impl Recursive for Natural {
350        type Container = Box<Natural>;
351
352        fn destruct(self) -> impl Iterator<Item = Self::Container> {
353            match self {
354                Natural::Zero => None,
355                Natural::Succ(pred) => Some(pred.into_inner()),
356            }
357            .into_iter()
358        }
359    }
360
361    impl Natural {
362        pub fn from_usize(value: usize) -> Self {
363            (0..value).fold(Self::Zero, |nat, _| {
364                Self::Succ(FlatDrop::new(Box::new(nat)))
365            })
366        }
367    }
368
369    #[test]
370    fn test_large_natural() {
371        // Create a new thread with a 4kb stack and allocate a number far bigger than 4 * 1024.
372        const STACK_SIZE: usize = 4 * 1024;
373
374        fn task() {
375            let nat = Natural::from_usize(STACK_SIZE * 100);
376            println!("Dropping...");
377            drop(std::hint::black_box(nat));
378            println!("Dropped.");
379        }
380
381        std::thread::Builder::new()
382            .stack_size(STACK_SIZE)
383            .spawn(task)
384            .unwrap()
385            .join()
386            .unwrap();
387    }
388}