use std::sync::Arc;
use papaya::HashMap;
use tokio::sync::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
const SEPARATOR: u8 = b'/';
#[derive(Default)]
pub struct Lock<const N: usize> {
inner: HashMap<Vec<u8>, Arc<RwLock<()>>>,
}
pub enum Guard {
Read(ReadGuard),
Write(WriteGuard),
}
pub type ReadGuard = OwnedRwLockReadGuard<()>;
pub type WriteGuard = OwnedRwLockWriteGuard<()>;
impl<const N: usize> Lock<N> {
pub async fn read_backward<T: AsRef<[u8]>>(&self, path: T) -> [Option<Guard>; N] {
let mut guards: [Option<Guard>; N] = [const { None }; N];
for (index, path) in partition_backward(path.as_ref()).take(N).enumerate() {
let lock = self.lock(path);
guards[index] = Some(Guard::Read(lock.read_owned().await));
}
guards
}
pub async fn read_forward<T: AsRef<[u8]>>(&self, path: T) -> [Option<Guard>; N] {
let mut guards: [Option<Guard>; N] = [const { None }; N];
for (index, path) in partition_forward(path.as_ref()).take(N).enumerate() {
let lock = self.lock(path);
guards[index] = Some(Guard::Read(lock.read_owned().await));
}
guards
}
pub async fn write_backward<T: AsRef<[u8]>>(&self, path: T) -> [Option<Guard>; N] {
let mut guards: [Option<Guard>; N] = [const { None }; N];
for (index, path) in partition_backward(path.as_ref()).take(N).enumerate() {
let lock = self.lock(path);
if index == 0 {
guards[index] = Some(Guard::Write(lock.write_owned().await));
} else {
guards[index] = Some(Guard::Read(lock.read_owned().await));
}
}
guards
}
pub async fn write_forward<T: AsRef<[u8]>>(&self, path: T) -> [Option<Guard>; N] {
let mut guards: [Option<Guard>; N] = [const { None }; N];
let mut iterator = partition_forward(path.as_ref())
.take(N)
.enumerate()
.peekable();
while let Some((index, path)) = iterator.next() {
let lock = self.lock(path);
if iterator.peek().is_none() {
guards[index] = Some(Guard::Write(lock.write_owned().await));
} else {
guards[index] = Some(Guard::Read(lock.read_owned().await));
}
}
guards
}
fn lock(&self, path: &[u8]) -> Arc<RwLock<()>> {
match self.inner.pin().get(path).cloned() {
Some(value) => value,
_ => {
let path = path.to_vec();
self.inner
.pin()
.get_or_insert_with(path, Default::default)
.clone()
}
}
}
}
impl Guard {
pub fn downgrade(self) -> Self {
match self {
Self::Read(guard) => Self::Read(guard),
Self::Write(guard) => Self::Read(guard.downgrade()),
}
}
}
fn partition_backward(value: &[u8]) -> impl Iterator<Item = &[u8]> {
let count = value.len();
value
.iter()
.rev()
.enumerate()
.filter_map(move |(index, character)| {
if index == 0 {
Some(value)
} else if *character == SEPARATOR {
Some(&value[..(count - index - 1)])
} else {
None
}
})
}
fn partition_forward(value: &[u8]) -> impl Iterator<Item = &[u8]> {
let count = value.len();
value
.iter()
.enumerate()
.filter_map(move |(index, character)| {
if index == count - 1 {
Some(value)
} else if *character == SEPARATOR {
Some(&value[..index])
} else {
None
}
})
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::Lock;
const N: usize = 10;
macro_rules! ok(($result:expr) => ($result.unwrap()));
#[test]
fn partition_backward() {
assert_eq!(
super::partition_backward(b"a/b/c").collect::<Vec<_>>(),
vec![b"a/b/c".as_ref(), b"a/b".as_ref(), b"a".as_ref()],
);
}
#[test]
fn partition_forward() {
assert_eq!(
super::partition_forward(b"a/b/c").collect::<Vec<_>>(),
vec![b"a".as_ref(), b"a/b".as_ref(), b"a/b/c".as_ref()]
);
}
#[tokio::test(flavor = "multi_thread")]
async fn read_independent() {
let lock = Lock::<N>::default();
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
{
let guard = lock.read_backward("a/b/c").await;
let sender = sender.clone();
tokio::task::spawn(async move {
work(1).await;
ok!(sender.send(2));
std::mem::drop(guard);
});
}
{
let sender = sender.clone();
tokio::task::spawn(async move {
let _guard = lock.read_backward("a/b/c").await;
ok!(sender.send(1));
});
}
for index in [1, 2] {
assert_eq!(receiver.recv().await, Some(index));
}
}
#[tokio::test(flavor = "multi_thread")]
async fn write_dependent() {
let lock = Lock::<N>::default();
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
{
let guard = lock.write_backward("a/b/c").await;
let sender = sender.clone();
tokio::task::spawn(async move {
work(1).await;
ok!(sender.send(1));
std::mem::drop(guard);
});
}
{
let sender = sender.clone();
tokio::task::spawn(async move {
let _guard = lock.write_backward("a/b/c").await;
ok!(sender.send(2));
});
}
for index in [1, 2] {
assert_eq!(receiver.recv().await, Some(index));
}
}
#[tokio::test(flavor = "multi_thread")]
async fn write_independent() {
let lock = Lock::<N>::default();
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
{
let guard = lock.write_backward("a/b/c").await;
let sender = sender.clone();
tokio::task::spawn(async move {
work(1).await;
ok!(sender.send(2));
std::mem::drop(guard);
});
}
{
let sender = sender.clone();
tokio::task::spawn(async move {
let _guard = lock.write_backward("a/b/d").await;
ok!(sender.send(1));
});
}
for index in [1, 2] {
assert_eq!(receiver.recv().await, Some(index));
}
}
#[tokio::test(flavor = "multi_thread")]
async fn write_read_independent_dependent() {
let lock = Arc::new(Lock::<N>::default());
let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
{
let guard = lock.write_backward("a/b/c").await;
let sender = sender.clone();
tokio::task::spawn(async move {
work(1).await;
ok!(sender.send(2));
std::mem::drop(guard);
});
}
{
let lock = lock.clone();
let sender = sender.clone();
tokio::task::spawn(async move {
let _guard = lock.read_backward("a/b").await;
ok!(sender.send(1));
});
}
{
let sender = sender.clone();
tokio::task::spawn(async move {
let _guard = lock.read_backward("a/b/c/d").await;
ok!(sender.send(3));
});
}
for index in [1, 2, 3] {
assert_eq!(receiver.recv().await, Some(index));
}
}
async fn work(load: u64) {
tokio::time::sleep(std::time::Duration::from_secs(4 * load)).await;
}
}