use candid::Principal;
use std::{
collections::HashMap,
panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
sync::{Mutex, MutexGuard},
};
use super::{Pic, PicSerialGuard, acquire_pic_serial_guard, startup};
struct ControllerSnapshot {
snapshot_id: Vec<u8>,
sender: Option<Principal>,
}
pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
pub struct CachedPicBaseline<T> {
pic: Pic,
snapshots: ControllerSnapshots,
metadata: T,
_serial_guard: PicSerialGuard,
}
pub struct CachedPicBaselineGuard<'a, T> {
guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
}
enum CachedBaselineRestoreFailure {
DeadInstanceTransport,
Panic(Box<dyn std::any::Any + Send>),
}
fn acquire_cached_pic_baseline<T, F>(
slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
build: F,
) -> (CachedPicBaselineGuard<'static, T>, bool)
where
F: FnOnce() -> CachedPicBaseline<T>,
{
let mut guard = slot
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let cache_hit = guard.is_some();
if !cache_hit {
*guard = Some(build());
}
(CachedPicBaselineGuard { guard }, cache_hit)
}
pub fn restore_or_rebuild_cached_pic_baseline<T, B, R>(
slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
build: B,
restore: R,
) -> (CachedPicBaselineGuard<'static, T>, bool)
where
B: Fn() -> CachedPicBaseline<T>,
R: Fn(&CachedPicBaseline<T>),
{
let (baseline, cache_hit) = acquire_cached_pic_baseline(slot, &build);
if !cache_hit {
return (baseline, false);
}
match try_restore_cached_pic_baseline(
baseline
.guard
.as_ref()
.expect("cached PocketIC baseline must exist"),
restore,
) {
Ok(()) => return (baseline, true),
Err(CachedBaselineRestoreFailure::DeadInstanceTransport) => {}
Err(CachedBaselineRestoreFailure::Panic(payload)) => {
resume_unwind(payload);
}
}
drop(baseline);
drop_stale_cached_pic_baseline(slot);
let (rebuilt, _cache_hit) = acquire_cached_pic_baseline(slot, build);
(rebuilt, false)
}
fn try_restore_cached_pic_baseline<T, R>(
baseline: &CachedPicBaseline<T>,
restore: R,
) -> Result<(), CachedBaselineRestoreFailure>
where
R: Fn(&CachedPicBaseline<T>),
{
match catch_unwind(AssertUnwindSafe(|| restore(baseline))) {
Ok(()) => Ok(()),
Err(payload) => {
if startup::panic_is_dead_instance_transport(payload.as_ref()) {
Err(CachedBaselineRestoreFailure::DeadInstanceTransport)
} else {
Err(CachedBaselineRestoreFailure::Panic(payload))
}
}
}
}
fn drop_stale_cached_pic_baseline<T>(slot: &'static Mutex<Option<CachedPicBaseline<T>>>) {
let stale = {
let mut slot = slot
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
slot.take()
};
if let Some(stale) = stale {
let _ = catch_unwind(AssertUnwindSafe(|| {
drop(stale);
}));
}
}
impl<T> CachedPicBaselineGuard<'_, T> {
#[must_use]
pub fn pic(&self) -> &Pic {
self.guard
.as_ref()
.expect("cached PocketIC baseline must exist")
.pic()
}
#[must_use]
pub fn pic_mut(&mut self) -> &mut Pic {
self.guard
.as_mut()
.expect("cached PocketIC baseline must exist")
.pic_mut()
}
#[must_use]
pub fn metadata(&self) -> &T {
self.guard
.as_ref()
.expect("cached PocketIC baseline must exist")
.metadata()
}
#[must_use]
pub fn metadata_mut(&mut self) -> &mut T {
self.guard
.as_mut()
.expect("cached PocketIC baseline must exist")
.metadata_mut()
}
pub fn restore(&self, controller_id: Principal) {
self.guard
.as_ref()
.expect("cached PocketIC baseline must exist")
.restore(controller_id);
}
}
impl<T> CachedPicBaseline<T> {
pub fn capture<I>(
pic: Pic,
controller_id: Principal,
canister_ids: I,
metadata: T,
) -> Option<Self>
where
I: IntoIterator<Item = Principal>,
{
let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
Some(Self {
pic,
snapshots,
metadata,
_serial_guard: acquire_pic_serial_guard(),
})
}
pub fn restore(&self, controller_id: Principal) {
self.pic
.restore_controller_snapshots(controller_id, &self.snapshots);
}
#[must_use]
pub const fn pic(&self) -> &Pic {
&self.pic
}
#[must_use]
pub const fn pic_mut(&mut self) -> &mut Pic {
&mut self.pic
}
#[must_use]
pub const fn metadata(&self) -> &T {
&self.metadata
}
#[must_use]
pub const fn metadata_mut(&mut self) -> &mut T {
&mut self.metadata
}
}
impl ControllerSnapshots {
pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
Self(
snapshots
.into_iter()
.map(|(canister_id, (snapshot_id, sender))| {
(
canister_id,
ControllerSnapshot {
snapshot_id,
sender,
},
)
})
.collect(),
)
}
pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
self.0.iter().map(|(canister_id, snapshot)| {
(
*canister_id,
snapshot.snapshot_id.as_slice(),
snapshot.sender,
)
})
}
}