#![cfg_attr(not(feature = "std"), no_std)]
use sp_std::{prelude::*, marker::PhantomData, ops::{Sub, Rem}};
use codec::Decode;
use sp_runtime::{KeyTypeId, Perbill, RuntimeAppPublic, BoundToRuntimeAppPublic};
use frame_support::weights::SimpleDispatchInfo;
use sp_runtime::traits::{Convert, Zero, Member, OpaqueKeys};
use sp_staking::SessionIndex;
use frame_support::{ensure, decl_module, decl_event, decl_storage, decl_error, ConsensusEngineId};
use frame_support::{traits::{Get, FindAuthor, ValidatorRegistration}, Parameter};
use frame_support::dispatch::{self, DispatchResult, DispatchError};
use frame_system::{self as system, ensure_signed};
#[cfg(test)]
mod mock;
#[cfg(feature = "historical")]
pub mod historical;
pub trait ShouldEndSession<BlockNumber> {
fn should_end_session(now: BlockNumber) -> bool;
}
pub struct PeriodicSessions<
Period,
Offset,
>(PhantomData<(Period, Offset)>);
impl<
BlockNumber: Rem<Output=BlockNumber> + Sub<Output=BlockNumber> + Zero + PartialOrd,
Period: Get<BlockNumber>,
Offset: Get<BlockNumber>,
> ShouldEndSession<BlockNumber> for PeriodicSessions<Period, Offset> {
fn should_end_session(now: BlockNumber) -> bool {
let offset = Offset::get();
now >= offset && ((now - offset) % Period::get()).is_zero()
}
}
pub trait SessionManager<ValidatorId> {
fn new_session(new_index: SessionIndex) -> Option<Vec<ValidatorId>>;
fn end_session(end_index: SessionIndex);
}
impl<A> SessionManager<A> for () {
fn new_session(_: SessionIndex) -> Option<Vec<A>> { None }
fn end_session(_: SessionIndex) {}
}
pub trait SessionHandler<ValidatorId> {
const KEY_TYPE_IDS: &'static [KeyTypeId];
fn on_genesis_session<Ks: OpaqueKeys>(validators: &[(ValidatorId, Ks)]);
fn on_new_session<Ks: OpaqueKeys>(
changed: bool,
validators: &[(ValidatorId, Ks)],
queued_validators: &[(ValidatorId, Ks)],
);
fn on_before_session_ending() {}
fn on_disabled(validator_index: usize);
}
pub trait OneSessionHandler<ValidatorId>: BoundToRuntimeAppPublic {
type Key: Decode + Default + RuntimeAppPublic;
fn on_genesis_session<'a, I: 'a>(validators: I)
where I: Iterator<Item=(&'a ValidatorId, Self::Key)>, ValidatorId: 'a;
fn on_new_session<'a, I: 'a>(
changed: bool,
validators: I,
queued_validators: I,
) where I: Iterator<Item=(&'a ValidatorId, Self::Key)>, ValidatorId: 'a;
fn on_before_session_ending() {}
fn on_disabled(_validator_index: usize);
}
#[impl_trait_for_tuples::impl_for_tuples(1, 30)]
#[tuple_types_no_default_trait_bound]
impl<AId> SessionHandler<AId> for Tuple {
for_tuples!( where #( Tuple: OneSessionHandler<AId> )* );
for_tuples!(
const KEY_TYPE_IDS: &'static [KeyTypeId] = &[ #( <Tuple::Key as RuntimeAppPublic>::ID ),* ];
);
fn on_genesis_session<Ks: OpaqueKeys>(validators: &[(AId, Ks)]) {
for_tuples!(
#(
let our_keys: Box<dyn Iterator<Item=_>> = Box::new(validators.iter()
.map(|k| (&k.0, k.1.get::<Tuple::Key>(<Tuple::Key as RuntimeAppPublic>::ID)
.unwrap_or_default())));
Tuple::on_genesis_session(our_keys);
)*
)
}
fn on_new_session<Ks: OpaqueKeys>(
changed: bool,
validators: &[(AId, Ks)],
queued_validators: &[(AId, Ks)],
) {
for_tuples!(
#(
let our_keys: Box<dyn Iterator<Item=_>> = Box::new(validators.iter()
.map(|k| (&k.0, k.1.get::<Tuple::Key>(<Tuple::Key as RuntimeAppPublic>::ID)
.unwrap_or_default())));
let queued_keys: Box<dyn Iterator<Item=_>> = Box::new(queued_validators.iter()
.map(|k| (&k.0, k.1.get::<Tuple::Key>(<Tuple::Key as RuntimeAppPublic>::ID)
.unwrap_or_default())));
Tuple::on_new_session(changed, our_keys, queued_keys);
)*
)
}
fn on_before_session_ending() {
for_tuples!( #( Tuple::on_before_session_ending(); )* )
}
fn on_disabled(i: usize) {
for_tuples!( #( Tuple::on_disabled(i); )* )
}
}
pub struct TestSessionHandler;
impl<AId> SessionHandler<AId> for TestSessionHandler {
const KEY_TYPE_IDS: &'static [KeyTypeId] = &[sp_runtime::key_types::DUMMY];
fn on_genesis_session<Ks: OpaqueKeys>(_: &[(AId, Ks)]) {}
fn on_new_session<Ks: OpaqueKeys>(_: bool, _: &[(AId, Ks)], _: &[(AId, Ks)]) {}
fn on_before_session_ending() {}
fn on_disabled(_: usize) {}
}
impl<T: Trait> ValidatorRegistration<T::ValidatorId> for Module<T> {
fn is_registered(id: &T::ValidatorId) -> bool {
Self::load_keys(id).is_some()
}
}
pub trait Trait: frame_system::Trait {
type Event: From<Event> + Into<<Self as frame_system::Trait>::Event>;
type ValidatorId: Member + Parameter;
type ValidatorIdOf: Convert<Self::AccountId, Option<Self::ValidatorId>>;
type ShouldEndSession: ShouldEndSession<Self::BlockNumber>;
type SessionManager: SessionManager<Self::ValidatorId>;
type SessionHandler: SessionHandler<Self::ValidatorId>;
type Keys: OpaqueKeys + Member + Parameter + Default;
type DisabledValidatorsThreshold: Get<Perbill>;
}
const DEDUP_KEY_PREFIX: &[u8] = b":session:keys";
decl_storage! {
trait Store for Module<T: Trait> as Session {
Validators get(fn validators): Vec<T::ValidatorId>;
CurrentIndex get(fn current_index): SessionIndex;
QueuedChanged: bool;
QueuedKeys get(fn queued_keys): Vec<(T::ValidatorId, T::Keys)>;
DisabledValidators get(fn disabled_validators): Vec<u32>;
NextKeys: double_map hasher(twox_64_concat) Vec<u8>, hasher(blake2_256) T::ValidatorId
=> Option<T::Keys>;
KeyOwner: double_map hasher(twox_64_concat) Vec<u8>, hasher(blake2_256) (KeyTypeId, Vec<u8>)
=> Option<T::ValidatorId>;
}
add_extra_genesis {
config(keys): Vec<(T::AccountId, T::ValidatorId, T::Keys)>;
build(|config: &GenesisConfig<T>| {
if T::SessionHandler::KEY_TYPE_IDS.len() != T::Keys::key_ids().len() {
panic!("Number of keys in session handler and session keys does not match");
}
T::SessionHandler::KEY_TYPE_IDS.iter().zip(T::Keys::key_ids()).enumerate()
.for_each(|(i, (sk, kk))| {
if sk != kk {
panic!(
"Session handler and session key expect different key type at index: {}",
i,
);
}
});
for (account, val, keys) in config.keys.iter().cloned() {
<Module<T>>::inner_set_keys(&val, keys)
.expect("genesis config must not contain duplicates; qed");
system::Module::<T>::inc_ref(&account);
}
let initial_validators_0 = T::SessionManager::new_session(0)
.unwrap_or_else(|| {
frame_support::print("No initial validator provided by `SessionManager`, use \
session config keys to generate initial validator set.");
config.keys.iter().map(|x| x.1.clone()).collect()
});
assert!(!initial_validators_0.is_empty(), "Empty validator set for session 0 in genesis block!");
let initial_validators_1 = T::SessionManager::new_session(1)
.unwrap_or_else(|| initial_validators_0.clone());
assert!(!initial_validators_1.is_empty(), "Empty validator set for session 1 in genesis block!");
let queued_keys: Vec<_> = initial_validators_1
.iter()
.cloned()
.map(|v| (
v.clone(),
<Module<T>>::load_keys(&v).unwrap_or_default(),
))
.collect();
T::SessionHandler::on_genesis_session::<T::Keys>(&queued_keys);
<Validators<T>>::put(initial_validators_0);
<QueuedKeys<T>>::put(queued_keys);
});
}
}
decl_event!(
pub enum Event {
NewSession(SessionIndex),
}
);
decl_error! {
pub enum Error for Module<T: Trait> {
InvalidProof,
NoAssociatedValidatorId,
DuplicatedKey,
NoKeys,
}
}
decl_module! {
pub struct Module<T: Trait> for enum Call where origin: T::Origin {
const DEDUP_KEY_PREFIX: &[u8] = DEDUP_KEY_PREFIX;
type Error = Error<T>;
fn deposit_event() = default;
#[weight = SimpleDispatchInfo::FixedNormal(150_000)]
fn set_keys(origin, keys: T::Keys, proof: Vec<u8>) -> dispatch::DispatchResult {
let who = ensure_signed(origin)?;
ensure!(keys.ownership_proof_is_valid(&proof), Error::<T>::InvalidProof);
Self::do_set_keys(&who, keys)?;
Ok(())
}
#[weight = SimpleDispatchInfo::FixedNormal(150_000)]
fn purge_keys(origin) {
let who = ensure_signed(origin)?;
Self::do_purge_keys(&who)?;
}
fn on_initialize(n: T::BlockNumber) {
if T::ShouldEndSession::should_end_session(n) {
Self::rotate_session();
}
}
}
}
impl<T: Trait> Module<T> {
pub fn rotate_session() {
let session_index = CurrentIndex::get();
let changed = QueuedChanged::get();
T::SessionHandler::on_before_session_ending();
let session_keys = <QueuedKeys<T>>::get();
let validators = session_keys.iter()
.map(|(validator, _)| validator.clone())
.collect::<Vec<_>>();
<Validators<T>>::put(&validators);
if changed {
DisabledValidators::take();
}
T::SessionManager::end_session(session_index);
let session_index = session_index + 1;
CurrentIndex::put(session_index);
let maybe_next_validators = T::SessionManager::new_session(session_index + 1);
let (next_validators, next_identities_changed)
= if let Some(validators) = maybe_next_validators
{
(validators, true)
} else {
(<Validators<T>>::get(), false)
};
let (queued_amalgamated, next_changed) = {
let mut changed = next_identities_changed;
let mut now_session_keys = session_keys.iter();
let mut check_next_changed = |keys: &T::Keys| {
if changed { return }
if let Some(&(_, ref old_keys)) = now_session_keys.next() {
if old_keys != keys {
changed = true;
return
}
}
};
let queued_amalgamated = next_validators.into_iter()
.map(|a| {
let k = Self::load_keys(&a).unwrap_or_default();
check_next_changed(&k);
(a, k)
})
.collect::<Vec<_>>();
(queued_amalgamated, changed)
};
<QueuedKeys<T>>::put(queued_amalgamated.clone());
QueuedChanged::put(next_changed);
Self::deposit_event(Event::NewSession(session_index));
T::SessionHandler::on_new_session::<T::Keys>(
changed,
&session_keys,
&queued_amalgamated,
);
}
pub fn disable_index(i: usize) -> bool {
let (fire_event, threshold_reached) = DisabledValidators::mutate(|disabled| {
let i = i as u32;
if let Err(index) = disabled.binary_search(&i) {
let count = <Validators<T>>::decode_len().unwrap_or(0) as u32;
let threshold = T::DisabledValidatorsThreshold::get() * count;
disabled.insert(index, i);
(true, disabled.len() as u32 > threshold)
} else {
(false, false)
}
});
if fire_event {
T::SessionHandler::on_disabled(i);
}
threshold_reached
}
pub fn disable(c: &T::ValidatorId) -> sp_std::result::Result<bool, ()> {
Self::validators().iter().position(|i| i == c).map(Self::disable_index).ok_or(())
}
fn do_set_keys(account: &T::AccountId, keys: T::Keys) -> dispatch::DispatchResult {
let who = T::ValidatorIdOf::convert(account.clone())
.ok_or(Error::<T>::NoAssociatedValidatorId)?;
let old_keys = Self::inner_set_keys(&who, keys)?;
if old_keys.is_none() {
system::Module::<T>::inc_ref(&account);
}
Ok(())
}
fn inner_set_keys(who: &T::ValidatorId, keys: T::Keys) -> Result<Option<T::Keys>, DispatchError> {
let old_keys = Self::load_keys(who);
for id in T::Keys::key_ids() {
let key = keys.get_raw(*id);
ensure!(
Self::key_owner(*id, key).map_or(true, |owner| &owner == who),
Error::<T>::DuplicatedKey,
);
if let Some(old) = old_keys.as_ref().map(|k| k.get_raw(*id)) {
if key == old {
continue;
}
Self::clear_key_owner(*id, old);
}
Self::put_key_owner(*id, key, who);
}
Self::put_keys(who, &keys);
Ok(old_keys)
}
fn do_purge_keys(account: &T::AccountId) -> DispatchResult {
let who = T::ValidatorIdOf::convert(account.clone())
.ok_or(Error::<T>::NoAssociatedValidatorId)?;
let old_keys = Self::take_keys(&who).ok_or(Error::<T>::NoKeys)?;
for id in T::Keys::key_ids() {
let key_data = old_keys.get_raw(*id);
Self::clear_key_owner(*id, key_data);
}
system::Module::<T>::dec_ref(&account);
Ok(())
}
fn load_keys(v: &T::ValidatorId) -> Option<T::Keys> {
<NextKeys<T>>::get(DEDUP_KEY_PREFIX, v)
}
fn take_keys(v: &T::ValidatorId) -> Option<T::Keys> {
<NextKeys<T>>::take(DEDUP_KEY_PREFIX, v)
}
fn put_keys(v: &T::ValidatorId, keys: &T::Keys) {
<NextKeys<T>>::insert(DEDUP_KEY_PREFIX, v, keys);
}
fn key_owner(id: KeyTypeId, key_data: &[u8]) -> Option<T::ValidatorId> {
<KeyOwner<T>>::get(DEDUP_KEY_PREFIX, (id, key_data))
}
fn put_key_owner(id: KeyTypeId, key_data: &[u8], v: &T::ValidatorId) {
<KeyOwner<T>>::insert(DEDUP_KEY_PREFIX, (id, key_data), v)
}
fn clear_key_owner(id: KeyTypeId, key_data: &[u8]) {
<KeyOwner<T>>::remove(DEDUP_KEY_PREFIX, (id, key_data));
}
}
pub struct FindAccountFromAuthorIndex<T, Inner>(sp_std::marker::PhantomData<(T, Inner)>);
impl<T: Trait, Inner: FindAuthor<u32>> FindAuthor<T::ValidatorId>
for FindAccountFromAuthorIndex<T, Inner>
{
fn find_author<'a, I>(digests: I) -> Option<T::ValidatorId>
where I: 'a + IntoIterator<Item=(ConsensusEngineId, &'a [u8])>
{
let i = Inner::find_author(digests)?;
let validators = <Module<T>>::validators();
validators.get(i as usize).map(|k| k.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use frame_support::assert_ok;
use sp_core::crypto::key_types::DUMMY;
use sp_runtime::{traits::OnInitialize, testing::UintAuthorityId};
use mock::{
NEXT_VALIDATORS, SESSION_CHANGED, TEST_SESSION_CHANGED, authorities, force_new_session,
set_next_validators, set_session_length, session_changed, Test, Origin, System, Session,
reset_before_session_end_called, before_session_end_called,
};
fn new_test_ext() -> sp_io::TestExternalities {
let mut t = frame_system::GenesisConfig::default().build_storage::<Test>().unwrap();
GenesisConfig::<Test> {
keys: NEXT_VALIDATORS.with(|l|
l.borrow().iter().cloned().map(|i| (i, i, UintAuthorityId(i).into())).collect()
),
}.assimilate_storage(&mut t).unwrap();
sp_io::TestExternalities::new(t)
}
fn initialize_block(block: u64) {
SESSION_CHANGED.with(|l| *l.borrow_mut() = false);
System::set_block_number(block);
Session::on_initialize(block);
}
#[test]
fn simple_setup_should_work() {
new_test_ext().execute_with(|| {
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
assert_eq!(Session::validators(), vec![1, 2, 3]);
});
}
#[test]
fn put_get_keys() {
new_test_ext().execute_with(|| {
Session::put_keys(&10, &UintAuthorityId(10).into());
assert_eq!(Session::load_keys(&10), Some(UintAuthorityId(10).into()));
})
}
#[test]
fn keys_cleared_on_kill() {
let mut ext = new_test_ext();
ext.execute_with(|| {
assert_eq!(Session::validators(), vec![1, 2, 3]);
assert_eq!(Session::load_keys(&1), Some(UintAuthorityId(1).into()));
let id = DUMMY;
assert_eq!(Session::key_owner(id, UintAuthorityId(1).get_raw(id)), Some(1));
assert!(!System::allow_death(&1));
assert_ok!(Session::purge_keys(Origin::signed(1)));
assert!(System::allow_death(&1));
assert_eq!(Session::load_keys(&1), None);
assert_eq!(Session::key_owner(id, UintAuthorityId(1).get_raw(id)), None);
})
}
#[test]
fn authorities_should_track_validators() {
reset_before_session_end_called();
new_test_ext().execute_with(|| {
set_next_validators(vec![1, 2]);
force_new_session();
initialize_block(1);
assert_eq!(Session::queued_keys(), vec![
(1, UintAuthorityId(1).into()),
(2, UintAuthorityId(2).into()),
]);
assert_eq!(Session::validators(), vec![1, 2, 3]);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
assert!(before_session_end_called());
reset_before_session_end_called();
force_new_session();
initialize_block(2);
assert_eq!(Session::queued_keys(), vec![
(1, UintAuthorityId(1).into()),
(2, UintAuthorityId(2).into()),
]);
assert_eq!(Session::validators(), vec![1, 2]);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2)]);
assert!(before_session_end_called());
reset_before_session_end_called();
set_next_validators(vec![1, 2, 4]);
assert_ok!(Session::set_keys(Origin::signed(4), UintAuthorityId(4).into(), vec![]));
force_new_session();
initialize_block(3);
assert_eq!(Session::queued_keys(), vec![
(1, UintAuthorityId(1).into()),
(2, UintAuthorityId(2).into()),
(4, UintAuthorityId(4).into()),
]);
assert_eq!(Session::validators(), vec![1, 2]);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2)]);
assert!(before_session_end_called());
force_new_session();
initialize_block(4);
assert_eq!(Session::queued_keys(), vec![
(1, UintAuthorityId(1).into()),
(2, UintAuthorityId(2).into()),
(4, UintAuthorityId(4).into()),
]);
assert_eq!(Session::validators(), vec![1, 2, 4]);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(4)]);
});
}
#[test]
fn should_work_with_early_exit() {
new_test_ext().execute_with(|| {
set_session_length(10);
initialize_block(1);
assert_eq!(Session::current_index(), 0);
initialize_block(2);
assert_eq!(Session::current_index(), 0);
force_new_session();
initialize_block(3);
assert_eq!(Session::current_index(), 1);
initialize_block(9);
assert_eq!(Session::current_index(), 1);
initialize_block(10);
assert_eq!(Session::current_index(), 2);
});
}
#[test]
fn session_change_should_work() {
new_test_ext().execute_with(|| {
initialize_block(1);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
initialize_block(2);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
initialize_block(3);
assert_ok!(Session::set_keys(Origin::signed(2), UintAuthorityId(5).into(), vec![]));
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
initialize_block(4);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
initialize_block(5);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(2), UintAuthorityId(3)]);
initialize_block(6);
assert_eq!(authorities(), vec![UintAuthorityId(1), UintAuthorityId(5), UintAuthorityId(3)]);
});
}
#[test]
fn duplicates_are_not_allowed() {
new_test_ext().execute_with(|| {
System::set_block_number(1);
Session::on_initialize(1);
assert!(Session::set_keys(Origin::signed(4), UintAuthorityId(1).into(), vec![]).is_err());
assert!(Session::set_keys(Origin::signed(1), UintAuthorityId(10).into(), vec![]).is_ok());
assert!(Session::set_keys(Origin::signed(4), UintAuthorityId(1).into(), vec![]).is_ok());
});
}
#[test]
fn session_changed_flag_works() {
reset_before_session_end_called();
new_test_ext().execute_with(|| {
TEST_SESSION_CHANGED.with(|l| *l.borrow_mut() = true);
force_new_session();
initialize_block(1);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
force_new_session();
initialize_block(2);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
Session::disable_index(0);
force_new_session();
initialize_block(3);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
force_new_session();
initialize_block(4);
assert!(session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
force_new_session();
initialize_block(5);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
assert_ok!(Session::set_keys(Origin::signed(2), UintAuthorityId(5).into(), vec![]));
force_new_session();
initialize_block(6);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
assert_ok!(Session::set_keys(Origin::signed(69), UintAuthorityId(69).into(), vec![]));
force_new_session();
initialize_block(7);
assert!(session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
force_new_session();
initialize_block(7);
assert!(!session_changed());
assert!(before_session_end_called());
reset_before_session_end_called();
});
}
#[test]
fn periodic_session_works() {
struct Period;
struct Offset;
impl Get<u64> for Period {
fn get() -> u64 { 10 }
}
impl Get<u64> for Offset {
fn get() -> u64 { 3 }
}
type P = PeriodicSessions<Period, Offset>;
for i in 0..3 {
assert!(!P::should_end_session(i));
}
assert!(P::should_end_session(3));
for i in (1..10).map(|i| 3 + i) {
assert!(!P::should_end_session(i));
}
assert!(P::should_end_session(13));
}
#[test]
fn session_keys_generate_output_works_as_set_keys_input() {
new_test_ext().execute_with(|| {
let new_keys = mock::MockSessionKeys::generate(None);
assert_ok!(
Session::set_keys(
Origin::signed(2),
<mock::Test as Trait>::Keys::decode(&mut &new_keys[..]).expect("Decode keys"),
vec![],
)
);
});
}
#[test]
fn return_true_if_more_than_third_is_disabled() {
new_test_ext().execute_with(|| {
set_next_validators(vec![1, 2, 3, 4, 5, 6, 7]);
force_new_session();
initialize_block(1);
force_new_session();
initialize_block(2);
assert_eq!(Session::disable_index(0), false);
assert_eq!(Session::disable_index(1), false);
assert_eq!(Session::disable_index(2), true);
assert_eq!(Session::disable_index(3), true);
});
}
}