1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
 * Copyright (c) Facebook, Inc. and its affiliates.
 *
 * This source code is licensed under both the MIT license found in the
 * LICENSE-MIT file in the root directory of this source tree and the Apache
 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
 * of this source tree.
 */

#![deny(warnings, missing_docs, clippy::all, broken_intra_doc_links)]

//! Crate extending functionalities of [std::sync]

use parking_lot::Mutex as ParkingLotMutex;
use std::sync::{Mutex, RwLock};

/// Extend functionality of [std::sync::Mutex]
///
/// # Example
/// ```
/// # use std::sync::Mutex;
/// # use lock_ext::LockExt;
/// let lock = Mutex::new(Vec::new());
/// lock.with(|value| value.push("hello"));
/// let hello = lock.with(|value| value.get(0).unwrap().to_owned());
/// # assert_eq!(&hello, &"hello");
/// ```
pub trait LockExt {
    /// Value that is being held inside the lock
    type Value;

    /// The passed `scope` function will be called with the lock being held
    /// and the locked value will be accessible inside the `scope` as `&mut`
    fn with<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&mut Self::Value) -> Out;
}

impl<V> LockExt for Mutex<V> {
    type Value = V;

    fn with<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&mut Self::Value) -> Out,
    {
        let mut value = self.lock().expect("lock poisoned");
        scope(&mut *value)
    }
}

impl<V> LockExt for ParkingLotMutex<V> {
    type Value = V;

    fn with<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&mut Self::Value) -> Out,
    {
        let mut value = self.lock();
        scope(&mut *value)
    }
}

/// Extend functionality of [std::sync::RwLock]
///
/// # Example
/// ```
/// # use std::sync::RwLock;
/// # use lock_ext::RwLockExt;
/// let lock = RwLock::new(Vec::new());
/// lock.with_write(|value| value.push("hello"));
/// let hello = lock.with_read(|value| value.get(0).unwrap().to_owned());
/// # assert_eq!(&hello, &"hello");
/// ```
pub trait RwLockExt {
    /// Value that is being held inside the lock
    type Value;

    /// The passed `scope` function will be called with the read lock being held
    /// and the locked value will be accessible inside the `scope` as `&`
    fn with_read<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&Self::Value) -> Out;

    /// The passed `scope` function will be called with the write lock being held
    /// and the locked value will be accessible inside the `scope` as `&mut`
    fn with_write<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&mut Self::Value) -> Out;
}

impl<V> RwLockExt for RwLock<V> {
    type Value = V;

    fn with_read<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&Self::Value) -> Out,
    {
        let value = self.read().expect("lock poisoned");
        scope(&*value)
    }

    fn with_write<Scope, Out>(&self, scope: Scope) -> Out
    where
        Scope: FnOnce(&mut Self::Value) -> Out,
    {
        let mut value = self.write().expect("lock poisoned");
        scope(&mut *value)
    }
}

#[cfg(test)]
mod test {
    use super::{LockExt, RwLockExt};
    use std::sync::{Arc, Mutex, RwLock};

    #[test]
    fn simple() {
        let vs = Arc::new(Mutex::new(Vec::new()));
        assert_eq!(vs.with(|vs| vs.len()), 0);
        vs.with(|vs| vs.push("test"));
        assert_eq!(vs.with(|vs| vs.pop()), Some("test"));
        assert_eq!(vs.with(|vs| vs.len()), 0);
    }

    #[test]
    fn rwlock() {
        let vs = Arc::new(RwLock::new(Vec::new()));
        assert_eq!(vs.with_read(|vs| vs.len()), 0);
        vs.with_write(|vs| vs.push("test"));
        assert_eq!(vs.with_write(|vs| vs.pop()), Some("test"));
        assert_eq!(vs.with_read(|vs| vs.len()), 0);
    }
}