use crate::{
ids::{MessageId, ReservationId},
message::IncomingDispatch,
};
use alloc::collections::BTreeMap;
use gear_core_errors::ReservationError;
use hashbrown::HashMap;
use scale_info::{
scale::{Decode, Encode},
TypeInfo,
};
#[derive(
Clone, Copy, Default, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo,
)]
pub struct ReservationNonce(u64);
impl From<&InnerNonce> for ReservationNonce {
fn from(nonce: &InnerNonce) -> Self {
ReservationNonce(nonce.0)
}
}
#[derive(Debug, Clone)]
struct InnerNonce(u64);
impl InnerNonce {
fn fetch_inc(&mut self) -> u64 {
let current = self.0;
self.0 = self.0.saturating_add(1);
current
}
}
impl From<ReservationNonce> for InnerNonce {
fn from(frozen_nonce: ReservationNonce) -> Self {
InnerNonce(frozen_nonce.0)
}
}
#[derive(Debug, Clone)]
pub struct GasReserver {
message_id: MessageId,
nonce: InnerNonce,
states: GasReservationStates,
max_reservations: u64,
}
impl GasReserver {
pub fn new(
incoming_dispatch: &IncomingDispatch,
map: GasReservationMap,
max_reservations: u64,
) -> Self {
let message_id = incoming_dispatch.id();
let nonce = incoming_dispatch
.context()
.as_ref()
.map(|c| c.reservation_nonce())
.unwrap_or_default()
.into();
Self {
message_id,
nonce,
states: {
let mut states = HashMap::with_capacity(max_reservations as usize);
states.extend(map.into_iter().map(|(id, slot)| (id, slot.into())));
states
},
max_reservations,
}
}
fn check_execution_limit(&self) -> Result<(), ReservationError> {
let current_reservations = self
.states
.values()
.map(|state| {
matches!(
state,
GasReservationState::Exists { .. } | GasReservationState::Created { .. }
) as u64
})
.sum::<u64>();
if current_reservations > self.max_reservations {
Err(ReservationError::ReservationsLimitReached)
} else {
Ok(())
}
}
pub fn reserve(
&mut self,
amount: u64,
duration: u32,
) -> Result<ReservationId, ReservationError> {
self.check_execution_limit()?;
let id = ReservationId::generate(self.message_id, self.nonce.fetch_inc());
let maybe_reservation = self.states.insert(
id,
GasReservationState::Created {
amount,
duration,
used: false,
},
);
if maybe_reservation.is_some() {
unreachable!(
"Duplicate reservation was created with message id {} and nonce {}",
self.message_id, self.nonce.0,
);
}
Ok(id)
}
pub fn unreserve(&mut self, id: ReservationId) -> Result<u64, ReservationError> {
let state = self
.states
.get(&id)
.ok_or(ReservationError::InvalidReservationId)?;
if let GasReservationState::Exists { used: true, .. }
| GasReservationState::Created { used: true, .. } = state
{
return Err(ReservationError::InvalidReservationId);
}
let state = self.states.remove(&id).unwrap();
let amount = match state {
GasReservationState::Exists { amount, finish, .. } => {
self.states
.insert(id, GasReservationState::Removed { expiration: finish });
amount
}
GasReservationState::Created { amount, .. } => amount,
GasReservationState::Removed { .. } => {
return Err(ReservationError::InvalidReservationId);
}
};
Ok(amount)
}
pub fn mark_used(&mut self, id: ReservationId) -> Result<(), ReservationError> {
if let Some(
GasReservationState::Created { used, .. } | GasReservationState::Exists { used, .. },
) = self.states.get_mut(&id)
{
if *used {
Err(ReservationError::InvalidReservationId)
} else {
*used = true;
Ok(())
}
} else {
Err(ReservationError::InvalidReservationId)
}
}
pub fn nonce(&self) -> ReservationNonce {
(&self.nonce).into()
}
pub fn states(&self) -> &GasReservationStates {
&self.states
}
pub fn into_map<F>(
self,
current_block_height: u32,
duration_into_expiration: F,
) -> GasReservationMap
where
F: Fn(u32) -> u32,
{
self.states
.into_iter()
.flat_map(|(id, state)| match state {
GasReservationState::Exists {
amount,
start,
finish,
..
} => Some((
id,
GasReservationSlot {
amount,
start,
finish,
},
)),
GasReservationState::Created {
amount, duration, ..
} => {
let expiration = duration_into_expiration(duration);
Some((
id,
GasReservationSlot {
amount,
start: current_block_height,
finish: expiration,
},
))
}
GasReservationState::Removed { .. } => None,
})
.collect()
}
}
pub type GasReservationStates = HashMap<ReservationId, GasReservationState>;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum GasReservationState {
Exists {
amount: u64,
start: u32,
finish: u32,
used: bool,
},
Created {
amount: u64,
duration: u32,
used: bool,
},
Removed {
expiration: u32,
},
}
impl From<GasReservationSlot> for GasReservationState {
fn from(slot: GasReservationSlot) -> Self {
Self::Exists {
amount: slot.amount,
start: slot.start,
finish: slot.finish,
used: false,
}
}
}
pub type GasReservationMap = BTreeMap<ReservationId, GasReservationSlot>;
#[derive(Debug, Clone, Eq, PartialEq, Encode, Decode, TypeInfo)]
pub struct GasReservationSlot {
pub amount: u64,
pub start: u32,
pub finish: u32,
}
#[cfg(test)]
mod tests {
use super::*;
const MAX_RESERVATIONS: u64 = 256;
fn new_reserver() -> GasReserver {
let d = IncomingDispatch::default();
GasReserver::new(&d, Default::default(), MAX_RESERVATIONS)
}
#[test]
fn max_reservations_limit_works() {
let mut reserver = new_reserver();
for n in 0..(MAX_RESERVATIONS * 10) {
let res = reserver.reserve(100, 10);
if n > MAX_RESERVATIONS {
assert_eq!(res, Err(ReservationError::ReservationsLimitReached));
} else {
assert!(res.is_ok());
}
}
}
#[test]
fn mark_used_for_unreserved_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.mark_used(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn mark_used_twice_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.mark_used(id),
Err(ReservationError::InvalidReservationId)
);
assert_eq!(
reserver.mark_used(ReservationId::default()),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn remove_reservation_twice_fails() {
let mut reserver = new_reserver();
let id = reserver.reserve(1, 1).unwrap();
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn remove_non_existing_reservation_fails() {
let id = ReservationId::from([0xff; 32]);
let mut map = GasReservationMap::new();
map.insert(
id,
GasReservationSlot {
amount: 1,
start: 1,
finish: 100,
},
);
let mut reserver = GasReserver::new(&Default::default(), map, 256);
reserver.unreserve(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn fresh_reserve_unreserve() {
let mut reserver = new_reserver();
let id = reserver.reserve(10_000, 5).unwrap();
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
#[test]
fn existing_reserve_unreserve() {
let id = ReservationId::from([0xff; 32]);
let mut map = GasReservationMap::new();
map.insert(
id,
GasReservationSlot {
amount: 1,
start: 1,
finish: 100,
},
);
let mut reserver = GasReserver::new(&Default::default(), map, 256);
reserver.mark_used(id).unwrap();
assert_eq!(
reserver.unreserve(id),
Err(ReservationError::InvalidReservationId)
);
}
}