use crate::ado_contract::ADOContract;
use crate::common::encode_binary;
use crate::error::ContractError;
use crate::os::aos_querier::AOSQuerier;
use crate::os::{kernel::ExecuteMsg as KernelExecuteMsg, kernel::QueryMsg as KernelQueryMsg};
use cosmwasm_schema::cw_serde;
use cosmwasm_std::{
to_binary, Addr, Binary, Coin, ContractInfoResponse, CosmosMsg, Deps, MessageInfo,
QueryRequest, ReplyOn, SubMsg, WasmMsg, WasmQuery,
};
use super::addresses::AndrAddr;
use super::ADO_DB_KEY;
#[cw_serde]
pub enum ExecuteMsg {
#[serde(rename = "amp_receive")]
AMPReceive(AMPPkt),
}
#[cw_serde]
pub struct AMPMsgConfig {
pub reply_on: ReplyOn,
pub exit_at_error: bool,
pub gas_limit: Option<u64>,
pub direct: bool,
}
impl AMPMsgConfig {
#[inline]
pub fn new(
reply_on: Option<ReplyOn>,
exit_at_error: Option<bool>,
gas_limit: Option<u64>,
) -> AMPMsgConfig {
AMPMsgConfig {
reply_on: reply_on.unwrap_or(ReplyOn::Always),
exit_at_error: exit_at_error.unwrap_or(true),
gas_limit,
direct: false,
}
}
pub fn as_direct_msg(self) -> AMPMsgConfig {
AMPMsgConfig {
reply_on: self.reply_on,
exit_at_error: self.exit_at_error,
gas_limit: self.gas_limit,
direct: true,
}
}
}
impl Default for AMPMsgConfig {
#[inline]
fn default() -> AMPMsgConfig {
AMPMsgConfig {
reply_on: ReplyOn::Always,
exit_at_error: true,
gas_limit: None,
direct: false,
}
}
}
#[cw_serde]
pub struct AMPMsg {
pub recipient: AndrAddr,
pub message: Binary,
pub funds: Vec<Coin>,
pub config: AMPMsgConfig,
}
impl AMPMsg {
pub fn new(recipient: impl Into<String>, message: Binary, funds: Option<Vec<Coin>>) -> AMPMsg {
AMPMsg {
recipient: AndrAddr::from_string(recipient),
message,
funds: funds.unwrap_or_default(),
config: AMPMsgConfig::default(),
}
}
pub fn with_config(&self, config: AMPMsgConfig) -> AMPMsg {
AMPMsg {
recipient: self.recipient.clone(),
message: self.message.clone(),
funds: self.funds.clone(),
config,
}
}
pub fn generate_amp_pkt(
&self,
deps: &Deps,
origin: impl Into<String>,
previous_sender: impl Into<String>,
id: u64,
) -> Result<SubMsg, ContractError> {
let contract_addr = self.recipient.get_raw_address(deps)?;
let pkt = AMPPkt::new(origin, previous_sender, vec![self.clone()]);
let msg = to_binary(&ExecuteMsg::AMPReceive(pkt))?;
Ok(SubMsg {
id,
reply_on: self.config.reply_on.clone(),
gas_limit: self.config.gas_limit,
msg: CosmosMsg::Wasm(WasmMsg::Execute {
contract_addr: contract_addr.into(),
msg,
funds: self.funds.to_vec(),
}),
})
}
}
#[cw_serde]
pub struct AMPCtx {
origin: String,
origin_username: Option<AndrAddr>,
pub previous_sender: String,
pub id: u64,
}
impl AMPCtx {
#[inline]
pub fn new(
origin: impl Into<String>,
previous_sender: impl Into<String>,
id: u64,
origin_username: Option<AndrAddr>,
) -> AMPCtx {
AMPCtx {
origin: origin.into(),
origin_username,
previous_sender: previous_sender.into(),
id,
}
}
pub fn get_origin(&self) -> String {
self.origin.clone()
}
pub fn get_previous_sender(&self) -> String {
self.previous_sender.clone()
}
}
#[cw_serde]
pub struct AMPPkt {
pub messages: Vec<AMPMsg>,
pub ctx: AMPCtx,
}
impl AMPPkt {
pub fn new(
origin: impl Into<String>,
previous_sender: impl Into<String>,
messages: Vec<AMPMsg>,
) -> AMPPkt {
AMPPkt {
messages,
ctx: AMPCtx::new(origin, previous_sender, 0, None),
}
}
pub fn add_message(mut self, message: AMPMsg) -> Self {
self.messages.push(message);
self
}
pub fn get_unique_recipients(&self) -> Vec<String> {
let mut recipients: Vec<String> = self
.messages
.iter()
.cloned()
.map(|msg| msg.recipient.to_string())
.collect();
recipients.sort_unstable();
recipients.dedup();
recipients
}
pub fn get_messages_for_recipient(&self, recipient: String) -> Vec<AMPMsg> {
self.messages
.iter()
.cloned()
.filter(|msg| msg.recipient == recipient.clone())
.collect()
}
pub fn verify_origin(&self, info: &MessageInfo, deps: &Deps) -> Result<(), ContractError> {
let kernel_address = ADOContract::default().get_kernel_address(deps.storage)?;
if info.sender == self.ctx.origin || info.sender == kernel_address {
Ok(())
} else {
let adodb_address: Addr =
deps.querier.query(&QueryRequest::Wasm(WasmQuery::Smart {
contract_addr: kernel_address.to_string(),
msg: to_binary(&KernelQueryMsg::KeyAddress {
key: ADO_DB_KEY.to_string(),
})?,
}))?;
let contract_info: ContractInfoResponse =
deps.querier
.query(&QueryRequest::Wasm(WasmQuery::ContractInfo {
contract_addr: info.sender.to_string(),
}))?;
let sender_code_id = contract_info.code_id;
AOSQuerier::verify_code_id(&deps.querier, &adodb_address, sender_code_id)
}
}
pub fn get_verified_origin(
&self,
info: &MessageInfo,
deps: &Deps,
) -> Result<String, ContractError> {
let origin = self.ctx.get_origin();
let res = self.verify_origin(info, deps);
match res {
Ok(_) => Ok(origin),
Err(err) => Err(err),
}
}
pub fn to_sub_msg(
&self,
address: impl Into<String>,
funds: Option<Vec<Coin>>,
id: u64,
) -> Result<SubMsg, ContractError> {
let sub_msg = SubMsg::reply_always(
WasmMsg::Execute {
contract_addr: address.into(),
msg: encode_binary(&KernelExecuteMsg::AMPReceive(self.clone()))?,
funds: funds.unwrap_or_default(),
},
id,
);
Ok(sub_msg)
}
pub fn with_id(&self, id: u64) -> AMPPkt {
let mut new = self.clone();
new.ctx.id = id;
new
}
pub fn to_ibc_hooks_memo(&self, contract_addr: String, callback_addr: String) -> String {
#[derive(::serde::Serialize)]
struct IbcHooksWasmMsg<T: ::serde::Serialize> {
contract: String,
msg: T,
}
#[derive(::serde::Serialize)]
struct IbcHooksMsg<T: ::serde::Serialize> {
wasm: IbcHooksWasmMsg<T>,
ibc_callback: String,
}
let wasm_msg = IbcHooksWasmMsg {
contract: contract_addr,
msg: KernelExecuteMsg::AMPReceive(self.clone()),
};
let msg = IbcHooksMsg {
wasm: wasm_msg,
ibc_callback: callback_addr,
};
serde_json_wasm::to_string(&msg).unwrap()
}
pub fn to_json(&self) -> String {
serde_json_wasm::to_string(&self).unwrap()
}
pub fn from_ctx(ctx: Option<AMPPkt>, current_address: String) -> Self {
let mut ctx = if let Some(pkt) = ctx {
pkt.ctx
} else {
AMPCtx::new(current_address.clone(), current_address.clone(), 0, None)
};
ctx.previous_sender = current_address;
Self {
messages: vec![],
ctx,
}
}
}
#[cfg(test)]
mod tests {
use cosmwasm_std::testing::{mock_dependencies, mock_info};
use crate::testing::mock_querier::{mock_dependencies_custom, INVALID_CONTRACT};
use super::*;
#[test]
fn test_generate_amp_pkt() {
let deps = mock_dependencies();
let msg = AMPMsg::new("test", Binary::default(), None);
let sub_msg = msg
.generate_amp_pkt(&deps.as_ref(), "origin", "previoussender", 1)
.unwrap();
let expected_msg = ExecuteMsg::AMPReceive(AMPPkt::new(
"origin",
"previoussender",
vec![AMPMsg::new("test", Binary::default(), None)],
));
assert_eq!(sub_msg.id, 1);
assert_eq!(sub_msg.reply_on, ReplyOn::Always);
assert_eq!(sub_msg.gas_limit, None);
assert_eq!(
sub_msg.msg,
CosmosMsg::Wasm(WasmMsg::Execute {
contract_addr: "test".to_string(),
msg: to_binary(&expected_msg).unwrap(),
funds: vec![],
})
);
}
#[test]
fn test_get_unique_recipients() {
let msg = AMPMsg::new("test", Binary::default(), None);
let msg2 = AMPMsg::new("test2", Binary::default(), None);
let mut pkt = AMPPkt::new("origin", "previoussender", vec![msg, msg2]);
let recipients = pkt.get_unique_recipients();
assert_eq!(recipients.len(), 2);
assert_eq!(recipients[0], "test".to_string());
assert_eq!(recipients[1], "test2".to_string());
pkt = pkt.add_message(AMPMsg::new("test", Binary::default(), None));
let recipients = pkt.get_unique_recipients();
assert_eq!(recipients.len(), 2);
assert_eq!(recipients[0], "test".to_string());
assert_eq!(recipients[1], "test2".to_string());
}
#[test]
fn test_get_messages_for_recipient() {
let msg = AMPMsg::new("test", Binary::default(), None);
let msg2 = AMPMsg::new("test2", Binary::default(), None);
let mut pkt = AMPPkt::new("origin", "previoussender", vec![msg, msg2]);
let messages = pkt.get_messages_for_recipient("test".to_string());
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].recipient.to_string(), "test".to_string());
let messages = pkt.get_messages_for_recipient("test2".to_string());
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].recipient.to_string(), "test2".to_string());
pkt = pkt.add_message(AMPMsg::new("test", Binary::default(), None));
let messages = pkt.get_messages_for_recipient("test".to_string());
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].recipient.to_string(), "test".to_string());
assert_eq!(messages[1].recipient.to_string(), "test".to_string());
}
#[test]
fn test_verify_origin() {
let deps = mock_dependencies_custom(&[]);
let msg = AMPMsg::new("test", Binary::default(), None);
let pkt = AMPPkt::new("origin", "previoussender", vec![msg.clone()]);
let info = mock_info("validaddress", &[]);
let res = pkt.verify_origin(&info, &deps.as_ref());
assert!(res.is_ok());
let info = mock_info(INVALID_CONTRACT, &[]);
let res = pkt.verify_origin(&info, &deps.as_ref());
assert!(res.is_err());
let offchain_pkt = AMPPkt::new(INVALID_CONTRACT, INVALID_CONTRACT, vec![msg]);
let res = offchain_pkt.verify_origin(&info, &deps.as_ref());
assert!(res.is_ok());
}
#[test]
fn test_to_sub_msg() {
let msg = AMPMsg::new("test", Binary::default(), None);
let pkt = AMPPkt::new("origin", "previoussender", vec![msg.clone()]);
let sub_msg = pkt.to_sub_msg("kernel", None, 1).unwrap();
let expected_msg =
ExecuteMsg::AMPReceive(AMPPkt::new("origin", "previoussender", vec![msg]));
assert_eq!(sub_msg.id, 1);
assert_eq!(sub_msg.reply_on, ReplyOn::Always);
assert_eq!(sub_msg.gas_limit, None);
assert_eq!(
sub_msg.msg,
CosmosMsg::Wasm(WasmMsg::Execute {
contract_addr: "kernel".to_string(),
msg: to_binary(&expected_msg).unwrap(),
funds: vec![],
})
);
}
#[test]
fn test_to_json() {
let msg = AMPPkt::new("origin", "previoussender", vec![]);
let memo = msg.to_json();
assert_eq!(memo, "{\"messages\":[],\"ctx\":{\"origin\":\"origin\",\"origin_username\":null,\"previous_sender\":\"previoussender\",\"id\":0}}".to_string());
}
#[test]
fn test_to_ibc_hooks_memo() {
let msg = AMPPkt::new("origin", "previoussender", vec![]);
let contract_addr = "contractaddr";
let memo = msg.to_ibc_hooks_memo(contract_addr.to_string(), "callback".to_string());
assert_eq!(memo, "{\"wasm\":{\"contract\":\"contractaddr\",\"msg\":{\"amp_receive\":{\"messages\":[],\"ctx\":{\"origin\":\"origin\",\"origin_username\":null,\"previous_sender\":\"previoussender\",\"id\":0}}}},\"ibc_callback\":\"callback\"}".to_string());
}
}