drt_chain_vm/with_shared/
shareable.rs

1use std::{
2    ops::{Deref, DerefMut},
3    sync::Arc,
4};
5
6/// Wraps an object and provides mutable access to it.
7///
8/// The point is that sometimes we want to stop mutability and proliferate reference-counted pointers to it.
9///
10/// This happens in a controlled environment, in the `with_shared` method closure argument.
11/// All reference-counted pointers are expected to be dropped until that closure finishes.
12pub enum Shareable<T> {
13    Owned(T),
14    Shared(Arc<T>),
15}
16
17impl<T> Shareable<T> {
18    pub fn new(t: T) -> Self {
19        Shareable::Owned(t)
20    }
21
22    /// Destroys the object and returns the contents.
23    pub fn into_inner(self) -> T {
24        if let Shareable::Owned(t) = self {
25            t
26        } else {
27            panic!("cannot access ShareableMut owned object")
28        }
29    }
30}
31
32impl<T> Default for Shareable<T>
33where
34    T: Default,
35{
36    fn default() -> Self {
37        Shareable::new(T::default())
38    }
39}
40
41impl<T> Deref for Shareable<T> {
42    type Target = T;
43
44    fn deref(&self) -> &Self::Target {
45        match self {
46            Shareable::Owned(t) => t,
47            Shareable::Shared(rc) => rc.deref(),
48        }
49    }
50}
51
52impl<T> DerefMut for Shareable<T> {
53    fn deref_mut(&mut self) -> &mut Self::Target {
54        match self {
55            Shareable::Owned(t) => t,
56            Shareable::Shared(_) => {
57                panic!("cannot mutably dereference ShareableMut when in Shared state")
58            },
59        }
60    }
61}
62
63impl<T> Shareable<T> {
64    fn get_arc(&self) -> Arc<T> {
65        if let Shareable::Shared(arc) = self {
66            arc.clone()
67        } else {
68            panic!("invalid ShareableMut state: Shared expected")
69        }
70    }
71
72    fn wrap_arc_strict(&mut self) {
73        unsafe {
74            let temp = std::ptr::read(self);
75            if let Shareable::Owned(t) = temp {
76                std::ptr::write(self, Shareable::Shared(Arc::new(t)));
77            } else {
78                std::mem::forget(temp);
79                panic!("invalid ShareableMut state: Owned expected")
80            }
81        }
82    }
83
84    fn unwrap_arc_strict(&mut self) {
85        unsafe {
86            let temp = std::ptr::read(self);
87            if let Shareable::Shared(arc) = temp {
88                match Arc::try_unwrap(arc) {
89                    Ok(t) => {
90                        std::ptr::write(self, Shareable::Owned(t));
91                    },
92                    Err(rc) => {
93                        std::mem::forget(rc);
94                        panic!("failed to recover Owned ShareableMut from Shared, not all Rc pointers dropped")
95                    },
96                }
97            } else {
98                std::mem::forget(temp);
99                panic!("invalid ShareableMut state: Shared expected")
100            }
101        }
102    }
103
104    /// The main functionality of `Shared`.
105    ///
106    /// Temporarily makes the object immutable, and creates a Rc pointer to the contents, which can then be cloned.
107    ///
108    /// Important restriction: all Rc pointers creates from the one given to the closure `f` must be dropped before its execution ends.
109    /// Otherwise the operation will panic.
110    pub fn with_shared<F, R>(&mut self, f: F) -> R
111    where
112        F: FnOnce(Arc<T>) -> R,
113    {
114        self.wrap_arc_strict();
115
116        let result = f(self.get_arc());
117
118        self.unwrap_arc_strict();
119
120        result
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use std::cell::RefCell;
127
128    use super::Shareable;
129
130    #[test]
131    fn test_shareable_mut_1() {
132        let mut s = Shareable::new("test string".to_string());
133        let l = s.with_shared(|s_arc| s_arc.len());
134        assert_eq!(s.len(), l);
135    }
136
137    #[test]
138    fn test_shareable_mut_2() {
139        let mut s = Shareable::new(RefCell::new("test string".to_string()));
140        s.with_shared(|s_arc| {
141            s_arc.borrow_mut().push_str(" ... changed");
142        });
143        assert_eq!(s.borrow().as_str(), "test string ... changed");
144        assert_eq!(s.into_inner().into_inner(), "test string ... changed");
145    }
146
147    #[test]
148    #[should_panic = "failed to recover Owned ShareableMut from Shared, not all Rc pointers dropped"]
149    fn test_shareable_mut_fail() {
150        let mut s = Shareable::new("test string".to_string());
151        let _illegally_extracted_arc = s.with_shared(|s_arc| s_arc);
152    }
153}