1use std::ops::{Deref, DerefMut};
2use std::sync::{Arc};
3use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
4use crate::save::{Loader, Savable, Saver};
5use crate::unsafe_utils::DangerousCell;
6
7pub struct State<T> {
8 inner: Arc<(DangerousCell<u64>, RwLock<T>)>,
9 local_version: DangerousCell<u64>,
10}
11
12impl<T> State<T> {
13 pub fn new(value: T) -> Self {
14 Self {
15 inner: Arc::new((DangerousCell::new(0), RwLock::new(value))),
16 local_version: DangerousCell::new(0),
17 }
18 }
19
20 pub fn read(&self) -> RwLockReadGuard<T> {
21 self.inner.1.read()
22 }
23
24 pub fn write(&self) -> StateWriteGuard<T> {
25 StateWriteGuard {
26 inner: self.inner.1.write(),
27 ptr: self.inner.0.get_mut(),
28 }
29 }
30
31 pub fn get_version(&self) -> u64 {
32 self.inner.0.get_val()
33 }
34
35 pub fn get_local_version(&self) -> u64 {
36 self.local_version.get_val()
37 }
38
39 pub fn is_outdated(&self) -> bool {
40 self.inner.0.get_val() != self.local_version.get_val()
41 }
42
43 pub fn update(&self) {
44 self.local_version.replace(self.inner.0.get_val());
45 }
46
47 pub fn force_outdated(&self) {
48 if self.inner.0.get_val() == 0 {
49 self.local_version.replace(u64::MAX);
50 } else {
51 self.local_version.replace(0);
52 }
53 }
54
55 pub fn map<U>(&self, mapper: fn(&T) -> U) -> MappedState<T, U> {
56 MappedState::new(mapper, self.clone())
57 }
58}
59
60impl<T: Clone> State<T> {
61 pub fn map_identity(&self) -> MappedState<T, T> {
62 MappedState::new(|x| x.clone(), self.clone())
63 }
64}
65
66unsafe impl<T> Send for State<T> {}
67unsafe impl<T> Sync for State<T> {}
68
69impl<T> Clone for State<T> {
70 fn clone(&self) -> Self {
71 State {
72 inner: self.inner.clone(),
73 local_version: DangerousCell::new(self.local_version.get_val()),
74 }
75 }
76}
77
78impl<T> PartialEq for State<T> {
79 fn eq(&self, other: &Self) -> bool {
80 Arc::ptr_eq(&self.inner, &other.inner)
81 }
82}
83
84impl<T> Eq for State<T> {}
85
86impl<T> PartialOrd for State<T> {
87 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92impl<T> Ord for State<T> {
93 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
94 self.local_version.get_val().cmp(&other.local_version.get_val())
95 }
96}
97
98impl<T: Savable> Savable for State<T> {
99 fn save(&self, saver: &mut impl Saver) {
100 let l = self.read();
101 l.save(saver);
102 }
103
104 fn load(loader: &mut impl Loader) -> Result<Self, String> {
105 let t = T::load(loader)?;
106 Ok(Self::new(t))
107 }
108}
109
110pub struct StateWriteGuard<'a, T: ?Sized + 'a> {
111 inner: RwLockWriteGuard<'a, T>,
112 ptr: &'a mut u64,
113}
114
115impl<'a, T: ?Sized + 'a> Deref for StateWriteGuard<'a, T> {
116 type Target = T;
117
118 fn deref(&self) -> &Self::Target {
119 &self.inner
120 }
121}
122
123impl<'a, T: ?Sized + 'a> DerefMut for StateWriteGuard<'a, T> {
124 fn deref_mut(&mut self) -> &mut Self::Target {
125 &mut self.inner
126 }
127}
128
129impl<'a, T: ?Sized + 'a> Drop for StateWriteGuard<'a, T> {
130 fn drop(&mut self) {
131 *self.ptr += 1;
132 }
133}
134
135#[macro_export]
136macro_rules! when {
137 ([$($dependency:expr),+$(,)?] => $code:block) => {
138 if $($dependency.is_outdated())||+ $code
139 };
140 ([$($dependency:expr),+$(,)?] => $code:block else $otherwise:block) => {
141 if $($dependency.is_outdated())||+ $code
142 else $otherwise
143 };
144 ([] => $code:block) => {};
145 ([] => $code:block else $otherwise:block) => { $otherwise };
146}
147
148#[macro_export]
149macro_rules! update {
150 ([$($dependency:expr),+$(,)?]) => {
151 $(
152 $dependency.update();
153 )+
154 };
155 ([]) => {};
156}
157
158#[derive(Clone)]
159pub struct MappedState<T, U> {
160 mapper: fn(&T) -> U,
161 old: State<T>,
162}
163
164impl<T, U> MappedState<T, U> {
165 pub fn new(mapper: fn(&T) -> U, state: State<T>) -> Self {
166 Self {
167 mapper,
168 old: state,
169 }
170 }
171
172 pub fn read(&self) -> MappedStateReadGuard<'_, T, U> {
173 let guard = self.old.read();
174 MappedStateReadGuard {
175 mapped: (self.mapper)(guard.deref()),
176 rwlock_guard: guard,
177 }
178 }
179
180 pub fn write(&self) -> StateWriteGuard<T> {
181 StateWriteGuard {
182 inner: self.old.inner.1.write(),
183 ptr: self.old.inner.0.get_mut(),
184 }
185 }
186
187 pub fn get_version(&self) -> u64 {
188 self.old.inner.0.get_val()
189 }
190
191 pub fn get_local_version(&self) -> u64 {
192 self.old.local_version.get_val()
193 }
194
195 pub fn is_outdated(&self) -> bool {
196 self.old.inner.0.get_val() != self.old.local_version.get_val()
197 }
198
199 pub fn update(&self) {
200 self.old.local_version.replace(self.old.inner.0.get_val());
201 }
202
203 pub fn force_outdated(&self) {
204 if self.old.inner.0.get_val() == 0 {
205 self.old.local_version.replace(u64::MAX);
206 } else {
207 self.old.local_version.replace(0);
208 }
209 }
210}
211
212pub struct MappedStateReadGuard<'a, T, U> {
213 mapped: U,
214 rwlock_guard: RwLockReadGuard<'a, T>
215}
216
217impl<'a, T, U> Deref for MappedStateReadGuard<'a, T, U> {
218 type Target = U;
219
220 fn deref(&self) -> &Self::Target {
221 &self.mapped
222 }
223}