rustfs-lock 0.0.3

Distributed locking mechanism for RustFS, providing synchronization and coordination across distributed systems.
Documentation
// Copyright 2024 RustFS Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use rand::Rng;
use std::time::{Duration, Instant};
use tokio::{sync::RwLock, time::sleep};
use tracing::info;

#[derive(Debug, Default)]
pub struct LRWMutex {
    id: RwLock<String>,
    source: RwLock<String>,
    is_write: RwLock<bool>,
    refrence: RwLock<usize>,
}

impl LRWMutex {
    pub async fn lock(&self) -> bool {
        let is_write = true;
        let id = self.id.read().await.clone();
        let source = self.source.read().await.clone();
        let timeout = Duration::from_secs(10000);
        self.look_loop(
            &id, &source, &timeout, // big enough
            is_write,
        )
        .await
    }

    pub async fn get_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
        let is_write = true;
        self.look_loop(id, source, timeout, is_write).await
    }

    pub async fn r_lock(&self) -> bool {
        let is_write: bool = false;
        let id = self.id.read().await.clone();
        let source = self.source.read().await.clone();
        let timeout = Duration::from_secs(10000);
        self.look_loop(
            &id, &source, &timeout, // big enough
            is_write,
        )
        .await
    }

    pub async fn get_r_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
        let is_write = false;
        self.look_loop(id, source, timeout, is_write).await
    }

    async fn inner_lock(&self, id: &str, source: &str, is_write: bool) -> bool {
        *self.id.write().await = id.to_string();
        *self.source.write().await = source.to_string();

        let mut locked = false;
        if is_write {
            if *self.refrence.read().await == 0 && !*self.is_write.read().await {
                *self.refrence.write().await = 1;
                *self.is_write.write().await = true;
                locked = true;
            }
        } else if !*self.is_write.read().await {
            *self.refrence.write().await += 1;
            locked = true;
        }

        locked
    }

    async fn look_loop(&self, id: &str, source: &str, timeout: &Duration, is_write: bool) -> bool {
        let start = Instant::now();
        loop {
            if self.inner_lock(id, source, is_write).await {
                return true;
            } else {
                if Instant::now().duration_since(start) > *timeout {
                    return false;
                }
                let sleep_time: u64;
                {
                    let mut rng = rand::rng();
                    sleep_time = rng.random_range(10..=50);
                }
                sleep(Duration::from_millis(sleep_time)).await;
            }
        }
    }

    pub async fn un_lock(&self) {
        let is_write = true;
        if !self.unlock(is_write).await {
            info!("Trying to un_lock() while no Lock() is active")
        }
    }

    pub async fn un_r_lock(&self) {
        let is_write = false;
        if !self.unlock(is_write).await {
            info!("Trying to un_r_lock() while no Lock() is active")
        }
    }

    async fn unlock(&self, is_write: bool) -> bool {
        let mut unlocked = false;
        if is_write {
            if *self.is_write.read().await && *self.refrence.read().await == 1 {
                *self.refrence.write().await = 0;
                *self.is_write.write().await = false;
                unlocked = true;
            }
        } else if !*self.is_write.read().await && *self.refrence.read().await > 0 {
            *self.refrence.write().await -= 1;
            unlocked = true;
        }

        unlocked
    }

    pub async fn force_un_lock(&self) {
        *self.refrence.write().await = 0;
        *self.is_write.write().await = false;
    }
}

#[cfg(test)]
mod test {
    use std::{sync::Arc, time::Duration};

    use std::io::Result;
    use tokio::time::sleep;

    use crate::lrwmutex::LRWMutex;

    #[tokio::test]
    async fn test_lock_unlock() -> Result<()> {
        let l_rw_lock = LRWMutex::default();
        let id = "foo";
        let source = "dandan";
        let timeout = Duration::from_secs(5);
        assert!(l_rw_lock.get_lock(id, source, &timeout).await);
        l_rw_lock.un_lock().await;

        l_rw_lock.lock().await;

        assert!(!l_rw_lock.get_r_lock(id, source, &timeout).await);
        l_rw_lock.un_lock().await;
        assert!(l_rw_lock.get_r_lock(id, source, &timeout).await);

        Ok(())
    }

    #[tokio::test]
    async fn multi_thread_test() -> Result<()> {
        let l_rw_lock = Arc::new(LRWMutex::default());
        let id = "foo";
        let source = "dandan";

        let one_fn = async {
            let one = Arc::clone(&l_rw_lock);
            let timeout = Duration::from_secs(1);
            assert!(one.get_lock(id, source, &timeout).await);
            sleep(Duration::from_secs(5)).await;
            l_rw_lock.un_lock().await;
        };

        let two_fn = async {
            let two = Arc::clone(&l_rw_lock);
            let timeout = Duration::from_secs(2);
            assert!(!two.get_r_lock(id, source, &timeout).await);
            sleep(Duration::from_secs(5)).await;
            assert!(two.get_r_lock(id, source, &timeout).await);
            two.un_r_lock().await;
        };

        tokio::join!(one_fn, two_fn);

        Ok(())
    }
}