1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use anyhow::Context;
use async_trait::async_trait;
use candid::utils::{decode_args, encode_args, ArgumentDecoder, ArgumentEncoder};
use candid::Principal;
use ic_test_state_machine_client::StateMachine;
use icrc1_test_env::LedgerEnv;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

fn new_principal(n: u64) -> Principal {
    let mut bytes = n.to_le_bytes().to_vec();
    bytes.push(0xfe);
    bytes.push(0x01);
    Principal::try_from_slice(&bytes[..]).unwrap()
}

#[derive(Clone)]
pub struct SMLedger {
    counter: Arc<AtomicU64>,
    sm: Arc<StateMachine>,
    sender: Principal,
    canister_id: Principal,
}

#[async_trait(?Send)]
impl LedgerEnv for SMLedger {
    fn fork(&self) -> Self {
        Self {
            counter: self.counter.clone(),
            sm: self.sm.clone(),
            sender: new_principal(self.counter.fetch_add(1, Ordering::Relaxed)),
            canister_id: self.canister_id,
        }
    }

    fn principal(&self) -> Principal {
        self.sender
    }

    fn time(&self) -> std::time::SystemTime {
        self.sm.time()
    }

    async fn query<Input, Output>(&self, method: &str, input: Input) -> anyhow::Result<Output>
    where
        Input: ArgumentEncoder + std::fmt::Debug,
        Output: for<'a> ArgumentDecoder<'a>,
    {
        let debug_inputs = format!("{:?}", input);
        let in_bytes = encode_args(input)
            .with_context(|| format!("Failed to encode arguments {}", debug_inputs))?;
        match self
            .sm
            .query_call(
                Principal::from_slice(self.canister_id.as_slice()),
                Principal::from_slice(self.sender.as_slice()),
                method,
                in_bytes,
            )
            .map_err(|err| anyhow::Error::msg(err.to_string()))?
        {
            ic_test_state_machine_client::WasmResult::Reply(bytes) => decode_args(&bytes)
                .with_context(|| {
                    format!(
                        "Failed to decode method {} response into type {}, bytes: {}",
                        method,
                        std::any::type_name::<Output>(),
                        hex::encode(bytes)
                    )
                }),
            ic_test_state_machine_client::WasmResult::Reject(msg) => {
                return Err(anyhow::Error::msg(format!(
                    "Query call to ledger {:?} was rejected: {}",
                    self.canister_id, msg
                )))
            }
        }
    }

    async fn update<Input, Output>(&self, method: &str, input: Input) -> anyhow::Result<Output>
    where
        Input: ArgumentEncoder + std::fmt::Debug,
        Output: for<'a> ArgumentDecoder<'a>,
    {
        let debug_inputs = format!("{:?}", input);
        let in_bytes = encode_args(input)
            .with_context(|| format!("Failed to encode arguments {}", debug_inputs))?;
        match self
            .sm
            .update_call(self.canister_id, self.sender, method, in_bytes)
            .map_err(|err| anyhow::Error::msg(err.to_string()))
            .with_context(|| {
                format!(
                    "failed to execute update call {} on canister {}",
                    method, self.canister_id
                )
            })? {
            ic_test_state_machine_client::WasmResult::Reply(bytes) => decode_args(&bytes)
                .with_context(|| {
                    format!(
                        "Failed to decode method {} response into type {}, bytes: {}",
                        method,
                        std::any::type_name::<Output>(),
                        hex::encode(bytes)
                    )
                }),
            ic_test_state_machine_client::WasmResult::Reject(msg) => {
                return Err(anyhow::Error::msg(format!(
                    "Query call to ledger {:?} was rejected: {}",
                    self.canister_id, msg
                )))
            }
        }
    }
}

impl SMLedger {
    pub fn new(sm: Arc<StateMachine>, canister_id: Principal, sender: Principal) -> Self {
        Self {
            counter: Arc::new(AtomicU64::new(0)),
            sm,
            canister_id,
            sender,
        }
    }
}