use async_trait::async_trait;
use primitive_types::H160;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::{
neo_builder::TransactionBuilder,
neo_clients::{JsonRpcProvider, RpcClient},
neo_contract::{traits::SmartContractTrait, ContractError, NeoIterator},
neo_types::{
serde_with_utils::{deserialize_script_hash, serialize_script_hash},
ScriptHash, StackItem, WhitelistedContract,
},
ScriptHashExtension,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyContract<'a, P: JsonRpcProvider> {
#[serde(deserialize_with = "deserialize_script_hash")]
#[serde(serialize_with = "serialize_script_hash")]
script_hash: ScriptHash,
#[serde(skip)]
provider: Option<&'a RpcClient<P>>,
}
impl<'a, P: JsonRpcProvider + 'static> PolicyContract<'a, P> {
pub const NAME: &'static str = "PolicyContract";
pub const DEFAULT_EXEC_FEE_FACTOR: u32 = 30;
pub const DEFAULT_STORAGE_PRICE: u32 = 100000;
pub const DEFAULT_FEE_PER_BYTE: u32 = 1000;
pub const DEFAULT_ATTRIBUTE_FEE: u32 = 0;
pub const DEFAULT_NOTARY_ASSISTED_ATTRIBUTE_FEE: u32 = 10_000_000;
pub const MAX_EXEC_FEE_FACTOR: u64 = 100;
pub const MAX_ATTRIBUTE_FEE: u32 = 10_0000_0000;
pub const MAX_STORAGE_PRICE: u32 = 10_000_000;
pub const MAX_MILLISECONDS_PER_BLOCK: u32 = 30_000;
pub const MAX_MAX_VALID_UNTIL_BLOCK_INCREMENT: u32 = 86400;
pub const MAX_MAX_TRACEABLE_BLOCKS: u32 = 2_102_400;
pub fn new(provider: Option<&'a RpcClient<P>>) -> Self {
Self { script_hash: Self::calc_native_contract_hash_unchecked(Self::NAME), provider }
}
pub async fn get_fee_per_byte(&self) -> Result<i64, ContractError> {
Ok(self.call_function_returning_int("getFeePerByte", vec![]).await? as i64)
}
pub async fn get_exec_fee_factor(&self) -> Result<u32, ContractError> {
Ok(self.call_function_returning_int("getExecFeeFactor", vec![]).await? as u32)
}
pub async fn get_exec_pico_fee_factor(&self) -> Result<i64, ContractError> {
Ok(self.call_function_returning_int("getExecPicoFeeFactor", vec![]).await? as i64)
}
pub async fn get_storage_price(&self) -> Result<u32, ContractError> {
Ok(self.call_function_returning_int("getStoragePrice", vec![]).await? as u32)
}
pub async fn get_milliseconds_per_block(&self) -> Result<u32, ContractError> {
Ok(self.call_function_returning_int("getMillisecondsPerBlock", vec![]).await? as u32)
}
pub async fn set_milliseconds_per_block(
&self,
value: u32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setMillisecondsPerBlock", vec![value.into()]).await
}
pub async fn get_max_valid_until_block_increment(&self) -> Result<u32, ContractError> {
Ok(self
.call_function_returning_int("getMaxValidUntilBlockIncrement", vec![])
.await? as u32)
}
pub async fn set_max_valid_until_block_increment(
&self,
value: u32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setMaxValidUntilBlockIncrement", vec![value.into()]).await
}
pub async fn get_max_traceable_blocks(&self) -> Result<u32, ContractError> {
Ok(self.call_function_returning_int("getMaxTraceableBlocks", vec![]).await? as u32)
}
pub async fn set_max_traceable_blocks(
&self,
value: u32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setMaxTraceableBlocks", vec![value.into()]).await
}
pub async fn get_attribute_fee(&self, attribute_type: u8) -> Result<u32, ContractError> {
Ok(self
.call_function_returning_int("getAttributeFee", vec![attribute_type.into()])
.await? as u32)
}
pub async fn set_attribute_fee(
&self,
attribute_type: u8,
value: u32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setAttributeFee", vec![attribute_type.into(), value.into()])
.await
}
pub async fn is_blocked(&self, script_hash: &H160) -> Result<bool, ContractError> {
self.call_function_returning_bool("isBlocked", vec![script_hash.into()]).await
}
pub async fn get_blocked_accounts(&self) -> Result<NeoIterator<'_, H160, P>, ContractError> {
self.call_function_returning_iterator(
"getBlockedAccounts",
vec![],
Arc::new(|item: StackItem| {
item.as_hash160()
.ok_or_else(|| ContractError::UnexpectedReturnType("Hash160".to_string()))
}),
)
.await
}
pub async fn get_blocked_accounts_all(&self) -> Result<Vec<H160>, ContractError> {
self.get_blocked_accounts_all_with_batch(
<Self as SmartContractTrait>::DEFAULT_ITERATOR_COUNT,
)
.await
}
pub async fn get_blocked_accounts_all_with_batch(
&self,
batch_size: usize,
) -> Result<Vec<H160>, ContractError> {
let iterator = self.get_blocked_accounts().await?;
self.collect_all(iterator, batch_size).await
}
pub async fn block_account(
&self,
account: &H160,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("blockAccount", vec![account.into()]).await
}
pub async fn block_account_address(
&self,
address: &str,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
let account = ScriptHash::from_address(address)
.map_err(|_| ContractError::InvalidAccount("Invalid address".to_string()))?;
self.block_account(&account).await
}
pub async fn unblock_account(
&self,
account: &H160,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("unblockAccount", vec![account.into()]).await
}
pub async fn unblock_account_address(
&self,
address: &str,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
let account = ScriptHash::from_address(address)
.map_err(|_| ContractError::InvalidAccount("Invalid address".to_string()))?;
self.unblock_account(&account).await
}
pub async fn set_fee_per_byte(
&self,
fee: i64,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setFeePerByte", vec![fee.into()]).await
}
pub async fn set_exec_fee_factor(
&self,
fee: u64,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setExecFeeFactor", vec![fee.into()]).await
}
pub async fn set_storage_price(
&self,
price: u32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("setStoragePrice", vec![price.into()]).await
}
pub async fn get_whitelist_fee_contracts(
&self,
) -> Result<NeoIterator<'_, WhitelistedContract, P>, ContractError> {
self.call_function_returning_iterator(
"getWhitelistFeeContracts",
vec![],
Arc::new(|item: StackItem| {
WhitelistedContract::from_stack_item(&item).map_err(|err| {
ContractError::UnexpectedReturnType(format!("WhitelistedContract: {err}"))
})
}),
)
.await
}
pub async fn get_whitelist_fee_contracts_all(
&self,
) -> Result<Vec<WhitelistedContract>, ContractError> {
self.get_whitelist_fee_contracts_all_with_batch(
<Self as SmartContractTrait>::DEFAULT_ITERATOR_COUNT,
)
.await
}
pub async fn get_whitelist_fee_contracts_all_with_batch(
&self,
batch_size: usize,
) -> Result<Vec<WhitelistedContract>, ContractError> {
let iterator = self.get_whitelist_fee_contracts().await?;
self.collect_all(iterator, batch_size).await
}
async fn collect_all<T>(
&self,
iterator: NeoIterator<'_, T, P>,
batch_size: usize,
) -> Result<Vec<T>, ContractError> {
if batch_size == 0 {
return Err(ContractError::InvalidArgError(
"Batch size must be greater than zero".to_string(),
));
}
let mut all_items = Vec::new();
let mut traverse_error: Option<ContractError> = None;
loop {
match iterator.traverse(batch_size as i32).await {
Ok(batch) => {
if batch.is_empty() {
break;
}
all_items.extend(batch);
},
Err(err) => {
traverse_error = Some(err);
break;
},
}
}
if let Some(err) = traverse_error {
let _ = iterator.terminate_session().await;
return Err(err);
}
iterator.terminate_session().await?;
Ok(all_items)
}
pub async fn set_whitelist_fee_contract(
&self,
contract_hash: &H160,
method: &str,
arg_count: i32,
fixed_fee: i64,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function(
"setWhitelistFeeContract",
vec![contract_hash.into(), method.into(), arg_count.into(), fixed_fee.into()],
)
.await
}
pub async fn remove_whitelist_fee_contract(
&self,
contract_hash: &H160,
method: &str,
arg_count: i32,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function(
"removeWhitelistFeeContract",
vec![contract_hash.into(), method.into(), arg_count.into()],
)
.await
}
pub async fn recover_fund(
&self,
account: &H160,
token: &H160,
) -> Result<TransactionBuilder<'_, P>, ContractError> {
self.invoke_function("recoverFund", vec![account.into(), token.into()]).await
}
}
#[async_trait]
impl<'a, P: JsonRpcProvider> SmartContractTrait<'a> for PolicyContract<'a, P> {
type P = P;
fn script_hash(&self) -> H160 {
self.script_hash
}
fn set_script_hash(&mut self, script_hash: H160) {
self.script_hash = script_hash;
}
fn provider(&self) -> Option<&RpcClient<P>> {
self.provider
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
neo_clients::{MockProvider, RpcClient},
neo_types::{InvocationResult, NeoVMStateType, StackItem},
};
use serde_json::json;
fn iterator_invocation_result(session_id: &str, iterator_id: &str) -> InvocationResult {
InvocationResult::new(
String::new(),
NeoVMStateType::Halt,
"0".to_string(),
None,
None,
None,
vec![StackItem::InteropInterface {
id: iterator_id.to_string(),
interface: "IIterator".to_string(),
}],
None,
None,
Some(session_id.to_string()),
)
}
#[test]
fn test_policy_contract_constants() {
assert_eq!(PolicyContract::<MockProvider>::DEFAULT_EXEC_FEE_FACTOR, 30);
assert_eq!(PolicyContract::<MockProvider>::DEFAULT_STORAGE_PRICE, 100000);
assert_eq!(PolicyContract::<MockProvider>::DEFAULT_FEE_PER_BYTE, 1000);
assert_eq!(PolicyContract::<MockProvider>::MAX_MILLISECONDS_PER_BLOCK, 30_000);
assert_eq!(PolicyContract::<MockProvider>::MAX_MAX_VALID_UNTIL_BLOCK_INCREMENT, 86400);
assert_eq!(PolicyContract::<MockProvider>::MAX_MAX_TRACEABLE_BLOCKS, 2_102_400);
}
#[test]
fn test_policy_contract_name() {
assert_eq!(PolicyContract::<MockProvider>::NAME, "PolicyContract");
}
#[tokio::test]
async fn test_get_blocked_accounts_iterator_rejects_invalid_items() {
let provider = MockProvider::new();
let client = RpcClient::new(provider.clone());
let contract = PolicyContract::new(Some(&client));
let hash = contract.script_hash();
provider.push_result_with_params(
"invokefunction",
json!([hash.to_hex(), "getBlockedAccounts", [], []]),
serde_json::to_value(iterator_invocation_result("session-1", "iter-1")).unwrap(),
);
provider.push_result_with_params(
"traverseiterator",
json!(["session-1", "iter-1", 1]),
serde_json::to_value(vec![StackItem::Integer { value: 7 }]).unwrap(),
);
let iterator = contract.get_blocked_accounts().await.unwrap();
let result = iterator.traverse(1).await;
assert!(matches!(
result,
Err(ContractError::UnexpectedReturnType(message))
if message.contains("Hash160")
));
}
#[tokio::test]
async fn test_get_whitelist_fee_contracts_iterator_rejects_invalid_items() {
let provider = MockProvider::new();
let client = RpcClient::new(provider.clone());
let contract = PolicyContract::new(Some(&client));
let hash = contract.script_hash();
provider.push_result_with_params(
"invokefunction",
json!([hash.to_hex(), "getWhitelistFeeContracts", [], []]),
serde_json::to_value(iterator_invocation_result("session-2", "iter-2")).unwrap(),
);
provider.push_result_with_params(
"traverseiterator",
json!(["session-2", "iter-2", 1]),
serde_json::to_value(vec![StackItem::Integer { value: 11 }]).unwrap(),
);
let iterator = contract.get_whitelist_fee_contracts().await.unwrap();
let result = iterator.traverse(1).await;
assert!(matches!(
result,
Err(ContractError::UnexpectedReturnType(message))
if message.contains("WhitelistedContract")
));
}
}