soil-txpool 0.2.0

Soil transaction pool implementation
Documentation
// This file is part of Soil.

// Copyright (C) Soil contributors.
// Copyright (C) Parity Technologies (UK) Ltd.
// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0

use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::{
	collections::HashMap,
	sync::{
		atomic::{AtomicIsize, Ordering as AtomicOrdering},
		Arc,
	},
};

/// Something that can report its size.
pub trait Size {
	fn size(&self) -> usize;
}

/// Map with size tracking.
///
/// Size reported might be slightly off and only approximately true.
#[derive(Debug)]
pub struct TrackedMap<K, V> {
	index: Arc<RwLock<HashMap<K, V>>>,
	bytes: AtomicIsize,
	length: AtomicIsize,
}

impl<K, V> Default for TrackedMap<K, V> {
	fn default() -> Self {
		Self { index: Arc::new(HashMap::default().into()), bytes: 0.into(), length: 0.into() }
	}
}

impl<K, V> Clone for TrackedMap<K, V>
where
	K: Clone,
	V: Clone,
{
	fn clone(&self) -> Self {
		Self {
			index: Arc::from(RwLock::from(self.index.read().clone())),
			bytes: self.bytes.load(AtomicOrdering::Relaxed).into(),
			length: self.length.load(AtomicOrdering::Relaxed).into(),
		}
	}
}

impl<K, V> TrackedMap<K, V> {
	/// Current tracked length of the content.
	pub fn len(&self) -> usize {
		std::cmp::max(self.length.load(AtomicOrdering::Relaxed), 0) as usize
	}

	/// Current sum of content length.
	pub fn bytes(&self) -> usize {
		std::cmp::max(self.bytes.load(AtomicOrdering::Relaxed), 0) as usize
	}

	/// Lock map for read.
	pub fn read(&self) -> TrackedMapReadAccess<'_, K, V> {
		TrackedMapReadAccess { inner_guard: self.index.read() }
	}

	/// Lock map for write.
	pub fn write(&self) -> TrackedMapWriteAccess<'_, K, V> {
		TrackedMapWriteAccess {
			inner_guard: self.index.write(),
			bytes: &self.bytes,
			length: &self.length,
		}
	}
}

impl<K: Clone, V: Clone> TrackedMap<K, V> {
	/// Clone the inner map.
	pub fn clone_map(&self) -> HashMap<K, V> {
		self.index.read().clone()
	}
}

pub struct TrackedMapReadAccess<'a, K, V> {
	inner_guard: RwLockReadGuard<'a, HashMap<K, V>>,
}

impl<'a, K, V> TrackedMapReadAccess<'a, K, V>
where
	K: Eq + std::hash::Hash,
{
	/// Returns true if the map contains given key.
	pub fn contains_key(&self, key: &K) -> bool {
		self.inner_guard.contains_key(key)
	}

	/// Returns the reference to the contained value by key, if exists.
	pub fn get(&self, key: &K) -> Option<&V> {
		self.inner_guard.get(key)
	}

	/// Returns an iterator over all values.
	pub fn values(&self) -> std::collections::hash_map::Values<'_, K, V> {
		self.inner_guard.values()
	}
}

pub struct TrackedMapWriteAccess<'a, K, V> {
	bytes: &'a AtomicIsize,
	length: &'a AtomicIsize,
	inner_guard: RwLockWriteGuard<'a, HashMap<K, V>>,
}

impl<'a, K, V> TrackedMapWriteAccess<'a, K, V>
where
	K: Eq + std::hash::Hash,
	V: Size,
{
	/// Insert value and return previous (if any).
	pub fn insert(&mut self, key: K, val: V) -> Option<V> {
		let new_bytes = val.size();
		self.bytes.fetch_add(new_bytes as isize, AtomicOrdering::Relaxed);
		self.length.fetch_add(1, AtomicOrdering::Relaxed);
		self.inner_guard.insert(key, val).inspect(|old_val| {
			self.bytes.fetch_sub(old_val.size() as isize, AtomicOrdering::Relaxed);
			self.length.fetch_sub(1, AtomicOrdering::Relaxed);
		})
	}

	/// Remove value by key.
	pub fn remove(&mut self, key: &K) -> Option<V> {
		let val = self.inner_guard.remove(key);
		if let Some(size) = val.as_ref().map(Size::size) {
			self.bytes.fetch_sub(size as isize, AtomicOrdering::Relaxed);
			self.length.fetch_sub(1, AtomicOrdering::Relaxed);
		}
		val
	}

	/// Returns mutable reference to the contained value by key, if exists.
	pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
		self.inner_guard.get_mut(key)
	}
}

#[cfg(test)]
mod tests {

	use super::*;

	impl Size for i32 {
		fn size(&self) -> usize {
			*self as usize / 10
		}
	}

	#[test]
	fn basic() {
		let map = TrackedMap::default();
		map.write().insert(5, 10);
		map.write().insert(6, 20);

		assert_eq!(map.bytes(), 3);
		assert_eq!(map.len(), 2);

		map.write().insert(6, 30);

		assert_eq!(map.bytes(), 4);
		assert_eq!(map.len(), 2);

		map.write().remove(&6);
		assert_eq!(map.bytes(), 1);
		assert_eq!(map.len(), 1);
	}
}