1mod root;
2
3use candid::{CandidType, Principal, decode_one, encode_args, encode_one, utils::ArgumentEncoder};
4use canic::{
5 Error,
6 cdk::types::TC,
7 dto::{
8 abi::v1::CanisterInitPayload,
9 env::EnvBootstrapArgs,
10 subnet::SubnetIdentity,
11 topology::{AppDirectoryArgs, SubnetDirectoryArgs, SubnetRegistryResponse},
12 },
13 ids::CanisterRole,
14 protocol,
15};
16use pocket_ic::{PocketIc, PocketIcBuilder};
17use serde::de::DeserializeOwned;
18use std::{
19 collections::HashMap,
20 env, fs, io,
21 ops::{Deref, DerefMut},
22 panic::{AssertUnwindSafe, catch_unwind},
23 path::{Path, PathBuf},
24 process,
25 sync::{Mutex, MutexGuard},
26 thread,
27 time::{Duration, Instant},
28};
29
30pub use root::{
31 RootBaselineMetadata, RootBaselineSpec, build_root_cached_baseline,
32 ensure_root_release_artifacts_built, load_root_wasm, restore_root_cached_baseline,
33 setup_root_topology,
34};
35
36const INSTALL_CYCLES: u128 = 500 * TC;
37const PIC_PROCESS_LOCK_DIR_NAME: &str = "canic-pocket-ic.lock";
38const PIC_PROCESS_LOCK_RETRY_DELAY: Duration = Duration::from_millis(100);
39const PIC_PROCESS_LOCK_LOG_AFTER: Duration = Duration::from_secs(1);
40static PIC_PROCESS_LOCK_STATE: Mutex<ProcessLockState> = Mutex::new(ProcessLockState {
41 ref_count: 0,
42 process_lock: None,
43});
44
45struct ControllerSnapshot {
46 snapshot_id: Vec<u8>,
47 sender: Option<Principal>,
48}
49
50struct ProcessLockGuard {
51 path: PathBuf,
52}
53
54struct ProcessLockState {
55 ref_count: usize,
56 process_lock: Option<ProcessLockGuard>,
57}
58
59pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
64
65pub struct CachedPicBaseline<T> {
70 pub pic: Pic,
71 pub snapshots: ControllerSnapshots,
72 pub metadata: T,
73 _serial_guard: PicSerialGuard,
74}
75
76pub struct CachedPicBaselineGuard<'a, T> {
81 guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
82}
83
84pub struct PicSerialGuard {
89 _private: (),
90}
91
92#[must_use]
101pub fn pic() -> Pic {
102 PicBuilder::new().with_application_subnet().build()
103}
104
105#[must_use]
107pub fn acquire_pic_serial_guard() -> PicSerialGuard {
108 let mut state = PIC_PROCESS_LOCK_STATE
109 .lock()
110 .unwrap_or_else(std::sync::PoisonError::into_inner);
111
112 if state.ref_count == 0 {
113 state.process_lock = Some(acquire_process_lock());
114 }
115 state.ref_count += 1;
116
117 PicSerialGuard { _private: () }
118}
119
120pub fn acquire_cached_pic_baseline<T, F>(
122 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
123 build: F,
124) -> (CachedPicBaselineGuard<'static, T>, bool)
125where
126 F: FnOnce() -> CachedPicBaseline<T>,
127{
128 let mut guard = slot
129 .lock()
130 .unwrap_or_else(std::sync::PoisonError::into_inner);
131 let cache_hit = guard.is_some();
132
133 if !cache_hit {
134 *guard = Some(build());
135 }
136
137 (CachedPicBaselineGuard { guard }, cache_hit)
138}
139
140pub fn wait_until_ready(pic: &PocketIc, canister_id: Principal, tick_limit: usize) {
142 let payload = encode_args(()).expect("encode empty args");
143
144 for _ in 0..tick_limit {
145 if let Ok(bytes) = pic.query_call(
146 canister_id,
147 Principal::anonymous(),
148 protocol::CANIC_READY,
149 payload.clone(),
150 ) && let Ok(ready) = decode_one::<bool>(&bytes)
151 && ready
152 {
153 return;
154 }
155 pic.tick();
156 }
157
158 panic!("canister did not report ready in time: {canister_id}");
159}
160
161#[must_use]
163pub fn role_pid(
164 pic: &PocketIc,
165 root_id: Principal,
166 role: &'static str,
167 tick_limit: usize,
168) -> Principal {
169 for _ in 0..tick_limit {
170 let registry: Result<Result<SubnetRegistryResponse, Error>, Error> = {
171 let payload = encode_args(()).expect("encode empty args");
172 pic.query_call(
173 root_id,
174 Principal::anonymous(),
175 protocol::CANIC_SUBNET_REGISTRY,
176 payload,
177 )
178 .map_err(|err| {
179 Error::internal(format!(
180 "pocket_ic query_call failed (canister={root_id}, method={}): {err}",
181 protocol::CANIC_SUBNET_REGISTRY
182 ))
183 })
184 .and_then(|bytes| {
185 decode_one(&bytes).map_err(|err| {
186 Error::internal(format!("decode_one failed for subnet registry: {err}"))
187 })
188 })
189 };
190
191 if let Ok(Ok(registry)) = registry
192 && let Some(pid) = registry
193 .0
194 .into_iter()
195 .find(|entry| entry.role == CanisterRole::new(role))
196 .map(|entry| entry.pid)
197 {
198 return pid;
199 }
200
201 pic.tick();
202 }
203
204 panic!("{role} canister must be registered");
205}
206
207pub struct PicBuilder(PocketIcBuilder);
218
219#[expect(clippy::new_without_default)]
220impl PicBuilder {
221 #[must_use]
223 pub fn new() -> Self {
224 Self(PocketIcBuilder::new())
225 }
226
227 #[must_use]
229 pub fn with_application_subnet(mut self) -> Self {
230 self.0 = self.0.with_application_subnet();
231 self
232 }
233
234 #[must_use]
236 pub fn with_nns_subnet(mut self) -> Self {
237 self.0 = self.0.with_nns_subnet();
238 self
239 }
240
241 #[must_use]
243 pub fn build(self) -> Pic {
244 Pic {
245 inner: self.0.build(),
246 }
247 }
248}
249
250pub struct Pic {
260 inner: PocketIc,
261}
262
263impl<T> Deref for CachedPicBaselineGuard<'_, T> {
264 type Target = CachedPicBaseline<T>;
265
266 fn deref(&self) -> &Self::Target {
267 self.guard
268 .as_ref()
269 .expect("cached PocketIC baseline must exist")
270 }
271}
272
273impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
274 fn deref_mut(&mut self) -> &mut Self::Target {
275 self.guard
276 .as_mut()
277 .expect("cached PocketIC baseline must exist")
278 }
279}
280
281impl<T> CachedPicBaseline<T> {
282 pub fn capture<I>(
284 pic: Pic,
285 controller_id: Principal,
286 canister_ids: I,
287 metadata: T,
288 ) -> Option<Self>
289 where
290 I: IntoIterator<Item = Principal>,
291 {
292 let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
293
294 Some(Self {
295 pic,
296 snapshots,
297 metadata,
298 _serial_guard: acquire_pic_serial_guard(),
299 })
300 }
301
302 pub fn restore(&self, controller_id: Principal) {
304 self.pic
305 .restore_controller_snapshots(controller_id, &self.snapshots);
306 }
307}
308
309impl Pic {
310 #[must_use]
312 pub fn current_time_nanos(&self) -> u64 {
313 self.inner.get_time().as_nanos_since_unix_epoch()
314 }
315
316 pub fn restore_time_nanos(&self, nanos_since_epoch: u64) {
318 let restored = pocket_ic::Time::from_nanos_since_unix_epoch(nanos_since_epoch);
319 self.inner.set_time(restored);
320 self.inner.set_certified_time(restored);
321 }
322
323 pub fn create_and_install_root_canister(&self, wasm: Vec<u8>) -> Result<Principal, Error> {
325 let init_bytes = install_root_args()?;
326
327 Ok(self.create_funded_and_install(wasm, init_bytes))
328 }
329
330 pub fn create_and_install_canister(
334 &self,
335 role: CanisterRole,
336 wasm: Vec<u8>,
337 ) -> Result<Principal, Error> {
338 let init_bytes = install_args(role)?;
339
340 Ok(self.create_funded_and_install(wasm, init_bytes))
341 }
342
343 pub fn wait_for_ready(&self, canister_id: Principal, tick_limit: usize, context: &str) {
345 for _ in 0..tick_limit {
346 self.tick();
347 if self.fetch_ready(canister_id) {
348 return;
349 }
350 }
351
352 self.dump_canister_debug(canister_id, context);
353 panic!("{context}: canister {canister_id} did not become ready after {tick_limit} ticks");
354 }
355
356 pub fn wait_for_all_ready<I>(&self, canister_ids: I, tick_limit: usize, context: &str)
358 where
359 I: IntoIterator<Item = Principal>,
360 {
361 let canister_ids = canister_ids.into_iter().collect::<Vec<_>>();
362
363 for _ in 0..tick_limit {
364 self.tick();
365 if canister_ids
366 .iter()
367 .copied()
368 .all(|canister_id| self.fetch_ready(canister_id))
369 {
370 return;
371 }
372 }
373
374 for canister_id in &canister_ids {
375 self.dump_canister_debug(*canister_id, context);
376 }
377 panic!("{context}: canisters did not become ready after {tick_limit} ticks");
378 }
379
380 pub fn dump_canister_debug(&self, canister_id: Principal, context: &str) {
382 eprintln!("{context}: debug for canister {canister_id}");
383
384 match self.canister_status(canister_id, None) {
385 Ok(status) => eprintln!("canister_status: {status:?}"),
386 Err(err) => eprintln!("canister_status failed: {err:?}"),
387 }
388
389 match self.fetch_canister_logs(canister_id, Principal::anonymous()) {
390 Ok(records) => {
391 if records.is_empty() {
392 eprintln!("canister logs: <empty>");
393 } else {
394 for record in records {
395 eprintln!("canister log: {record:?}");
396 }
397 }
398 }
399 Err(err) => eprintln!("fetch_canister_logs failed: {err:?}"),
400 }
401 }
402
403 pub fn capture_controller_snapshots<I>(
405 &self,
406 controller_id: Principal,
407 canister_ids: I,
408 ) -> Option<ControllerSnapshots>
409 where
410 I: IntoIterator<Item = Principal>,
411 {
412 let mut snapshots = HashMap::new();
413
414 for canister_id in canister_ids {
415 let Some(snapshot) = self.try_take_controller_snapshot(controller_id, canister_id)
416 else {
417 eprintln!(
418 "capture_controller_snapshots: snapshot capture unavailable for {canister_id}"
419 );
420 return None;
421 };
422 snapshots.insert(canister_id, snapshot);
423 }
424
425 Some(ControllerSnapshots(snapshots))
426 }
427
428 pub fn restore_controller_snapshots(
430 &self,
431 controller_id: Principal,
432 snapshots: &ControllerSnapshots,
433 ) {
434 for (canister_id, snapshot) in &snapshots.0 {
435 self.restore_controller_snapshot(controller_id, *canister_id, snapshot);
436 }
437 }
438
439 pub fn update_call<T, A>(
441 &self,
442 canister_id: Principal,
443 method: &str,
444 args: A,
445 ) -> Result<T, Error>
446 where
447 T: CandidType + DeserializeOwned,
448 A: ArgumentEncoder,
449 {
450 let bytes: Vec<u8> = encode_args(args)
451 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
452 let result = self
453 .inner
454 .update_call(canister_id, Principal::anonymous(), method, bytes)
455 .map_err(|err| {
456 Error::internal(format!(
457 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
458 ))
459 })?;
460
461 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
462 }
463
464 pub fn update_call_as<T, A>(
466 &self,
467 canister_id: Principal,
468 caller: Principal,
469 method: &str,
470 args: A,
471 ) -> Result<T, Error>
472 where
473 T: CandidType + DeserializeOwned,
474 A: ArgumentEncoder,
475 {
476 let bytes: Vec<u8> = encode_args(args)
477 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
478 let result = self
479 .inner
480 .update_call(canister_id, caller, method, bytes)
481 .map_err(|err| {
482 Error::internal(format!(
483 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
484 ))
485 })?;
486
487 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
488 }
489
490 pub fn query_call<T, A>(
492 &self,
493 canister_id: Principal,
494 method: &str,
495 args: A,
496 ) -> Result<T, Error>
497 where
498 T: CandidType + DeserializeOwned,
499 A: ArgumentEncoder,
500 {
501 let bytes: Vec<u8> = encode_args(args)
502 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
503 let result = self
504 .inner
505 .query_call(canister_id, Principal::anonymous(), method, bytes)
506 .map_err(|err| {
507 Error::internal(format!(
508 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
509 ))
510 })?;
511
512 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
513 }
514
515 pub fn query_call_as<T, A>(
517 &self,
518 canister_id: Principal,
519 caller: Principal,
520 method: &str,
521 args: A,
522 ) -> Result<T, Error>
523 where
524 T: CandidType + DeserializeOwned,
525 A: ArgumentEncoder,
526 {
527 let bytes: Vec<u8> = encode_args(args)
528 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
529 let result = self
530 .inner
531 .query_call(canister_id, caller, method, bytes)
532 .map_err(|err| {
533 Error::internal(format!(
534 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
535 ))
536 })?;
537
538 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
539 }
540
541 pub fn tick_n(&self, times: usize) {
543 for _ in 0..times {
544 self.tick();
545 }
546 }
547
548 fn create_funded_and_install(&self, wasm: Vec<u8>, init_bytes: Vec<u8>) -> Principal {
550 let canister_id = self.create_canister();
551 self.add_cycles(canister_id, INSTALL_CYCLES);
552
553 let install = catch_unwind(AssertUnwindSafe(|| {
554 self.inner
555 .install_canister(canister_id, wasm, init_bytes, None);
556 }));
557 if let Err(err) = install {
558 eprintln!("install_canister trapped for {canister_id}");
559 if let Ok(status) = self.inner.canister_status(canister_id, None) {
560 eprintln!("canister_status for {canister_id}: {status:?}");
561 }
562 if let Ok(logs) = self
563 .inner
564 .fetch_canister_logs(canister_id, Principal::anonymous())
565 {
566 for record in logs {
567 eprintln!("canister_log {canister_id}: {record:?}");
568 }
569 }
570 std::panic::resume_unwind(err);
571 }
572
573 canister_id
574 }
575
576 fn fetch_ready(&self, canister_id: Principal) -> bool {
578 match self.query_call(canister_id, protocol::CANIC_READY, ()) {
579 Ok(ready) => ready,
580 Err(err) => {
581 self.dump_canister_debug(canister_id, "query canic_ready failed");
582 panic!("query canic_ready failed: {err:?}");
583 }
584 }
585 }
586
587 fn try_take_controller_snapshot(
589 &self,
590 controller_id: Principal,
591 canister_id: Principal,
592 ) -> Option<ControllerSnapshot> {
593 let candidates = controller_sender_candidates(controller_id, canister_id);
594 let mut last_err = None;
595
596 for sender in candidates {
597 match self.take_canister_snapshot(canister_id, sender, None) {
598 Ok(snapshot) => {
599 return Some(ControllerSnapshot {
600 snapshot_id: snapshot.id,
601 sender,
602 });
603 }
604 Err(err) => last_err = Some((sender, err)),
605 }
606 }
607
608 if let Some((sender, err)) = last_err {
609 eprintln!(
610 "failed to capture canister snapshot for {canister_id} using sender {sender:?}: {err}"
611 );
612 }
613 None
614 }
615
616 fn restore_controller_snapshot(
618 &self,
619 controller_id: Principal,
620 canister_id: Principal,
621 snapshot: &ControllerSnapshot,
622 ) {
623 let fallback_sender = if snapshot.sender.is_some() {
624 None
625 } else {
626 Some(controller_id)
627 };
628 let candidates = [snapshot.sender, fallback_sender];
629 let mut last_err = None;
630
631 for sender in candidates {
632 match self.load_canister_snapshot(canister_id, sender, snapshot.snapshot_id.clone()) {
633 Ok(()) => return,
634 Err(err) => last_err = Some((sender, err)),
635 }
636 }
637
638 let (sender, err) =
639 last_err.expect("snapshot restore must have at least one sender attempt");
640 panic!(
641 "failed to restore canister snapshot for {canister_id} using sender {sender:?}: {err}"
642 );
643 }
644}
645
646impl Drop for ProcessLockGuard {
647 fn drop(&mut self) {
648 let _ = fs::remove_file(process_lock_owner_path(&self.path));
649 let _ = fs::remove_dir(&self.path);
650 }
651}
652
653impl Drop for PicSerialGuard {
654 fn drop(&mut self) {
655 let mut state = PIC_PROCESS_LOCK_STATE
656 .lock()
657 .unwrap_or_else(std::sync::PoisonError::into_inner);
658
659 state.ref_count = state
660 .ref_count
661 .checked_sub(1)
662 .expect("PocketIC serial guard refcount underflow");
663 if state.ref_count == 0 {
664 state.process_lock.take();
665 }
666 }
667}
668
669impl Deref for Pic {
670 type Target = PocketIc;
671
672 fn deref(&self) -> &Self::Target {
673 &self.inner
674 }
675}
676
677impl DerefMut for Pic {
678 fn deref_mut(&mut self) -> &mut Self::Target {
679 &mut self.inner
680 }
681}
682
683fn install_args(role: CanisterRole) -> Result<Vec<u8>, Error> {
698 if role.is_root() {
699 install_root_args()
700 } else {
701 let env = EnvBootstrapArgs {
704 prime_root_pid: None,
705 subnet_role: None,
706 subnet_pid: None,
707 root_pid: None,
708 canister_role: Some(role),
709 parent_pid: None,
710 };
711
712 let payload = CanisterInitPayload {
715 env,
716 app_directory: AppDirectoryArgs(Vec::new()),
717 subnet_directory: SubnetDirectoryArgs(Vec::new()),
718 };
719
720 encode_args::<(CanisterInitPayload, Option<Vec<u8>>)>((payload, None))
721 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))
722 }
723}
724
725fn install_root_args() -> Result<Vec<u8>, Error> {
726 encode_one(SubnetIdentity::Manual)
727 .map_err(|err| Error::internal(format!("encode_one failed: {err}")))
728}
729
730fn controller_sender_candidates(
732 controller_id: Principal,
733 canister_id: Principal,
734) -> [Option<Principal>; 2] {
735 if canister_id == controller_id {
736 [None, Some(controller_id)]
737 } else {
738 [Some(controller_id), None]
739 }
740}
741
742fn acquire_process_lock() -> ProcessLockGuard {
743 let lock_dir = env::temp_dir().join(PIC_PROCESS_LOCK_DIR_NAME);
744 let started_waiting = Instant::now();
745 let mut logged_wait = false;
746
747 loop {
748 match fs::create_dir(&lock_dir) {
749 Ok(()) => {
750 fs::write(
751 process_lock_owner_path(&lock_dir),
752 process::id().to_string(),
753 )
754 .unwrap_or_else(|err| {
755 let _ = fs::remove_dir(&lock_dir);
756 panic!(
757 "failed to record PocketIC process lock owner at {}: {err}",
758 lock_dir.display()
759 );
760 });
761
762 if logged_wait {
763 eprintln!(
764 "[canic_testkit::pic] acquired cross-process PocketIC lock at {}",
765 lock_dir.display()
766 );
767 }
768
769 return ProcessLockGuard { path: lock_dir };
770 }
771 Err(err) if err.kind() == io::ErrorKind::AlreadyExists => {
772 if process_lock_is_stale(&lock_dir) {
773 let _ = fs::remove_file(process_lock_owner_path(&lock_dir));
774 let _ = fs::remove_dir(&lock_dir);
775 continue;
776 }
777
778 if !logged_wait && started_waiting.elapsed() >= PIC_PROCESS_LOCK_LOG_AFTER {
779 eprintln!(
780 "[canic_testkit::pic] waiting for cross-process PocketIC lock at {}",
781 lock_dir.display()
782 );
783 logged_wait = true;
784 }
785
786 thread::sleep(PIC_PROCESS_LOCK_RETRY_DELAY);
787 }
788 Err(err) => panic!(
789 "failed to create PocketIC process lock dir at {}: {err}",
790 lock_dir.display()
791 ),
792 }
793 }
794}
795
796fn process_lock_owner_path(lock_dir: &Path) -> PathBuf {
797 lock_dir.join("owner")
798}
799
800fn process_lock_is_stale(lock_dir: &Path) -> bool {
801 let Ok(pid_text) = fs::read_to_string(process_lock_owner_path(lock_dir)) else {
802 return true;
803 };
804 let Ok(pid) = pid_text.trim().parse::<u32>() else {
805 return true;
806 };
807
808 !Path::new("/proc").join(pid.to_string()).exists()
809}