cw_ica_controller/
contract.rs

1//! This module handles the execution logic of the contract.
2
3use cosmwasm_std::{entry_point, Reply};
4use cosmwasm_std::{to_json_binary, Binary, Deps, DepsMut, Env, MessageInfo, Response, StdResult};
5
6use crate::ibc::types::stargate::channel::new_ica_channel_open_init_cosmos_msg;
7use crate::types::keys;
8use crate::types::msg::{ExecuteMsg, InstantiateMsg, MigrateMsg, QueryMsg};
9use crate::types::state::{self, ChannelState, ContractState};
10use crate::types::ContractError;
11
12/// Instantiates the contract.
13#[entry_point]
14#[allow(clippy::pedantic)]
15pub fn instantiate(
16    deps: DepsMut,
17    env: Env,
18    info: MessageInfo,
19    msg: InstantiateMsg,
20) -> Result<Response, ContractError> {
21    cw2::set_contract_version(deps.storage, keys::CONTRACT_NAME, keys::CONTRACT_VERSION)?;
22
23    let owner = msg.owner.unwrap_or_else(|| info.sender.to_string());
24    cw_ownable::initialize_owner(deps.storage, deps.api, Some(&owner))?;
25
26    let callback_address = msg
27        .send_callbacks_to
28        .map(|addr| deps.api.addr_validate(&addr))
29        .transpose()?;
30
31    // Save the admin. Ica address is determined during handshake.
32    state::STATE.save(deps.storage, &ContractState::new(callback_address))?;
33
34    state::CHANNEL_OPEN_INIT_OPTIONS.save(deps.storage, &msg.channel_open_init_options)?;
35
36    state::ALLOW_CHANNEL_OPEN_INIT.save(deps.storage, &true)?;
37
38    let ica_channel_open_init_msg = new_ica_channel_open_init_cosmos_msg(
39        env.contract.address.to_string(),
40        msg.channel_open_init_options.connection_id,
41        msg.channel_open_init_options.counterparty_port_id,
42        msg.channel_open_init_options.counterparty_connection_id,
43        None,
44        msg.channel_open_init_options.channel_ordering,
45    );
46
47    Ok(Response::new().add_message(ica_channel_open_init_msg))
48}
49
50/// Handles the execution of the contract.
51#[entry_point]
52#[allow(clippy::pedantic)]
53pub fn execute(
54    deps: DepsMut,
55    env: Env,
56    info: MessageInfo,
57    msg: ExecuteMsg,
58) -> Result<Response, ContractError> {
59    match msg {
60        ExecuteMsg::CreateChannel {
61            channel_open_init_options,
62        } => execute::create_channel(deps, env, info, channel_open_init_options),
63        ExecuteMsg::CloseChannel {} => execute::close_channel(deps, info),
64        ExecuteMsg::UpdateCallbackAddress { callback_address } => {
65            execute::update_callback_address(deps, info, callback_address)
66        }
67        ExecuteMsg::SendCosmosMsgs {
68            messages,
69            queries,
70            packet_memo,
71            timeout_seconds,
72        } => execute::send_cosmos_msgs(
73            deps,
74            env,
75            info,
76            messages,
77            queries,
78            packet_memo,
79            timeout_seconds,
80        ),
81        ExecuteMsg::UpdateOwnership(action) => execute::update_ownership(deps, env, info, action),
82    }
83}
84
85/// Handles the replies to the submessages.
86#[entry_point]
87#[allow(clippy::pedantic)]
88pub fn reply(deps: DepsMut, env: Env, msg: Reply) -> Result<Response, ContractError> {
89    match msg.id {
90        keys::reply_ids::SEND_QUERY_PACKET => reply::send_query_packet(deps, env, msg.result),
91        _ => Err(ContractError::UnknownReplyId(msg.id)),
92    }
93}
94
95/// Handles the query of the contract.
96#[entry_point]
97#[allow(clippy::pedantic)]
98pub fn query(deps: Deps, _env: Env, msg: QueryMsg) -> StdResult<Binary> {
99    match msg {
100        QueryMsg::GetContractState {} => to_json_binary(&query::state(deps)?),
101        QueryMsg::GetChannel {} => to_json_binary(&query::channel(deps)?),
102        QueryMsg::Ownership {} => to_json_binary(&cw_ownable::get_ownership(deps.storage)?),
103    }
104}
105
106/// Migrate contract if version is lower than current version
107#[entry_point]
108#[allow(clippy::pedantic)]
109pub fn migrate(deps: DepsMut, _env: Env, _msg: MigrateMsg) -> Result<Response, ContractError> {
110    migrate::validate_semver(deps.as_ref())?;
111    migrate::validate_channel_encoding(deps.as_ref())?;
112
113    cw2::set_contract_version(deps.storage, keys::CONTRACT_NAME, keys::CONTRACT_VERSION)?;
114    // If state structure changed in any contract version in the way migration is needed, it
115    // should occur here
116
117    Ok(Response::default())
118}
119
120mod execute {
121    use cosmwasm_std::{CosmosMsg, IbcMsg, SubMsg};
122
123    use crate::{ibc::types::packet::IcaPacketData, types::msg::options::ChannelOpenInitOptions};
124
125    use super::{
126        keys, new_ica_channel_open_init_cosmos_msg, state, ContractError, DepsMut, Env,
127        MessageInfo, Response,
128    };
129
130    use cosmwasm_std::{Empty, QueryRequest};
131
132    /// Submits a stargate `MsgChannelOpenInit` to the chain.
133    /// Can only be called by the contract owner or a whitelisted address.
134    /// Only the contract owner can include the channel open init options.
135    #[allow(clippy::needless_pass_by_value)]
136    pub fn create_channel(
137        deps: DepsMut,
138        env: Env,
139        info: MessageInfo,
140        options: Option<ChannelOpenInitOptions>,
141    ) -> Result<Response, ContractError> {
142        cw_ownable::assert_owner(deps.storage, &info.sender)?;
143
144        let options = if let Some(new_options) = options {
145            state::CHANNEL_OPEN_INIT_OPTIONS.save(deps.storage, &new_options)?;
146            new_options
147        } else {
148            state::CHANNEL_OPEN_INIT_OPTIONS
149                .may_load(deps.storage)?
150                .ok_or(ContractError::NoChannelInitOptions)?
151        };
152
153        state::ALLOW_CHANNEL_OPEN_INIT.save(deps.storage, &true)?;
154
155        let ica_channel_open_init_msg = new_ica_channel_open_init_cosmos_msg(
156            env.contract.address.to_string(),
157            options.connection_id,
158            options.counterparty_port_id,
159            options.counterparty_connection_id,
160            None,
161            options.channel_ordering,
162        );
163
164        Ok(Response::new().add_message(ica_channel_open_init_msg))
165    }
166
167    /// Submits a [`IbcMsg::CloseChannel`].
168    #[allow(clippy::needless_pass_by_value)]
169    pub fn close_channel(deps: DepsMut, info: MessageInfo) -> Result<Response, ContractError> {
170        cw_ownable::assert_owner(deps.storage, &info.sender)?;
171
172        let channel_state = state::CHANNEL_STATE.load(deps.storage)?;
173        if !channel_state.is_open() {
174            return Err(ContractError::InvalidChannelStatus {
175                expected: state::ChannelStatus::Open.to_string(),
176                actual: channel_state.channel_status.to_string(),
177            });
178        }
179
180        state::ALLOW_CHANNEL_CLOSE_INIT.save(deps.storage, &true)?;
181
182        let channel_close_msg = CosmosMsg::Ibc(IbcMsg::CloseChannel {
183            channel_id: channel_state.channel.endpoint.channel_id,
184        });
185
186        Ok(Response::new().add_message(channel_close_msg))
187    }
188
189    /// Sends an array of [`CosmosMsg`] to the ICA host.
190    #[allow(clippy::needless_pass_by_value)]
191    pub fn send_cosmos_msgs(
192        deps: DepsMut,
193        env: Env,
194        info: MessageInfo,
195        messages: Vec<CosmosMsg>,
196        queries: Vec<QueryRequest<Empty>>,
197        packet_memo: Option<String>,
198        timeout_seconds: Option<u64>,
199    ) -> Result<Response, ContractError> {
200        cw_ownable::assert_owner(deps.storage, &info.sender)?;
201
202        let contract_state = state::STATE.load(deps.storage)?;
203        let ica_info = contract_state.get_ica_info()?;
204        let has_queries = !queries.is_empty();
205
206        let ica_packet = IcaPacketData::from_cosmos_msgs(
207            deps.storage,
208            messages,
209            queries,
210            &ica_info.encoding,
211            packet_memo,
212            &ica_info.ica_address,
213        )?;
214        let send_packet_msg = ica_packet.to_ibc_msg(&env, ica_info.channel_id, timeout_seconds)?;
215
216        let send_packet_submsg = if has_queries {
217            // TODO: use payload when we switch to cosmwasm_2_0 feature
218            SubMsg::reply_on_success(send_packet_msg, keys::reply_ids::SEND_QUERY_PACKET)
219        } else {
220            SubMsg::new(send_packet_msg)
221        };
222
223        Ok(Response::default().add_submessage(send_packet_submsg))
224    }
225
226    /// Update the ownership of the contract.
227    #[allow(clippy::needless_pass_by_value)]
228    pub fn update_ownership(
229        deps: DepsMut,
230        env: Env,
231        info: MessageInfo,
232        action: cw_ownable::Action,
233    ) -> Result<Response, ContractError> {
234        if action == cw_ownable::Action::RenounceOwnership {
235            return Err(ContractError::OwnershipCannotBeRenounced);
236        };
237
238        cw_ownable::update_ownership(deps, &env.block, &info.sender, action)?;
239
240        Ok(Response::default())
241    }
242
243    /// Updates the callback address.
244    #[allow(clippy::needless_pass_by_value)]
245    pub fn update_callback_address(
246        deps: DepsMut,
247        info: MessageInfo,
248        callback_address: Option<String>,
249    ) -> Result<Response, ContractError> {
250        cw_ownable::assert_owner(deps.storage, &info.sender)?;
251
252        let mut contract_state = state::STATE.load(deps.storage)?;
253
254        contract_state.callback_address = callback_address
255            .map(|addr| deps.api.addr_validate(&addr))
256            .transpose()?;
257
258        state::STATE.save(deps.storage, &contract_state)?;
259
260        Ok(Response::default())
261    }
262}
263
264mod reply {
265    use cosmwasm_std::SubMsgResult;
266
267    use super::{state, ContractError, DepsMut, Env, Response};
268
269    /// Handles the reply to the query packet.
270    #[allow(clippy::needless_pass_by_value)]
271    pub fn send_query_packet(
272        deps: DepsMut,
273        _env: Env,
274        result: SubMsgResult,
275    ) -> Result<Response, ContractError> {
276        match result {
277            SubMsgResult::Ok(resp) => {
278                #[allow(deprecated)] // TODO: Remove deprecated `.data` field
279                let sequence = anybuf::Bufany::deserialize(&resp.data.unwrap_or_default())?
280                    .uint64(1)
281                    .unwrap();
282                let channel_id = state::STATE.load(deps.storage)?.get_ica_info()?.channel_id;
283                let query_paths = state::QUERY.load(deps.storage)?;
284
285                state::QUERY.remove(deps.storage);
286                state::PENDING_QUERIES.save(deps.storage, (&channel_id, sequence), &query_paths)?;
287
288                Ok(Response::default())
289            }
290            SubMsgResult::Err(err) => unreachable!("query packet failed: {err}"),
291        }
292    }
293}
294
295mod query {
296    use super::{state, ChannelState, ContractState, Deps, StdResult};
297
298    /// Returns the saved contract state.
299    pub fn state(deps: Deps) -> StdResult<ContractState> {
300        state::STATE.load(deps.storage)
301    }
302
303    /// Returns the saved channel state if it exists.
304    pub fn channel(deps: Deps) -> StdResult<ChannelState> {
305        state::CHANNEL_STATE.load(deps.storage)
306    }
307}
308
309mod migrate {
310    use super::{keys, state, ContractError, Deps};
311
312    /// Validate that the contract version is semver compliant
313    /// and greater than the previous version.
314    pub fn validate_semver(deps: Deps) -> Result<(), ContractError> {
315        let prev_cw2_version = cw2::get_contract_version(deps.storage)?;
316        if prev_cw2_version.contract != keys::CONTRACT_NAME {
317            return Err(ContractError::InvalidMigrationVersion {
318                expected: keys::CONTRACT_NAME.to_string(),
319                actual: prev_cw2_version.contract,
320            });
321        }
322
323        let version: semver::Version = keys::CONTRACT_VERSION.parse()?;
324        let prev_version: semver::Version = prev_cw2_version.version.parse()?;
325        if prev_version >= version {
326            return Err(ContractError::InvalidMigrationVersion {
327                expected: format!("> {prev_version}"),
328                actual: keys::CONTRACT_VERSION.to_string(),
329            });
330        }
331        Ok(())
332    }
333
334    /// Validate that the channel encoding is protobuf if set.
335    pub fn validate_channel_encoding(deps: Deps) -> Result<(), ContractError> {
336        // Reject the migration if the channel encoding is not protobuf
337        if let Some(ica_info) = state::STATE.load(deps.storage)?.ica_info {
338            if !matches!(
339                ica_info.encoding,
340                crate::ibc::types::metadata::TxEncoding::Protobuf
341            ) {
342                return Err(ContractError::UnsupportedPacketEncoding(
343                    ica_info.encoding.to_string(),
344                ));
345            }
346        }
347
348        Ok(())
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use crate::types::msg::options::ChannelOpenInitOptions;
355
356    use super::*;
357    use cosmwasm_std::testing::{message_info, mock_dependencies, mock_env};
358    use cosmwasm_std::{Api, StdError, SubMsg};
359
360    #[test]
361    fn test_instantiate() {
362        let mut deps = mock_dependencies();
363
364        let creator = deps.api.addr_make("creator");
365        let info = message_info(&creator, &[]);
366        let env = mock_env();
367
368        let channel_open_init_options = ChannelOpenInitOptions {
369            connection_id: "connection-0".to_string(),
370            counterparty_connection_id: "connection-1".to_string(),
371            counterparty_port_id: None,
372            channel_ordering: None,
373        };
374
375        let msg = InstantiateMsg {
376            owner: None,
377            channel_open_init_options: channel_open_init_options.clone(),
378            send_callbacks_to: None,
379        };
380
381        let res = instantiate(deps.as_mut(), env.clone(), info.clone(), msg).unwrap();
382
383        // Ensure that the channel open init options are saved correctly
384        assert_eq!(
385            state::CHANNEL_OPEN_INIT_OPTIONS
386                .load(deps.as_ref().storage)
387                .unwrap(),
388            channel_open_init_options
389        );
390
391        // Ensure the contract is instantiated successfully
392        assert_eq!(1, res.messages.len());
393
394        let expected_msg = new_ica_channel_open_init_cosmos_msg(
395            env.contract.address.to_string(),
396            channel_open_init_options.connection_id,
397            channel_open_init_options.counterparty_port_id,
398            channel_open_init_options.counterparty_connection_id,
399            None,
400            channel_open_init_options.channel_ordering,
401        );
402        assert_eq!(res.messages[0], SubMsg::new(expected_msg));
403
404        // Ensure the admin is saved correctly
405        let owner = cw_ownable::get_ownership(&deps.storage)
406            .unwrap()
407            .owner
408            .unwrap();
409        assert_eq!(owner, info.sender);
410
411        // Ensure that the contract name and version are saved correctly
412        let contract_version = cw2::get_contract_version(&deps.storage).unwrap();
413        assert_eq!(contract_version.contract, keys::CONTRACT_NAME);
414        assert_eq!(contract_version.version, keys::CONTRACT_VERSION);
415    }
416
417    #[test]
418    fn test_update_callback_address() {
419        let mut deps = mock_dependencies();
420
421        let creator = deps.api.addr_make("creator");
422        let info = message_info(&creator, &[]);
423        let env = mock_env();
424
425        let channel_open_init_options = ChannelOpenInitOptions {
426            connection_id: "connection-0".to_string(),
427            counterparty_connection_id: "connection-1".to_string(),
428            counterparty_port_id: None,
429            channel_ordering: None,
430        };
431
432        // Instantiate the contract
433        let _res = instantiate(
434            deps.as_mut(),
435            env.clone(),
436            info.clone(),
437            InstantiateMsg {
438                owner: None,
439                channel_open_init_options,
440                send_callbacks_to: None,
441            },
442        )
443        .unwrap();
444
445        // Ensure the contract admin can update the callback address
446        let new_callback_address = deps.api.addr_make("new_callback_address").to_string();
447        let msg = ExecuteMsg::UpdateCallbackAddress {
448            callback_address: Some(new_callback_address.clone()),
449        };
450        let res = execute(deps.as_mut(), env.clone(), info, msg).unwrap();
451
452        assert_eq!(0, res.messages.len());
453
454        let state = state::STATE.load(&deps.storage).unwrap();
455        assert_eq!(
456            state.callback_address,
457            Some(deps.api.addr_validate(&new_callback_address).unwrap())
458        );
459
460        // Ensure a non-admin cannot update the callback address
461        let non_admin = deps.api.addr_make("non-admin");
462        let info = message_info(&non_admin, &[]);
463        let msg = ExecuteMsg::UpdateCallbackAddress {
464            callback_address: Some("new_callback_address".to_string()),
465        };
466
467        let res = execute(deps.as_mut(), env, info, msg);
468        assert_eq!(
469            res.unwrap_err().to_string(),
470            "Caller is not the contract's current owner".to_string()
471        );
472    }
473
474    // In this test, we aim to verify that the semver validation is performed correctly.
475    // And that the contract version in cw2 is updated correctly.
476    #[test]
477    fn test_migrate() {
478        let mut deps = mock_dependencies();
479
480        let creator = deps.api.addr_make("creator");
481        let info = message_info(&creator, &[]);
482
483        let channel_open_init_options = ChannelOpenInitOptions {
484            connection_id: "connection-0".to_string(),
485            counterparty_connection_id: "connection-1".to_string(),
486            counterparty_port_id: None,
487            channel_ordering: None,
488        };
489
490        // Instantiate the contract
491        let _res = instantiate(
492            deps.as_mut(),
493            mock_env(),
494            info,
495            InstantiateMsg {
496                owner: None,
497                channel_open_init_options,
498                send_callbacks_to: None,
499            },
500        )
501        .unwrap();
502
503        // We need to set the contract version manually to a lower version than the current version
504        cw2::set_contract_version(&mut deps.storage, keys::CONTRACT_NAME, "0.0.1").unwrap();
505
506        // Ensure that the contract version is updated correctly
507        let contract_version = cw2::get_contract_version(&deps.storage).unwrap();
508        assert_eq!(contract_version.contract, keys::CONTRACT_NAME);
509        assert_eq!(contract_version.version, "0.0.1");
510
511        // Perform the migration
512        let _res = migrate(deps.as_mut(), mock_env(), MigrateMsg {}).unwrap();
513
514        let contract_version = cw2::get_contract_version(&deps.storage).unwrap();
515        assert_eq!(contract_version.contract, keys::CONTRACT_NAME);
516        assert_eq!(contract_version.version, keys::CONTRACT_VERSION);
517
518        // Ensure that the contract version cannot be downgraded
519        cw2::set_contract_version(&mut deps.storage, keys::CONTRACT_NAME, "100.0.0").unwrap();
520
521        let res = migrate(deps.as_mut(), mock_env(), MigrateMsg {});
522        assert_eq!(
523            res.unwrap_err().to_string(),
524            format!(
525                "invalid migration version: expected > 100.0.0, got {}",
526                keys::CONTRACT_VERSION
527            )
528        );
529    }
530
531    #[test]
532    fn test_migrate_with_encoding() {
533        let mut deps = mock_dependencies();
534
535        let creator = deps.api.addr_make("creator");
536        let info = message_info(&creator, &[]);
537
538        let channel_open_init_options = ChannelOpenInitOptions {
539            connection_id: "connection-0".to_string(),
540            counterparty_connection_id: "connection-1".to_string(),
541            counterparty_port_id: None,
542            channel_ordering: None,
543        };
544
545        // Instantiate the contract
546        let _res = instantiate(
547            deps.as_mut(),
548            mock_env(),
549            info,
550            InstantiateMsg {
551                owner: None,
552                channel_open_init_options,
553                send_callbacks_to: None,
554            },
555        )
556        .unwrap();
557
558        // We need to set the contract version manually to a lower version than the current version
559        cw2::set_contract_version(&mut deps.storage, keys::CONTRACT_NAME, "0.0.1").unwrap();
560
561        // Ensure that the contract version is updated correctly
562        let contract_version = cw2::get_contract_version(&deps.storage).unwrap();
563        assert_eq!(contract_version.contract, keys::CONTRACT_NAME);
564        assert_eq!(contract_version.version, "0.0.1");
565
566        // Set the encoding to proto3json
567        state::STATE
568            .update::<_, StdError>(&mut deps.storage, |mut state| {
569                state.set_ica_info("", "", crate::ibc::types::metadata::TxEncoding::Proto3Json);
570                Ok(state)
571            })
572            .unwrap();
573
574        // Migration should fail because the encoding is not protobuf
575        let err = migrate(deps.as_mut(), mock_env(), MigrateMsg {}).unwrap_err();
576        assert_eq!(
577            err.to_string(),
578            ContractError::UnsupportedPacketEncoding(
579                crate::ibc::types::metadata::TxEncoding::Proto3Json.to_string()
580            )
581            .to_string()
582        );
583
584        // Set the encoding to protobuf
585        state::STATE
586            .update::<_, StdError>(&mut deps.storage, |mut state| {
587                state.set_ica_info("", "", crate::ibc::types::metadata::TxEncoding::Protobuf);
588                Ok(state)
589            })
590            .unwrap();
591
592        // Migration should succeed because the encoding is protobuf
593        let _res = migrate(deps.as_mut(), mock_env(), MigrateMsg {}).unwrap();
594
595        let contract_version = cw2::get_contract_version(&deps.storage).unwrap();
596        assert_eq!(contract_version.contract, keys::CONTRACT_NAME);
597        assert_eq!(contract_version.version, keys::CONTRACT_VERSION);
598    }
599}