canic_testkit/pic/
baseline.rs1use candid::Principal;
2use std::{
3 collections::HashMap,
4 ops::{Deref, DerefMut},
5 panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
6 sync::{Mutex, MutexGuard},
7};
8
9use super::{Pic, PicSerialGuard, acquire_pic_serial_guard, startup};
10
11struct ControllerSnapshot {
12 snapshot_id: Vec<u8>,
13 sender: Option<Principal>,
14}
15
16pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
21
22pub struct CachedPicBaseline<T> {
27 pic: Pic,
28 snapshots: ControllerSnapshots,
29 metadata: T,
30 _serial_guard: PicSerialGuard,
31}
32
33pub struct CachedPicBaselineGuard<'a, T> {
38 guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
39}
40
41enum CachedBaselineRestoreFailure {
42 DeadInstanceTransport,
43 Panic(Box<dyn std::any::Any + Send>),
44}
45
46fn acquire_cached_pic_baseline<T, F>(
48 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
49 build: F,
50) -> (CachedPicBaselineGuard<'static, T>, bool)
51where
52 F: FnOnce() -> CachedPicBaseline<T>,
53{
54 let mut guard = slot
55 .lock()
56 .unwrap_or_else(std::sync::PoisonError::into_inner);
57 let cache_hit = guard.is_some();
58
59 if !cache_hit {
60 *guard = Some(build());
61 }
62
63 (CachedPicBaselineGuard { guard }, cache_hit)
64}
65
66pub fn restore_or_rebuild_cached_pic_baseline<T, B, R>(
69 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
70 build: B,
71 restore: R,
72) -> (CachedPicBaselineGuard<'static, T>, bool)
73where
74 B: Fn() -> CachedPicBaseline<T>,
75 R: Fn(&CachedPicBaseline<T>),
76{
77 let (baseline, cache_hit) = acquire_cached_pic_baseline(slot, &build);
78 if !cache_hit {
79 return (baseline, false);
80 }
81
82 match try_restore_cached_pic_baseline(&baseline, restore) {
83 Ok(()) => return (baseline, true),
84 Err(CachedBaselineRestoreFailure::DeadInstanceTransport) => {}
85 Err(CachedBaselineRestoreFailure::Panic(payload)) => {
86 resume_unwind(payload);
87 }
88 }
89
90 drop(baseline);
91 drop_stale_cached_pic_baseline(slot);
92
93 let (rebuilt, _cache_hit) = acquire_cached_pic_baseline(slot, build);
94 (rebuilt, false)
95}
96
97fn try_restore_cached_pic_baseline<T, R>(
100 baseline: &CachedPicBaseline<T>,
101 restore: R,
102) -> Result<(), CachedBaselineRestoreFailure>
103where
104 R: Fn(&CachedPicBaseline<T>),
105{
106 match catch_unwind(AssertUnwindSafe(|| restore(baseline))) {
107 Ok(()) => Ok(()),
108 Err(payload) => {
109 if startup::panic_is_dead_instance_transport(payload.as_ref()) {
110 Err(CachedBaselineRestoreFailure::DeadInstanceTransport)
111 } else {
112 Err(CachedBaselineRestoreFailure::Panic(payload))
113 }
114 }
115 }
116}
117
118fn drop_stale_cached_pic_baseline<T>(slot: &'static Mutex<Option<CachedPicBaseline<T>>>) {
121 let stale = {
122 let mut slot = slot
123 .lock()
124 .unwrap_or_else(std::sync::PoisonError::into_inner);
125 slot.take()
126 };
127
128 if let Some(stale) = stale {
129 let _ = catch_unwind(AssertUnwindSafe(|| {
130 drop(stale);
131 }));
132 }
133}
134
135impl<T> Deref for CachedPicBaselineGuard<'_, T> {
136 type Target = CachedPicBaseline<T>;
137
138 fn deref(&self) -> &Self::Target {
139 self.guard
140 .as_ref()
141 .expect("cached PocketIC baseline must exist")
142 }
143}
144
145impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
146 fn deref_mut(&mut self) -> &mut Self::Target {
147 self.guard
148 .as_mut()
149 .expect("cached PocketIC baseline must exist")
150 }
151}
152
153impl<T> CachedPicBaseline<T> {
154 pub fn capture<I>(
156 pic: Pic,
157 controller_id: Principal,
158 canister_ids: I,
159 metadata: T,
160 ) -> Option<Self>
161 where
162 I: IntoIterator<Item = Principal>,
163 {
164 let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
165
166 Some(Self {
167 pic,
168 snapshots,
169 metadata,
170 _serial_guard: acquire_pic_serial_guard(),
171 })
172 }
173
174 pub fn restore(&self, controller_id: Principal) {
176 self.pic
177 .restore_controller_snapshots(controller_id, &self.snapshots);
178 }
179
180 #[must_use]
182 pub const fn pic(&self) -> &Pic {
183 &self.pic
184 }
185
186 #[must_use]
188 pub const fn pic_mut(&mut self) -> &mut Pic {
189 &mut self.pic
190 }
191
192 #[must_use]
194 pub const fn metadata(&self) -> &T {
195 &self.metadata
196 }
197
198 #[must_use]
200 pub const fn metadata_mut(&mut self) -> &mut T {
201 &mut self.metadata
202 }
203}
204
205impl ControllerSnapshots {
206 pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
207 Self(
208 snapshots
209 .into_iter()
210 .map(|(canister_id, (snapshot_id, sender))| {
211 (
212 canister_id,
213 ControllerSnapshot {
214 snapshot_id,
215 sender,
216 },
217 )
218 })
219 .collect(),
220 )
221 }
222
223 pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
224 self.0.iter().map(|(canister_id, snapshot)| {
225 (
226 *canister_id,
227 snapshot.snapshot_id.as_slice(),
228 snapshot.sender,
229 )
230 })
231 }
232}