use std::{cell::RefCell, collections::BTreeMap, fmt, rc::Rc};
use num_rational::Ratio;
use thiserror::Error;
use casper_hashing::Digest;
use casper_types::{
bytesrepr::{self, ToBytes},
contracts::NamedKeys,
system::{handle_payment::ACCUMULATION_PURSE_KEY, SystemContractType},
AccessRights, CLValue, CLValueError, Contract, ContractHash, EraId, Key, Phase,
ProtocolVersion, StoredValue, U512,
};
use crate::{
core::{
engine_state::{execution_effect::ExecutionEffect, ChainspecRegistry},
execution::AddressGenerator,
tracking_copy::TrackingCopy,
},
shared::newtypes::CorrelationId,
storage::global_state::StateProvider,
};
use super::{engine_config::FeeHandling, EngineConfig};
#[derive(Debug, Clone)]
pub struct UpgradeSuccess {
pub post_state_hash: Digest,
pub execution_effect: ExecutionEffect,
}
impl fmt::Display for UpgradeSuccess {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
f,
"Success: {} {:?}",
self.post_state_hash, self.execution_effect
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UpgradeConfig {
pre_state_hash: Digest,
current_protocol_version: ProtocolVersion,
new_protocol_version: ProtocolVersion,
activation_point: Option<EraId>,
new_validator_slots: Option<u32>,
new_auction_delay: Option<u64>,
new_locked_funds_period_millis: Option<u64>,
new_round_seigniorage_rate: Option<Ratio<u64>>,
new_unbonding_delay: Option<u64>,
global_state_update: BTreeMap<Key, StoredValue>,
chainspec_registry: ChainspecRegistry,
}
impl UpgradeConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
pre_state_hash: Digest,
current_protocol_version: ProtocolVersion,
new_protocol_version: ProtocolVersion,
activation_point: Option<EraId>,
new_validator_slots: Option<u32>,
new_auction_delay: Option<u64>,
new_locked_funds_period_millis: Option<u64>,
new_round_seigniorage_rate: Option<Ratio<u64>>,
new_unbonding_delay: Option<u64>,
global_state_update: BTreeMap<Key, StoredValue>,
chainspec_registry: ChainspecRegistry,
) -> Self {
UpgradeConfig {
pre_state_hash,
current_protocol_version,
new_protocol_version,
activation_point,
new_validator_slots,
new_auction_delay,
new_locked_funds_period_millis,
new_round_seigniorage_rate,
new_unbonding_delay,
global_state_update,
chainspec_registry,
}
}
pub fn pre_state_hash(&self) -> Digest {
self.pre_state_hash
}
pub fn current_protocol_version(&self) -> ProtocolVersion {
self.current_protocol_version
}
pub fn new_protocol_version(&self) -> ProtocolVersion {
self.new_protocol_version
}
pub fn activation_point(&self) -> Option<EraId> {
self.activation_point
}
pub fn new_validator_slots(&self) -> Option<u32> {
self.new_validator_slots
}
pub fn new_auction_delay(&self) -> Option<u64> {
self.new_auction_delay
}
pub fn new_locked_funds_period_millis(&self) -> Option<u64> {
self.new_locked_funds_period_millis
}
pub fn new_round_seigniorage_rate(&self) -> Option<Ratio<u64>> {
self.new_round_seigniorage_rate
}
pub fn new_unbonding_delay(&self) -> Option<u64> {
self.new_unbonding_delay
}
pub fn global_state_update(&self) -> &BTreeMap<Key, StoredValue> {
&self.global_state_update
}
pub fn chainspec_registry(&self) -> &ChainspecRegistry {
&self.chainspec_registry
}
pub fn with_pre_state_hash(&mut self, pre_state_hash: Digest) {
self.pre_state_hash = pre_state_hash;
}
}
#[derive(Clone, Error, Debug)]
pub enum ProtocolUpgradeError {
#[error("Invalid upgrade config")]
InvalidUpgradeConfig,
#[error("Unable to retrieve system contract: {0}")]
UnableToRetrieveSystemContract(String),
#[error("Unable to retrieve system contract package: {0}")]
UnableToRetrieveSystemContractPackage(String),
#[error("Failed to disable previous version of system contract: {0}")]
FailedToDisablePreviousVersion(String),
#[error("{0}")]
Bytesrepr(bytesrepr::Error),
#[error("{0}")]
CLValue(CLValueError),
#[error("Failed to insert system contract registry")]
FailedToCreateSystemRegistry,
#[error("Unexpected stored value variant")]
UnexpectedStoredValueVariant,
}
impl From<bytesrepr::Error> for ProtocolUpgradeError {
fn from(error: bytesrepr::Error) -> Self {
ProtocolUpgradeError::Bytesrepr(error)
}
}
impl From<CLValueError> for ProtocolUpgradeError {
fn from(v: CLValueError) -> Self {
Self::CLValue(v)
}
}
pub(crate) struct SystemUpgrader<S>
where
S: StateProvider,
{
new_protocol_version: ProtocolVersion,
old_protocol_version: ProtocolVersion,
tracking_copy: Rc<RefCell<TrackingCopy<<S as StateProvider>::Reader>>>,
}
impl<S> SystemUpgrader<S>
where
S: StateProvider,
{
pub(crate) fn new(
new_protocol_version: ProtocolVersion,
old_protocol_version: ProtocolVersion,
tracking_copy: Rc<RefCell<TrackingCopy<<S as StateProvider>::Reader>>>,
) -> Self {
SystemUpgrader {
new_protocol_version,
old_protocol_version,
tracking_copy,
}
}
pub(crate) fn refresh_system_contracts(
&self,
correlation_id: CorrelationId,
mint_hash: &ContractHash,
auction_hash: &ContractHash,
handle_payment_hash: &ContractHash,
standard_payment_hash: &ContractHash,
) -> Result<(), ProtocolUpgradeError> {
self.refresh_system_contract_entry_points(
correlation_id,
*mint_hash,
SystemContractType::Mint,
)?;
self.refresh_system_contract_entry_points(
correlation_id,
*auction_hash,
SystemContractType::Auction,
)?;
self.refresh_system_contract_entry_points(
correlation_id,
*handle_payment_hash,
SystemContractType::HandlePayment,
)?;
self.refresh_system_contract_entry_points(
correlation_id,
*standard_payment_hash,
SystemContractType::StandardPayment,
)?;
Ok(())
}
fn refresh_system_contract_entry_points(
&self,
correlation_id: CorrelationId,
contract_hash: ContractHash,
system_contract_type: SystemContractType,
) -> Result<(), ProtocolUpgradeError> {
let contract_name = system_contract_type.contract_name();
let entry_points = system_contract_type.contract_entry_points();
let mut contract = if let StoredValue::Contract(contract) = self
.tracking_copy
.borrow_mut()
.read(correlation_id, &Key::Hash(contract_hash.value()))
.map_err(|_| {
ProtocolUpgradeError::UnableToRetrieveSystemContract(contract_name.to_string())
})?
.ok_or_else(|| {
ProtocolUpgradeError::UnableToRetrieveSystemContract(contract_name.to_string())
})? {
contract
} else {
return Err(ProtocolUpgradeError::UnableToRetrieveSystemContract(
contract_name,
));
};
let is_major_bump = self
.old_protocol_version
.check_next_version(&self.new_protocol_version)
.is_major_version();
let entry_points_unchanged = *contract.entry_points() == entry_points;
if entry_points_unchanged && !is_major_bump {
return Ok(());
}
let contract_package_key = Key::Hash(contract.contract_package_hash().value());
let mut contract_package = if let StoredValue::ContractPackage(contract_package) = self
.tracking_copy
.borrow_mut()
.read(correlation_id, &contract_package_key)
.map_err(|_| {
ProtocolUpgradeError::UnableToRetrieveSystemContractPackage(
contract_name.to_string(),
)
})?
.ok_or_else(|| {
ProtocolUpgradeError::UnableToRetrieveSystemContractPackage(
contract_name.to_string(),
)
})? {
contract_package
} else {
return Err(ProtocolUpgradeError::UnableToRetrieveSystemContractPackage(
contract_name,
));
};
contract_package
.disable_contract_version(contract_hash)
.map_err(|_| {
ProtocolUpgradeError::FailedToDisablePreviousVersion(contract_name.to_string())
})?;
contract.set_protocol_version(self.new_protocol_version);
let new_contract = Contract::new(
contract.contract_package_hash(),
contract.contract_wasm_hash(),
contract.named_keys().clone(),
entry_points,
self.new_protocol_version,
);
self.tracking_copy
.borrow_mut()
.write(contract_hash.into(), StoredValue::Contract(new_contract));
contract_package
.insert_contract_version(self.new_protocol_version.value().major, contract_hash);
self.tracking_copy.borrow_mut().write(
contract_package_key,
StoredValue::ContractPackage(contract_package),
);
Ok(())
}
pub(crate) fn create_accumulation_purse_if_required(
&self,
correlation_id: CorrelationId,
handle_payment_hash: &ContractHash,
engine_config: &EngineConfig,
) -> Result<(), ProtocolUpgradeError> {
match engine_config.fee_handling() {
FeeHandling::PayToProposer | FeeHandling::Burn => return Ok(()),
FeeHandling::Accumulate => {}
}
let mut address_generator = {
let seed_bytes = (self.old_protocol_version, self.new_protocol_version).to_bytes()?;
let phase = Phase::System;
AddressGenerator::new(&seed_bytes, phase)
};
let system_contract = SystemContractType::HandlePayment;
let contract_name = system_contract.contract_name();
let mut contract = if let StoredValue::Contract(contract) = self
.tracking_copy
.borrow_mut()
.read(correlation_id, &Key::Hash(handle_payment_hash.value()))
.map_err(|_| {
ProtocolUpgradeError::UnableToRetrieveSystemContract(contract_name.to_string())
})?
.ok_or_else(|| {
ProtocolUpgradeError::UnableToRetrieveSystemContract(contract_name.to_string())
})? {
contract
} else {
return Err(ProtocolUpgradeError::UnableToRetrieveSystemContract(
contract_name,
));
};
if !contract.named_keys().contains_key(ACCUMULATION_PURSE_KEY) {
let purse_uref = address_generator.new_uref(AccessRights::READ_ADD_WRITE);
let balance_clvalue = CLValue::from_t(U512::zero())?;
self.tracking_copy.borrow_mut().write(
Key::Balance(purse_uref.addr()),
StoredValue::CLValue(balance_clvalue),
);
self.tracking_copy
.borrow_mut()
.write(Key::URef(purse_uref), StoredValue::CLValue(CLValue::unit()));
let mut new_named_keys = NamedKeys::new();
new_named_keys.insert(ACCUMULATION_PURSE_KEY.into(), Key::from(purse_uref));
contract.named_keys_append(&mut new_named_keys);
self.tracking_copy.borrow_mut().write(
(*handle_payment_hash).into(),
StoredValue::Contract(contract),
);
}
Ok(())
}
}