canic_testkit/pic/
baseline.rs1use candid::Principal;
2use std::{
3 collections::HashMap,
4 ops::{Deref, DerefMut},
5 sync::{Mutex, MutexGuard},
6};
7
8use super::{Pic, PicSerialGuard, acquire_pic_serial_guard};
9
10struct ControllerSnapshot {
11 snapshot_id: Vec<u8>,
12 sender: Option<Principal>,
13}
14
15pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
20
21pub struct CachedPicBaseline<T> {
26 pub pic: Pic,
27 pub snapshots: ControllerSnapshots,
28 pub metadata: T,
29 _serial_guard: PicSerialGuard,
30}
31
32pub struct CachedPicBaselineGuard<'a, T> {
37 guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
38}
39
40pub fn acquire_cached_pic_baseline<T, F>(
42 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
43 build: F,
44) -> (CachedPicBaselineGuard<'static, T>, bool)
45where
46 F: FnOnce() -> CachedPicBaseline<T>,
47{
48 let mut guard = slot
49 .lock()
50 .unwrap_or_else(std::sync::PoisonError::into_inner);
51 let cache_hit = guard.is_some();
52
53 if !cache_hit {
54 *guard = Some(build());
55 }
56
57 (CachedPicBaselineGuard { guard }, cache_hit)
58}
59
60impl<T> Deref for CachedPicBaselineGuard<'_, T> {
61 type Target = CachedPicBaseline<T>;
62
63 fn deref(&self) -> &Self::Target {
64 self.guard
65 .as_ref()
66 .expect("cached PocketIC baseline must exist")
67 }
68}
69
70impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
71 fn deref_mut(&mut self) -> &mut Self::Target {
72 self.guard
73 .as_mut()
74 .expect("cached PocketIC baseline must exist")
75 }
76}
77
78impl<T> CachedPicBaseline<T> {
79 pub fn capture<I>(
81 pic: Pic,
82 controller_id: Principal,
83 canister_ids: I,
84 metadata: T,
85 ) -> Option<Self>
86 where
87 I: IntoIterator<Item = Principal>,
88 {
89 let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
90
91 Some(Self {
92 pic,
93 snapshots,
94 metadata,
95 _serial_guard: acquire_pic_serial_guard(),
96 })
97 }
98
99 pub fn restore(&self, controller_id: Principal) {
101 self.pic
102 .restore_controller_snapshots(controller_id, &self.snapshots);
103 }
104}
105
106impl ControllerSnapshots {
107 pub(super) fn new(snapshots: HashMap<Principal, (Vec<u8>, Option<Principal>)>) -> Self {
108 Self(
109 snapshots
110 .into_iter()
111 .map(|(canister_id, (snapshot_id, sender))| {
112 (
113 canister_id,
114 ControllerSnapshot {
115 snapshot_id,
116 sender,
117 },
118 )
119 })
120 .collect(),
121 )
122 }
123
124 pub(super) fn iter(&self) -> impl Iterator<Item = (Principal, &[u8], Option<Principal>)> + '_ {
125 self.0.iter().map(|(canister_id, snapshot)| {
126 (
127 *canister_id,
128 snapshot.snapshot_id.as_slice(),
129 snapshot.sender,
130 )
131 })
132 }
133}