multiversx_sc/types/interaction/
callback_closure.rs

1use unwrap_infallible::UnwrapInfallible;
2
3use crate::{
4    api::{BlockchainApi, ErrorApi, ManagedTypeApi, StorageReadApi, StorageWriteApi},
5    codec::{
6        self,
7        derive::{TopDecode, TopEncode},
8        TopEncodeMulti,
9    },
10    contract_base::{BlockchainWrapper, ExitCodecErrorHandler, ManagedSerializer},
11    err_msg,
12    storage::StorageKey,
13    storage_clear, storage_get, storage_set,
14    types::{ManagedBuffer, ManagedType, ManagedVecRefIterator},
15};
16
17use super::ManagedArgBuffer;
18
19pub const CALLBACK_CLOSURE_STORAGE_BASE_KEY: &[u8] = b"CB_CLOSURE";
20
21/// Object that encodes full async callback data.
22///
23/// Should not be created manually, we have auto-generated call proxies
24/// that will create this object in a type-safe manner.
25///
26/// How it functions:
27/// - With the old async call mechanism, this data is serialized to storage.
28/// - With the new promises framework, the VM handles this data.
29///
30/// In both cases the framework hides all the magic, the developer shouldn't worry about it.
31#[derive(TopEncode)]
32pub struct CallbackClosure<M>
33where
34    M: ManagedTypeApi + ErrorApi,
35{
36    pub(super) callback_name: &'static str,
37    pub(super) closure_args: ManagedArgBuffer<M>,
38}
39
40pub struct CallbackClosureWithGas<M>
41where
42    M: ManagedTypeApi + ErrorApi,
43{
44    pub(super) closure: CallbackClosure<M>,
45    pub(super) gas_for_callback: u64,
46}
47
48/// Syntactical sugar to help macros to generate code easier.
49/// Unlike calling `CallbackClosure::<SA, R>::new`, here types can be inferred from the context.
50pub fn new_callback_call<A>(callback_name: &'static str) -> CallbackClosure<A>
51where
52    A: ManagedTypeApi + ErrorApi,
53{
54    CallbackClosure::new(callback_name)
55}
56
57impl<M: ManagedTypeApi + ErrorApi> CallbackClosure<M> {
58    pub fn new(callback_name: &'static str) -> Self {
59        CallbackClosure {
60            callback_name,
61            closure_args: ManagedArgBuffer::new(),
62        }
63    }
64
65    pub fn push_endpoint_arg<T: TopEncodeMulti>(&mut self, endpoint_arg: &T) {
66        let h = ExitCodecErrorHandler::<M>::from(err_msg::CONTRACT_CALL_ENCODE_ERROR);
67        endpoint_arg
68            .multi_encode_or_handle_err(&mut self.closure_args, h)
69            .unwrap_infallible()
70    }
71
72    pub fn save_to_storage<A: BlockchainApi + StorageWriteApi>(&self) {
73        let storage_key = cb_closure_storage_key::<A>();
74        storage_set(storage_key.as_ref(), self);
75    }
76}
77
78pub(super) fn cb_closure_storage_key<A: BlockchainApi>() -> StorageKey<A> {
79    let tx_hash = BlockchainWrapper::<A>::new().get_tx_hash();
80    let mut storage_key = StorageKey::new(CALLBACK_CLOSURE_STORAGE_BASE_KEY);
81    storage_key.append_managed_buffer(tx_hash.as_managed_buffer());
82    storage_key
83}
84
85/// Similar object to `CallbackClosure`, but only used for deserializing from storage
86/// the callback data with the old async call mechanism.
87///
88/// Should not be visible to the developer.
89///
90/// It is a separate type from `CallbackClosure`, because we want a different representation of the endpoint name.
91#[derive(TopDecode)]
92pub struct CallbackClosureForDeser<M: ManagedTypeApi + ErrorApi> {
93    callback_name: ManagedBuffer<M>,
94    closure_args: ManagedArgBuffer<M>,
95}
96
97impl<M: ManagedTypeApi + ErrorApi> CallbackClosureForDeser<M> {
98    /// Used by callback_raw.
99    /// TODO: avoid creating any new managed buffers.
100    pub fn no_callback() -> Self {
101        CallbackClosureForDeser {
102            callback_name: ManagedBuffer::new(),
103            closure_args: ManagedArgBuffer::new(),
104        }
105    }
106
107    pub fn storage_load_and_clear<A: BlockchainApi + StorageReadApi + StorageWriteApi>(
108    ) -> Option<Self> {
109        let storage_key = cb_closure_storage_key::<A>();
110        let storage_value_raw: ManagedBuffer<A> = storage_get(storage_key.as_ref());
111        if !storage_value_raw.is_empty() {
112            let serializer = ManagedSerializer::<A>::new();
113            let closure = serializer.top_decode_from_managed_buffer(&storage_value_raw);
114            storage_clear(storage_key.as_ref());
115            Some(closure)
116        } else {
117            None
118        }
119    }
120
121    pub fn matcher<const CB_NAME_MAX_LENGTH: usize>(
122        &self,
123    ) -> CallbackClosureMatcher<CB_NAME_MAX_LENGTH> {
124        CallbackClosureMatcher::new(&self.callback_name)
125    }
126
127    pub fn arg_iter(&self) -> ManagedVecRefIterator<'_, M, ManagedBuffer<M>> {
128        self.closure_args.iter_buffers()
129    }
130}
131
132/// Helps the callback macro expansion to perform callback name matching more efficiently.
133/// The current implementation hashes by callback name length,
134/// but in principle further optimizations are possible.
135pub struct CallbackClosureMatcher<const CB_NAME_MAX_LENGTH: usize> {
136    name_len: usize,
137    compare_buffer: [u8; CB_NAME_MAX_LENGTH],
138}
139
140impl<const CB_NAME_MAX_LENGTH: usize> CallbackClosureMatcher<CB_NAME_MAX_LENGTH> {
141    pub fn new<M: ManagedTypeApi + ErrorApi>(callback_name: &ManagedBuffer<M>) -> Self {
142        let mut compare_buffer = [0u8; CB_NAME_MAX_LENGTH];
143        let name_len = callback_name.len();
144        callback_name.load_slice(0, &mut compare_buffer[..name_len]);
145        CallbackClosureMatcher {
146            name_len,
147            compare_buffer,
148        }
149    }
150
151    pub fn new_from_unmanaged(callback_name: &[u8]) -> Self {
152        let mut compare_buffer = [0u8; CB_NAME_MAX_LENGTH];
153        let name_len = callback_name.len();
154        compare_buffer[..name_len].copy_from_slice(callback_name);
155        CallbackClosureMatcher {
156            name_len,
157            compare_buffer,
158        }
159    }
160
161    pub fn matches_empty(&self) -> bool {
162        self.name_len == 0
163    }
164
165    pub fn name_matches(&self, name_match: &[u8]) -> bool {
166        if self.name_len != name_match.len() {
167            false
168        } else {
169            &self.compare_buffer[..self.name_len] == name_match
170        }
171    }
172}