use candid::utils::{ArgumentDecoder, ArgumentEncoder};
use candid::{decode_args, encode_args, Principal};
use ciborium::de::from_reader;
use ic_cdk::api::management_canister::main::{
CanisterId, CanisterIdRecord, CanisterInstallMode, CanisterSettings, CreateCanisterArgument,
InstallCodeArgument,
};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_bytes::ByteBuf;
use std::cell::RefCell;
use std::fmt;
use std::io::{Read, Write};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use std::time::{Duration, SystemTime};
#[derive(Serialize, Deserialize)]
pub enum Request {
RootKey,
Time,
SetTime(SystemTime),
AdvanceTime(Duration),
CanisterUpdateCall(CanisterCall),
CanisterQueryCall(CanisterCall),
CanisterExists(RawCanisterId),
CyclesBalance(RawCanisterId),
AddCycles(AddCyclesArg),
SetStableMemory(SetStableMemoryArg),
ReadStableMemory(RawCanisterId),
Tick,
RunUntilCompletion(RunUntilCompletionArg),
VerifyCanisterSig(VerifyCanisterSigArg),
}
#[derive(Serialize, Deserialize)]
pub struct VerifyCanisterSigArg {
pub msg: Vec<u8>,
pub sig: Vec<u8>,
pub pubkey: Vec<u8>,
pub root_pubkey: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
pub struct RunUntilCompletionArg {
pub max_ticks: u64,
}
#[derive(Serialize, Deserialize)]
pub struct AddCyclesArg {
pub canister_id: Vec<u8>,
pub amount: u128,
}
#[derive(Serialize, Deserialize)]
pub struct SetStableMemoryArg {
pub canister_id: Vec<u8>,
pub data: ByteBuf,
}
#[derive(Serialize, Deserialize)]
pub struct RawCanisterId {
pub canister_id: Vec<u8>,
}
impl From<Principal> for RawCanisterId {
fn from(principal: Principal) -> Self {
Self {
canister_id: principal.as_slice().to_vec(),
}
}
}
#[derive(Serialize, Deserialize)]
pub struct CanisterCall {
pub sender: Vec<u8>,
pub canister_id: Vec<u8>,
pub method: String,
pub arg: Vec<u8>,
}
pub struct StateMachine {
proc: Child,
child_in: RefCell<ChildStdin>,
child_out: RefCell<ChildStdout>,
}
impl StateMachine {
pub fn new(binary_path: &str, debug: bool) -> Self {
let mut command = Command::new(binary_path);
command
.env("LOG_TO_STDERR", "1")
.stdin(Stdio::piped())
.stdout(Stdio::piped());
if debug {
command.arg("--debug");
}
let mut child = command.spawn().unwrap_or_else(|err| {
panic!(
"failed to start test state machine at path {}: {:?}",
binary_path, err
)
});
let child_in = child.stdin.take().unwrap();
let child_out = child.stdout.take().unwrap();
Self {
proc: child,
child_in: RefCell::new(child_in),
child_out: RefCell::new(child_out),
}
}
pub fn update_call(
&self,
canister_id: Principal,
sender: Principal,
method: &str,
arg: Vec<u8>,
) -> Result<WasmResult, UserError> {
self.call_state_machine(Request::CanisterUpdateCall(CanisterCall {
sender: sender.as_slice().to_vec(),
canister_id: canister_id.as_slice().to_vec(),
method: method.to_string(),
arg,
}))
}
pub fn query_call(
&self,
canister_id: Principal,
sender: Principal,
method: &str,
arg: Vec<u8>,
) -> Result<WasmResult, UserError> {
self.call_state_machine(Request::CanisterQueryCall(CanisterCall {
sender: sender.as_slice().to_vec(),
canister_id: canister_id.as_slice().to_vec(),
method: method.to_string(),
arg,
}))
}
pub fn root_key(&self) -> Vec<u8> {
self.call_state_machine(Request::RootKey)
}
pub fn create_canister(&self, sender: Option<Principal>) -> CanisterId {
let CanisterIdRecord { canister_id } = call_candid_as(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"create_canister",
(CreateCanisterArgument { settings: None },),
)
.map(|(x,)| x)
.unwrap();
canister_id
}
pub fn create_canister_with_settings(
&self,
settings: Option<CanisterSettings>,
sender: Option<Principal>,
) -> CanisterId {
let CanisterIdRecord { canister_id } = call_candid_as(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"create_canister",
(CreateCanisterArgument { settings },),
)
.map(|(x,)| x)
.unwrap();
canister_id
}
pub fn install_canister(
&self,
canister_id: CanisterId,
wasm_module: Vec<u8>,
arg: Vec<u8>,
sender: Option<Principal>,
) {
call_candid_as::<(InstallCodeArgument,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"install_code",
(InstallCodeArgument {
mode: CanisterInstallMode::Install,
canister_id,
wasm_module,
arg,
},),
)
.unwrap();
}
pub fn upgrade_canister(
&self,
canister_id: CanisterId,
wasm_module: Vec<u8>,
arg: Vec<u8>,
sender: Option<Principal>,
) -> Result<(), CallError> {
call_candid_as::<(InstallCodeArgument,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"install_code",
(InstallCodeArgument {
mode: CanisterInstallMode::Upgrade,
canister_id,
wasm_module,
arg,
},),
)
}
pub fn reinstall_canister(
&self,
canister_id: CanisterId,
wasm_module: Vec<u8>,
arg: Vec<u8>,
sender: Option<Principal>,
) -> Result<(), CallError> {
call_candid_as::<(InstallCodeArgument,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"install_code",
(InstallCodeArgument {
mode: CanisterInstallMode::Reinstall,
canister_id,
wasm_module,
arg,
},),
)
}
pub fn start_canister(
&self,
canister_id: CanisterId,
sender: Option<Principal>,
) -> Result<(), CallError> {
call_candid_as::<(CanisterIdRecord,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"start_canister",
(CanisterIdRecord { canister_id },),
)
}
pub fn stop_canister(
&self,
canister_id: CanisterId,
sender: Option<Principal>,
) -> Result<(), CallError> {
call_candid_as::<(CanisterIdRecord,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"stop_canister",
(CanisterIdRecord { canister_id },),
)
}
pub fn delete_canister(
&self,
canister_id: CanisterId,
sender: Option<Principal>,
) -> Result<(), CallError> {
call_candid_as::<(CanisterIdRecord,), ()>(
self,
Principal::management_canister(),
sender.unwrap_or(Principal::anonymous()),
"delete_canister",
(CanisterIdRecord { canister_id },),
)
}
pub fn canister_exists(&self, canister_id: Principal) -> bool {
self.call_state_machine(Request::CanisterExists(RawCanisterId::from(canister_id)))
}
pub fn time(&self) -> SystemTime {
self.call_state_machine(Request::Time)
}
pub fn set_time(&self, time: SystemTime) {
self.call_state_machine(Request::SetTime(time))
}
pub fn advance_time(&self, duration: Duration) {
self.call_state_machine(Request::AdvanceTime(duration))
}
pub fn tick(&self) {
self.call_state_machine(Request::Tick)
}
pub fn run_until_completion(&self, max_ticks: u64) {
self.call_state_machine(Request::RunUntilCompletion(RunUntilCompletionArg {
max_ticks,
}))
}
pub fn stable_memory(&self, canister_id: Principal) -> Vec<u8> {
self.call_state_machine(Request::ReadStableMemory(RawCanisterId::from(canister_id)))
}
pub fn set_stable_memory(&self, canister_id: Principal, data: ByteBuf) {
self.call_state_machine(Request::SetStableMemory(SetStableMemoryArg {
canister_id: canister_id.as_slice().to_vec(),
data,
}))
}
pub fn cycle_balance(&self, canister_id: Principal) -> u128 {
self.call_state_machine(Request::CyclesBalance(RawCanisterId::from(canister_id)))
}
pub fn add_cycles(&self, canister_id: Principal, amount: u128) -> u128 {
self.call_state_machine(Request::AddCycles(AddCyclesArg {
canister_id: canister_id.as_slice().to_vec(),
amount,
}))
}
pub fn verify_canister_signature(
&self,
msg: Vec<u8>,
sig: Vec<u8>,
pubkey: Vec<u8>,
root_pubkey: Vec<u8>,
) -> Result<(), String> {
self.call_state_machine(Request::VerifyCanisterSig(VerifyCanisterSigArg {
msg,
sig,
pubkey,
root_pubkey,
}))
}
fn call_state_machine<T: DeserializeOwned>(&self, request: Request) -> T {
self.send_request(request);
self.read_response()
}
fn send_request(&self, request: Request) {
let mut cbor = vec![];
ciborium::ser::into_writer(&request, &mut cbor).expect("failed to serialize request");
let mut child_in = self.child_in.borrow_mut();
child_in
.write_all(&(cbor.len() as u64).to_le_bytes())
.expect("failed to send request length");
child_in
.write_all(cbor.as_slice())
.expect("failed to send request data");
child_in.flush().expect("failed to flush child stdin");
}
fn read_response<T: DeserializeOwned>(&self) -> T {
let vec = self.read_bytes(8);
let size = usize::from_le_bytes(TryFrom::try_from(vec).expect("failed to read data size"));
from_reader(&self.read_bytes(size)[..]).expect("failed to deserialize response")
}
fn read_bytes(&self, num_bytes: usize) -> Vec<u8> {
let mut buf = vec![0u8; num_bytes];
self.child_out
.borrow_mut()
.read_exact(&mut buf)
.expect("failed to read from child_stdout");
buf
}
}
impl Drop for StateMachine {
fn drop(&mut self) {
self.proc
.kill()
.expect("failed to kill state machine process")
}
}
pub fn query_candid<Input, Output>(
env: &StateMachine,
canister_id: Principal,
method: &str,
input: Input,
) -> Result<Output, CallError>
where
Input: ArgumentEncoder,
Output: for<'a> ArgumentDecoder<'a>,
{
query_candid_as(env, canister_id, Principal::anonymous(), method, input)
}
pub fn query_candid_as<Input, Output>(
env: &StateMachine,
canister_id: Principal,
sender: Principal,
method: &str,
input: Input,
) -> Result<Output, CallError>
where
Input: ArgumentEncoder,
Output: for<'a> ArgumentDecoder<'a>,
{
with_candid(input, |bytes| {
env.query_call(canister_id, sender, method, bytes)
})
}
pub fn call_candid_as<Input, Output>(
env: &StateMachine,
canister_id: Principal,
sender: Principal,
method: &str,
input: Input,
) -> Result<Output, CallError>
where
Input: ArgumentEncoder,
Output: for<'a> ArgumentDecoder<'a>,
{
with_candid(input, |bytes| {
env.update_call(canister_id, sender, method, bytes)
})
}
pub fn call_candid<Input, Output>(
env: &StateMachine,
canister_id: Principal,
method: &str,
input: Input,
) -> Result<Output, CallError>
where
Input: ArgumentEncoder,
Output: for<'a> ArgumentDecoder<'a>,
{
call_candid_as(env, canister_id, Principal::anonymous(), method, input)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ErrorCode {
SubnetOversubscribed = 101,
MaxNumberOfCanistersReached = 102,
CanisterOutputQueueFull = 201,
IngressMessageTimeout = 202,
CanisterQueueNotEmpty = 203,
CanisterNotFound = 301,
CanisterMethodNotFound = 302,
CanisterAlreadyInstalled = 303,
CanisterWasmModuleNotFound = 304,
InsufficientMemoryAllocation = 402,
InsufficientCyclesForCreateCanister = 403,
SubnetNotFound = 404,
CanisterNotHostedBySubnet = 405,
CanisterOutOfCycles = 501,
CanisterTrapped = 502,
CanisterCalledTrap = 503,
CanisterContractViolation = 504,
CanisterInvalidWasm = 505,
CanisterDidNotReply = 506,
CanisterOutOfMemory = 507,
CanisterStopped = 508,
CanisterStopping = 509,
CanisterNotStopped = 510,
CanisterStoppingCancelled = 511,
CanisterInvalidController = 512,
CanisterFunctionNotFound = 513,
CanisterNonEmpty = 514,
CertifiedStateUnavailable = 515,
CanisterRejectedMessage = 516,
QueryCallGraphLoopDetected = 517,
UnknownManagementMessage = 518,
InvalidManagementPayload = 519,
InsufficientCyclesInCall = 520,
CanisterWasmEngineError = 521,
CanisterInstructionLimitExceeded = 522,
CanisterInstallCodeRateLimited = 523,
CanisterMemoryAccessLimitExceeded = 524,
QueryCallGraphTooDeep = 525,
QueryCallGraphTotalInstructionLimitExceeded = 526,
CompositeQueryCalledInReplicatedMode = 527,
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "IC{:04}", *self as i32)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct UserError {
pub code: ErrorCode,
pub description: String,
}
impl fmt::Display for UserError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.code, self.description)
}
}
#[derive(Debug)]
pub enum CallError {
Reject(String),
UserError(UserError),
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum WasmResult {
Reply(#[serde(with = "serde_bytes")] Vec<u8>),
Reject(String),
}
pub fn with_candid<Input, Output>(
input: Input,
f: impl FnOnce(Vec<u8>) -> Result<WasmResult, UserError>,
) -> Result<Output, CallError>
where
Input: ArgumentEncoder,
Output: for<'a> ArgumentDecoder<'a>,
{
let in_bytes = encode_args(input).expect("failed to encode args");
match f(in_bytes) {
Ok(WasmResult::Reply(out_bytes)) => Ok(decode_args(&out_bytes).unwrap_or_else(|e| {
panic!(
"Failed to decode response as candid type {}:\nerror: {}\nbytes: {:?}\nutf8: {}",
std::any::type_name::<Output>(),
e,
out_bytes,
String::from_utf8_lossy(&out_bytes),
)
})),
Ok(WasmResult::Reject(message)) => Err(CallError::Reject(message)),
Err(user_error) => Err(CallError::UserError(user_error)),
}
}