apollo_cw_multi_test/
contracts.rs

1use schemars::JsonSchema;
2use serde::de::DeserializeOwned;
3use std::error::Error;
4use std::fmt::{self, Debug, Display};
5use std::ops::Deref;
6
7use cosmwasm_std::{
8    from_slice, Binary, CosmosMsg, CustomQuery, Deps, DepsMut, Empty, Env, MessageInfo,
9    QuerierWrapper, Reply, Response, SubMsg,
10};
11
12use anyhow::{anyhow, bail, Result as AnyResult};
13
14/// Interface to call into a Contract
15pub trait Contract<T, Q = Empty>
16where
17    T: Clone + fmt::Debug + PartialEq + JsonSchema,
18    Q: CustomQuery,
19{
20    fn execute(
21        &self,
22        deps: DepsMut<Q>,
23        env: Env,
24        info: MessageInfo,
25        msg: Vec<u8>,
26    ) -> AnyResult<Response<T>>;
27
28    fn instantiate(
29        &self,
30        deps: DepsMut<Q>,
31        env: Env,
32        info: MessageInfo,
33        msg: Vec<u8>,
34    ) -> AnyResult<Response<T>>;
35
36    fn query(&self, deps: Deps<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Binary>;
37
38    fn sudo(&self, deps: DepsMut<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Response<T>>;
39
40    fn reply(&self, deps: DepsMut<Q>, env: Env, msg: Reply) -> AnyResult<Response<T>>;
41
42    fn migrate(&self, deps: DepsMut<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Response<T>>;
43}
44
45type ContractFn<T, C, E, Q> =
46    fn(deps: DepsMut<Q>, env: Env, info: MessageInfo, msg: T) -> Result<Response<C>, E>;
47type PermissionedFn<T, C, E, Q> = fn(deps: DepsMut<Q>, env: Env, msg: T) -> Result<Response<C>, E>;
48type ReplyFn<C, E, Q> = fn(deps: DepsMut<Q>, env: Env, msg: Reply) -> Result<Response<C>, E>;
49type QueryFn<T, E, Q> = fn(deps: Deps<Q>, env: Env, msg: T) -> Result<Binary, E>;
50
51type ContractClosure<T, C, E, Q> =
52    Box<dyn Fn(DepsMut<Q>, Env, MessageInfo, T) -> Result<Response<C>, E>>;
53type PermissionedClosure<T, C, E, Q> = Box<dyn Fn(DepsMut<Q>, Env, T) -> Result<Response<C>, E>>;
54type ReplyClosure<C, E, Q> = Box<dyn Fn(DepsMut<Q>, Env, Reply) -> Result<Response<C>, E>>;
55type QueryClosure<T, E, Q> = Box<dyn Fn(Deps<Q>, Env, T) -> Result<Binary, E>>;
56
57/// Wraps the exported functions from a contract and provides the normalized format
58/// Place T4 and E4 at the end, as we just want default placeholders for most contracts that don't have sudo
59pub struct ContractWrapper<
60    T1,
61    T2,
62    T3,
63    E1,
64    E2,
65    E3,
66    C = Empty,
67    Q = Empty,
68    T4 = Empty,
69    E4 = anyhow::Error,
70    E5 = anyhow::Error,
71    T6 = Empty,
72    E6 = anyhow::Error,
73> where
74    T1: DeserializeOwned + Debug,
75    T2: DeserializeOwned,
76    T3: DeserializeOwned,
77    T4: DeserializeOwned,
78    T6: DeserializeOwned,
79    E1: Display + Debug + Send + Sync + 'static,
80    E2: Display + Debug + Send + Sync + 'static,
81    E3: Display + Debug + Send + Sync + 'static,
82    E4: Display + Debug + Send + Sync + 'static,
83    E5: Display + Debug + Send + Sync + 'static,
84    E6: Display + Debug + Send + Sync + 'static,
85    C: Clone + fmt::Debug + PartialEq + JsonSchema,
86    Q: CustomQuery + DeserializeOwned + 'static,
87{
88    execute_fn: ContractClosure<T1, C, E1, Q>,
89    instantiate_fn: ContractClosure<T2, C, E2, Q>,
90    query_fn: QueryClosure<T3, E3, Q>,
91    sudo_fn: Option<PermissionedClosure<T4, C, E4, Q>>,
92    reply_fn: Option<ReplyClosure<C, E5, Q>>,
93    migrate_fn: Option<PermissionedClosure<T6, C, E6, Q>>,
94}
95
96impl<T1, T2, T3, E1, E2, E3, C, Q> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q>
97where
98    T1: DeserializeOwned + Debug + 'static,
99    T2: DeserializeOwned + 'static,
100    T3: DeserializeOwned + 'static,
101    E1: Display + Debug + Send + Sync + 'static,
102    E2: Display + Debug + Send + Sync + 'static,
103    E3: Display + Debug + Send + Sync + 'static,
104    C: Clone + fmt::Debug + PartialEq + JsonSchema + 'static,
105    Q: CustomQuery + DeserializeOwned + 'static,
106{
107    pub fn new(
108        execute_fn: ContractFn<T1, C, E1, Q>,
109        instantiate_fn: ContractFn<T2, C, E2, Q>,
110        query_fn: QueryFn<T3, E3, Q>,
111    ) -> Self {
112        Self {
113            execute_fn: Box::new(execute_fn),
114            instantiate_fn: Box::new(instantiate_fn),
115            query_fn: Box::new(query_fn),
116            sudo_fn: None,
117            reply_fn: None,
118            migrate_fn: None,
119        }
120    }
121
122    /// this will take a contract that returns Response<Empty> and will "upgrade" it
123    /// to Response<C> if needed to be compatible with a chain-specific extension
124    pub fn new_with_empty(
125        execute_fn: ContractFn<T1, Empty, E1, Empty>,
126        instantiate_fn: ContractFn<T2, Empty, E2, Empty>,
127        query_fn: QueryFn<T3, E3, Empty>,
128    ) -> Self {
129        Self {
130            execute_fn: customize_fn(execute_fn),
131            instantiate_fn: customize_fn(instantiate_fn),
132            query_fn: customize_query(query_fn),
133            sudo_fn: None,
134            reply_fn: None,
135            migrate_fn: None,
136        }
137    }
138}
139
140impl<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5, T6, E6>
141    ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5, T6, E6>
142where
143    T1: DeserializeOwned + Debug + 'static,
144    T2: DeserializeOwned + 'static,
145    T3: DeserializeOwned + 'static,
146    T4: DeserializeOwned + 'static,
147    T6: DeserializeOwned + 'static,
148    E1: Display + Debug + Send + Sync + 'static,
149    E2: Display + Debug + Send + Sync + 'static,
150    E3: Display + Debug + Send + Sync + 'static,
151    E4: Display + Debug + Send + Sync + 'static,
152    E5: Display + Debug + Send + Sync + 'static,
153    E6: Display + Debug + Send + Sync + 'static,
154    C: Clone + fmt::Debug + PartialEq + JsonSchema + 'static,
155    Q: CustomQuery + DeserializeOwned + 'static,
156{
157    pub fn with_sudo<T4A, E4A>(
158        self,
159        sudo_fn: PermissionedFn<T4A, C, E4A, Q>,
160    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4A, E4A, E5, T6, E6>
161    where
162        T4A: DeserializeOwned + 'static,
163        E4A: Display + Debug + Send + Sync + 'static,
164    {
165        ContractWrapper {
166            execute_fn: self.execute_fn,
167            instantiate_fn: self.instantiate_fn,
168            query_fn: self.query_fn,
169            sudo_fn: Some(Box::new(sudo_fn)),
170            reply_fn: self.reply_fn,
171            migrate_fn: self.migrate_fn,
172        }
173    }
174
175    pub fn with_sudo_empty<T4A, E4A>(
176        self,
177        sudo_fn: PermissionedFn<T4A, Empty, E4A, Q>,
178    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4A, E4A, E5, T6, E6>
179    where
180        T4A: DeserializeOwned + 'static,
181        E4A: Display + Debug + Send + Sync + 'static,
182    {
183        ContractWrapper {
184            execute_fn: self.execute_fn,
185            instantiate_fn: self.instantiate_fn,
186            query_fn: self.query_fn,
187            sudo_fn: Some(customize_permissioned_fn(sudo_fn)),
188            reply_fn: self.reply_fn,
189            migrate_fn: self.migrate_fn,
190        }
191    }
192
193    pub fn with_reply<E5A>(
194        self,
195        reply_fn: ReplyFn<C, E5A, Q>,
196    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5A, T6, E6>
197    where
198        E5A: Display + Debug + Send + Sync + 'static,
199    {
200        ContractWrapper {
201            execute_fn: self.execute_fn,
202            instantiate_fn: self.instantiate_fn,
203            query_fn: self.query_fn,
204            sudo_fn: self.sudo_fn,
205            reply_fn: Some(Box::new(reply_fn)),
206            migrate_fn: self.migrate_fn,
207        }
208    }
209
210    /// A correlate of new_with_empty
211    pub fn with_reply_empty<E5A>(
212        self,
213        reply_fn: ReplyFn<Empty, E5A, Q>,
214    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5A, T6, E6>
215    where
216        E5A: Display + Debug + Send + Sync + 'static,
217    {
218        ContractWrapper {
219            execute_fn: self.execute_fn,
220            instantiate_fn: self.instantiate_fn,
221            query_fn: self.query_fn,
222            sudo_fn: self.sudo_fn,
223            reply_fn: Some(customize_permissioned_fn(reply_fn)),
224            migrate_fn: self.migrate_fn,
225        }
226    }
227
228    pub fn with_migrate<T6A, E6A>(
229        self,
230        migrate_fn: PermissionedFn<T6A, C, E6A, Q>,
231    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5, T6A, E6A>
232    where
233        T6A: DeserializeOwned + 'static,
234        E6A: Display + Debug + Send + Sync + 'static,
235    {
236        ContractWrapper {
237            execute_fn: self.execute_fn,
238            instantiate_fn: self.instantiate_fn,
239            query_fn: self.query_fn,
240            sudo_fn: self.sudo_fn,
241            reply_fn: self.reply_fn,
242            migrate_fn: Some(Box::new(migrate_fn)),
243        }
244    }
245
246    pub fn with_migrate_empty<T6A, E6A>(
247        self,
248        migrate_fn: PermissionedFn<T6A, Empty, E6A, Q>,
249    ) -> ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5, T6A, E6A>
250    where
251        T6A: DeserializeOwned + 'static,
252        E6A: Display + Debug + Send + Sync + 'static,
253    {
254        ContractWrapper {
255            execute_fn: self.execute_fn,
256            instantiate_fn: self.instantiate_fn,
257            query_fn: self.query_fn,
258            sudo_fn: self.sudo_fn,
259            reply_fn: self.reply_fn,
260            migrate_fn: Some(customize_permissioned_fn(migrate_fn)),
261        }
262    }
263}
264
265fn customize_fn<T, C, E, Q>(raw_fn: ContractFn<T, Empty, E, Empty>) -> ContractClosure<T, C, E, Q>
266where
267    T: DeserializeOwned + 'static,
268    E: Display + Debug + Send + Sync + 'static,
269    C: Clone + fmt::Debug + PartialEq + JsonSchema + 'static,
270    Q: CustomQuery + DeserializeOwned + 'static,
271{
272    let customized = move |mut deps: DepsMut<Q>,
273                           env: Env,
274                           info: MessageInfo,
275                           msg: T|
276          -> Result<Response<C>, E> {
277        let deps = decustomize_deps_mut(&mut deps);
278        raw_fn(deps, env, info, msg).map(customize_response::<C>)
279    };
280    Box::new(customized)
281}
282
283fn customize_query<T, E, Q>(raw_fn: QueryFn<T, E, Empty>) -> QueryClosure<T, E, Q>
284where
285    T: DeserializeOwned + 'static,
286    E: Display + Debug + Send + Sync + 'static,
287    Q: CustomQuery + DeserializeOwned + 'static,
288{
289    let customized = move |deps: Deps<Q>, env: Env, msg: T| -> Result<Binary, E> {
290        let deps = decustomize_deps(&deps);
291        raw_fn(deps, env, msg)
292    };
293    Box::new(customized)
294}
295
296fn decustomize_deps_mut<'a, Q>(deps: &'a mut DepsMut<Q>) -> DepsMut<'a, Empty>
297where
298    Q: CustomQuery + DeserializeOwned + 'static,
299{
300    DepsMut {
301        storage: deps.storage,
302        api: deps.api,
303        querier: QuerierWrapper::new(deps.querier.deref()),
304    }
305}
306
307fn decustomize_deps<'a, Q>(deps: &'a Deps<'a, Q>) -> Deps<'a, Empty>
308where
309    Q: CustomQuery + DeserializeOwned + 'static,
310{
311    Deps {
312        storage: deps.storage,
313        api: deps.api,
314        querier: QuerierWrapper::new(deps.querier.deref()),
315    }
316}
317
318fn customize_permissioned_fn<T, C, E, Q>(
319    raw_fn: PermissionedFn<T, Empty, E, Q>,
320) -> PermissionedClosure<T, C, E, Q>
321where
322    T: DeserializeOwned + 'static,
323    E: Display + Debug + Send + Sync + 'static,
324    C: Clone + fmt::Debug + PartialEq + JsonSchema + 'static,
325    Q: CustomQuery + DeserializeOwned + 'static,
326{
327    let customized = move |deps: DepsMut<Q>, env: Env, msg: T| -> Result<Response<C>, E> {
328        raw_fn(deps, env, msg).map(customize_response::<C>)
329    };
330    Box::new(customized)
331}
332
333fn customize_response<C>(resp: Response<Empty>) -> Response<C>
334where
335    C: Clone + fmt::Debug + PartialEq + JsonSchema,
336{
337    let mut customized_resp = Response::<C>::new()
338        .add_submessages(resp.messages.into_iter().map(customize_msg::<C>))
339        .add_events(resp.events)
340        .add_attributes(resp.attributes);
341    customized_resp.data = resp.data;
342    customized_resp
343}
344
345fn customize_msg<C>(msg: SubMsg<Empty>) -> SubMsg<C>
346where
347    C: Clone + fmt::Debug + PartialEq + JsonSchema,
348{
349    SubMsg {
350        msg: match msg.msg {
351            CosmosMsg::Wasm(wasm) => CosmosMsg::Wasm(wasm),
352            CosmosMsg::Bank(bank) => CosmosMsg::Bank(bank),
353            CosmosMsg::Staking(staking) => CosmosMsg::Staking(staking),
354            CosmosMsg::Distribution(distribution) => CosmosMsg::Distribution(distribution),
355            CosmosMsg::Custom(_) => unreachable!(),
356            #[cfg(feature = "stargate")]
357            CosmosMsg::Ibc(ibc) => CosmosMsg::Ibc(ibc),
358            #[cfg(feature = "stargate")]
359            CosmosMsg::Stargate { type_url, value } => CosmosMsg::Stargate { type_url, value },
360            _ => panic!("unknown message variant {:?}", msg),
361        },
362        id: msg.id,
363        gas_limit: msg.gas_limit,
364        reply_on: msg.reply_on,
365    }
366}
367
368impl<T1, T2, T3, E1, E2, E3, C, T4, E4, E5, T6, E6, Q> Contract<C, Q>
369    for ContractWrapper<T1, T2, T3, E1, E2, E3, C, Q, T4, E4, E5, T6, E6>
370where
371    T1: DeserializeOwned + Debug + Clone,
372    T2: DeserializeOwned + Debug + Clone,
373    T3: DeserializeOwned + Debug + Clone,
374    T4: DeserializeOwned,
375    T6: DeserializeOwned,
376    E1: Display + Debug + Send + Sync + Error + 'static,
377    E2: Display + Debug + Send + Sync + Error + 'static,
378    E3: Display + Debug + Send + Sync + Error + 'static,
379    E4: Display + Debug + Send + Sync + 'static,
380    E5: Display + Debug + Send + Sync + 'static,
381    E6: Display + Debug + Send + Sync + 'static,
382    C: Clone + fmt::Debug + PartialEq + JsonSchema,
383    Q: CustomQuery + DeserializeOwned,
384{
385    fn execute(
386        &self,
387        deps: DepsMut<Q>,
388        env: Env,
389        info: MessageInfo,
390        msg: Vec<u8>,
391    ) -> AnyResult<Response<C>> {
392        let msg: T1 = from_slice(&msg)?;
393        (self.execute_fn)(deps, env, info, msg).map_err(|err| anyhow!(err))
394    }
395
396    fn instantiate(
397        &self,
398        deps: DepsMut<Q>,
399        env: Env,
400        info: MessageInfo,
401        msg: Vec<u8>,
402    ) -> AnyResult<Response<C>> {
403        let msg: T2 = from_slice(&msg)?;
404        (self.instantiate_fn)(deps, env, info, msg).map_err(|err| anyhow!(err))
405    }
406
407    fn query(&self, deps: Deps<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Binary> {
408        let msg: T3 = from_slice(&msg)?;
409        (self.query_fn)(deps, env, msg).map_err(|err| anyhow!(err))
410    }
411
412    // this returns an error if the contract doesn't implement sudo
413    fn sudo(&self, deps: DepsMut<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Response<C>> {
414        let msg = from_slice(&msg)?;
415        match &self.sudo_fn {
416            Some(sudo) => sudo(deps, env, msg).map_err(|err| anyhow!(err)),
417            None => bail!("sudo not implemented for contract"),
418        }
419    }
420
421    // this returns an error if the contract doesn't implement reply
422    fn reply(&self, deps: DepsMut<Q>, env: Env, reply_data: Reply) -> AnyResult<Response<C>> {
423        match &self.reply_fn {
424            Some(reply) => reply(deps, env, reply_data).map_err(|err| anyhow!(err)),
425            None => bail!("reply not implemented for contract"),
426        }
427    }
428
429    // this returns an error if the contract doesn't implement migrate
430    fn migrate(&self, deps: DepsMut<Q>, env: Env, msg: Vec<u8>) -> AnyResult<Response<C>> {
431        let msg = from_slice(&msg)?;
432        match &self.migrate_fn {
433            Some(migrate) => migrate(deps, env, msg).map_err(|err| anyhow!(err)),
434            None => bail!("migrate not implemented for contract"),
435        }
436    }
437}