agave_scheduling_utils/
thread_aware_account_locks.rs

1use {
2    ahash::AHashMap,
3    solana_pubkey::Pubkey,
4    std::{
5        collections::hash_map::Entry,
6        fmt::{Debug, Display},
7        ops::{BitAnd, BitAndAssign, Sub},
8    },
9};
10
11pub const MAX_THREADS: usize = u64::BITS as usize;
12
13/// Identifier for a thread
14pub type ThreadId = usize; // 0..MAX_THREADS-1
15
16type LockCount = u32;
17
18/// A bit-set of threads an account is scheduled or can be scheduled for.
19#[derive(Copy, Clone, PartialEq, Eq)]
20pub struct ThreadSet(u64);
21
22struct AccountWriteLocks {
23    thread_id: ThreadId,
24    lock_count: LockCount,
25}
26
27struct AccountReadLocks {
28    thread_set: ThreadSet,
29    lock_counts: [LockCount; MAX_THREADS],
30}
31
32/// Account locks.
33/// Write Locks - only one thread can hold a write lock at a time.
34///     Contains how many write locks are held by the thread.
35/// Read Locks - multiple threads can hold a read lock at a time.
36///     Contains thread-set for easily checking which threads are scheduled.
37#[derive(Default)]
38struct AccountLocks {
39    pub write_locks: Option<AccountWriteLocks>,
40    pub read_locks: Option<AccountReadLocks>,
41}
42
43/// `try_lock_accounts` may fail for different reasons:
44#[derive(Debug, PartialEq, Eq)]
45pub enum TryLockError {
46    /// Outstanding conflicts with multiple threads.
47    MultipleConflicts,
48    /// Outstanding conflict (if any) not in `allowed_threads`.
49    ThreadNotAllowed,
50}
51
52/// Thread-aware account locks which allows for scheduling on threads
53/// that already hold locks on the account. This is useful for allowing
54/// queued transactions to be scheduled on a thread while the transaction
55/// is still being executed on the thread.
56pub struct ThreadAwareAccountLocks {
57    /// Number of threads.
58    num_threads: usize, // 0..MAX_THREADS
59    /// Locks for each account. An account should only have an entry if there
60    /// is at least one lock.
61    locks: AHashMap<Pubkey, AccountLocks>,
62}
63
64impl ThreadAwareAccountLocks {
65    /// Creates a new `ThreadAwareAccountLocks` with the given number of threads.
66    pub fn new(num_threads: usize) -> Self {
67        assert!(num_threads > 0, "num threads must be > 0");
68        assert!(
69            num_threads <= MAX_THREADS,
70            "num threads must be <= {MAX_THREADS}"
71        );
72
73        Self {
74            num_threads,
75            locks: AHashMap::new(),
76        }
77    }
78
79    /// Returns the `ThreadId` if the accounts are able to be locked
80    /// for the given thread, otherwise `None` is returned.
81    /// `allowed_threads` is a set of threads that the caller restricts locking to.
82    /// If accounts are schedulable, then they are locked for the thread
83    /// selected by the `thread_selector` function.
84    /// `thread_selector` is only called if all accounts are schdulable, meaning
85    /// that the `thread_set` passed to `thread_selector` is non-empty.
86    pub fn try_lock_accounts<'a>(
87        &mut self,
88        write_account_locks: impl Iterator<Item = &'a Pubkey> + Clone,
89        read_account_locks: impl Iterator<Item = &'a Pubkey> + Clone,
90        allowed_threads: ThreadSet,
91        thread_selector: impl FnOnce(ThreadSet) -> ThreadId,
92    ) -> Result<ThreadId, TryLockError> {
93        let schedulable_threads = self
94            .accounts_schedulable_threads(write_account_locks.clone(), read_account_locks.clone())
95            .ok_or(TryLockError::MultipleConflicts)?;
96        let schedulable_threads = schedulable_threads & allowed_threads;
97        if schedulable_threads.is_empty() {
98            return Err(TryLockError::ThreadNotAllowed);
99        }
100
101        let thread_id = thread_selector(schedulable_threads);
102        self.lock_accounts(write_account_locks, read_account_locks, thread_id);
103        Ok(thread_id)
104    }
105
106    /// Unlocks the accounts for the given thread.
107    pub fn unlock_accounts<'a>(
108        &mut self,
109        write_account_locks: impl Iterator<Item = &'a Pubkey>,
110        read_account_locks: impl Iterator<Item = &'a Pubkey>,
111        thread_id: ThreadId,
112    ) {
113        for account in write_account_locks {
114            self.write_unlock_account(account, thread_id);
115        }
116
117        for account in read_account_locks {
118            self.read_unlock_account(account, thread_id);
119        }
120    }
121
122    /// Returns `ThreadSet` that the given accounts can be scheduled on.
123    fn accounts_schedulable_threads<'a>(
124        &self,
125        write_account_locks: impl Iterator<Item = &'a Pubkey>,
126        read_account_locks: impl Iterator<Item = &'a Pubkey>,
127    ) -> Option<ThreadSet> {
128        let mut schedulable_threads = ThreadSet::any(self.num_threads);
129
130        for account in write_account_locks {
131            schedulable_threads &= self.write_schedulable_threads(account);
132            if schedulable_threads.is_empty() {
133                return None;
134            }
135        }
136
137        for account in read_account_locks {
138            schedulable_threads &= self.read_schedulable_threads(account);
139            if schedulable_threads.is_empty() {
140                return None;
141            }
142        }
143
144        Some(schedulable_threads)
145    }
146
147    /// Returns `ThreadSet` of schedulable threads for the given readable account.
148    fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet {
149        self.schedulable_threads::<false>(account)
150    }
151
152    /// Returns `ThreadSet` of schedulable threads for the given writable account.
153    fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet {
154        self.schedulable_threads::<true>(account)
155    }
156
157    /// Returns `ThreadSet` of schedulable threads.
158    /// If there are no locks, then all threads are schedulable.
159    /// If only write-locked, then only the thread holding the write lock is schedulable.
160    /// If a mix of locks, then only the write thread is schedulable.
161    /// If only read-locked, the only write-schedulable thread is if a single thread
162    ///   holds all read locks. Otherwise, no threads are write-schedulable.
163    /// If only read-locked, all threads are read-schedulable.
164    fn schedulable_threads<const WRITE: bool>(&self, account: &Pubkey) -> ThreadSet {
165        match self.locks.get(account) {
166            None => ThreadSet::any(self.num_threads),
167            Some(AccountLocks {
168                write_locks: None,
169                read_locks: Some(read_locks),
170            }) => {
171                if WRITE {
172                    read_locks
173                        .thread_set
174                        .only_one_contained()
175                        .map(ThreadSet::only)
176                        .unwrap_or_else(ThreadSet::none)
177                } else {
178                    ThreadSet::any(self.num_threads)
179                }
180            }
181            Some(AccountLocks {
182                write_locks: Some(write_locks),
183                read_locks: None,
184            }) => ThreadSet::only(write_locks.thread_id),
185            Some(AccountLocks {
186                write_locks: Some(write_locks),
187                read_locks: Some(read_locks),
188            }) => {
189                assert_eq!(
190                    read_locks.thread_set.only_one_contained(),
191                    Some(write_locks.thread_id)
192                );
193                read_locks.thread_set
194            }
195            Some(AccountLocks {
196                write_locks: None,
197                read_locks: None,
198            }) => unreachable!(),
199        }
200    }
201
202    /// Add locks for all writable and readable accounts on `thread_id`.
203    fn lock_accounts<'a>(
204        &mut self,
205        write_account_locks: impl Iterator<Item = &'a Pubkey>,
206        read_account_locks: impl Iterator<Item = &'a Pubkey>,
207        thread_id: ThreadId,
208    ) {
209        assert!(
210            thread_id < self.num_threads,
211            "thread_id must be < num_threads"
212        );
213        for account in write_account_locks {
214            self.write_lock_account(account, thread_id);
215        }
216
217        for account in read_account_locks {
218            self.read_lock_account(account, thread_id);
219        }
220    }
221
222    /// Locks the given `account` for writing on `thread_id`.
223    /// Panics if the account is already locked for writing on another thread.
224    fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
225        let entry = self.locks.entry(*account).or_default();
226
227        let AccountLocks {
228            write_locks,
229            read_locks,
230        } = entry;
231
232        if let Some(read_locks) = read_locks {
233            assert_eq!(
234                read_locks.thread_set.only_one_contained(),
235                Some(thread_id),
236                "outstanding read lock must be on same thread"
237            );
238        }
239
240        if let Some(write_locks) = write_locks {
241            assert_eq!(
242                write_locks.thread_id, thread_id,
243                "outstanding write lock must be on same thread"
244            );
245            write_locks.lock_count = write_locks.lock_count.wrapping_add(1);
246        } else {
247            *write_locks = Some(AccountWriteLocks {
248                thread_id,
249                lock_count: 1,
250            });
251        }
252    }
253
254    /// Unlocks the given `account` for writing on `thread_id`.
255    /// Panics if the account is not locked for writing on `thread_id`.
256    fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
257        let Entry::Occupied(mut entry) = self.locks.entry(*account) else {
258            panic!("write lock must exist for account: {account}");
259        };
260
261        let AccountLocks {
262            write_locks: maybe_write_locks,
263            read_locks,
264        } = entry.get_mut();
265
266        let Some(write_locks) = maybe_write_locks else {
267            panic!("write lock must exist for account: {account}");
268        };
269
270        assert_eq!(
271            write_locks.thread_id, thread_id,
272            "outstanding write lock must be on same thread"
273        );
274
275        write_locks.lock_count = write_locks.lock_count.wrapping_sub(1);
276        if write_locks.lock_count == 0 {
277            *maybe_write_locks = None;
278            if read_locks.is_none() {
279                entry.remove();
280            }
281        }
282    }
283
284    /// Locks the given `account` for reading on `thread_id`.
285    /// Panics if the account is already locked for writing on another thread.
286    fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
287        let AccountLocks {
288            write_locks,
289            read_locks,
290        } = self.locks.entry(*account).or_default();
291
292        if let Some(write_locks) = write_locks {
293            assert_eq!(
294                write_locks.thread_id, thread_id,
295                "outstanding write lock must be on same thread"
296            );
297        }
298
299        match read_locks {
300            Some(read_locks) => {
301                read_locks.thread_set.insert(thread_id);
302                read_locks.lock_counts[thread_id] =
303                    read_locks.lock_counts[thread_id].wrapping_add(1);
304            }
305            None => {
306                let mut lock_counts = [0; MAX_THREADS];
307                lock_counts[thread_id] = 1;
308                *read_locks = Some(AccountReadLocks {
309                    thread_set: ThreadSet::only(thread_id),
310                    lock_counts,
311                });
312            }
313        }
314    }
315
316    /// Unlocks the given `account` for reading on `thread_id`.
317    /// Panics if the account is not locked for reading on `thread_id`.
318    fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) {
319        let Entry::Occupied(mut entry) = self.locks.entry(*account) else {
320            panic!("read lock must exist for account: {account}");
321        };
322
323        let AccountLocks {
324            write_locks,
325            read_locks: maybe_read_locks,
326        } = entry.get_mut();
327
328        let Some(read_locks) = maybe_read_locks else {
329            panic!("read lock must exist for account: {account}");
330        };
331
332        assert!(
333            read_locks.thread_set.contains(thread_id),
334            "outstanding read lock must be on same thread"
335        );
336
337        read_locks.lock_counts[thread_id] = read_locks.lock_counts[thread_id].wrapping_sub(1);
338        if read_locks.lock_counts[thread_id] == 0 {
339            read_locks.thread_set.remove(thread_id);
340            if read_locks.thread_set.is_empty() {
341                *maybe_read_locks = None;
342                if write_locks.is_none() {
343                    entry.remove();
344                }
345            }
346        }
347    }
348}
349
350impl BitAnd for ThreadSet {
351    type Output = Self;
352
353    fn bitand(self, rhs: Self) -> Self::Output {
354        Self(self.0 & rhs.0)
355    }
356}
357
358impl BitAndAssign for ThreadSet {
359    fn bitand_assign(&mut self, rhs: Self) {
360        self.0 &= rhs.0;
361    }
362}
363
364impl Sub for ThreadSet {
365    type Output = Self;
366
367    fn sub(self, rhs: Self) -> Self::Output {
368        Self(self.0 & !rhs.0)
369    }
370}
371
372impl Display for ThreadSet {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS)
375    }
376}
377
378impl Debug for ThreadSet {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        Display::fmt(self, f)
381    }
382}
383
384impl ThreadSet {
385    #[inline(always)]
386    pub const fn none() -> Self {
387        Self(0b0)
388    }
389
390    #[inline(always)]
391    pub const fn any(num_threads: usize) -> Self {
392        if num_threads == MAX_THREADS {
393            Self(u64::MAX)
394        } else {
395            Self(Self::as_flag(num_threads).wrapping_sub(1))
396        }
397    }
398
399    #[inline(always)]
400    pub const fn only(thread_id: ThreadId) -> Self {
401        Self(Self::as_flag(thread_id))
402    }
403
404    #[inline(always)]
405    pub fn num_threads(&self) -> u32 {
406        self.0.count_ones()
407    }
408
409    #[inline(always)]
410    pub fn only_one_contained(&self) -> Option<ThreadId> {
411        (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId)
412    }
413
414    #[inline(always)]
415    pub fn is_empty(&self) -> bool {
416        self == &Self::none()
417    }
418
419    #[inline(always)]
420    pub fn contains(&self, thread_id: ThreadId) -> bool {
421        self.0 & Self::as_flag(thread_id) != 0
422    }
423
424    #[inline(always)]
425    pub fn insert(&mut self, thread_id: ThreadId) {
426        self.0 |= Self::as_flag(thread_id);
427    }
428
429    #[inline(always)]
430    pub fn remove(&mut self, thread_id: ThreadId) {
431        self.0 &= !Self::as_flag(thread_id);
432    }
433
434    #[inline(always)]
435    pub fn contained_threads_iter(self) -> impl Iterator<Item = ThreadId> {
436        ThreadSetIterator(self.0)
437    }
438
439    #[inline(always)]
440    const fn as_flag(thread_id: ThreadId) -> u64 {
441        0b1 << thread_id
442    }
443}
444
445struct ThreadSetIterator(u64);
446
447impl Iterator for ThreadSetIterator {
448    type Item = ThreadId;
449
450    fn next(&mut self) -> Option<Self::Item> {
451        if self.0 == 0 {
452            None
453        } else {
454            // Find the first set bit by counting trailing zeros.
455            // This is guaranteed to be < 64 because self.0 != 0.
456            let thread_id = self.0.trailing_zeros() as ThreadId;
457            // Clear the lowest set bit. The subtraction is safe because
458            // we know that self.0 != 0.
459            // Example (with 4 bits):
460            //  self.0 = 0b1010           // initial value
461            //  self.0 - 1 = 0b1001       // all bits at or after the lowest set bit are flipped
462            //  0b1010 & 0b1001 = 0b1000  // the lowest bit has been cleared
463            self.0 &= self.0.wrapping_sub(1);
464            Some(thread_id)
465        }
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    const TEST_NUM_THREADS: usize = 4;
474    const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS);
475
476    // Simple thread selector to select the first schedulable thread
477    fn test_thread_selector(thread_set: ThreadSet) -> ThreadId {
478        thread_set.contained_threads_iter().next().unwrap()
479    }
480
481    #[test]
482    #[should_panic(expected = "num threads must be > 0")]
483    fn test_too_few_num_threads() {
484        ThreadAwareAccountLocks::new(0);
485    }
486
487    #[test]
488    #[should_panic(expected = "num threads must be <=")]
489    fn test_too_many_num_threads() {
490        ThreadAwareAccountLocks::new(MAX_THREADS + 1);
491    }
492
493    #[test]
494    fn test_try_lock_accounts_none() {
495        let pk1 = Pubkey::new_unique();
496        let pk2 = Pubkey::new_unique();
497        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
498        locks.read_lock_account(&pk1, 2);
499        locks.read_lock_account(&pk1, 3);
500        assert_eq!(
501            locks.try_lock_accounts(
502                [&pk1].into_iter(),
503                [&pk2].into_iter(),
504                TEST_ANY_THREADS,
505                test_thread_selector
506            ),
507            Err(TryLockError::MultipleConflicts)
508        );
509    }
510
511    #[test]
512    fn test_try_lock_accounts_one() {
513        let pk1 = Pubkey::new_unique();
514        let pk2 = Pubkey::new_unique();
515        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
516        locks.write_lock_account(&pk2, 3);
517
518        assert_eq!(
519            locks.try_lock_accounts(
520                [&pk1].into_iter(),
521                [&pk2].into_iter(),
522                TEST_ANY_THREADS,
523                test_thread_selector
524            ),
525            Ok(3)
526        );
527    }
528
529    #[test]
530    fn test_try_lock_accounts_one_not_allowed() {
531        let pk1 = Pubkey::new_unique();
532        let pk2 = Pubkey::new_unique();
533        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
534        locks.write_lock_account(&pk2, 3);
535
536        assert_eq!(
537            locks.try_lock_accounts(
538                [&pk1].into_iter(),
539                [&pk2].into_iter(),
540                ThreadSet::none(),
541                test_thread_selector
542            ),
543            Err(TryLockError::ThreadNotAllowed)
544        );
545    }
546
547    #[test]
548    fn test_try_lock_accounts_multiple() {
549        let pk1 = Pubkey::new_unique();
550        let pk2 = Pubkey::new_unique();
551        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
552        locks.read_lock_account(&pk2, 0);
553        locks.read_lock_account(&pk2, 0);
554
555        assert_eq!(
556            locks.try_lock_accounts(
557                [&pk1].into_iter(),
558                [&pk2].into_iter(),
559                TEST_ANY_THREADS - ThreadSet::only(0), // exclude 0
560                test_thread_selector
561            ),
562            Ok(1)
563        );
564    }
565
566    #[test]
567    fn test_try_lock_accounts_any() {
568        let pk1 = Pubkey::new_unique();
569        let pk2 = Pubkey::new_unique();
570        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
571        assert_eq!(
572            locks.try_lock_accounts(
573                [&pk1].into_iter(),
574                [&pk2].into_iter(),
575                TEST_ANY_THREADS,
576                test_thread_selector
577            ),
578            Ok(0)
579        );
580    }
581
582    #[test]
583    fn test_accounts_schedulable_threads_no_outstanding_locks() {
584        let pk1 = Pubkey::new_unique();
585        let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
586
587        assert_eq!(
588            locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()),
589            Some(TEST_ANY_THREADS)
590        );
591        assert_eq!(
592            locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()),
593            Some(TEST_ANY_THREADS)
594        );
595    }
596
597    #[test]
598    fn test_accounts_schedulable_threads_outstanding_write_only() {
599        let pk1 = Pubkey::new_unique();
600        let pk2 = Pubkey::new_unique();
601        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
602
603        locks.write_lock_account(&pk1, 2);
604        assert_eq!(
605            locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
606            Some(ThreadSet::only(2))
607        );
608        assert_eq!(
609            locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
610            Some(ThreadSet::only(2))
611        );
612    }
613
614    #[test]
615    fn test_accounts_schedulable_threads_outstanding_read_only() {
616        let pk1 = Pubkey::new_unique();
617        let pk2 = Pubkey::new_unique();
618        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
619
620        locks.read_lock_account(&pk1, 2);
621        assert_eq!(
622            locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
623            Some(ThreadSet::only(2))
624        );
625        assert_eq!(
626            locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
627            Some(TEST_ANY_THREADS)
628        );
629
630        locks.read_lock_account(&pk1, 0);
631        assert_eq!(
632            locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
633            None
634        );
635        assert_eq!(
636            locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
637            Some(TEST_ANY_THREADS)
638        );
639    }
640
641    #[test]
642    fn test_accounts_schedulable_threads_outstanding_mixed() {
643        let pk1 = Pubkey::new_unique();
644        let pk2 = Pubkey::new_unique();
645        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
646
647        locks.read_lock_account(&pk1, 2);
648        locks.write_lock_account(&pk1, 2);
649        assert_eq!(
650            locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()),
651            Some(ThreadSet::only(2))
652        );
653        assert_eq!(
654            locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()),
655            Some(ThreadSet::only(2))
656        );
657    }
658
659    #[test]
660    #[should_panic(expected = "outstanding write lock must be on same thread")]
661    fn test_write_lock_account_write_conflict_panic() {
662        let pk1 = Pubkey::new_unique();
663        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
664        locks.write_lock_account(&pk1, 0);
665        locks.write_lock_account(&pk1, 1);
666    }
667
668    #[test]
669    #[should_panic(expected = "outstanding read lock must be on same thread")]
670    fn test_write_lock_account_read_conflict_panic() {
671        let pk1 = Pubkey::new_unique();
672        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
673        locks.read_lock_account(&pk1, 0);
674        locks.write_lock_account(&pk1, 1);
675    }
676
677    #[test]
678    #[should_panic(expected = "write lock must exist")]
679    fn test_write_unlock_account_not_locked() {
680        let pk1 = Pubkey::new_unique();
681        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
682        locks.write_unlock_account(&pk1, 0);
683    }
684
685    #[test]
686    #[should_panic(expected = "outstanding write lock must be on same thread")]
687    fn test_write_unlock_account_thread_mismatch() {
688        let pk1 = Pubkey::new_unique();
689        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
690        locks.write_lock_account(&pk1, 1);
691        locks.write_unlock_account(&pk1, 0);
692    }
693
694    #[test]
695    #[should_panic(expected = "outstanding write lock must be on same thread")]
696    fn test_read_lock_account_write_conflict_panic() {
697        let pk1 = Pubkey::new_unique();
698        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
699        locks.write_lock_account(&pk1, 0);
700        locks.read_lock_account(&pk1, 1);
701    }
702
703    #[test]
704    #[should_panic(expected = "read lock must exist")]
705    fn test_read_unlock_account_not_locked() {
706        let pk1 = Pubkey::new_unique();
707        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
708        locks.read_unlock_account(&pk1, 1);
709    }
710
711    #[test]
712    #[should_panic(expected = "outstanding read lock must be on same thread")]
713    fn test_read_unlock_account_thread_mismatch() {
714        let pk1 = Pubkey::new_unique();
715        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
716        locks.read_lock_account(&pk1, 0);
717        locks.read_unlock_account(&pk1, 1);
718    }
719
720    #[test]
721    fn test_write_locking() {
722        let pk1 = Pubkey::new_unique();
723        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
724        locks.write_lock_account(&pk1, 1);
725        locks.write_lock_account(&pk1, 1);
726        locks.write_unlock_account(&pk1, 1);
727        locks.write_unlock_account(&pk1, 1);
728        assert!(locks.locks.is_empty());
729    }
730
731    #[test]
732    fn test_read_locking() {
733        let pk1 = Pubkey::new_unique();
734        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
735        locks.read_lock_account(&pk1, 1);
736        locks.read_lock_account(&pk1, 1);
737        locks.read_unlock_account(&pk1, 1);
738        locks.read_unlock_account(&pk1, 1);
739        assert!(locks.locks.is_empty());
740    }
741
742    #[test]
743    #[should_panic(expected = "thread_id must be < num_threads")]
744    fn test_lock_accounts_invalid_thread() {
745        let pk1 = Pubkey::new_unique();
746        let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS);
747        locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS);
748    }
749
750    #[test]
751    fn test_thread_set() {
752        let mut thread_set = ThreadSet::none();
753        assert!(thread_set.is_empty());
754        assert_eq!(thread_set.num_threads(), 0);
755        assert_eq!(thread_set.only_one_contained(), None);
756        for idx in 0..MAX_THREADS {
757            assert!(!thread_set.contains(idx));
758        }
759
760        thread_set.insert(4);
761        assert!(!thread_set.is_empty());
762        assert_eq!(thread_set.num_threads(), 1);
763        assert_eq!(thread_set.only_one_contained(), Some(4));
764        for idx in 0..MAX_THREADS {
765            assert_eq!(thread_set.contains(idx), idx == 4);
766        }
767
768        thread_set.insert(2);
769        assert!(!thread_set.is_empty());
770        assert_eq!(thread_set.num_threads(), 2);
771        assert_eq!(thread_set.only_one_contained(), None);
772        for idx in 0..MAX_THREADS {
773            assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4);
774        }
775
776        thread_set.remove(4);
777        assert!(!thread_set.is_empty());
778        assert_eq!(thread_set.num_threads(), 1);
779        assert_eq!(thread_set.only_one_contained(), Some(2));
780        for idx in 0..MAX_THREADS {
781            assert_eq!(thread_set.contains(idx), idx == 2);
782        }
783    }
784
785    #[test]
786    fn test_thread_set_any_zero() {
787        let any_threads = ThreadSet::any(0);
788        assert_eq!(any_threads.num_threads(), 0);
789    }
790
791    #[test]
792    fn test_thread_set_any_max() {
793        let any_threads = ThreadSet::any(MAX_THREADS);
794        assert_eq!(any_threads.num_threads(), MAX_THREADS as u32);
795    }
796
797    #[test]
798    fn test_thread_set_iter() {
799        let mut thread_set = ThreadSet::none();
800        assert!(thread_set.contained_threads_iter().next().is_none());
801
802        thread_set.insert(4);
803        assert_eq!(
804            thread_set.contained_threads_iter().collect::<Vec<_>>(),
805            vec![4]
806        );
807
808        thread_set.insert(5);
809        assert_eq!(
810            thread_set.contained_threads_iter().collect::<Vec<_>>(),
811            vec![4, 5]
812        );
813        thread_set.insert(63);
814        assert_eq!(
815            thread_set.contained_threads_iter().collect::<Vec<_>>(),
816            vec![4, 5, 63]
817        );
818
819        thread_set.remove(5);
820        assert_eq!(
821            thread_set.contained_threads_iter().collect::<Vec<_>>(),
822            vec![4, 63]
823        );
824
825        let thread_set = ThreadSet::any(64);
826        assert_eq!(
827            thread_set.contained_threads_iter().collect::<Vec<_>>(),
828            (0..64).collect::<Vec<_>>()
829        );
830    }
831}