Skip to main content

canic_testkit/pic/
baseline.rs

1use candid::Principal;
2use std::{
3    collections::HashMap,
4    ops::{Deref, DerefMut},
5    panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
6    sync::{Mutex, MutexGuard},
7};
8
9use super::{Pic, PicSerialGuard, acquire_pic_serial_guard, startup};
10
11struct ControllerSnapshot {
12    snapshot_id: Vec<u8>,
13    sender: Option<Principal>,
14}
15
16///
17/// ControllerSnapshots
18///
19
20pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
21
22///
23/// CachedPicBaseline
24///
25
26pub struct CachedPicBaseline<T> {
27    pub pic: Pic,
28    pub snapshots: ControllerSnapshots,
29    pub metadata: T,
30    _serial_guard: PicSerialGuard,
31}
32
33///
34/// CachedPicBaselineGuard
35///
36
37pub struct CachedPicBaselineGuard<'a, T> {
38    guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
39}
40
41/// Acquire one process-local cached PocketIC baseline, building it on first use.
42pub fn acquire_cached_pic_baseline<T, F>(
43    slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
44    build: F,
45) -> (CachedPicBaselineGuard<'static, T>, bool)
46where
47    F: FnOnce() -> CachedPicBaseline<T>,
48{
49    let mut guard = slot
50        .lock()
51        .unwrap_or_else(std::sync::PoisonError::into_inner);
52    let cache_hit = guard.is_some();
53
54    if !cache_hit {
55        *guard = Some(build());
56    }
57
58    (CachedPicBaselineGuard { guard }, cache_hit)
59}
60
61/// Restore one cached PocketIC baseline, rebuilding it if the owned PocketIC
62/// instance has died between tests.
63pub fn restore_or_rebuild_cached_pic_baseline<T, B, R>(
64    slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
65    build: B,
66    restore: R,
67) -> (CachedPicBaselineGuard<'static, T>, bool)
68where
69    B: Fn() -> CachedPicBaseline<T>,
70    R: Fn(&CachedPicBaseline<T>),
71{
72    let (baseline, cache_hit) = acquire_cached_pic_baseline(slot, &build);
73    if !cache_hit {
74        return (baseline, false);
75    }
76
77    let restore_result = catch_unwind(AssertUnwindSafe(|| {
78        restore(&baseline);
79    }));
80    if restore_result.is_ok() {
81        return (baseline, true);
82    }
83
84    let panic_payload = restore_result.expect_err("restore failure must carry panic payload");
85    let message = startup::panic_payload_to_string(panic_payload.as_ref());
86    if !startup::is_dead_instance_transport_error(&message) {
87        resume_unwind(panic_payload);
88    }
89
90    drop(baseline);
91    drop_stale_cached_pic_baseline(slot);
92
93    let (rebuilt, _cache_hit) = acquire_cached_pic_baseline(slot, build);
94    (rebuilt, false)
95}
96
97/// Remove one dead cached baseline and swallow teardown panics from a broken
98/// PocketIC instance so callers can rebuild cleanly.
99pub fn drop_stale_cached_pic_baseline<T>(slot: &'static Mutex<Option<CachedPicBaseline<T>>>) {
100    let stale = {
101        let mut slot = slot
102            .lock()
103            .unwrap_or_else(std::sync::PoisonError::into_inner);
104        slot.take()
105    };
106
107    if let Some(stale) = stale {
108        let _ = catch_unwind(AssertUnwindSafe(|| {
109            drop(stale);
110        }));
111    }
112}
113
114impl<T> Deref for CachedPicBaselineGuard<'_, T> {
115    type Target = CachedPicBaseline<T>;
116
117    fn deref(&self) -> &Self::Target {
118        self.guard
119            .as_ref()
120            .expect("cached PocketIC baseline must exist")
121    }
122}
123
124impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
125    fn deref_mut(&mut self) -> &mut Self::Target {
126        self.guard
127            .as_mut()
128            .expect("cached PocketIC baseline must exist")
129    }
130}
131
132impl<T> CachedPicBaseline<T> {
133    /// Capture one immutable cached baseline from the current PocketIC instance.
134    pub fn capture<I>(
135        pic: Pic,
136        controller_id: Principal,
137        canister_ids: I,
138        metadata: T,
139    ) -> Option<Self>
140    where
141        I: IntoIterator<Item = Principal>,
142    {
143        let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
144
145        Some(Self {
146            pic,
147            snapshots,
148            metadata,
149            _serial_guard: acquire_pic_serial_guard(),
150        })
151    }
152
153    /// Restore the captured snapshot set back into the owned PocketIC instance.
154    pub fn restore(&self, controller_id: Principal) {
155        self.pic
156            .restore_controller_snapshots(controller_id, &self.snapshots);
157    }
158}
159
160impl ControllerSnapshots {
161    pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
162        Self(
163            snapshots
164                .into_iter()
165                .map(|(canister_id, (snapshot_id, sender))| {
166                    (
167                        canister_id,
168                        ControllerSnapshot {
169                            snapshot_id,
170                            sender,
171                        },
172                    )
173                })
174                .collect(),
175        )
176    }
177
178    pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
179        self.0.iter().map(|(canister_id, snapshot)| {
180            (
181                *canister_id,
182                snapshot.snapshot_id.as_slice(),
183                snapshot.sender,
184            )
185        })
186    }
187}