1use candid::{CandidType, Principal, decode_one, encode_args, encode_one, utils::ArgumentEncoder};
2use canic::{
3 Error,
4 cdk::types::TC,
5 dto::{
6 abi::v1::CanisterInitPayload,
7 env::EnvBootstrapArgs,
8 subnet::SubnetIdentity,
9 topology::{AppDirectoryArgs, SubnetDirectoryArgs, SubnetRegistryResponse},
10 },
11 ids::CanisterRole,
12 protocol,
13};
14use pocket_ic::{PocketIc, PocketIcBuilder};
15use serde::de::DeserializeOwned;
16use std::{
17 collections::HashMap,
18 env, fs, io,
19 ops::{Deref, DerefMut},
20 panic::{AssertUnwindSafe, catch_unwind},
21 path::{Path, PathBuf},
22 process,
23 sync::{Mutex, MutexGuard},
24 thread,
25 time::{Duration, Instant},
26};
27
28const INSTALL_CYCLES: u128 = 500 * TC;
29const PIC_PROCESS_LOCK_DIR_NAME: &str = "canic-pocket-ic.lock";
30const PIC_PROCESS_LOCK_RETRY_DELAY: Duration = Duration::from_millis(100);
31const PIC_PROCESS_LOCK_LOG_AFTER: Duration = Duration::from_secs(1);
32static PIC_PROCESS_LOCK_STATE: Mutex<ProcessLockState> = Mutex::new(ProcessLockState {
33 ref_count: 0,
34 process_lock: None,
35});
36
37struct ControllerSnapshot {
38 snapshot_id: Vec<u8>,
39 sender: Option<Principal>,
40}
41
42struct ProcessLockGuard {
43 path: PathBuf,
44}
45
46struct ProcessLockState {
47 ref_count: usize,
48 process_lock: Option<ProcessLockGuard>,
49}
50
51pub struct ControllerSnapshots(HashMap<Principal, ControllerSnapshot>);
56
57pub struct CachedPicBaseline<T> {
62 pub pic: Pic,
63 pub snapshots: ControllerSnapshots,
64 pub metadata: T,
65 _serial_guard: PicSerialGuard,
66}
67
68pub struct CachedPicBaselineGuard<'a, T> {
73 guard: MutexGuard<'a, Option<CachedPicBaseline<T>>>,
74}
75
76pub struct PicSerialGuard {
81 _private: (),
82}
83
84#[must_use]
93pub fn pic() -> Pic {
94 PicBuilder::new().with_application_subnet().build()
95}
96
97#[must_use]
99pub fn acquire_pic_serial_guard() -> PicSerialGuard {
100 let mut state = PIC_PROCESS_LOCK_STATE
101 .lock()
102 .unwrap_or_else(std::sync::PoisonError::into_inner);
103
104 if state.ref_count == 0 {
105 state.process_lock = Some(acquire_process_lock());
106 }
107 state.ref_count += 1;
108
109 PicSerialGuard { _private: () }
110}
111
112pub fn acquire_cached_pic_baseline<T, F>(
114 slot: &'static Mutex<Option<CachedPicBaseline<T>>>,
115 build: F,
116) -> (CachedPicBaselineGuard<'static, T>, bool)
117where
118 F: FnOnce() -> CachedPicBaseline<T>,
119{
120 let mut guard = slot
121 .lock()
122 .unwrap_or_else(std::sync::PoisonError::into_inner);
123 let cache_hit = guard.is_some();
124
125 if !cache_hit {
126 *guard = Some(build());
127 }
128
129 (CachedPicBaselineGuard { guard }, cache_hit)
130}
131
132pub fn wait_until_ready(pic: &PocketIc, canister_id: Principal, tick_limit: usize) {
134 let payload = encode_args(()).expect("encode empty args");
135
136 for _ in 0..tick_limit {
137 if let Ok(bytes) = pic.query_call(
138 canister_id,
139 Principal::anonymous(),
140 protocol::CANIC_READY,
141 payload.clone(),
142 ) && let Ok(ready) = decode_one::<bool>(&bytes)
143 && ready
144 {
145 return;
146 }
147 pic.tick();
148 }
149
150 panic!("canister did not report ready in time: {canister_id}");
151}
152
153#[must_use]
155pub fn role_pid(
156 pic: &PocketIc,
157 root_id: Principal,
158 role: &'static str,
159 tick_limit: usize,
160) -> Principal {
161 for _ in 0..tick_limit {
162 let registry: Result<Result<SubnetRegistryResponse, Error>, Error> = {
163 let payload = encode_args(()).expect("encode empty args");
164 pic.query_call(
165 root_id,
166 Principal::anonymous(),
167 protocol::CANIC_SUBNET_REGISTRY,
168 payload,
169 )
170 .map_err(|err| {
171 Error::internal(format!(
172 "pocket_ic query_call failed (canister={root_id}, method={}): {err}",
173 protocol::CANIC_SUBNET_REGISTRY
174 ))
175 })
176 .and_then(|bytes| {
177 decode_one(&bytes).map_err(|err| {
178 Error::internal(format!("decode_one failed for subnet registry: {err}"))
179 })
180 })
181 };
182
183 if let Ok(Ok(registry)) = registry
184 && let Some(pid) = registry
185 .0
186 .into_iter()
187 .find(|entry| entry.role == CanisterRole::new(role))
188 .map(|entry| entry.pid)
189 {
190 return pid;
191 }
192
193 pic.tick();
194 }
195
196 panic!("{role} canister must be registered");
197}
198
199pub struct PicBuilder(PocketIcBuilder);
210
211#[expect(clippy::new_without_default)]
212impl PicBuilder {
213 #[must_use]
215 pub fn new() -> Self {
216 Self(PocketIcBuilder::new())
217 }
218
219 #[must_use]
221 pub fn with_application_subnet(mut self) -> Self {
222 self.0 = self.0.with_application_subnet();
223 self
224 }
225
226 #[must_use]
228 pub fn with_ii_subnet(mut self) -> Self {
229 self.0 = self.0.with_ii_subnet();
230 self
231 }
232
233 #[must_use]
235 pub fn with_nns_subnet(mut self) -> Self {
236 self.0 = self.0.with_nns_subnet();
237 self
238 }
239
240 #[must_use]
242 pub fn build(self) -> Pic {
243 Pic {
244 inner: self.0.build(),
245 }
246 }
247}
248
249pub struct Pic {
259 inner: PocketIc,
260}
261
262impl<T> Deref for CachedPicBaselineGuard<'_, T> {
263 type Target = CachedPicBaseline<T>;
264
265 fn deref(&self) -> &Self::Target {
266 self.guard
267 .as_ref()
268 .expect("cached PocketIC baseline must exist")
269 }
270}
271
272impl<T> DerefMut for CachedPicBaselineGuard<'_, T> {
273 fn deref_mut(&mut self) -> &mut Self::Target {
274 self.guard
275 .as_mut()
276 .expect("cached PocketIC baseline must exist")
277 }
278}
279
280impl<T> CachedPicBaseline<T> {
281 pub fn capture<I>(
283 pic: Pic,
284 controller_id: Principal,
285 canister_ids: I,
286 metadata: T,
287 ) -> Option<Self>
288 where
289 I: IntoIterator<Item = Principal>,
290 {
291 let snapshots = pic.capture_controller_snapshots(controller_id, canister_ids)?;
292
293 Some(Self {
294 pic,
295 snapshots,
296 metadata,
297 _serial_guard: acquire_pic_serial_guard(),
298 })
299 }
300
301 pub fn restore(&self, controller_id: Principal) {
303 self.pic
304 .restore_controller_snapshots(controller_id, &self.snapshots);
305 }
306}
307
308impl Pic {
309 #[must_use]
311 pub fn current_time_nanos(&self) -> u64 {
312 self.inner.get_time().as_nanos_since_unix_epoch()
313 }
314
315 pub fn restore_time_nanos(&self, nanos_since_epoch: u64) {
317 let restored = pocket_ic::Time::from_nanos_since_unix_epoch(nanos_since_epoch);
318 self.inner.set_time(restored);
319 self.inner.set_certified_time(restored);
320 }
321
322 pub fn create_and_install_root_canister(&self, wasm: Vec<u8>) -> Result<Principal, Error> {
324 let init_bytes = install_root_args()?;
325
326 Ok(self.create_funded_and_install(wasm, init_bytes))
327 }
328
329 pub fn create_and_install_canister(
333 &self,
334 role: CanisterRole,
335 wasm: Vec<u8>,
336 ) -> Result<Principal, Error> {
337 let init_bytes = install_args(role)?;
338
339 Ok(self.create_funded_and_install(wasm, init_bytes))
340 }
341
342 pub fn wait_for_ready(&self, canister_id: Principal, tick_limit: usize, context: &str) {
344 for _ in 0..tick_limit {
345 self.tick();
346 if self.fetch_ready(canister_id) {
347 return;
348 }
349 }
350
351 self.dump_canister_debug(canister_id, context);
352 panic!("{context}: canister {canister_id} did not become ready after {tick_limit} ticks");
353 }
354
355 pub fn wait_for_all_ready<I>(&self, canister_ids: I, tick_limit: usize, context: &str)
357 where
358 I: IntoIterator<Item = Principal>,
359 {
360 let canister_ids = canister_ids.into_iter().collect::<Vec<_>>();
361
362 for _ in 0..tick_limit {
363 self.tick();
364 if canister_ids
365 .iter()
366 .copied()
367 .all(|canister_id| self.fetch_ready(canister_id))
368 {
369 return;
370 }
371 }
372
373 for canister_id in &canister_ids {
374 self.dump_canister_debug(*canister_id, context);
375 }
376 panic!("{context}: canisters did not become ready after {tick_limit} ticks");
377 }
378
379 pub fn dump_canister_debug(&self, canister_id: Principal, context: &str) {
381 eprintln!("{context}: debug for canister {canister_id}");
382
383 match self.canister_status(canister_id, None) {
384 Ok(status) => eprintln!("canister_status: {status:?}"),
385 Err(err) => eprintln!("canister_status failed: {err:?}"),
386 }
387
388 match self.fetch_canister_logs(canister_id, Principal::anonymous()) {
389 Ok(records) => {
390 if records.is_empty() {
391 eprintln!("canister logs: <empty>");
392 } else {
393 for record in records {
394 eprintln!("canister log: {record:?}");
395 }
396 }
397 }
398 Err(err) => eprintln!("fetch_canister_logs failed: {err:?}"),
399 }
400 }
401
402 pub fn capture_controller_snapshots<I>(
404 &self,
405 controller_id: Principal,
406 canister_ids: I,
407 ) -> Option<ControllerSnapshots>
408 where
409 I: IntoIterator<Item = Principal>,
410 {
411 let mut snapshots = HashMap::new();
412
413 for canister_id in canister_ids {
414 let Some(snapshot) = self.try_take_controller_snapshot(controller_id, canister_id)
415 else {
416 eprintln!(
417 "capture_controller_snapshots: snapshot capture unavailable for {canister_id}"
418 );
419 return None;
420 };
421 snapshots.insert(canister_id, snapshot);
422 }
423
424 Some(ControllerSnapshots(snapshots))
425 }
426
427 pub fn restore_controller_snapshots(
429 &self,
430 controller_id: Principal,
431 snapshots: &ControllerSnapshots,
432 ) {
433 for (canister_id, snapshot) in &snapshots.0 {
434 self.restore_controller_snapshot(controller_id, *canister_id, snapshot);
435 }
436 }
437
438 pub fn update_call<T, A>(
440 &self,
441 canister_id: Principal,
442 method: &str,
443 args: A,
444 ) -> Result<T, Error>
445 where
446 T: CandidType + DeserializeOwned,
447 A: ArgumentEncoder,
448 {
449 let bytes: Vec<u8> = encode_args(args)
450 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
451 let result = self
452 .inner
453 .update_call(canister_id, Principal::anonymous(), method, bytes)
454 .map_err(|err| {
455 Error::internal(format!(
456 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
457 ))
458 })?;
459
460 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
461 }
462
463 pub fn update_call_as<T, A>(
465 &self,
466 canister_id: Principal,
467 caller: Principal,
468 method: &str,
469 args: A,
470 ) -> Result<T, Error>
471 where
472 T: CandidType + DeserializeOwned,
473 A: ArgumentEncoder,
474 {
475 let bytes: Vec<u8> = encode_args(args)
476 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
477 let result = self
478 .inner
479 .update_call(canister_id, caller, method, bytes)
480 .map_err(|err| {
481 Error::internal(format!(
482 "pocket_ic update_call failed (canister={canister_id}, method={method}): {err}"
483 ))
484 })?;
485
486 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
487 }
488
489 pub fn query_call<T, A>(
491 &self,
492 canister_id: Principal,
493 method: &str,
494 args: A,
495 ) -> Result<T, Error>
496 where
497 T: CandidType + DeserializeOwned,
498 A: ArgumentEncoder,
499 {
500 let bytes: Vec<u8> = encode_args(args)
501 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
502 let result = self
503 .inner
504 .query_call(canister_id, Principal::anonymous(), method, bytes)
505 .map_err(|err| {
506 Error::internal(format!(
507 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
508 ))
509 })?;
510
511 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
512 }
513
514 pub fn query_call_as<T, A>(
516 &self,
517 canister_id: Principal,
518 caller: Principal,
519 method: &str,
520 args: A,
521 ) -> Result<T, Error>
522 where
523 T: CandidType + DeserializeOwned,
524 A: ArgumentEncoder,
525 {
526 let bytes: Vec<u8> = encode_args(args)
527 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))?;
528 let result = self
529 .inner
530 .query_call(canister_id, caller, method, bytes)
531 .map_err(|err| {
532 Error::internal(format!(
533 "pocket_ic query_call failed (canister={canister_id}, method={method}): {err}"
534 ))
535 })?;
536
537 decode_one(&result).map_err(|err| Error::internal(format!("decode_one failed: {err}")))
538 }
539
540 pub fn tick_n(&self, times: usize) {
542 for _ in 0..times {
543 self.tick();
544 }
545 }
546
547 fn create_funded_and_install(&self, wasm: Vec<u8>, init_bytes: Vec<u8>) -> Principal {
549 let canister_id = self.create_canister();
550 self.add_cycles(canister_id, INSTALL_CYCLES);
551
552 let install = catch_unwind(AssertUnwindSafe(|| {
553 self.inner
554 .install_canister(canister_id, wasm, init_bytes, None);
555 }));
556 if let Err(err) = install {
557 eprintln!("install_canister trapped for {canister_id}");
558 if let Ok(status) = self.inner.canister_status(canister_id, None) {
559 eprintln!("canister_status for {canister_id}: {status:?}");
560 }
561 if let Ok(logs) = self
562 .inner
563 .fetch_canister_logs(canister_id, Principal::anonymous())
564 {
565 for record in logs {
566 eprintln!("canister_log {canister_id}: {record:?}");
567 }
568 }
569 std::panic::resume_unwind(err);
570 }
571
572 canister_id
573 }
574
575 fn fetch_ready(&self, canister_id: Principal) -> bool {
577 match self.query_call(canister_id, protocol::CANIC_READY, ()) {
578 Ok(ready) => ready,
579 Err(err) => {
580 self.dump_canister_debug(canister_id, "query canic_ready failed");
581 panic!("query canic_ready failed: {err:?}");
582 }
583 }
584 }
585
586 fn try_take_controller_snapshot(
588 &self,
589 controller_id: Principal,
590 canister_id: Principal,
591 ) -> Option<ControllerSnapshot> {
592 let candidates = controller_sender_candidates(controller_id, canister_id);
593 let mut last_err = None;
594
595 for sender in candidates {
596 match self.take_canister_snapshot(canister_id, sender, None) {
597 Ok(snapshot) => {
598 return Some(ControllerSnapshot {
599 snapshot_id: snapshot.id,
600 sender,
601 });
602 }
603 Err(err) => last_err = Some((sender, err)),
604 }
605 }
606
607 if let Some((sender, err)) = last_err {
608 eprintln!(
609 "failed to capture canister snapshot for {canister_id} using sender {sender:?}: {err}"
610 );
611 }
612 None
613 }
614
615 fn restore_controller_snapshot(
617 &self,
618 controller_id: Principal,
619 canister_id: Principal,
620 snapshot: &ControllerSnapshot,
621 ) {
622 let fallback_sender = if snapshot.sender.is_some() {
623 None
624 } else {
625 Some(controller_id)
626 };
627 let candidates = [snapshot.sender, fallback_sender];
628 let mut last_err = None;
629
630 for sender in candidates {
631 match self.load_canister_snapshot(canister_id, sender, snapshot.snapshot_id.clone()) {
632 Ok(()) => return,
633 Err(err) => last_err = Some((sender, err)),
634 }
635 }
636
637 let (sender, err) =
638 last_err.expect("snapshot restore must have at least one sender attempt");
639 panic!(
640 "failed to restore canister snapshot for {canister_id} using sender {sender:?}: {err}"
641 );
642 }
643}
644
645impl Drop for ProcessLockGuard {
646 fn drop(&mut self) {
647 let _ = fs::remove_dir_all(&self.path);
648 }
649}
650
651impl Drop for PicSerialGuard {
652 fn drop(&mut self) {
653 let mut state = PIC_PROCESS_LOCK_STATE
654 .lock()
655 .unwrap_or_else(std::sync::PoisonError::into_inner);
656
657 state.ref_count = state
658 .ref_count
659 .checked_sub(1)
660 .expect("PocketIC serial guard refcount underflow");
661 if state.ref_count == 0 {
662 state.process_lock.take();
663 }
664 }
665}
666
667impl Deref for Pic {
668 type Target = PocketIc;
669
670 fn deref(&self) -> &Self::Target {
671 &self.inner
672 }
673}
674
675impl DerefMut for Pic {
676 fn deref_mut(&mut self) -> &mut Self::Target {
677 &mut self.inner
678 }
679}
680
681fn install_args(role: CanisterRole) -> Result<Vec<u8>, Error> {
696 if role.is_root() {
697 install_root_args()
698 } else {
699 let env = EnvBootstrapArgs {
702 prime_root_pid: None,
703 subnet_role: None,
704 subnet_pid: None,
705 root_pid: None,
706 canister_role: Some(role),
707 parent_pid: None,
708 };
709
710 let payload = CanisterInitPayload {
713 env,
714 app_directory: AppDirectoryArgs(Vec::new()),
715 subnet_directory: SubnetDirectoryArgs(Vec::new()),
716 };
717
718 encode_args::<(CanisterInitPayload, Option<Vec<u8>>)>((payload, None))
719 .map_err(|err| Error::internal(format!("encode_args failed: {err}")))
720 }
721}
722
723fn install_root_args() -> Result<Vec<u8>, Error> {
724 encode_one(SubnetIdentity::Manual)
725 .map_err(|err| Error::internal(format!("encode_one failed: {err}")))
726}
727
728fn controller_sender_candidates(
730 controller_id: Principal,
731 canister_id: Principal,
732) -> [Option<Principal>; 2] {
733 if canister_id == controller_id {
734 [None, Some(controller_id)]
735 } else {
736 [Some(controller_id), None]
737 }
738}
739
740fn acquire_process_lock() -> ProcessLockGuard {
741 let lock_dir = env::temp_dir().join(PIC_PROCESS_LOCK_DIR_NAME);
742 let started_waiting = Instant::now();
743 let mut logged_wait = false;
744
745 loop {
746 match fs::create_dir(&lock_dir) {
747 Ok(()) => {
748 fs::write(
749 process_lock_owner_path(&lock_dir),
750 process::id().to_string(),
751 )
752 .unwrap_or_else(|err| {
753 let _ = fs::remove_dir(&lock_dir);
754 panic!(
755 "failed to record PocketIC process lock owner at {}: {err}",
756 lock_dir.display()
757 );
758 });
759
760 if logged_wait {
761 eprintln!(
762 "[canic_testkit::pic] acquired cross-process PocketIC lock at {}",
763 lock_dir.display()
764 );
765 }
766
767 return ProcessLockGuard { path: lock_dir };
768 }
769 Err(err) if err.kind() == io::ErrorKind::AlreadyExists => {
770 if process_lock_is_stale(&lock_dir) && clear_stale_process_lock(&lock_dir).is_ok() {
771 continue;
772 }
773
774 if !logged_wait && started_waiting.elapsed() >= PIC_PROCESS_LOCK_LOG_AFTER {
775 eprintln!(
776 "[canic_testkit::pic] waiting for cross-process PocketIC lock at {}",
777 lock_dir.display()
778 );
779 logged_wait = true;
780 }
781
782 thread::sleep(PIC_PROCESS_LOCK_RETRY_DELAY);
783 }
784 Err(err) => panic!(
785 "failed to create PocketIC process lock dir at {}: {err}",
786 lock_dir.display()
787 ),
788 }
789 }
790}
791
792fn process_lock_owner_path(lock_dir: &Path) -> PathBuf {
793 lock_dir.join("owner")
794}
795
796fn clear_stale_process_lock(lock_dir: &Path) -> io::Result<()> {
797 match fs::remove_dir_all(lock_dir) {
798 Ok(()) => Ok(()),
799 Err(err) if err.kind() == io::ErrorKind::NotFound => Ok(()),
800 Err(err) => Err(err),
801 }
802}
803
804fn process_lock_is_stale(lock_dir: &Path) -> bool {
805 let Ok(pid_text) = fs::read_to_string(process_lock_owner_path(lock_dir)) else {
806 return true;
807 };
808 let Ok(pid) = pid_text.trim().parse::<u32>() else {
809 return true;
810 };
811
812 !Path::new("/proc").join(pid.to_string()).exists()
813}
814
815#[cfg(test)]
816mod process_lock_tests {
817 use super::{clear_stale_process_lock, process_lock_is_stale, process_lock_owner_path};
818 use std::{
819 fs,
820 path::PathBuf,
821 time::{SystemTime, UNIX_EPOCH},
822 };
823
824 fn unique_lock_dir() -> PathBuf {
825 let nanos = SystemTime::now()
826 .duration_since(UNIX_EPOCH)
827 .expect("clock must be after unix epoch")
828 .as_nanos();
829 std::env::temp_dir().join(format!("canic-pocket-ic-test-lock-{nanos}"))
830 }
831
832 #[test]
833 fn stale_process_lock_is_detected_and_removed() {
834 let lock_dir = unique_lock_dir();
835 fs::create_dir(&lock_dir).expect("create lock dir");
836 fs::write(process_lock_owner_path(&lock_dir), "999999").expect("write stale owner");
837
838 assert!(process_lock_is_stale(&lock_dir));
839 clear_stale_process_lock(&lock_dir).expect("remove stale lock dir");
840 assert!(!lock_dir.exists());
841 }
842}