1use cosmwasm_std::{DepsMut, StdError, StdResult};
2use cw2::{get_contract_version, set_contract_version};
3
4pub fn migrate_version(
5 deps: DepsMut,
6 target_contract_version: &str,
7 name: &str,
8 version: &str,
9) -> StdResult<()> {
10 let prev_version = get_contract_version(deps.as_ref().storage)?;
11 if prev_version.contract != name {
12 return Err(StdError::generic_err("invalid contract"));
13 }
14
15 if prev_version.version != target_contract_version {
16 return Err(StdError::generic_err(format!(
17 "invalid contract version. target {}, but source is {}",
18 target_contract_version, prev_version.version
19 )));
20 }
21
22 set_contract_version(deps.storage, name, version)?;
23
24 Ok(())
25}
26
27#[cfg(test)]
28mod test {
29 use crate::mock_querier::mock_dependencies;
30
31 use super::*;
32
33 const TARGET_VERSION: &str = "version";
34 const NAME: &str = "name";
35 const CURRENT_VERSION: &str = "c_version";
36
37 #[test]
38 pub fn normal_migration() {
39 let mut deps = mock_dependencies(&[]);
40 set_contract_version(deps.as_mut().storage, NAME, TARGET_VERSION).unwrap();
41
42 let res = migrate_version(deps.as_mut(), TARGET_VERSION, NAME, CURRENT_VERSION);
43
44 assert_eq!(res, Ok(()));
45
46 let version = get_contract_version(deps.as_ref().storage).unwrap();
47
48 assert_eq!(version.contract, NAME);
49
50 assert_eq!(version.version, CURRENT_VERSION);
51 }
52
53 #[test]
54 pub fn failed_migration_with_invalid_contract_name() {
55 let mut deps = mock_dependencies(&[]);
56 set_contract_version(deps.as_mut().storage, NAME, TARGET_VERSION).unwrap();
57
58 let res = migrate_version(
59 deps.as_mut(),
60 TARGET_VERSION,
61 "invalid_name",
62 CURRENT_VERSION,
63 );
64
65 assert_eq!(res, Err(StdError::generic_err("invalid contract")));
66
67 let version = get_contract_version(deps.as_ref().storage).unwrap();
68
69 assert_eq!(version.contract, NAME);
70
71 assert_eq!(version.version, TARGET_VERSION);
72 }
73
74 #[test]
75 pub fn failed_migration_with_invalid_target_version() {
76 let mut deps = mock_dependencies(&[]);
77 set_contract_version(deps.as_mut().storage, NAME, TARGET_VERSION).unwrap();
78
79 let res = migrate_version(deps.as_mut(), "invalide_version", NAME, CURRENT_VERSION);
80
81 assert_eq!(
82 res,
83 Err(StdError::generic_err(format!(
84 "invalid contract version. target {}, but source is {}",
85 "invalide_version", TARGET_VERSION
86 )))
87 );
88
89 let version = get_contract_version(deps.as_ref().storage).unwrap();
90
91 assert_eq!(version.contract, NAME);
92
93 assert_eq!(version.version, TARGET_VERSION);
94 }
95}