rustfs_lock/
lrwmutex.rs

1// Copyright 2024 RustFS Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use rand::Rng;
16use std::time::{Duration, Instant};
17use tokio::{sync::RwLock, time::sleep};
18use tracing::info;
19
20#[derive(Debug, Default)]
21pub struct LRWMutex {
22    id: RwLock<String>,
23    source: RwLock<String>,
24    is_write: RwLock<bool>,
25    refrence: RwLock<usize>,
26}
27
28impl LRWMutex {
29    pub async fn lock(&self) -> bool {
30        let is_write = true;
31        let id = self.id.read().await.clone();
32        let source = self.source.read().await.clone();
33        let timeout = Duration::from_secs(10000);
34        self.look_loop(
35            &id, &source, &timeout, // big enough
36            is_write,
37        )
38        .await
39    }
40
41    pub async fn get_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
42        let is_write = true;
43        self.look_loop(id, source, timeout, is_write).await
44    }
45
46    pub async fn r_lock(&self) -> bool {
47        let is_write: bool = false;
48        let id = self.id.read().await.clone();
49        let source = self.source.read().await.clone();
50        let timeout = Duration::from_secs(10000);
51        self.look_loop(
52            &id, &source, &timeout, // big enough
53            is_write,
54        )
55        .await
56    }
57
58    pub async fn get_r_lock(&self, id: &str, source: &str, timeout: &Duration) -> bool {
59        let is_write = false;
60        self.look_loop(id, source, timeout, is_write).await
61    }
62
63    async fn inner_lock(&self, id: &str, source: &str, is_write: bool) -> bool {
64        *self.id.write().await = id.to_string();
65        *self.source.write().await = source.to_string();
66
67        let mut locked = false;
68        if is_write {
69            if *self.refrence.read().await == 0 && !*self.is_write.read().await {
70                *self.refrence.write().await = 1;
71                *self.is_write.write().await = true;
72                locked = true;
73            }
74        } else if !*self.is_write.read().await {
75            *self.refrence.write().await += 1;
76            locked = true;
77        }
78
79        locked
80    }
81
82    async fn look_loop(&self, id: &str, source: &str, timeout: &Duration, is_write: bool) -> bool {
83        let start = Instant::now();
84        loop {
85            if self.inner_lock(id, source, is_write).await {
86                return true;
87            } else {
88                if Instant::now().duration_since(start) > *timeout {
89                    return false;
90                }
91                let sleep_time: u64;
92                {
93                    let mut rng = rand::rng();
94                    sleep_time = rng.random_range(10..=50);
95                }
96                sleep(Duration::from_millis(sleep_time)).await;
97            }
98        }
99    }
100
101    pub async fn un_lock(&self) {
102        let is_write = true;
103        if !self.unlock(is_write).await {
104            info!("Trying to un_lock() while no Lock() is active")
105        }
106    }
107
108    pub async fn un_r_lock(&self) {
109        let is_write = false;
110        if !self.unlock(is_write).await {
111            info!("Trying to un_r_lock() while no Lock() is active")
112        }
113    }
114
115    async fn unlock(&self, is_write: bool) -> bool {
116        let mut unlocked = false;
117        if is_write {
118            if *self.is_write.read().await && *self.refrence.read().await == 1 {
119                *self.refrence.write().await = 0;
120                *self.is_write.write().await = false;
121                unlocked = true;
122            }
123        } else if !*self.is_write.read().await && *self.refrence.read().await > 0 {
124            *self.refrence.write().await -= 1;
125            unlocked = true;
126        }
127
128        unlocked
129    }
130
131    pub async fn force_un_lock(&self) {
132        *self.refrence.write().await = 0;
133        *self.is_write.write().await = false;
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use std::{sync::Arc, time::Duration};
140
141    use std::io::Result;
142    use tokio::time::sleep;
143
144    use crate::lrwmutex::LRWMutex;
145
146    #[tokio::test]
147    async fn test_lock_unlock() -> Result<()> {
148        let l_rw_lock = LRWMutex::default();
149        let id = "foo";
150        let source = "dandan";
151        let timeout = Duration::from_secs(5);
152        assert!(l_rw_lock.get_lock(id, source, &timeout).await);
153        l_rw_lock.un_lock().await;
154
155        l_rw_lock.lock().await;
156
157        assert!(!l_rw_lock.get_r_lock(id, source, &timeout).await);
158        l_rw_lock.un_lock().await;
159        assert!(l_rw_lock.get_r_lock(id, source, &timeout).await);
160
161        Ok(())
162    }
163
164    #[tokio::test]
165    async fn multi_thread_test() -> Result<()> {
166        let l_rw_lock = Arc::new(LRWMutex::default());
167        let id = "foo";
168        let source = "dandan";
169
170        let one_fn = async {
171            let one = Arc::clone(&l_rw_lock);
172            let timeout = Duration::from_secs(1);
173            assert!(one.get_lock(id, source, &timeout).await);
174            sleep(Duration::from_secs(5)).await;
175            l_rw_lock.un_lock().await;
176        };
177
178        let two_fn = async {
179            let two = Arc::clone(&l_rw_lock);
180            let timeout = Duration::from_secs(2);
181            assert!(!two.get_r_lock(id, source, &timeout).await);
182            sleep(Duration::from_secs(5)).await;
183            assert!(two.get_r_lock(id, source, &timeout).await);
184            two.un_r_lock().await;
185        };
186
187        tokio::join!(one_fn, two_fn);
188
189        Ok(())
190    }
191}