#![cfg(test)]
use crate::{
mock::{
build_and_execute, gen_seed, set_weight, Callback, CountingMessageProcessor, IntoWeight,
MessagesProcessed, MockedWeightInfo, NumMessagesProcessed, YieldingQueues,
},
mock_helpers::{MessageOrigin, MessageOrigin::Everywhere},
*,
};
use crate as pallet_message_queue;
use frame_support::{derive_impl, parameter_types};
use rand::{rngs::StdRng, Rng, SeedableRng};
use rand_distr::Pareto;
use std::collections::{BTreeMap, BTreeSet};
type Block = frame_system::mocking::MockBlock<Test>;
frame_support::construct_runtime!(
pub enum Test
{
System: frame_system,
MessageQueue: pallet_message_queue,
}
);
#[derive_impl(frame_system::config_preludes::TestDefaultConfig)]
impl frame_system::Config for Test {
type Block = Block;
}
parameter_types! {
pub const HeapSize: u32 = 32 * 1024;
pub const MaxStale: u32 = 32;
pub static ServiceWeight: Option<Weight> = Some(Weight::from_parts(100, 100));
}
impl Config for Test {
type RuntimeEvent = RuntimeEvent;
type WeightInfo = MockedWeightInfo;
type MessageProcessor = CountingMessageProcessor;
type Size = u32;
type QueueChangeHandler = AhmPrioritizer;
type QueuePausedQuery = ();
type HeapSize = HeapSize;
type MaxStale = MaxStale;
type ServiceWeight = ServiceWeight;
type IdleMaxServiceWeight = ();
}
#[derive(Debug, Default, codec::Encode, codec::Decode)]
pub struct AhmPrioritizer {
streak_until: Option<u64>,
prioritized_queue: Option<MessageOriginOf<Test>>,
favorite_queue_num_messages: Option<u64>,
}
parameter_types! {
pub storage AhmPrioritizerStorage: AhmPrioritizer = AhmPrioritizer::default();
}
const STREAK_LEN: u64 = 3;
impl OnQueueChanged<MessageOrigin> for AhmPrioritizer {
fn on_queue_changed(origin: MessageOrigin, f: QueueFootprint) {
let mut this = AhmPrioritizerStorage::get();
if this.prioritized_queue != Some(origin) {
return;
}
if this.favorite_queue_num_messages.map_or(false, |n| n <= f.storage.count) {
return;
}
this.favorite_queue_num_messages = Some(f.storage.count);
if this.streak_until.map_or(false, |s| s < System::block_number()) {
this.streak_until = Some(System::block_number().saturating_add(STREAK_LEN));
}
}
}
impl AhmPrioritizer {
fn on_initialize(now: u64) -> Weight {
let mut meter = WeightMeter::new();
let mut this = AhmPrioritizerStorage::get();
let Some(q) = this.prioritized_queue else {
return meter.consumed();
};
if this.streak_until.is_none() {
this.streak_until = Some(0);
}
if this.favorite_queue_num_messages.is_none() {
this.favorite_queue_num_messages = Some(0);
}
if Pallet::<Test>::footprint(q).pages == 0 {
return meter.consumed();
}
if this.streak_until.map_or(false, |until| until < now.saturating_sub(10)) {
log::warn!("Queue is being starved, scheduling streak of {} blocks", STREAK_LEN);
this.streak_until = Some(now.saturating_add(STREAK_LEN));
}
if this.streak_until.map_or(false, |until| until > now) {
let _ = Pallet::<Test>::force_set_head(&mut meter, &q).defensive();
}
meter.consumed()
}
}
impl Drop for AhmPrioritizer {
fn drop(&mut self) {
AhmPrioritizerStorage::set(self);
}
}
#[test]
#[ignore] fn stress_test_enqueue_and_service() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for _ in 0..blocks {
let enqueued =
enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
msgs_remaining += enqueued;
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); msgs_remaining -= processed;
}
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
post_conditions();
});
}
#[test]
#[ignore] fn stress_test_force_set_head() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for _ in 0..blocks {
let enqueued =
enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
msgs_remaining += enqueued;
for _ in 0..10 {
let random_queue = rng.gen_range(0..=max_queues);
MessageQueue::force_set_head(&mut WeightMeter::new(), &Everywhere(random_queue))
.unwrap();
}
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); msgs_remaining -= processed;
}
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
post_conditions();
});
}
#[test]
#[ignore] fn stress_test_prioritize_queue() {
let blocks = 20;
let max_queues = 10_000;
let favorite_queue = Everywhere(9000);
let max_messages_per_queue = 1_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut prio = AhmPrioritizerStorage::get();
prio.prioritized_queue = Some(favorite_queue);
drop(prio);
let mut msgs_remaining = 0;
for _ in 0..blocks {
let enqueued =
enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
msgs_remaining += enqueued;
for _ in 0..200 {
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from("favorite".as_bytes()),
favorite_queue,
);
msgs_remaining += 1;
}
let processed = rng.gen_range(1..=100);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); msgs_remaining -= processed;
}
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
post_conditions();
});
}
#[test]
#[ignore] fn stress_test_recursive() {
let blocks = 20;
let mut rng = StdRng::seed_from_u64(gen_seed());
parameter_types! {
pub static TotalEnqueued: u32 = 0;
pub static Enqueued: u32 = 0;
pub static Called: u32 = 0;
}
Called::take();
Enqueued::take();
TotalEnqueued::take();
Callback::set(Box::new(|_, _| {
let mut rng = StdRng::seed_from_u64(Enqueued::get() as u64);
let max_queues = 1_000;
let max_messages_per_queue = 1_000;
let max_msg_len = MaxMessageLenOf::<Test>::get();
let enqueued = enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
TotalEnqueued::set(TotalEnqueued::get() + enqueued);
Enqueued::set(Enqueued::get() + enqueued);
Called::set(Called::get() + 1);
Ok(())
}));
build_and_execute::<Test>(|| {
let mut msgs_remaining = 0;
for b in 0..blocks {
log::info!("Block #{}", b);
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(format!("callback={b}").as_bytes()),
b.into(),
);
msgs_remaining += Enqueued::take() + 1;
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed + 1);
MessageQueue::do_try_state().unwrap();
}
while Called::get() < blocks {
msgs_remaining += Enqueued::take();
let processed = rng.gen_range(1..=msgs_remaining);
log::info!("Processing {} of all messages {}", processed, msgs_remaining);
process_some_messages(processed); msgs_remaining -= processed;
TotalEnqueued::set(TotalEnqueued::get() - processed);
MessageQueue::do_try_state().unwrap();
}
let msgs_remaining = TotalEnqueued::take();
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
assert_eq!(Called::get(), blocks);
post_conditions();
});
}
#[test]
#[ignore] fn stress_test_queue_suspension() {
let blocks = 20;
let max_queues = 10_000;
let max_messages_per_queue = 10_000;
let (max_suspend_per_block, max_resume_per_block) = (100, 50);
let max_msg_len = MaxMessageLenOf::<Test>::get();
let mut rng = StdRng::seed_from_u64(gen_seed());
build_and_execute::<Test>(|| {
let mut suspended = BTreeSet::<u32>::new();
let mut msgs_remaining = 0;
for _ in 0..blocks {
let enqueued =
enqueue_messages(max_queues, max_messages_per_queue, max_msg_len, &mut rng);
msgs_remaining += enqueued;
let per_queue = msgs_per_queue();
let to_suspend = rng.gen_range(0..max_suspend_per_block).min(per_queue.len());
for _ in 0..to_suspend {
let q = rng.gen_range(0..per_queue.len());
suspended.insert(*per_queue.iter().nth(q).map(|(q, _)| q).unwrap());
}
let to_resume = rng.gen_range(0..max_resume_per_block).min(suspended.len());
for _ in 0..to_resume {
let q = rng.gen_range(0..suspended.len());
suspended.remove(&suspended.iter().nth(q).unwrap().clone());
}
log::info!(
"Suspended {} and resumed {} queues of {} in total",
to_suspend,
to_resume,
per_queue.len()
);
YieldingQueues::set(suspended.iter().map(|q| MessageOrigin::Everywhere(*q)).collect());
let resumed_messages =
per_queue.iter().filter(|(q, _)| !suspended.contains(q)).map(|(_, n)| n).sum();
let processed = rng.gen_range(1..=resumed_messages);
log::info!(
"Processing {} messages. Resumed msgs: {}, All msgs: {}",
processed,
resumed_messages,
msgs_remaining
);
process_some_messages(processed); msgs_remaining -= processed;
}
let per_queue = msgs_per_queue();
let resumed_messages =
per_queue.iter().filter(|(q, _)| !suspended.contains(q)).map(|(_, n)| n).sum();
log::info!("Processing all {} remaining resumed messages", resumed_messages);
process_all_messages(resumed_messages);
msgs_remaining -= resumed_messages;
let resumed = YieldingQueues::take();
log::info!("Resumed all {} suspended queues", resumed.len());
log::info!("Processing all remaining {} messages", msgs_remaining);
process_all_messages(msgs_remaining);
post_conditions();
});
}
#[test]
#[ignore]
fn stress_test_ahm_despair_mode_works() {
build_and_execute::<Test>(|| {
let blocks = 200;
let queues = 200;
for o in 0..queues {
for i in 0..100 {
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(format!("{}:{}", o, i).as_bytes()),
Everywhere(o),
);
}
}
set_weight("bump_head", Weight::from_parts(1, 1));
let mut prio = AhmPrioritizerStorage::get();
prio.prioritized_queue = Some(Everywhere(199));
drop(prio);
ServiceWeight::set(Some(Weight::from_parts(10, 10)));
for _ in 0..blocks {
next_block();
}
let mut min = u64::MAX;
let mut min_origin = 0;
for o in 0..queues {
let fp = MessageQueue::footprint(Everywhere(o));
if fp.storage.count < min {
min = fp.storage.count;
min_origin = o;
}
}
assert_eq!(min_origin, 199);
ServiceWeight::set(Some(Weight::MAX));
next_block();
post_conditions();
});
}
fn msgs_per_queue() -> BTreeMap<u32, u32> {
let mut per_queue = BTreeMap::new();
for (o, q) in BookStateFor::<Test>::iter() {
let MessageOrigin::Everywhere(o) = o else {
unreachable!();
};
per_queue.insert(o, q.message_count as u32);
}
per_queue
}
fn enqueue_messages(
max_queues: u32,
max_per_queue: u32,
max_msg_len: u32,
rng: &mut StdRng,
) -> u32 {
let num_queues = rng.gen_range(1..max_queues);
let mut num_messages = 0;
let mut total_msg_len = 0;
for origin in 0..num_queues {
let num_messages_per_queue =
(rng.sample(Pareto::new(1.0, 1.1).unwrap()) as u32).min(max_per_queue);
for m in 0..num_messages_per_queue {
let mut message = format!("{}:{}", &origin, &m).into_bytes();
let msg_len = (rng.sample(Pareto::new(1.0, 1.0).unwrap()) as u32)
.clamp(message.len() as u32, max_msg_len);
message.resize(msg_len as usize, 0);
MessageQueue::enqueue_message(
BoundedSlice::defensive_truncate_from(&message),
origin.into(),
);
total_msg_len += msg_len;
}
num_messages += num_messages_per_queue;
}
log::info!(
"Enqueued {} messages across {} queues. Payload {:.2} KiB",
num_messages,
num_queues,
total_msg_len as f64 / 1024.0
);
num_messages
}
fn process_some_messages(num_msgs: u32) {
let weight = (num_msgs as u64).into_weight();
ServiceWeight::set(Some(weight));
let consumed = next_block();
for origin in BookStateFor::<Test>::iter_keys() {
let fp = MessageQueue::footprint(origin);
assert_eq!(fp.pages, fp.ready_pages);
}
assert_eq!(consumed, weight, "\n{}", MessageQueue::debug_info());
assert_eq!(NumMessagesProcessed::take(), num_msgs as usize);
}
fn process_all_messages(expected: u32) {
ServiceWeight::set(Some(Weight::MAX));
let consumed = next_block();
assert_eq!(consumed, Weight::from_all(expected as u64));
assert_eq!(NumMessagesProcessed::take(), expected as usize);
MessagesProcessed::take();
}
fn next_block() -> Weight {
log::info!("Next block: {}", System::block_number() + 1);
MessageQueue::on_finalize(System::block_number());
System::on_finalize(System::block_number());
System::set_block_number(System::block_number() + 1);
System::on_initialize(System::block_number());
AhmPrioritizer::on_initialize(System::block_number());
MessageQueue::on_initialize(System::block_number())
}
fn post_conditions() {
for (_, book) in BookStateFor::<Test>::iter() {
assert!(book.end >= book.begin);
assert_eq!(book.count, 0);
assert_eq!(book.size, 0);
assert_eq!(book.message_count, 0);
assert!(book.ready_neighbours.is_none());
}
assert_eq!(Pages::<Test>::iter().count(), 0);
assert!(ServiceHead::<Test>::get().is_none());
assert_eq!(MessageQueue::service_queues(Weight::MAX), Weight::zero(), "Nothing left");
MessageQueue::do_try_state().unwrap();
next_block();
}