Skip to main content

canic_testkit/pic/
baseline.rs

1use candid::Principal;
2use std::{
3    collections::HashMap,
4    ops::{Deref, DerefMut},
5    panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
6    sync::{Mutex, MutexGuard},
7};
8
9use super::{Pic, PicSerialGuard, acquire_pic_serial_guard, startup};
10
11struct ControllerSnapshot {
12    snapshot_id: Vec<u8>,
13    sender: Option<Principal>,
14}
15
16///
17/// ControllerSnapshots
18///
19
20pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
21
22///
23/// CachedPicBaseline
24///
25
26pub struct CachedPicBaseline<T> {
27    pic: Pic,
28    snapshots: ControllerSnapshots,
29    metadata: T,
30    _serial_guard: PicSerialGuard,
31}
32
33///
34/// CachedPicBaselineGuard
35///
36
37pub struct CachedPicBaselineGuard<'a, T> {
38    guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
39}
40
41enum CachedBaselineRestoreFailure {
42    DeadInstanceTransport,
43    Panic(Box<dyn std::any::Any + Send>),
44}
45
46/// Acquire one process-local cached PocketIC baseline, building it on first use.
47fn acquire_cached_pic_baseline<T, F>(
48    slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
49    build: F,
50) -> (CachedPicBaselineGuard<'static, T>, bool)
51where
52    F: FnOnce() -> CachedPicBaseline<T>,
53{
54    let mut guard = slot
55        .lock()
56        .unwrap_or_else(std::sync::PoisonError::into_inner);
57    let cache_hit = guard.is_some();
58
59    if !cache_hit {
60        *guard = Some(build());
61    }
62
63    (CachedPicBaselineGuard { guard }, cache_hit)
64}
65
66/// Restore one cached PocketIC baseline, rebuilding it if the owned PocketIC
67/// instance has died between tests.
68pub fn restore_or_rebuild_cached_pic_baseline<T, B, R>(
69    slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
70    build: B,
71    restore: R,
72) -> (CachedPicBaselineGuard<'static, T>, bool)
73where
74    B: Fn() -> CachedPicBaseline<T>,
75    R: Fn(&CachedPicBaseline<T>),
76{
77    let (baseline, cache_hit) = acquire_cached_pic_baseline(slot, &build);
78    if !cache_hit {
79        return (baseline, false);
80    }
81
82    match try_restore_cached_pic_baseline(&baseline, restore) {
83        Ok(()) => return (baseline, true),
84        Err(CachedBaselineRestoreFailure::DeadInstanceTransport) => {}
85        Err(CachedBaselineRestoreFailure::Panic(payload)) => {
86            resume_unwind(payload);
87        }
88    }
89
90    drop(baseline);
91    drop_stale_cached_pic_baseline(slot);
92
93    let (rebuilt, _cache_hit) = acquire_cached_pic_baseline(slot, build);
94    (rebuilt, false)
95}
96
97// Attempt one cached baseline restore and classify only the one recovery path
98// we intentionally swallow: a dead PocketIC transport instance.
99fn try_restore_cached_pic_baseline<T, R>(
100    baseline: &CachedPicBaseline<T>,
101    restore: R,
102) -> Result<(), CachedBaselineRestoreFailure>
103where
104    R: Fn(&CachedPicBaseline<T>),
105{
106    match catch_unwind(AssertUnwindSafe(|| restore(baseline))) {
107        Ok(()) => Ok(()),
108        Err(payload) => {
109            if startup::panic_is_dead_instance_transport(payload.as_ref()) {
110                Err(CachedBaselineRestoreFailure::DeadInstanceTransport)
111            } else {
112                Err(CachedBaselineRestoreFailure::Panic(payload))
113            }
114        }
115    }
116}
117
118/// Remove one dead cached baseline and swallow teardown panics from a broken
119/// PocketIC instance so callers can rebuild cleanly.
120fn drop_stale_cached_pic_baseline<T>(slot: &'static Mutex<Option<CachedPicBaseline<T>>>) {
121    let stale = {
122        let mut slot = slot
123            .lock()
124            .unwrap_or_else(std::sync::PoisonError::into_inner);
125        slot.take()
126    };
127
128    if let Some(stale) = stale {
129        let _ = catch_unwind(AssertUnwindSafe(|| {
130            drop(stale);
131        }));
132    }
133}
134
135impl<T> Deref for CachedPicBaselineGuard<'_, T> {
136    type Target = CachedPicBaseline<T>;
137
138    fn deref(&self) -> &Self::Target {
139        self.guard
140            .as_ref()
141            .expect("cached PocketIC baseline must exist")
142    }
143}
144
145impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
146    fn deref_mut(&mut self) -> &mut Self::Target {
147        self.guard
148            .as_mut()
149            .expect("cached PocketIC baseline must exist")
150    }
151}
152
153impl<T> CachedPicBaseline<T> {
154    /// Capture one immutable cached baseline from the current PocketIC instance.
155    pub fn capture<I>(
156        pic: Pic,
157        controller_id: Principal,
158        canister_ids: I,
159        metadata: T,
160    ) -> Option<Self>
161    where
162        I: IntoIterator<Item = Principal>,
163    {
164        let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
165
166        Some(Self {
167            pic,
168            snapshots,
169            metadata,
170            _serial_guard: acquire_pic_serial_guard(),
171        })
172    }
173
174    /// Restore the captured snapshot set back into the owned PocketIC instance.
175    pub fn restore(&self, controller_id: Principal) {
176        self.pic
177            .restore_controller_snapshots(controller_id, &self.snapshots);
178    }
179
180    /// Borrow the owned PocketIC instance behind this cached baseline.
181    #[must_use]
182    pub const fn pic(&self) -> &Pic {
183        &self.pic
184    }
185
186    /// Mutably borrow the owned PocketIC instance behind this cached baseline.
187    #[must_use]
188    pub const fn pic_mut(&mut self) -> &mut Pic {
189        &mut self.pic
190    }
191
192    /// Borrow the captured metadata associated with this cached baseline.
193    #[must_use]
194    pub const fn metadata(&self) -> &T {
195        &self.metadata
196    }
197
198    /// Mutably borrow the captured metadata associated with this cached baseline.
199    #[must_use]
200    pub const fn metadata_mut(&mut self) -> &mut T {
201        &mut self.metadata
202    }
203}
204
205impl ControllerSnapshots {
206    pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
207        Self(
208            snapshots
209                .into_iter()
210                .map(|(canister_id, (snapshot_id, sender))| {
211                    (
212                        canister_id,
213                        ControllerSnapshot {
214                            snapshot_id,
215                            sender,
216                        },
217                    )
218                })
219                .collect(),
220        )
221    }
222
223    pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
224        self.0.iter().map(|(canister_id, snapshot)| {
225            (
226                *canister_id,
227                snapshot.snapshot_id.as_slice(),
228                snapshot.sender,
229            )
230        })
231    }
232}