use crate::{
ids::{MessageId, ProgramId, ReservationId},
message::{
Dispatch, HandleMessage, HandlePacket, IncomingMessage, InitMessage, InitPacket, Payload,
ReplyMessage, ReplyPacket,
},
reservation::{GasReserver, ReservationNonce},
};
use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use gear_core_errors::{ExecutionError, ExtError, MessageError as Error, MessageError};
use scale_info::{
scale::{Decode, Encode},
TypeInfo,
};
use super::{DispatchKind, IncomingDispatch};
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo)]
pub struct ContextSettings {
sending_fee: u64,
scheduled_sending_fee: u64,
waiting_fee: u64,
waking_fee: u64,
reservation_fee: u64,
outgoing_limit: u32,
}
impl ContextSettings {
pub fn new(
sending_fee: u64,
scheduled_sending_fee: u64,
waiting_fee: u64,
waking_fee: u64,
reservation_fee: u64,
outgoing_limit: u32,
) -> Self {
Self {
sending_fee,
scheduled_sending_fee,
waiting_fee,
waking_fee,
reservation_fee,
outgoing_limit,
}
}
pub fn sending_fee(&self) -> u64 {
self.sending_fee
}
pub fn scheduled_sending_fee(&self) -> u64 {
self.scheduled_sending_fee
}
pub fn waiting_fee(&self) -> u64 {
self.waiting_fee
}
pub fn waking_fee(&self) -> u64 {
self.waking_fee
}
pub fn reservation_fee(&self) -> u64 {
self.reservation_fee
}
pub fn outgoing_limit(&self) -> u32 {
self.outgoing_limit
}
}
pub type OutgoingMessageInfo<T> = (T, u32, Option<ReservationId>);
pub type OutgoingMessageInfoNoDelay<T> = (T, Option<ReservationId>);
pub struct ContextOutcomeDrain {
pub outgoing_dispatches: Vec<OutgoingMessageInfo<Dispatch>>,
pub awakening: Vec<(MessageId, u32)>,
pub reply_deposits: Vec<(MessageId, u64)>,
}
#[derive(Default, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo)]
pub struct ContextOutcome {
init: Vec<OutgoingMessageInfo<InitMessage>>,
handle: Vec<OutgoingMessageInfo<HandleMessage>>,
reply: Option<OutgoingMessageInfoNoDelay<ReplyMessage>>,
awakening: Vec<(MessageId, u32)>,
reply_deposits: Vec<(MessageId, u64)>,
program_id: ProgramId,
source: ProgramId,
origin_msg_id: MessageId,
}
impl ContextOutcome {
fn new(program_id: ProgramId, source: ProgramId, origin_msg_id: MessageId) -> Self {
Self {
program_id,
source,
origin_msg_id,
..Default::default()
}
}
pub fn drain(self) -> ContextOutcomeDrain {
let mut dispatches = Vec::new();
for (msg, delay, reservation) in self.init.into_iter() {
dispatches.push((msg.into_dispatch(self.program_id), delay, reservation));
}
for (msg, delay, reservation) in self.handle.into_iter() {
dispatches.push((msg.into_dispatch(self.program_id), delay, reservation));
}
if let Some((msg, reservation)) = self.reply {
dispatches.push((
msg.into_dispatch(self.program_id, self.source, self.origin_msg_id),
0,
reservation,
));
};
ContextOutcomeDrain {
outgoing_dispatches: dispatches,
awakening: self.awakening,
reply_deposits: self.reply_deposits,
}
}
}
#[derive(Clone, Default, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo)]
pub struct ContextStore {
outgoing: BTreeMap<u32, Option<Payload>>,
reply: Option<Payload>,
initialized: BTreeSet<ProgramId>,
awaken: BTreeSet<MessageId>,
reply_sent: bool,
reservation_nonce: ReservationNonce,
system_reservation: Option<u64>,
}
impl ContextStore {
pub(crate) fn reservation_nonce(&self) -> ReservationNonce {
self.reservation_nonce
}
pub fn set_reservation_nonce(&mut self, gas_reserver: &GasReserver) {
self.reservation_nonce = gas_reserver.nonce();
}
pub fn add_system_reservation(&mut self, amount: u64) {
let reservation = &mut self.system_reservation;
*reservation = reservation
.map(|reservation| reservation.saturating_add(amount))
.or(Some(amount));
}
pub fn system_reservation(&self) -> Option<u64> {
self.system_reservation
}
pub fn reply_sent(&self) -> bool {
self.reply_sent
}
}
#[derive(Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Decode, Encode, TypeInfo)]
pub struct MessageContext {
kind: DispatchKind,
current: IncomingMessage,
outcome: ContextOutcome,
store: ContextStore,
settings: ContextSettings,
}
impl MessageContext {
pub fn new(
dispatch: IncomingDispatch,
program_id: ProgramId,
settings: ContextSettings,
) -> Self {
let (kind, message, store) = dispatch.into_parts();
Self {
kind,
outcome: ContextOutcome::new(program_id, message.source(), message.id()),
current: message,
store: store.unwrap_or_default(),
settings,
}
}
pub fn settings(&self) -> &ContextSettings {
&self.settings
}
fn check_reply_availability(&self) -> Result<(), ExecutionError> {
if !matches!(self.kind, DispatchKind::Init | DispatchKind::Handle) {
return Err(ExecutionError::IncorrectEntryForReply);
}
Ok(())
}
pub fn reply_sent(&self) -> bool {
self.store.reply_sent
}
pub fn init_program(
&mut self,
packet: InitPacket,
delay: u32,
) -> Result<(MessageId, ProgramId), Error> {
let program_id = packet.destination();
if self.store.initialized.contains(&program_id) {
return Err(Error::DuplicateInit);
}
let last = self.store.outgoing.len() as u32;
if last >= self.settings.outgoing_limit {
return Err(Error::OutgoingMessagesAmountLimitExceeded);
}
let message_id = MessageId::generate_outgoing(self.current.id(), last);
let message = InitMessage::from_packet(message_id, packet);
self.store.outgoing.insert(last, None);
self.store.initialized.insert(program_id);
self.outcome.init.push((message, delay, None));
Ok((message_id, program_id))
}
pub fn send_commit(
&mut self,
handle: u32,
packet: HandlePacket,
delay: u32,
reservation: Option<ReservationId>,
) -> Result<MessageId, Error> {
if let Some(payload) = self.store.outgoing.get_mut(&handle) {
if let Some(data) = payload.take() {
let packet = {
let mut packet = packet;
packet
.try_prepend(data)
.map_err(|_| Error::MaxMessageSizeExceed)?;
packet
};
let message_id = MessageId::generate_outgoing(self.current.id(), handle);
let message = HandleMessage::from_packet(message_id, packet);
self.outcome.handle.push((message, delay, reservation));
Ok(message_id)
} else {
Err(Error::LateAccess)
}
} else {
Err(Error::OutOfBounds)
}
}
pub fn send_init(&mut self) -> Result<u32, Error> {
let last = self.store.outgoing.len() as u32;
if last < self.settings.outgoing_limit {
self.store.outgoing.insert(last, Some(Default::default()));
Ok(last)
} else {
Err(Error::OutgoingMessagesAmountLimitExceeded)
}
}
pub fn send_push(&mut self, handle: u32, buffer: &[u8]) -> Result<(), Error> {
match self.store.outgoing.get_mut(&handle) {
Some(Some(data)) => {
data.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed)?;
Ok(())
}
Some(None) => Err(Error::LateAccess),
None => Err(Error::OutOfBounds),
}
}
pub fn send_push_input(&mut self, handle: u32, range: CheckedRange) -> Result<(), Error> {
let data = self
.store
.outgoing
.get_mut(&handle)
.ok_or(Error::OutOfBounds)?
.as_mut()
.ok_or(Error::LateAccess)?;
let CheckedRange {
offset,
excluded_end,
} = range;
data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end])
.map_err(|_| Error::MaxMessageSizeExceed)?;
Ok(())
}
pub fn check_input_range(&self, offset: u32, len: u32) -> CheckedRange {
let input = self.current.payload_bytes();
let offset = offset as usize;
if offset >= input.len() {
return CheckedRange {
offset: 0,
excluded_end: 0,
};
}
CheckedRange {
offset,
excluded_end: if len == 0 {
offset
} else {
offset.saturating_add(len as usize).min(input.len())
},
}
}
pub fn reply_commit(
&mut self,
packet: ReplyPacket,
reservation: Option<ReservationId>,
) -> Result<MessageId, ExtError> {
self.check_reply_availability()?;
if !self.reply_sent() {
let data = self.store.reply.take().unwrap_or_default();
let packet = {
let mut packet = packet;
packet
.try_prepend(data)
.map_err(|_| Error::MaxMessageSizeExceed)?;
packet
};
let message_id = MessageId::generate_reply(self.current.id());
let message = ReplyMessage::from_packet(message_id, packet);
self.outcome.reply = Some((message, reservation));
self.store.reply_sent = true;
Ok(message_id)
} else {
Err(Error::DuplicateReply.into())
}
}
pub fn reply_push(&mut self, buffer: &[u8]) -> Result<(), ExtError> {
self.check_reply_availability()?;
if !self.reply_sent() {
let data = self.store.reply.get_or_insert_with(Default::default);
data.try_extend_from_slice(buffer)
.map_err(|_| Error::MaxMessageSizeExceed)?;
Ok(())
} else {
Err(Error::LateAccess.into())
}
}
pub fn reply_destination(&self) -> ProgramId {
self.outcome.source
}
pub fn reply_push_input(&mut self, range: CheckedRange) -> Result<(), ExtError> {
self.check_reply_availability()?;
if !self.reply_sent() {
let CheckedRange {
offset,
excluded_end,
} = range;
let data = self.store.reply.get_or_insert_with(Default::default);
data.try_extend_from_slice(&self.current.payload_bytes()[offset..excluded_end])
.map_err(|_| Error::MaxMessageSizeExceed)?;
Ok(())
} else {
Err(Error::LateAccess.into())
}
}
pub fn wake(&mut self, waker_id: MessageId, delay: u32) -> Result<(), Error> {
if self.store.awaken.insert(waker_id) {
self.outcome.awakening.push((waker_id, delay));
Ok(())
} else {
Err(Error::DuplicateWaking)
}
}
pub fn reply_deposit(
&mut self,
message_id: MessageId,
amount: u64,
) -> Result<(), MessageError> {
if self
.outcome
.reply_deposits
.iter()
.any(|(mid, _)| mid == &message_id)
{
return Err(MessageError::DuplicateReplyDeposit);
}
if !self
.outcome
.handle
.iter()
.any(|(message, ..)| message.id() == message_id)
&& !self
.outcome
.init
.iter()
.any(|(message, ..)| message.id() == message_id)
{
return Err(MessageError::IncorrectMessageForReplyDeposit);
}
self.outcome.reply_deposits.push((message_id, amount));
Ok(())
}
pub fn current(&self) -> &IncomingMessage {
&self.current
}
pub fn payload_mut(&mut self) -> &mut Payload {
self.current.payload_mut()
}
pub fn program_id(&self) -> ProgramId {
self.outcome.program_id
}
pub fn drain(self) -> (ContextOutcome, ContextStore) {
let Self { outcome, store, .. } = self;
(outcome, store)
}
}
pub struct CheckedRange {
offset: usize,
excluded_end: usize,
}
impl CheckedRange {
pub fn len(&self) -> u32 {
(self.excluded_end - self.offset) as u32
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use core::convert::TryInto;
macro_rules! assert_ok {
( $x:expr $(,)? ) => {
let is = $x;
match is {
Ok(_) => (),
_ => assert!(false, "Expected Ok(_). Got {:#?}", is),
}
};
( $x:expr, $y:expr $(,)? ) => {
assert_eq!($x, Ok($y));
};
}
macro_rules! assert_err {
( $x:expr , $y:expr $(,)? ) => {
assert_eq!($x, Err($y.into()));
};
}
#[test]
fn duplicated_init() {
let mut message_context = MessageContext::new(
Default::default(),
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
assert_ok!(message_context.init_program(Default::default(), 0));
assert_err!(
message_context.init_program(Default::default(), 0),
Error::DuplicateInit,
);
}
#[test]
fn outgoing_limit_exceeded() {
let max_n = 5;
for n in 0..=max_n {
let settings = ContextSettings::new(0, 0, 0, 0, 0, n);
let mut message_context =
MessageContext::new(Default::default(), Default::default(), settings);
for _ in 0..n {
let handle = message_context.send_init().expect("unreachable");
message_context
.send_push(handle, b"payload")
.expect("unreachable");
message_context
.send_commit(handle, HandlePacket::default(), 0, None)
.expect("unreachable");
}
let limit_exceeded = message_context.send_init();
assert_eq!(
limit_exceeded,
Err(Error::OutgoingMessagesAmountLimitExceeded)
);
let limit_exceeded = message_context.init_program(Default::default(), 0);
assert_eq!(
limit_exceeded,
Err(Error::OutgoingMessagesAmountLimitExceeded)
);
}
}
#[test]
fn invalid_out_of_bounds() {
let mut message_context = MessageContext::new(
Default::default(),
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
let out_of_bounds = message_context.send_commit(0, Default::default(), 0, None);
assert_eq!(out_of_bounds, Err(Error::OutOfBounds));
let valid_handle = message_context.send_init().expect("unreachable");
assert_eq!(valid_handle, 0);
assert_ok!(message_context.send_commit(0, Default::default(), 0, None));
assert_err!(
message_context.send_commit(42, Default::default(), 0, None),
Error::OutOfBounds,
);
}
#[test]
fn double_reply() {
let mut message_context = MessageContext::new(
Default::default(),
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
assert_ok!(message_context.reply_commit(Default::default(), None));
assert_err!(
message_context.reply_commit(Default::default(), None),
Error::DuplicateReply,
);
}
const INCOMING_MESSAGE_ID: u64 = 3;
const INCOMING_MESSAGE_SOURCE: u64 = 4;
#[test]
fn message_context_api() {
let incoming_message = IncomingMessage::new(
MessageId::from(INCOMING_MESSAGE_ID),
ProgramId::from(INCOMING_MESSAGE_SOURCE),
vec![1, 2].try_into().unwrap(),
0,
0,
None,
);
let incoming_dispatch = IncomingDispatch::new(DispatchKind::Handle, incoming_message, None);
let mut context = MessageContext::new(
incoming_dispatch,
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
assert_eq!(context.current().id(), MessageId::from(INCOMING_MESSAGE_ID));
assert!(context.store.reply.is_none());
assert!(context.outcome.reply.is_none());
let reply_packet = ReplyPacket::new(vec![0, 0].try_into().unwrap(), 0);
assert_ok!(context.reply_push(&[1, 2, 3]));
assert_ok!(context.reply_commit(reply_packet.clone(), None));
assert_eq!(
context
.outcome
.reply
.as_ref()
.unwrap()
.0
.payload_bytes()
.to_vec(),
vec![1, 2, 3, 0, 0],
);
assert_err!(context.reply_push(&[1]), Error::LateAccess);
assert_eq!(
context
.outcome
.reply
.as_ref()
.unwrap()
.0
.payload_bytes()
.to_vec(),
vec![1, 2, 3, 0, 0],
);
assert_err!(
context.reply_commit(reply_packet, None),
Error::DuplicateReply
);
assert!(context.outcome.handle.is_empty());
let expected_handle = 0;
assert_eq!(
context.send_init().expect("Error initializing new message"),
expected_handle
);
assert!(context
.store
.outgoing
.get(&expected_handle)
.expect("This key should be")
.is_some());
assert_ok!(context.send_push(expected_handle, &[5, 7]));
assert_ok!(context.send_push(expected_handle, &[9]));
let commit_packet = HandlePacket::default();
assert_ok!(context.send_commit(expected_handle, commit_packet, 0, None));
assert_err!(
context.send_push(expected_handle, &[5, 7]),
Error::LateAccess,
);
assert_err!(
context.send_commit(expected_handle, HandlePacket::default(), 0, None),
Error::LateAccess,
);
let expected_handle = 15;
assert_err!(context.send_push(expected_handle, &[0]), Error::OutOfBounds);
assert_err!(
context.send_commit(expected_handle, HandlePacket::default(), 0, None),
Error::OutOfBounds,
);
let expected_handle = 1;
assert_eq!(
context.send_init().expect("Error initializing new message"),
expected_handle
);
assert_ok!(context.send_push(expected_handle, &[2, 2]));
assert!(context.outcome.reply.is_some());
assert_eq!(
context.outcome.reply.as_ref().unwrap().0.payload_bytes(),
vec![1, 2, 3, 0, 0]
);
let (expected_result, _) = context.drain();
assert_eq!(expected_result.handle.len(), 1);
assert_eq!(expected_result.handle[0].0.payload_bytes(), vec![5, 7, 9]);
}
#[test]
fn duplicate_waking() {
let incoming_message = IncomingMessage::new(
MessageId::from(INCOMING_MESSAGE_ID),
ProgramId::from(INCOMING_MESSAGE_SOURCE),
vec![1, 2].try_into().unwrap(),
0,
0,
None,
);
let incoming_dispatch = IncomingDispatch::new(DispatchKind::Handle, incoming_message, None);
let mut context = MessageContext::new(
incoming_dispatch,
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
context.wake(MessageId::default(), 10).unwrap();
assert_eq!(
context.wake(MessageId::default(), 1),
Err(Error::DuplicateWaking)
);
}
#[test]
fn duplicate_reply_deposit() {
let incoming_message = IncomingMessage::new(
MessageId::from(INCOMING_MESSAGE_ID),
ProgramId::from(INCOMING_MESSAGE_SOURCE),
vec![1, 2].try_into().unwrap(),
0,
0,
None,
);
let incoming_dispatch = IncomingDispatch::new(DispatchKind::Handle, incoming_message, None);
let mut message_context = MessageContext::new(
incoming_dispatch,
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
let handle = message_context.send_init().expect("unreachable");
message_context
.send_push(handle, b"payload")
.expect("unreachable");
let message_id = message_context
.send_commit(handle, HandlePacket::default(), 0, None)
.expect("unreachable");
assert!(message_context.reply_deposit(message_id, 1234).is_ok());
assert_err!(
message_context.reply_deposit(message_id, 1234),
MessageError::DuplicateReplyDeposit
);
}
#[test]
fn inexistent_reply_deposit() {
let incoming_message = IncomingMessage::new(
MessageId::from(INCOMING_MESSAGE_ID),
ProgramId::from(INCOMING_MESSAGE_SOURCE),
vec![1, 2].try_into().unwrap(),
0,
0,
None,
);
let incoming_dispatch = IncomingDispatch::new(DispatchKind::Handle, incoming_message, None);
let mut message_context = MessageContext::new(
incoming_dispatch,
Default::default(),
ContextSettings::new(0, 0, 0, 0, 0, 1024),
);
let message_id = message_context
.reply_commit(ReplyPacket::default(), None)
.expect("unreachable");
assert_err!(
message_context.reply_deposit(message_id, 1234),
MessageError::IncorrectMessageForReplyDeposit
);
assert_err!(
message_context.reply_deposit(Default::default(), 1234),
MessageError::IncorrectMessageForReplyDeposit
);
}
}