use crate::{
ids::{MessageId, ReservationId, prelude::*},
message::IncomingDispatch,
};
use alloc::{collections::BTreeMap, format};
use gear_core_errors::ReservationError;
use scale_decode::DecodeAsType;
use scale_encode::EncodeAsType;
use scale_info::{
TypeInfo,
scale::{Decode, Encode},
};
#[derive(
Clone,
Copy,
Default,
Debug,
Eq,
Hash,
Ord,
PartialEq,
PartialOrd,
Decode,
DecodeAsType,
Encode,
EncodeAsType,
TypeInfo,
)]
#[cfg_attr(feature = "std", derive(serde::Serialize, serde::Deserialize))]
pub struct ReservationNonce(u64);
impl From<&InnerNonce> for ReservationNonce {
fn from(nonce: &InnerNonce) -> Self {
ReservationNonce(nonce.0)
}
}
#[derive(Debug, Clone, Encode, EncodeAsType, Decode, DecodeAsType, PartialEq, Eq)]
struct InnerNonce(u64);
impl InnerNonce {
fn fetch_inc(&mut self) -> u64 {
let current = self.0;
self.0 = self.0.saturating_add(1);
current
}
}
impl From<ReservationNonce> for InnerNonce {
fn from(frozen_nonce: ReservationNonce) -> Self {
InnerNonce(frozen_nonce.0)
}
}
#[derive(Debug, Clone, Encode, Decode, PartialEq, Eq)]
pub struct GasReserver {
message_id: MessageId,
nonce: InnerNonce,
states: GasReservationStates,
max_reservations: u64,
}
impl GasReserver {
pub fn new(
incoming_dispatch: &IncomingDispatch,
map: GasReservationMap,
max_reservations: u64,
) -> Self {
let message_id = incoming_dispatch.id();
let nonce = incoming_dispatch
.context()
.as_ref()
.map(|c| c.reservation_nonce())
.unwrap_or_default()
.into();
Self {
message_id,
nonce,
states: {
let mut states = BTreeMap::new();
states.extend(map.into_iter().map(|(id, slot)| (id, slot.into())));
states
},
max_reservations,
}
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
fn check_execution_limit(&self) -> Result<(), ReservationError> {
let current_reservations = self
.states
.values()
.map(|state| {
matches!(
state,
GasReservationState::Exists { .. } | GasReservationState::Created { .. }
) as u64
})
.sum::<u64>();
if current_reservations > self.max_reservations {
Err(ReservationError::ReservationsLimitReached)
} else {
Ok(())
}
}
pub fn limit_of(&self, reservation_id: &ReservationId) -> Option<u64> {
self.states.get(reservation_id).and_then(|v| match v {
GasReservationState::Exists { amount, .. }
| GasReservationState::Created { amount, .. } => Some(*amount),
_ => None,
})
}
pub fn reserve(
&mut self,
amount: u64,
duration: u32,
) -> Result<ReservationId, ReservationError> {
self.check_execution_limit()?;
let id = ReservationId::generate(self.message_id, self.nonce.fetch_inc());
let maybe_reservation = self.states.insert(
id,
GasReservationState::Created {
amount,
duration,
used: false,
},
);
if maybe_reservation.is_some() {
let err_msg = format!(
"GasReserver::reserve: created a duplicate reservation. \
Message id - {message_id}, nonce - {nonce}",
message_id = self.message_id,
nonce = self.nonce.0
);
log::error!("{err_msg}");
unreachable!("{err_msg}");
}
Ok(id)
}
pub fn unreserve(
&mut self,
id: ReservationId,
) -> Result<(u64, Option<UnreservedReimbursement>), ReservationError> {
let state = self
.states
.get(&id)
.ok_or(ReservationError::InvalidReservationId)?;
if matches!(
state,
GasReservationState::Removed { .. } |
GasReservationState::Exists { used: true, .. } |
GasReservationState::Created { used: true, .. }
) {
return Err(ReservationError::InvalidReservationId);
}
let state = self.states.remove(&id).unwrap();
Ok(match state {
GasReservationState::Exists { amount, finish, .. } => {
self.states
.insert(id, GasReservationState::Removed { expiration: finish });
(amount, None)
}
GasReservationState::Created {
amount, duration, ..
} => (amount, Some(UnreservedReimbursement(duration))),
GasReservationState::Removed { .. } => {
let err_msg =
"GasReserver::unreserve: `Removed` variant is unreachable, checked above";
log::error!("{err_msg}");
unreachable!("{err_msg}")
}
})
}
pub fn mark_used(&mut self, id: ReservationId) -> Result<(), ReservationError> {
let used = self.check_not_used(id)?;
*used = true;
Ok(())
}
pub fn check_not_used(&mut self, id: ReservationId) -> Result<&mut bool, ReservationError> {
if let Some(
GasReservationState::Created { used, .. } | GasReservationState::Exists { used, .. },
) = self.states.get_mut(&id)
{
if *used {
Err(ReservationError::InvalidReservationId)
} else {
Ok(used)
}
} else {
Err(ReservationError::InvalidReservationId)
}
}
pub fn nonce(&self) -> ReservationNonce {
(&self.nonce).into()
}
pub fn states(&self) -> &GasReservationStates {
&self.states
}
pub fn into_map<F>(
self,
current_block_height: u32,
duration_into_expiration: F,
) -> GasReservationMap
where
F: Fn(u32) -> u32,
{
self.states
.into_iter()
.flat_map(|(id, state)| match state {
GasReservationState::Exists {
amount,
start,
finish,
..
} => Some((
id,
GasReservationSlot {
amount,
start,
finish,
},
)),
GasReservationState::Created {
amount, duration, ..
} => {
let expiration = duration_into_expiration(duration);
Some((
id,
GasReservationSlot {
amount,
start: current_block_height,
finish: expiration,
},
))
}
GasReservationState::Removed { .. } => None,
})
.collect()
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct UnreservedReimbursement(u32);
impl UnreservedReimbursement {
pub fn duration(&self) -> u32 {
self.0
}
}
pub type GasReservationStates = BTreeMap<ReservationId, GasReservationState>;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Encode, EncodeAsType, Decode, DecodeAsType)]
pub enum GasReservationState {
Exists {
amount: u64,
start: u32,
finish: u32,
used: bool,
},
Created {
amount: u64,
duration: u32,
used: bool,
},
Removed {
expiration: u32,
},
}
impl From<GasReservationSlot> for GasReservationState {
fn from(slot: GasReservationSlot) -> Self {
Self::Exists {
amount: slot.amount,
start: slot.start,
finish: slot.finish,
used: false,
}
}
}
pub type GasReservationMap = BTreeMap<ReservationId, GasReservationSlot>;
#[derive(Debug, Clone, Eq, PartialEq, Encode, EncodeAsType, Decode, DecodeAsType, TypeInfo)]
pub struct GasReservationSlot {
pub amount: u64,
pub start: u32,
pub finish: u32,
}
#[cfg(test)]
mod tests {
use super::*;
const MAX_RESERVATIONS: u64 = 256;
fn new_reserver() -> GasReserver {
let d = IncomingDispatch::default();
GasReserver::new(&d, Default::default(), MAX_RESERVATIONS)
}
#[test]
fn max_reservations_limit_works() {
let mut reserver = new_reserver();
for n in 0..(MAX_RESERVATIONS * 10) {
let res = reserver.reserve(100, 10);
if n > MAX_RESERVATIONS {
assert_eq!(res, Err(ReservationError::ReservationsLimitReached));
} else {
assert!(res.is_ok());
}
}
}
#[test]
fn mark_used_for_unreserved_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.mark_used(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn mark_used_twice_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.mark_used(id),
Err(ReservationError::InvalidReservationId)
);
assert_eq!(
reserver.mark_used(ReservationId::default()),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn remove_reservation_twice_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn remove_non_existing_reservation_fails() {
let id = ReservationId::from([0xff; 32]);
let mut map = GasReservationMap::new();
map.insert(
id,
GasReservationSlot {
amount: 1,
start: 1,
finish: 100,
},
);
let mut reserver = GasReserver::new(&Default::default(), map, 256);
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn fresh_reserve_unreserve() {
let mut reserver = new_reserver();
let id = reserver.reserve(10_000, 5).unwrap();
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn existing_reserve_unreserve() {
let id = ReservationId::from([0xff; 32]);
let mut map = GasReservationMap::new();
map.insert(
id,
GasReservationSlot {
amount: 1,
start: 1,
finish: 100,
},
);
let mut reserver = GasReserver::new(&Default::default(), map, 256);
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn unreserving_unreserved() {
let id = ReservationId::from([0xff; 32]);
let slot = GasReservationSlot {
amount: 1,
start: 2,
finish: 3,
};
let mut map = GasReservationMap::new();
map.insert(id, slot.clone());
let mut reserver = GasReserver::new(&Default::default(), map, 256);
let (amount, _) = reserver.unreserve(id).expect("Shouldn't fail");
assert_eq!(amount, slot.amount);
assert!(reserver.unreserve(id).is_err());
assert_eq!(
reserver.states().get(&id).cloned(),
Some(GasReservationState::Removed {
expiration: slot.finish
})
);
}
}