drt_chain_vm/with_shared/
shareable.rs1use std::{
2 ops::{Deref, DerefMut},
3 sync::Arc,
4};
5
6pub enum Shareable<T> {
13 Owned(T),
14 Shared(Arc<T>),
15}
16
17impl<T> Shareable<T> {
18 pub fn new(t: T) -> Self {
19 Shareable::Owned(t)
20 }
21
22 pub fn into_inner(self) -> T {
24 if let Shareable::Owned(t) = self {
25 t
26 } else {
27 panic!("cannot access ShareableMut owned object")
28 }
29 }
30}
31
32impl<T> Default for Shareable<T>
33where
34 T: Default,
35{
36 fn default() -> Self {
37 Shareable::new(T::default())
38 }
39}
40
41impl<T> Deref for Shareable<T> {
42 type Target = T;
43
44 fn deref(&self) -> &Self::Target {
45 match self {
46 Shareable::Owned(t) => t,
47 Shareable::Shared(rc) => rc.deref(),
48 }
49 }
50}
51
52impl<T> DerefMut for Shareable<T> {
53 fn deref_mut(&mut self) -> &mut Self::Target {
54 match self {
55 Shareable::Owned(t) => t,
56 Shareable::Shared(_) => {
57 panic!("cannot mutably dereference ShareableMut when in Shared state")
58 },
59 }
60 }
61}
62
63impl<T> Shareable<T> {
64 fn get_arc(&self) -> Arc<T> {
65 if let Shareable::Shared(arc) = self {
66 arc.clone()
67 } else {
68 panic!("invalid ShareableMut state: Shared expected")
69 }
70 }
71
72 fn wrap_arc_strict(&mut self) {
73 unsafe {
74 let temp = std::ptr::read(self);
75 if let Shareable::Owned(t) = temp {
76 std::ptr::write(self, Shareable::Shared(Arc::new(t)));
77 } else {
78 std::mem::forget(temp);
79 panic!("invalid ShareableMut state: Owned expected")
80 }
81 }
82 }
83
84 fn unwrap_arc_strict(&mut self) {
85 unsafe {
86 let temp = std::ptr::read(self);
87 if let Shareable::Shared(arc) = temp {
88 match Arc::try_unwrap(arc) {
89 Ok(t) => {
90 std::ptr::write(self, Shareable::Owned(t));
91 },
92 Err(rc) => {
93 std::mem::forget(rc);
94 panic!("failed to recover Owned ShareableMut from Shared, not all Rc pointers dropped")
95 },
96 }
97 } else {
98 std::mem::forget(temp);
99 panic!("invalid ShareableMut state: Shared expected")
100 }
101 }
102 }
103
104 pub fn with_shared<F, R>(&mut self, f: F) -> R
111 where
112 F: FnOnce(Arc<T>) -> R,
113 {
114 self.wrap_arc_strict();
115
116 let result = f(self.get_arc());
117
118 self.unwrap_arc_strict();
119
120 result
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use std::cell::RefCell;
127
128 use super::Shareable;
129
130 #[test]
131 fn test_shareable_mut_1() {
132 let mut s = Shareable::new("test string".to_string());
133 let l = s.with_shared(|s_arc| s_arc.len());
134 assert_eq!(s.len(), l);
135 }
136
137 #[test]
138 fn test_shareable_mut_2() {
139 let mut s = Shareable::new(RefCell::new("test string".to_string()));
140 s.with_shared(|s_arc| {
141 s_arc.borrow_mut().push_str(" ... changed");
142 });
143 assert_eq!(s.borrow().as_str(), "test string ... changed");
144 assert_eq!(s.into_inner().into_inner(), "test string ... changed");
145 }
146
147 #[test]
148 #[should_panic = "failed to recover Owned ShareableMut from Shared, not all Rc pointers dropped"]
149 fn test_shareable_mut_fail() {
150 let mut s = Shareable::new("test string".to_string());
151 let _illegally_extracted_arc = s.with_shared(|s_arc| s_arc);
152 }
153}