striped-lock 0.1.0

Striped Lock for Rust
Documentation
// Copyright (c) 2024 Mek101
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

use std::{
    hash::{BuildHasher, BuildHasherDefault, DefaultHasher, Hash},
    marker::PhantomData,
    num::NonZeroUsize,
    sync::{Mutex, MutexGuard},
};

use crate::batch::{KeyBatch, MAX_BATCH_KEYS};

/// The inner mutex is poisoned.
pub struct StripedPoisonError;

pub struct StripedLockGuard<'l> {
    _guard: MutexGuard<'l, ()>,
}

pub struct StripedBatchLockGuard<'l> {
    _guards: [Option<MutexGuard<'l, ()>>; MAX_BATCH_KEYS],
}

pub struct StripedLock<K, H = BuildHasherDefault<DefaultHasher>>
where
    K: Hash,
    H: BuildHasher,
{
    hasher_builder: H,
    locks: Box<[Mutex<()>]>,
    phantom: PhantomData<K>,
}

impl<K> StripedLock<K, BuildHasherDefault<DefaultHasher>>
where
    K: Hash,
{
    /// Create a new [`StripedLock`] instance with rust's default hasher.
    ///
    /// # Arguments
    ///
    /// * `locks` - The number of inner locks. Increase to reduce collisions.
    pub fn new(locks: NonZeroUsize) -> Self {
        Self::with_hasher(BuildHasherDefault::default(), locks)
    }
}

impl<K, H> StripedLock<K, H>
where
    K: Hash,
    H: BuildHasher,
{
    /// Create a new [`StripedLock`] instance.
    ///
    /// # Arguments
    ///
    /// * `locks` - The number of inner locks. Increase to reduce collisions.
    /// * `hasher_builder` - The factory of hashers.
    pub fn with_hasher(hasher_builder: H, locks: NonZeroUsize) -> Self {
        let locks = (0..locks.get())
            .map(|_| Mutex::new(()))
            .collect::<Vec<_>>()
            .into_boxed_slice();

        Self {
            hasher_builder,
            locks,
            phantom: PhantomData::default(),
        }
    }

    /// Lock on the key.
    /// Use `lock_batch` if you want to lock on multiple keys.
    ///
    /// # Arguments
    ///
    /// * `key` - The key to lock on.
    pub fn lock(&self, key: K) -> Result<StripedLockGuard, StripedPoisonError> {
        fn inner(locks: &[Mutex<()>], key: u64) -> Result<StripedLockGuard, StripedPoisonError> {
            let idx = (key % locks.len() as u64) as usize;
            let lock = &locks[idx];

            match lock.lock() {
                Ok(guard) => Ok(StripedLockGuard { _guard: guard }),
                Err(_) => Err(StripedPoisonError),
            }
        }

        let hash = self.hasher_builder.hash_one(key);
        inner(&self.locks, hash)
    }

    /// Lock on the key.
    /// Use `lock_batch` if you want to lock on multiple keys.
    ///
    /// # Arguments
    ///
    /// * `batch` - The batch of keys to lock on. May be up to 4.
    ///
    /// # Example
    ///
    /// ```
    /// # use std::hash::{BuildHasherDefault, DefaultHasher};
    /// # use std::num::NonZeroUsize;
    /// # use striped_lock::std::StripedLock;
    /// let sl: StripedLock<char> = StripedLock::new(NonZeroUsize::new(4).unwrap());
    /// sl.lock_batch(('a', 'b', 'c', 'd'));
    /// ```
    pub fn lock_batch<B>(&self, batch: B) -> Result<StripedBatchLockGuard, StripedPoisonError>
    where
        B: KeyBatch<K, H>,
    {
        fn inner<'l>(
            locks: &'l [Mutex<()>],
            batch: &mut [u64],
        ) -> Result<StripedBatchLockGuard<'l>, StripedPoisonError> {
            const ARRAY_REPEAT_VALUE: Option<MutexGuard<()>> = None;

            assert!(batch.len() > 0);
            assert!(batch.len() <= MAX_BATCH_KEYS);

            // "Normalize".
            for key in batch.iter_mut() {
                *key %= locks.len() as u64;
            }

            // Sort such that we always obtain the locks in the same order.
            batch.sort_unstable();

            let mut guards = [ARRAY_REPEAT_VALUE; MAX_BATCH_KEYS];

            guards[0] = Some(
                locks[batch[0] as usize]
                    .lock()
                    .map_err(|_| StripedPoisonError)?,
            );

            for i in 1..batch.len() {
                // Skip duplicates since locks are not re-entrant.
                if batch[i] != batch[i - 1] {
                    guards[i] = Some(
                        locks[batch[i] as usize]
                            .lock()
                            .map_err(|_| StripedPoisonError)?,
                    );
                }
            }

            Ok(StripedBatchLockGuard { _guards: guards })
        }

        let (mut arr, filled) = batch.into_hash_array(&self.hasher_builder);
        let batch = &mut arr[..filled];
        inner(&self.locks, batch)
    }

    /// Check if the mutex at the given key is poisoned.
    pub fn is_poisoned(&self, key: K) -> bool {
        fn inner(locks: &[Mutex<()>], key: u64) -> bool {
            let idx = (key % locks.len() as u64) as usize;
            let lock = &locks[idx];
            lock.is_poisoned()
        }

        let key = self.hasher_builder.hash_one(key);
        inner(&self.locks, key)
    }

    // Remove the poisoned status from the mutex at the given key.
    pub fn clear_poison(&self, key: K) {
        fn inner(locks: &[Mutex<()>], key: u64) {
            let idx = (key % locks.len() as u64) as usize;
            let lock = &locks[idx];
            lock.clear_poison();
        }

        let key = self.hasher_builder.hash_one(key);
        inner(&self.locks, key);
    }
}