use codec::Encode;
use frame_election_provider_support::{
bounds::{DataProviderBounds, SizeBound},
ElectionDataProvider, VoterOf,
};
#[derive(Clone, Copy, Debug)]
pub struct StaticTracker<DataProvider> {
pub size: usize,
pub counter: usize,
_marker: core::marker::PhantomData<DataProvider>,
}
impl<DataProvider> Default for StaticTracker<DataProvider> {
fn default() -> Self {
Self { size: 0, counter: 0, _marker: Default::default() }
}
}
impl<DataProvider> StaticTracker<DataProvider>
where
DataProvider: ElectionDataProvider,
{
pub fn try_register_voter(
&mut self,
voter: &VoterOf<DataProvider>,
bounds: &DataProviderBounds,
) -> Result<(), ()> {
let tracker_size_after = {
let voter_hint = Self::voter_size_hint(voter);
Self::final_byte_size_of(self.counter + 1, self.size.saturating_add(voter_hint))
};
match bounds.size_exhausted(SizeBound(tracker_size_after as u32)) {
true => Err(()),
false => {
self.size = tracker_size_after;
self.counter += 1;
Ok(())
},
}
}
fn voter_size_hint(voter: &VoterOf<DataProvider>) -> usize {
let (voter_account, vote_weight, targets) = voter;
voter_account
.size_hint()
.saturating_add(vote_weight.size_hint())
.saturating_add(voter_account.size_hint().saturating_mul(targets.len()))
}
pub fn try_register_target(
&mut self,
target: DataProvider::AccountId,
bounds: &DataProviderBounds,
) -> Result<(), ()> {
let tracker_size_after = Self::final_byte_size_of(
self.counter + 1,
self.size.saturating_add(target.size_hint()),
);
match bounds.size_exhausted(SizeBound(tracker_size_after as u32)) {
true => Err(()),
false => {
self.size = tracker_size_after;
self.counter += 1;
Ok(())
},
}
}
#[inline]
fn length_prefix(len: usize) -> usize {
use codec::{Compact, CompactLen};
Compact::<u32>::compact_len(&(len as u32))
}
fn final_byte_size_of(num_voters: usize, size: usize) -> usize {
Self::length_prefix(num_voters).saturating_add(size)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
mock::{AccountId, Staking, Test},
BoundedVec, MaxNominationsOf,
};
use frame_election_provider_support::bounds::ElectionBoundsBuilder;
use sp_core::bounded_vec;
type Voters = BoundedVec<AccountId, MaxNominationsOf<Test>>;
#[test]
pub fn election_size_tracker_works() {
let mut voters: Vec<(u64, u64, Voters)> = vec![];
let mut size_tracker = StaticTracker::<Staking>::default();
let voter_bounds = ElectionBoundsBuilder::default().voters_size(1_50.into()).build().voters;
let voter = (1, 10, bounded_vec![2]);
assert!(size_tracker.try_register_voter(&voter, &voter_bounds).is_ok());
voters.push(voter);
assert_eq!(
StaticTracker::<Staking>::final_byte_size_of(size_tracker.counter, size_tracker.size),
voters.encoded_size()
);
let voter = (2, 20, bounded_vec![3, 4, 5]);
assert!(size_tracker.try_register_voter(&voter, &voter_bounds).is_ok());
voters.push(voter);
assert_eq!(
StaticTracker::<Staking>::final_byte_size_of(size_tracker.counter, size_tracker.size),
voters.encoded_size()
);
let voter = (3, 30, bounded_vec![]);
assert!(size_tracker.try_register_voter(&voter, &voter_bounds).is_ok());
voters.push(voter);
assert_eq!(
StaticTracker::<Staking>::final_byte_size_of(size_tracker.counter, size_tracker.size),
voters.encoded_size()
);
}
#[test]
pub fn election_size_tracker_bounds_works() {
let mut voters: Vec<(u64, u64, Voters)> = vec![];
let mut size_tracker = StaticTracker::<Staking>::default();
let voter_bounds = ElectionBoundsBuilder::default().voters_size(1_00.into()).build().voters;
let voter = (1, 10, bounded_vec![2]);
assert!(size_tracker.try_register_voter(&voter, &voter_bounds).is_ok());
voters.push(voter);
assert_eq!(
StaticTracker::<Staking>::final_byte_size_of(size_tracker.counter, size_tracker.size),
voters.encoded_size()
);
assert!(size_tracker.size > 0 && size_tracker.size < 1_00);
let size_before_overflow = size_tracker.size;
let voter = (2, 10, bounded_vec![2, 3, 4, 5, 6, 7, 8, 9]);
voters.push(voter.clone());
assert!(size_tracker.try_register_voter(&voter, &voter_bounds).is_err());
assert!(size_tracker.size > 0 && size_tracker.size < 1_00);
assert_eq!(size_tracker.size, size_before_overflow);
}
#[test]
fn len_prefix_works() {
let length_samples =
vec![0usize, 1, 62, 63, 64, 16383, 16384, 16385, 1073741822, 1073741823, 1073741824];
for s in length_samples {
assert_eq!(vec![1u8; s].encoded_size(), StaticTracker::<Staking>::length_prefix(s) + s);
}
}
}