use alloc::{collections::BTreeSet, vec::Vec};
use core::{fmt, ops};
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct SourceId(u64);
pub struct AllForksSources<TSrc> {
sources: hashbrown::HashMap<SourceId, Source<TSrc>, fnv::FnvBuildHasher>,
next_source_id: SourceId,
known_blocks1: BTreeSet<(SourceId, u64, [u8; 32])>,
known_blocks2: BTreeSet<(u64, [u8; 32], SourceId)>,
finalized_block_height: u64,
}
#[derive(Debug)]
struct Source<TSrc> {
best_block_number: u64,
best_block_hash: [u8; 32],
user_data: TSrc,
}
impl<TSrc> AllForksSources<TSrc> {
pub fn new(sources_capacity: usize, finalized_block_height: u64) -> Self {
AllForksSources {
sources: hashbrown::HashMap::with_capacity_and_hasher(
sources_capacity,
Default::default(),
),
next_source_id: SourceId(0),
known_blocks1: Default::default(),
known_blocks2: Default::default(),
finalized_block_height,
}
}
pub fn keys(&self) -> impl ExactSizeIterator<Item = SourceId> {
self.sources.keys().copied()
}
pub fn is_empty(&self) -> bool {
self.sources.is_empty()
}
pub fn len(&self) -> usize {
self.sources.len()
}
pub fn clear(&mut self) {
self.sources.clear();
self.known_blocks1.clear();
self.known_blocks2.clear();
}
pub fn user_data_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut TSrc> {
self.sources.values_mut().map(|s| &mut s.user_data)
}
pub fn num_blocks(&self) -> usize {
self.known_blocks2
.iter()
.fold((0, None), |(uniques, prev), next| match (prev, next) {
(Some((pn, ph)), (nn, nh, _)) if pn == *nn && ph == *nh => {
(uniques, Some((pn, ph)))
}
(_, (nn, nh, _)) => (uniques + 1, Some((*nn, *nh))),
})
.0
}
pub fn finalized_block_height(&self) -> u64 {
self.finalized_block_height
}
pub fn add_source(
&mut self,
best_block_number: u64,
best_block_hash: [u8; 32],
user_data: TSrc,
) -> SourceId {
let new_id = {
let id = self.next_source_id;
self.next_source_id.0 += 1;
id
};
self.sources.insert(
new_id,
Source {
best_block_number,
best_block_hash,
user_data,
},
);
if best_block_number > self.finalized_block_height {
self.known_blocks1
.insert((new_id, best_block_number, best_block_hash));
self.known_blocks2
.insert((best_block_number, best_block_hash, new_id));
}
new_id
}
#[track_caller]
pub fn remove(&mut self, source_id: SourceId) -> TSrc {
let source = self.sources.remove(&source_id).unwrap();
let known_blocks = self
.known_blocks1
.range((source_id, 0, [0; 32])..=(source_id, u64::MAX, [0xff; 32]))
.map(|(_, n, h)| (*n, *h))
.collect::<Vec<_>>();
for (height, hash) in known_blocks {
let _was_in1 = self.known_blocks1.remove(&(source_id, height, hash));
let _was_in2 = self.known_blocks2.remove(&(height, hash, source_id));
debug_assert!(_was_in1);
debug_assert!(_was_in2);
}
source.user_data
}
pub fn set_finalized_block_height(&mut self, height: u64) {
assert!(height >= self.finalized_block_height);
debug_assert_eq!(
self.known_blocks2
.range(
(0, [0; 32], SourceId(u64::MIN))
..=(self.finalized_block_height, [0xff; 32], SourceId(u64::MAX)),
)
.count(),
0
);
let entries = self
.known_blocks2
.range((0, [0; 32], SourceId(u64::MIN))..=(height, [0xff; 32], SourceId(u64::MAX)))
.cloned()
.collect::<Vec<_>>();
for (height, hash, source_id) in entries {
self.known_blocks2.remove(&(height, hash, source_id));
let _was_in = self.known_blocks1.remove(&(source_id, height, hash));
debug_assert!(_was_in);
}
self.finalized_block_height = height;
}
pub fn add_known_block(&mut self, source_id: SourceId, height: u64, hash: [u8; 32]) {
if height > self.finalized_block_height {
self.known_blocks1.insert((source_id, height, hash));
self.known_blocks2.insert((height, hash, source_id));
}
}
pub fn remove_known_block(&mut self, height: u64, hash: &[u8; 32]) {
let sources = self
.known_blocks2
.range((height, *hash, SourceId(u64::MIN))..=(height, *hash, SourceId(u64::MAX)))
.map(|(_, _, source)| *source)
.collect::<Vec<_>>();
for source_id in sources {
self.known_blocks2.remove(&(height, *hash, source_id));
let _was_in = self.known_blocks1.remove(&(source_id, height, *hash));
debug_assert!(_was_in);
}
}
pub fn source_remove_known_block(&mut self, source_id: SourceId, height: u64, hash: &[u8; 32]) {
let _was_in1 = self.known_blocks1.remove(&(source_id, height, *hash));
let _was_in2 = self.known_blocks2.remove(&(height, *hash, source_id));
debug_assert_eq!(_was_in1, _was_in2);
}
#[track_caller]
pub fn add_known_block_and_set_best(
&mut self,
source_id: SourceId,
height: u64,
hash: [u8; 32],
) {
self.add_known_block(source_id, height, hash);
let source = self.sources.get_mut(&source_id).unwrap();
source.best_block_number = height;
source.best_block_hash = hash;
}
pub fn best_block(&self, source_id: SourceId) -> (u64, &[u8; 32]) {
let source = self.sources.get(&source_id).unwrap();
(source.best_block_number, &source.best_block_hash)
}
pub fn knows_non_finalized_block<'a>(
&'a self,
height: u64,
hash: &[u8; 32],
) -> impl Iterator<Item = SourceId> + use<'a, TSrc> {
assert!(height > self.finalized_block_height);
self.known_blocks2
.range((height, *hash, SourceId(u64::MIN))..=(height, *hash, SourceId(u64::MAX)))
.map(|(_, _, id)| *id)
}
pub fn source_knows_non_finalized_block(
&self,
source_id: SourceId,
height: u64,
hash: &[u8; 32],
) -> bool {
assert!(height > self.finalized_block_height);
self.known_blocks1.contains(&(source_id, height, *hash))
}
pub fn contains(&self, source_id: SourceId) -> bool {
self.sources.contains_key(&source_id)
}
}
impl<TSrc> ops::Index<SourceId> for AllForksSources<TSrc> {
type Output = TSrc;
#[track_caller]
fn index(&self, id: SourceId) -> &TSrc {
let source = self.sources.get(&id).unwrap();
&source.user_data
}
}
impl<TSrc> ops::IndexMut<SourceId> for AllForksSources<TSrc> {
#[track_caller]
fn index_mut(&mut self, id: SourceId) -> &mut TSrc {
let source = self.sources.get_mut(&id).unwrap();
&mut source.user_data
}
}
impl<TSrc: fmt::Debug> fmt::Debug for AllForksSources<TSrc> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AllForksSources")
.field("sources", &self.sources)
.field("finalized_block_height", &self.finalized_block_height)
.finish()
}
}
#[cfg(test)]
mod tests {
#[test]
fn basic_works() {
let mut sources = super::AllForksSources::new(256, 10);
assert!(sources.is_empty());
assert_eq!(sources.num_blocks(), 0);
let source1 = sources.add_source(12, [1; 32], ());
assert!(!sources.is_empty());
assert_eq!(sources.len(), 1);
assert_eq!(sources.num_blocks(), 1);
assert!(sources.source_knows_non_finalized_block(source1, 12, &[1; 32]));
sources.add_known_block_and_set_best(source1, 13, [2; 32]);
assert_eq!(sources.num_blocks(), 2);
assert!(sources.source_knows_non_finalized_block(source1, 12, &[1; 32]));
assert!(sources.source_knows_non_finalized_block(source1, 13, &[2; 32]));
sources.remove_known_block(13, &[2; 32]);
assert_eq!(sources.num_blocks(), 1);
assert!(sources.source_knows_non_finalized_block(source1, 12, &[1; 32]));
assert!(!sources.source_knows_non_finalized_block(source1, 13, &[2; 32]));
sources.set_finalized_block_height(12);
assert_eq!(sources.num_blocks(), 0);
sources.remove(source1);
assert!(sources.is_empty());
assert_eq!(sources.len(), 0);
}
}