canic_testkit/pic/
baseline.rs1use candid::Principal;
2use std::{
3 collections::HashMap,
4 panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
5 sync::{Mutex, MutexGuard},
6};
7
8use super::{Pic, PicSerialGuard, acquire_pic_serial_guard, startup};
9
10struct ControllerSnapshot {
11 snapshot_id: Vec<u8>,
12 sender: Option<Principal>,
13}
14
15pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
20
21pub struct CachedPicBaseline<T> {
26 pic: Pic,
27 snapshots: ControllerSnapshots,
28 metadata: T,
29 _serial_guard: PicSerialGuard,
30}
31
32pub struct CachedPicBaselineGuard<'a, T> {
37 guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
38}
39
40enum CachedBaselineRestoreFailure {
41 DeadInstanceTransport,
42 Panic(Box<dyn std::any::Any + Send>),
43}
44
45fn acquire_cached_pic_baseline<T, F>(
47 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
48 build: F,
49) -> (CachedPicBaselineGuard<'static, T>, bool)
50where
51 F: FnOnce() -> CachedPicBaseline<T>,
52{
53 let mut guard = slot
54 .lock()
55 .unwrap_or_else(std::sync::PoisonError::into_inner);
56 let cache_hit = guard.is_some();
57
58 if !cache_hit {
59 *guard = Some(build());
60 }
61
62 (CachedPicBaselineGuard { guard }, cache_hit)
63}
64
65pub fn restore_or_rebuild_cached_pic_baseline<T, B, R>(
68 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
69 build: B,
70 restore: R,
71) -> (CachedPicBaselineGuard<'static, T>, bool)
72where
73 B: Fn() -> CachedPicBaseline<T>,
74 R: Fn(&CachedPicBaseline<T>),
75{
76 let (baseline, cache_hit) = acquire_cached_pic_baseline(slot, &build);
77 if !cache_hit {
78 return (baseline, false);
79 }
80
81 match try_restore_cached_pic_baseline(
82 baseline
83 .guard
84 .as_ref()
85 .expect("cached PocketIC baseline must exist"),
86 restore,
87 ) {
88 Ok(()) => return (baseline, true),
89 Err(CachedBaselineRestoreFailure::DeadInstanceTransport) => {}
90 Err(CachedBaselineRestoreFailure::Panic(payload)) => {
91 resume_unwind(payload);
92 }
93 }
94
95 drop(baseline);
96 drop_stale_cached_pic_baseline(slot);
97
98 let (rebuilt, _cache_hit) = acquire_cached_pic_baseline(slot, build);
99 (rebuilt, false)
100}
101
102fn try_restore_cached_pic_baseline<T, R>(
105 baseline: &CachedPicBaseline<T>,
106 restore: R,
107) -> Result<(), CachedBaselineRestoreFailure>
108where
109 R: Fn(&CachedPicBaseline<T>),
110{
111 match catch_unwind(AssertUnwindSafe(|| restore(baseline))) {
112 Ok(()) => Ok(()),
113 Err(payload) => {
114 if startup::panic_is_dead_instance_transport(payload.as_ref()) {
115 Err(CachedBaselineRestoreFailure::DeadInstanceTransport)
116 } else {
117 Err(CachedBaselineRestoreFailure::Panic(payload))
118 }
119 }
120 }
121}
122
123fn drop_stale_cached_pic_baseline<T>(slot: &'static Mutex<Option<CachedPicBaseline<T>>>) {
126 let stale = {
127 let mut slot = slot
128 .lock()
129 .unwrap_or_else(std::sync::PoisonError::into_inner);
130 slot.take()
131 };
132
133 if let Some(stale) = stale {
134 let _ = catch_unwind(AssertUnwindSafe(|| {
135 drop(stale);
136 }));
137 }
138}
139
140impl<T> CachedPicBaselineGuard<'_, T> {
141 #[must_use]
143 pub fn pic(&self) -> &Pic {
144 self.guard
145 .as_ref()
146 .expect("cached PocketIC baseline must exist")
147 .pic()
148 }
149
150 #[must_use]
152 pub fn pic_mut(&mut self) -> &mut Pic {
153 self.guard
154 .as_mut()
155 .expect("cached PocketIC baseline must exist")
156 .pic_mut()
157 }
158
159 #[must_use]
161 pub fn metadata(&self) -> &T {
162 self.guard
163 .as_ref()
164 .expect("cached PocketIC baseline must exist")
165 .metadata()
166 }
167
168 #[must_use]
170 pub fn metadata_mut(&mut self) -> &mut T {
171 self.guard
172 .as_mut()
173 .expect("cached PocketIC baseline must exist")
174 .metadata_mut()
175 }
176
177 pub fn restore(&self, controller_id: Principal) {
179 self.guard
180 .as_ref()
181 .expect("cached PocketIC baseline must exist")
182 .restore(controller_id);
183 }
184}
185
186impl<T> CachedPicBaseline<T> {
187 pub fn capture<I>(
189 pic: Pic,
190 controller_id: Principal,
191 canister_ids: I,
192 metadata: T,
193 ) -> Option<Self>
194 where
195 I: IntoIterator<Item = Principal>,
196 {
197 let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
198
199 Some(Self {
200 pic,
201 snapshots,
202 metadata,
203 _serial_guard: acquire_pic_serial_guard(),
204 })
205 }
206
207 pub fn restore(&self, controller_id: Principal) {
209 self.pic
210 .restore_controller_snapshots(controller_id, &self.snapshots);
211 }
212
213 #[must_use]
215 pub const fn pic(&self) -> &Pic {
216 &self.pic
217 }
218
219 #[must_use]
221 pub const fn pic_mut(&mut self) -> &mut Pic {
222 &mut self.pic
223 }
224
225 #[must_use]
227 pub const fn metadata(&self) -> &T {
228 &self.metadata
229 }
230
231 #[must_use]
233 pub const fn metadata_mut(&mut self) -> &mut T {
234 &mut self.metadata
235 }
236}
237
238impl ControllerSnapshots {
239 pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
240 Self(
241 snapshots
242 .into_iter()
243 .map(|(canister_id, (snapshot_id, sender))| {
244 (
245 canister_id,
246 ControllerSnapshot {
247 snapshot_id,
248 sender,
249 },
250 )
251 })
252 .collect(),
253 )
254 }
255
256 pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
257 self.0.iter().map(|(canister_id, snapshot)| {
258 (
259 *canister_id,
260 snapshot.snapshot_id.as_slice(),
261 snapshot.sender,
262 )
263 })
264 }
265}