1use candid::{CandidType, Principal, decode_one, encode_args, encode_one, utils::ArgumentEncoder};
2use canic::{
3 Error,
4 cdk::types::TC,
5 dto::{
6 abi::v1::CanisterInitPayload,
7 env::EnvBootstrapArgs,
8 subnet::SubnetIdentity,
9 topology::{AppDirectoryArgs, SubnetDirectoryArgs},
10 },
11 ids::CanisterRole,
12 protocol,
13};
14use pocket_ic::{PocketIc, PocketIcBuilder};
15use serde::de::DeserializeOwned;
16use std::{
17 collections::HashMap,
18 ops::{Deref, DerefMut},
19 panic::{AssertUnwindSafe, catch_unwind},
20 sync::{Mutex, MutexGuard},
21};
22
23const INSTALL_CYCLES: u128 = 500 * TC;
24static PIC_BUILD_SERIAL: Mutex<()> = Mutex::new(());
25
26struct ControllerSnapshot {
27 snapshot_id: Vec<u8>,
28 sender: Option<Principal>,
29}
30
31pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
36
37#[must_use]
48pub fn pic() -> Pic {
49 PicBuilder::new().with_application_subnet().build()
50}
51
52pub struct PicBuilder(PocketIcBuilder);
63
64#[expect(clippy::new_without_default)]
65impl PicBuilder {
66 #[must_use]
68 pub fn new() -> Self {
69 Self(PocketIcBuilder::new())
70 }
71
72 #[must_use]
74 pub fn with_application_subnet(mut self) -> Self {
75 self.0 = self.0.with_application_subnet();
76 self
77 }
78
79 #[must_use]
81 pub fn with_nns_subnet(mut self) -> Self {
82 self.0 = self.0.with_nns_subnet();
83 self
84 }
85
86 #[must_use]
88 pub fn build(self) -> Pic {
89 let serial_guard = PIC_BUILD_SERIAL
93 .lock()
94 .unwrap_or_else(std::sync::PoisonError::into_inner);
95
96 Pic {
97 inner: self.0.build(),
98 _serial_guard: serial_guard,
99 }
100 }
101}
102
103pub struct Pic {
115 inner: PocketIc,
116 _serial_guard: MutexGuard<'static, ()>,
117}
118
119impl Pic {
120 pub fn create_and_install_root_canister(&self, wasm: Vec<u8>) -> Result<Principal, Error> {
122 let init_bytes = install_root_args()?;
123
124 Ok(self.create_funded_and_install(wasm, init_bytes))
125 }
126
127 pub fn create_and_install_canister(
131 &self,
132 role: CanisterRole,
133 wasm: Vec<u8>,
134 ) -> Result<Principal, Error> {
135 let init_bytes = install_args(role)?;
136
137 Ok(self.create_funded_and_install(wasm, init_bytes))
138 }
139
140 pub fn wait_for_ready(&self, canister_id: Principal, tick_limit: usize, context: &str) {
142 for _ in 0..tick_limit {
143 self.tick();
144 if self.fetch_ready(canister_id) {
145 return;
146 }
147 }
148
149 self.dump_canister_debug(canister_id, context);
150 panic!("{context}: canister {canister_id} did not become ready after {tick_limit} ticks");
151 }
152
153 pub fn wait_for_all_ready<I>(&self, canister_ids: I, tick_limit: usize, context: &str)
155 where
156 I: IntoIterator<Item = Principal>,
157 {
158 let canister_ids = canister_ids.into_iter().collect::<Vec<_>>();
159
160 for _ in 0..tick_limit {
161 self.tick();
162 if canister_ids
163 .iter()
164 .copied()
165 .all(|canister_id| self.fetch_ready(canister_id))
166 {
167 return;
168 }
169 }
170
171 for canister_id in &canister_ids {
172 self.dump_canister_debug(*canister_id, context);
173 }
174 panic!("{context}: canisters did not become ready after {tick_limit} ticks");
175 }
176
177 pub fn dump_canister_debug(&self, canister_id: Principal, context: &str) {
179 eprintln!("{context}: debug for canister {canister_id}");
180
181 match self.canister_status(canister_id, None) {
182 Ok(status) => eprintln!("canister_status: {status:?}"),
183 Err(err) => eprintln!("canister_status failed: {err:?}"),
184 }
185
186 match self.fetch_canister_logs(canister_id, Principal::anonymous()) {
187 Ok(records) => {
188 if records.is_empty() {
189 eprintln!("canister logs: <empty>");
190 } else {
191 for record in records {
192 eprintln!("canister log: {record:?}");
193 }
194 }
195 }
196 Err(err) => eprintln!("fetch_canister_logs failed: {err:?}"),
197 }
198 }
199
200 pub fn capture_controller_snapshots<I>(
202 &self,
203 controller_id: Principal,
204 canister_ids: I,
205 ) -> Option<ControllerSnapshots>
206 where
207 I: IntoIterator<Item = Principal>,
208 {
209 let mut snapshots = HashMap::new();
210
211 for canister_id in canister_ids {
212 let Some(snapshot) = self.try_take_controller_snapshot(controller_id, canister_id)
213 else {
214 eprintln!(
215 "capture_controller_snapshots: snapshot capture unavailable for {canister_id}"
216 );
217 return None;
218 };
219 snapshots.insert(canister_id, snapshot);
220 }
221
222 Some(ControllerSnapshots(snapshots))
223 }
224
225 pub fn restore_controller_snapshots(
227 &self,
228 controller_id: Principal,
229 snapshots: &ControllerSnapshots,
230 ) {
231 for (canister_id, snapshot) in &snapshots.0 {
232 self.restore_controller_snapshot(controller_id, *canister_id, snapshot);
233 }
234 }
235
236 pub fn update_call<T, A>(
238 &self,
239 canister_id: Principal,
240 method: &str,
241 args: A,
242 ) -> Result<T, Error>
243 where
244 T: CandidType + DeserializeOwned,
245 A: ArgumentEncoder,
246 {
247 let bytes: Vec<u8> = encode_args(args)
248 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
249 let result = self
250 .inner
251 .update_call(canister_id, Principal::anonymous(), method, bytes)
252 .map_err(|err| {
253 Error::internal(format!(
254 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
255 ))
256 })?;
257
258 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
259 }
260
261 pub fn update_call_as<T, A>(
263 &self,
264 canister_id: Principal,
265 caller: Principal,
266 method: &str,
267 args: A,
268 ) -> Result<T, Error>
269 where
270 T: CandidType + DeserializeOwned,
271 A: ArgumentEncoder,
272 {
273 let bytes: Vec<u8> = encode_args(args)
274 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
275 let result = self
276 .inner
277 .update_call(canister_id, caller, method, bytes)
278 .map_err(|err| {
279 Error::internal(format!(
280 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
281 ))
282 })?;
283
284 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
285 }
286
287 pub fn query_call<T, A>(
289 &self,
290 canister_id: Principal,
291 method: &str,
292 args: A,
293 ) -> Result<T, Error>
294 where
295 T: CandidType + DeserializeOwned,
296 A: ArgumentEncoder,
297 {
298 let bytes: Vec<u8> = encode_args(args)
299 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
300 let result = self
301 .inner
302 .query_call(canister_id, Principal::anonymous(), method, bytes)
303 .map_err(|err| {
304 Error::internal(format!(
305 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
306 ))
307 })?;
308
309 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
310 }
311
312 pub fn query_call_as<T, A>(
314 &self,
315 canister_id: Principal,
316 caller: Principal,
317 method: &str,
318 args: A,
319 ) -> Result<T, Error>
320 where
321 T: CandidType + DeserializeOwned,
322 A: ArgumentEncoder,
323 {
324 let bytes: Vec<u8> = encode_args(args)
325 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
326 let result = self
327 .inner
328 .query_call(canister_id, caller, method, bytes)
329 .map_err(|err| {
330 Error::internal(format!(
331 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
332 ))
333 })?;
334
335 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
336 }
337
338 pub fn tick_n(&self, times: usize) {
340 for _ in 0..times {
341 self.tick();
342 }
343 }
344
345 fn create_funded_and_install(&self, wasm: Vec<u8>, init_bytes: Vec<u8>) -> Principal {
347 let canister_id = self.create_canister();
348 self.add_cycles(canister_id, INSTALL_CYCLES);
349
350 let install = catch_unwind(AssertUnwindSafe(|| {
351 self.inner
352 .install_canister(canister_id, wasm, init_bytes, None);
353 }));
354 if let Err(err) = install {
355 eprintln!("install_canister trapped for {canister_id}");
356 if let Ok(status) = self.inner.canister_status(canister_id, None) {
357 eprintln!("canister_status for {canister_id}: {status:?}");
358 }
359 if let Ok(logs) = self
360 .inner
361 .fetch_canister_logs(canister_id, Principal::anonymous())
362 {
363 for record in logs {
364 eprintln!("canister_log {canister_id}: {record:?}");
365 }
366 }
367 std::panic::resume_unwind(err);
368 }
369
370 canister_id
371 }
372
373 fn fetch_ready(&self, canister_id: Principal) -> bool {
375 match self.query_call(canister_id, protocol::CANIC_READY, ()) {
376 Ok(ready) => ready,
377 Err(err) => {
378 self.dump_canister_debug(canister_id, "query canic_ready failed");
379 panic!("query canic_ready failed: {err:?}");
380 }
381 }
382 }
383
384 fn try_take_controller_snapshot(
386 &self,
387 controller_id: Principal,
388 canister_id: Principal,
389 ) -> Option<ControllerSnapshot> {
390 let candidates = controller_sender_candidates(controller_id, canister_id);
391 let mut last_err = None;
392
393 for sender in candidates {
394 match self.take_canister_snapshot(canister_id, sender, None) {
395 Ok(snapshot) => {
396 return Some(ControllerSnapshot {
397 snapshot_id: snapshot.id,
398 sender,
399 });
400 }
401 Err(err) => last_err = Some((sender, err)),
402 }
403 }
404
405 if let Some((sender, err)) = last_err {
406 eprintln!(
407 "failed to capture canister snapshot for {canister_id} using sender {sender:?}: {err}"
408 );
409 }
410 None
411 }
412
413 fn restore_controller_snapshot(
415 &self,
416 controller_id: Principal,
417 canister_id: Principal,
418 snapshot: &ControllerSnapshot,
419 ) {
420 let fallback_sender = if snapshot.sender.is_some() {
421 None
422 } else {
423 Some(controller_id)
424 };
425 let candidates = [snapshot.sender, fallback_sender];
426 let mut last_err = None;
427
428 for sender in candidates {
429 match self.load_canister_snapshot(canister_id, sender, snapshot.snapshot_id.clone()) {
430 Ok(()) => return,
431 Err(err) => last_err = Some((sender, err)),
432 }
433 }
434
435 let (sender, err) =
436 last_err.expect("snapshot restore must have at least one sender attempt");
437 panic!(
438 "failed to restore canister snapshot for {canister_id} using sender {sender:?}: {err}"
439 );
440 }
441}
442
443impl Deref for Pic {
444 type Target = PocketIc;
445
446 fn deref(&self) -> &Self::Target {
447 &self.inner
448 }
449}
450
451impl DerefMut for Pic {
452 fn deref_mut(&mut self) -> &mut Self::Target {
453 &mut self.inner
454 }
455}
456
457fn install_args(role: CanisterRole) -> Result<Vec<u8>, Error> {
472 if role.is_root() {
473 install_root_args()
474 } else {
475 let env = EnvBootstrapArgs {
478 prime_root_pid: None,
479 subnet_role: None,
480 subnet_pid: None,
481 root_pid: None,
482 canister_role: Some(role),
483 parent_pid: None,
484 };
485
486 let payload = CanisterInitPayload {
489 env,
490 app_directory: AppDirectoryArgs(Vec::new()),
491 subnet_directory: SubnetDirectoryArgs(Vec::new()),
492 };
493
494 encode_args::<(CanisterInitPayload, Option<Vec<u8>>)>((payload, None))
495 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))
496 }
497}
498
499fn install_root_args() -> Result<Vec<u8>, Error> {
500 encode_one(SubnetIdentity::Manual)
501 .map_err(|err| Error::internal(format!("encode_one failed: {err}")))
502}
503
504fn controller_sender_candidates(
506 controller_id: Principal,
507 canister_id: Principal,
508) -> [Option<Principal>; 2] {
509 if canister_id == controller_id {
510 [None, Some(controller_id)]
511 } else {
512 [Some(controller_id), None]
513 }
514}