use {
ahash::AHashMap,
solana_pubkey::Pubkey,
std::{
collections::hash_map::Entry,
fmt::{Debug, Display},
ops::{BitAnd, BitAndAssign, Sub},
},
};
pub const MAX_THREADS: usize = u64::BITS as usize;
pub type ThreadId = usize;
type LockCount = u32;
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct ThreadSet(u64);
struct AccountWriteLocks {
thread_id: ThreadId,
lock_count: LockCount,
}
struct AccountReadLocks {
thread_set: ThreadSet,
lock_counts: [LockCount; MAX_THREADS],
}
#[derive(Default)]
struct AccountLocks {
pub write_locks: Option<AccountWriteLocks>,
pub read_locks: Option<AccountReadLocks>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum TryLockError {
MultipleConflicts,
ThreadNotAllowed,
}
pub struct ThreadAwareAccountLocks {
num_threads: usize, locks: AHashMap<Pubkey, AccountLocks>,
}
impl ThreadAwareAccountLocks {
pub fn new(num_threads: usize) -> Self {
assert!(num_threads > 0, "num threads must be > 0");
assert!(
num_threads <= MAX_THREADS,
"num threads must be <= {MAX_THREADS}"
);
Self {
num_threads,
locks: AHashMap::new(),
}
}
pub fn try_lock_accounts<'a>(
&mut self,
write_account_locks: impl Iterator<Item = &'a Pubkey> + Clone,
read_account_locks: impl Iterator<Item = &'a Pubkey> + Clone,
allowed_threads: ThreadSet,
thread_selector: impl FnOnce(ThreadSet) -> ThreadId,
) -> Result<ThreadId, TryLockError> {
let schedulable_threads = self
.accounts_schedulable_threads(write_account_locks.clone(), read_account_locks.clone())
.ok_or(TryLockError::MultipleConflicts)?;
let schedulable_threads = schedulable_threads & allowed_threads;
if schedulable_threads.is_empty() {
return Err(TryLockError::ThreadNotAllowed);
}
let thread_id = thread_selector(schedulable_threads);
self.lock_accounts(write_account_locks, read_account_locks, thread_id);
Ok(thread_id)
}
pub fn unlock_accounts<'a>(
&mut self,
write_account_locks: impl Iterator<Item = &'a Pubkey>,
read_account_locks: impl Iterator<Item = &'a Pubkey>,
thread_id: ThreadId,
) {
for account in write_account_locks {
self.write_unlock_account(account, thread_id);
}
for account in read_account_locks {
self.read_unlock_account(account, thread_id);
}
}
fn accounts_schedulable_threads<'a>(
&self,
write_account_locks: impl Iterator<Item = &'a Pubkey>,
read_account_locks: impl Iterator<Item = &'a Pubkey>,
) -> Option<ThreadSet> {
let mut schedulable_threads = ThreadSet::any(self.num_threads);
for account in write_account_locks {
schedulable_threads &= self.write_schedulable_threads(account);
if schedulable_threads.is_empty() {
return None;
}
}
for account in read_account_locks {
schedulable_threads &= self.read_schedulable_threads(account);
if schedulable_threads.is_empty() {
return None;
}
}
Some(schedulable_threads)
}
fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet {
self.schedulable_threads::<false>(account)
}
fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet {
self.schedulable_threads::<true>(account)
}
fn schedulable_threads<const WRITE: bool>(&self, account: &Pubkey) -> ThreadSet {
match self.locks.get(account) {
None => ThreadSet::any(self.num_threads),
Some(AccountLocks {
write_locks: None,
read_locks: Some(read_locks),
}) => {
if WRITE {
read_locks
.thread_set
.only_one_contained()
.map(ThreadSet::only)
.unwrap_or_else(ThreadSet::none)
} else {
ThreadSet::any(self.num_threads)
}
}
Some(AccountLocks {
write_locks: Some(write_locks),
read_locks: None,
}) => ThreadSet::only(write_locks.thread_id),
Some(AccountLocks {
write_locks: Some(write_locks),
read_locks: Some(read_locks),
}) => {
assert_eq!(
read_locks.thread_set.only_one_contained(),
Some(write_locks.thread_id)
);
read_locks.thread_set
}
Some(AccountLocks {
write_locks: None,
read_locks: None,
}) => unreachable!(),
}
}
fn lock_accounts<'a>(
&mut self,
write_account_locks: impl Iterator<Item = &'a Pubkey>,
read_account_locks: impl Iterator<Item = &'a Pubkey>,
thread_id: ThreadId,
) {
assert!(
thread_id < self.num_threads,
"thread_id must be < num_threads"
);
for account in write_account_locks {
self.write_lock_account(account, thread_id);
}
for account in read_account_locks {
self.read_lock_account(account, thread_id);
}
}
fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
let entry = self.locks.entry(*account).or_default();
let AccountLocks {
write_locks,
read_locks,
} = entry;
if let Some(read_locks) = read_locks {
assert_eq!(
read_locks.thread_set.only_one_contained(),
Some(thread_id),
"outstanding read lock must be on same thread"
);
}
if let Some(write_locks) = write_locks {
assert_eq!(
write_locks.thread_id, thread_id,
"outstanding write lock must be on same thread"
);
write_locks.lock_count = write_locks.lock_count.wrapping_add(1);
} else {
*write_locks = Some(AccountWriteLocks {
thread_id,
lock_count: 1,
});
}
}
fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
let Entry::Occupied(mut entry) = self.locks.entry(*account) else {
panic!("write lock must exist for account: {account}");
};
let AccountLocks {
write_locks: maybe_write_locks,
read_locks,
} = entry.get_mut();
let Some(write_locks) = maybe_write_locks else {
panic!("write lock must exist for account: {account}");
};
assert_eq!(
write_locks.thread_id, thread_id,
"outstanding write lock must be on same thread"
);
write_locks.lock_count = write_locks.lock_count.wrapping_sub(1);
if write_locks.lock_count == 0 {
*maybe_write_locks = None;
if read_locks.is_none() {
entry.remove();
}
}
}
fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
let AccountLocks {
write_locks,
read_locks,
} = self.locks.entry(*account).or_default();
if let Some(write_locks) = write_locks {
assert_eq!(
write_locks.thread_id, thread_id,
"outstanding write lock must be on same thread"
);
}
match read_locks {
Some(read_locks) => {
read_locks.thread_set.insert(thread_id);
read_locks.lock_counts[thread_id] =
read_locks.lock_counts[thread_id].wrapping_add(1);
}
None => {
let mut lock_counts = [0; MAX_THREADS];
lock_counts[thread_id] = 1;
*read_locks = Some(AccountReadLocks {
thread_set: ThreadSet::only(thread_id),
lock_counts,
});
}
}
}
fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
let Entry::Occupied(mut entry) = self.locks.entry(*account) else {
panic!("read lock must exist for account: {account}");
};
let AccountLocks {
write_locks,
read_locks: maybe_read_locks,
} = entry.get_mut();
let Some(read_locks) = maybe_read_locks else {
panic!("read lock must exist for account: {account}");
};
assert!(
read_locks.thread_set.contains(thread_id),
"outstanding read lock must be on same thread"
);
read_locks.lock_counts[thread_id] = read_locks.lock_counts[thread_id].wrapping_sub(1);
if read_locks.lock_counts[thread_id] == 0 {
read_locks.thread_set.remove(thread_id);
if read_locks.thread_set.is_empty() {
*maybe_read_locks = None;
if write_locks.is_none() {
entry.remove();
}
}
}
}
}
impl BitAnd for ThreadSet {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
Self(self.0 & rhs.0)
}
}
impl BitAndAssign for ThreadSet {
fn bitand_assign(&mut self, rhs: Self) {
self.0 &= rhs.0;
}
}
impl Sub for ThreadSet {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 & !rhs.0)
}
}
impl Display for ThreadSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS)
}
}
impl Debug for ThreadSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(self, f)
}
}
impl ThreadSet {
#[inline(always)]
pub const fn none() -> Self {
Self(0b0)
}
#[inline(always)]
pub const fn any(num_threads: usize) -> Self {
if num_threads == MAX_THREADS {
Self(u64::MAX)
} else {
Self(Self::as_flag(num_threads).wrapping_sub(1))
}
}
#[inline(always)]
pub const fn only(thread_id: ThreadId) -> Self {
Self(Self::as_flag(thread_id))
}
#[inline(always)]
pub fn num_threads(&self) -> u32 {
self.0.count_ones()
}
#[inline(always)]
pub fn only_one_contained(&self) -> Option<ThreadId> {
(self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId)
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self == &Self::none()
}
#[inline(always)]
pub fn contains(&self, thread_id: ThreadId) -> bool {
self.0 & Self::as_flag(thread_id) != 0
}
#[inline(always)]
pub fn insert(&mut self, thread_id: ThreadId) {
self.0 |= Self::as_flag(thread_id);
}
#[inline(always)]
pub fn remove(&mut self, thread_id: ThreadId) {
self.0 &= !Self::as_flag(thread_id);
}
#[inline(always)]
pub fn contained_threads_iter(self) -> impl Iterator<Item = ThreadId> {
ThreadSetIterator(self.0)
}
#[inline(always)]
const fn as_flag(thread_id: ThreadId) -> u64 {
0b1 << thread_id
}
}
struct ThreadSetIterator(u64);
impl Iterator for ThreadSetIterator {
type Item = ThreadId;
fn next(&mut self) -> Option<Self::Item> {
if self.0 == 0 {
None
} else {
let thread_id = self.0.trailing_zeros() as ThreadId;
self.0 &= self.0.wrapping_sub(1);
Some(thread_id)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_NUM_THREADS: usize = 4;
const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS);
fn test_thread_selector(thread_set: ThreadSet) -> ThreadId {
thread_set.contained_threads_iter().next().unwrap()
}
#[test]
#[should_panic(expected = "num threads must be > 0")]
fn test_too_few_num_threads() {
ThreadAwareAccountLocks::new(0);
}
#[test]
#[should_panic(expected = "num threads must be <=")]
fn test_too_many_num_threads() {
ThreadAwareAccountLocks::new(MAX_THREADS + 1);
}
#[test]
fn test_try_lock_accounts_none() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 2);
locks.read_lock_account(&pk1, 3);
assert_eq!(
locks.try_lock_accounts(
[&pk1].into_iter(),
[&pk2].into_iter(),
TEST_ANY_THREADS,
test_thread_selector
),
Err(TryLockError::MultipleConflicts)
);
}
#[test]
fn test_try_lock_accounts_one() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk2, 3);
assert_eq!(
locks.try_lock_accounts(
[&pk1].into_iter(),
[&pk2].into_iter(),
TEST_ANY_THREADS,
test_thread_selector
),
Ok(3)
);
}
#[test]
fn test_try_lock_accounts_one_not_allowed() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk2, 3);
assert_eq!(
locks.try_lock_accounts(
[&pk1].into_iter(),
[&pk2].into_iter(),
ThreadSet::none(),
test_thread_selector
),
Err(TryLockError::ThreadNotAllowed)
);
}
#[test]
fn test_try_lock_accounts_multiple() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk2, 0);
locks.read_lock_account(&pk2, 0);
assert_eq!(
locks.try_lock_accounts(
[&pk1].into_iter(),
[&pk2].into_iter(),
TEST_ANY_THREADS - ThreadSet::only(0), test_thread_selector
),
Ok(1)
);
}
#[test]
fn test_try_lock_accounts_any() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
assert_eq!(
locks.try_lock_accounts(
[&pk1].into_iter(),
[&pk2].into_iter(),
TEST_ANY_THREADS,
test_thread_selector
),
Ok(0)
);
}
#[test]
fn test_accounts_schedulable_threads_no_outstanding_locks() {
let pk1 = Pubkey::new_unique();
let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
assert_eq!(
locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()),
Some(TEST_ANY_THREADS)
);
assert_eq!(
locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()),
Some(TEST_ANY_THREADS)
);
}
#[test]
fn test_accounts_schedulable_threads_outstanding_write_only() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk1, 2);
assert_eq!(
locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
Some(ThreadSet::only(2))
);
assert_eq!(
locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
Some(ThreadSet::only(2))
);
}
#[test]
fn test_accounts_schedulable_threads_outstanding_read_only() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 2);
assert_eq!(
locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
Some(ThreadSet::only(2))
);
assert_eq!(
locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
Some(TEST_ANY_THREADS)
);
locks.read_lock_account(&pk1, 0);
assert_eq!(
locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
None
);
assert_eq!(
locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
Some(TEST_ANY_THREADS)
);
}
#[test]
fn test_accounts_schedulable_threads_outstanding_mixed() {
let pk1 = Pubkey::new_unique();
let pk2 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 2);
locks.write_lock_account(&pk1, 2);
assert_eq!(
locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
Some(ThreadSet::only(2))
);
assert_eq!(
locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
Some(ThreadSet::only(2))
);
}
#[test]
#[should_panic(expected = "outstanding write lock must be on same thread")]
fn test_write_lock_account_write_conflict_panic() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk1, 0);
locks.write_lock_account(&pk1, 1);
}
#[test]
#[should_panic(expected = "outstanding read lock must be on same thread")]
fn test_write_lock_account_read_conflict_panic() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 0);
locks.write_lock_account(&pk1, 1);
}
#[test]
#[should_panic(expected = "write lock must exist")]
fn test_write_unlock_account_not_locked() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_unlock_account(&pk1, 0);
}
#[test]
#[should_panic(expected = "outstanding write lock must be on same thread")]
fn test_write_unlock_account_thread_mismatch() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk1, 1);
locks.write_unlock_account(&pk1, 0);
}
#[test]
#[should_panic(expected = "outstanding write lock must be on same thread")]
fn test_read_lock_account_write_conflict_panic() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk1, 0);
locks.read_lock_account(&pk1, 1);
}
#[test]
#[should_panic(expected = "read lock must exist")]
fn test_read_unlock_account_not_locked() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_unlock_account(&pk1, 1);
}
#[test]
#[should_panic(expected = "outstanding read lock must be on same thread")]
fn test_read_unlock_account_thread_mismatch() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 0);
locks.read_unlock_account(&pk1, 1);
}
#[test]
fn test_write_locking() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.write_lock_account(&pk1, 1);
locks.write_lock_account(&pk1, 1);
locks.write_unlock_account(&pk1, 1);
locks.write_unlock_account(&pk1, 1);
assert!(locks.locks.is_empty());
}
#[test]
fn test_read_locking() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.read_lock_account(&pk1, 1);
locks.read_lock_account(&pk1, 1);
locks.read_unlock_account(&pk1, 1);
locks.read_unlock_account(&pk1, 1);
assert!(locks.locks.is_empty());
}
#[test]
#[should_panic(expected = "thread_id must be < num_threads")]
fn test_lock_accounts_invalid_thread() {
let pk1 = Pubkey::new_unique();
let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS);
}
#[test]
fn test_thread_set() {
let mut thread_set = ThreadSet::none();
assert!(thread_set.is_empty());
assert_eq!(thread_set.num_threads(), 0);
assert_eq!(thread_set.only_one_contained(), None);
for idx in 0..MAX_THREADS {
assert!(!thread_set.contains(idx));
}
thread_set.insert(4);
assert!(!thread_set.is_empty());
assert_eq!(thread_set.num_threads(), 1);
assert_eq!(thread_set.only_one_contained(), Some(4));
for idx in 0..MAX_THREADS {
assert_eq!(thread_set.contains(idx), idx == 4);
}
thread_set.insert(2);
assert!(!thread_set.is_empty());
assert_eq!(thread_set.num_threads(), 2);
assert_eq!(thread_set.only_one_contained(), None);
for idx in 0..MAX_THREADS {
assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4);
}
thread_set.remove(4);
assert!(!thread_set.is_empty());
assert_eq!(thread_set.num_threads(), 1);
assert_eq!(thread_set.only_one_contained(), Some(2));
for idx in 0..MAX_THREADS {
assert_eq!(thread_set.contains(idx), idx == 2);
}
}
#[test]
fn test_thread_set_any_zero() {
let any_threads = ThreadSet::any(0);
assert_eq!(any_threads.num_threads(), 0);
}
#[test]
fn test_thread_set_any_max() {
let any_threads = ThreadSet::any(MAX_THREADS);
assert_eq!(any_threads.num_threads(), MAX_THREADS as u32);
}
#[test]
fn test_thread_set_iter() {
let mut thread_set = ThreadSet::none();
assert!(thread_set.contained_threads_iter().next().is_none());
thread_set.insert(4);
assert_eq!(
thread_set.contained_threads_iter().collect::<Vec<_>>(),
vec![4]
);
thread_set.insert(5);
assert_eq!(
thread_set.contained_threads_iter().collect::<Vec<_>>(),
vec![4, 5]
);
thread_set.insert(63);
assert_eq!(
thread_set.contained_threads_iter().collect::<Vec<_>>(),
vec![4, 5, 63]
);
thread_set.remove(5);
assert_eq!(
thread_set.contained_threads_iter().collect::<Vec<_>>(),
vec![4, 63]
);
let thread_set = ThreadSet::any(64);
assert_eq!(
thread_set.contained_threads_iter().collect::<Vec<_>>(),
(0..64).collect::<Vec<_>>()
);
}
}