dataforge/multithreading/
safety.rs

1//! 线程安全机制模块
2
3use std::sync::{Arc, RwLock, Mutex};
4use std::collections::HashMap;
5use serde_json::Value;
6use crate::error::{DataForgeError, Result};
7
8/// 线程安全的数据缓存
9pub struct ThreadSafeCache<K, V> {
10    data: Arc<RwLock<HashMap<K, V>>>,
11    max_size: usize,
12}
13
14impl<K, V> ThreadSafeCache<K, V>
15where
16    K: std::hash::Hash + Eq + Clone,
17    V: Clone,
18{
19    /// 创建新的线程安全缓存
20    pub fn new(max_size: usize) -> Self {
21        Self {
22            data: Arc::new(RwLock::new(HashMap::new())),
23            max_size,
24        }
25    }
26
27    /// 获取值
28    pub fn get(&self, key: &K) -> Option<V> {
29        self.data.read().unwrap().get(key).cloned()
30    }
31
32    /// 插入值
33    pub fn insert(&self, key: K, value: V) -> Result<()> {
34        let mut data = self.data.write().unwrap();
35        
36        if data.len() >= self.max_size && !data.contains_key(&key) {
37            return Err(DataForgeError::validation("Cache is full"));
38        }
39        
40        data.insert(key, value);
41        Ok(())
42    }
43
44    /// 移除值
45    pub fn remove(&self, key: &K) -> Option<V> {
46        self.data.write().unwrap().remove(key)
47    }
48
49    /// 清空缓存
50    pub fn clear(&self) {
51        self.data.write().unwrap().clear();
52    }
53
54    /// 获取缓存大小
55    pub fn len(&self) -> usize {
56        self.data.read().unwrap().len()
57    }
58
59    /// 检查是否为空
60    pub fn is_empty(&self) -> bool {
61        self.data.read().unwrap().is_empty()
62    }
63}
64
65/// 线程安全的计数器
66pub struct ThreadSafeCounter {
67    value: Arc<Mutex<usize>>,
68}
69
70impl ThreadSafeCounter {
71    /// 创建新的计数器
72    pub fn new(initial: usize) -> Self {
73        Self {
74            value: Arc::new(Mutex::new(initial)),
75        }
76    }
77
78    /// 增加计数
79    pub fn increment(&self) -> usize {
80        let mut value = self.value.lock().unwrap();
81        *value += 1;
82        *value
83    }
84
85    /// 减少计数
86    pub fn decrement(&self) -> usize {
87        let mut value = self.value.lock().unwrap();
88        if *value > 0 {
89            *value -= 1;
90        }
91        *value
92    }
93
94    /// 获取当前值
95    pub fn get(&self) -> usize {
96        *self.value.lock().unwrap()
97    }
98
99    /// 重置计数器
100    pub fn reset(&self) {
101        *self.value.lock().unwrap() = 0;
102    }
103}
104
105/// 线程安全的数据生成器注册表
106pub struct ThreadSafeGeneratorRegistry {
107    generators: Arc<RwLock<HashMap<String, Arc<dyn Fn() -> Value + Send + Sync>>>>,
108}
109
110impl ThreadSafeGeneratorRegistry {
111    /// 创建新的注册表
112    pub fn new() -> Self {
113        Self {
114            generators: Arc::new(RwLock::new(HashMap::new())),
115        }
116    }
117
118    /// 注册生成器
119    pub fn register<F>(&self, name: String, generator: F)
120    where
121        F: Fn() -> Value + Send + Sync + 'static,
122    {
123        self.generators.write().unwrap().insert(name, Arc::new(generator));
124    }
125
126    /// 获取生成器
127    pub fn get(&self, name: &str) -> Option<Arc<dyn Fn() -> Value + Send + Sync>> {
128        self.generators.read().unwrap().get(name).cloned()
129    }
130
131    /// 移除生成器
132    pub fn remove(&self, name: &str) -> bool {
133        self.generators.write().unwrap().remove(name).is_some()
134    }
135
136    /// 列出所有生成器名称
137    pub fn list_generators(&self) -> Vec<String> {
138        self.generators.read().unwrap().keys().cloned().collect()
139    }
140
141    /// 获取生成器数量
142    pub fn count(&self) -> usize {
143        self.generators.read().unwrap().len()
144    }
145}
146
147impl Default for ThreadSafeGeneratorRegistry {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153impl Clone for ThreadSafeGeneratorRegistry {
154    fn clone(&self) -> Self {
155        Self {
156            generators: Arc::clone(&self.generators),
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use std::thread;
165
166
167    #[test]
168    fn test_thread_safe_cache() {
169        let cache = ThreadSafeCache::new(10);
170        
171        assert!(cache.insert("key1".to_string(), "value1".to_string()).is_ok());
172        assert_eq!(cache.get(&"key1".to_string()), Some("value1".to_string()));
173        assert_eq!(cache.len(), 1);
174        
175        cache.clear();
176        assert!(cache.is_empty());
177    }
178
179    #[test]
180    fn test_thread_safe_counter() {
181        let counter = ThreadSafeCounter::new(0);
182        
183        assert_eq!(counter.increment(), 1);
184        assert_eq!(counter.increment(), 2);
185        assert_eq!(counter.get(), 2);
186        
187        assert_eq!(counter.decrement(), 1);
188        assert_eq!(counter.get(), 1);
189        
190        counter.reset();
191        assert_eq!(counter.get(), 0);
192    }
193
194    #[test]
195    fn test_thread_safe_generator_registry() {
196        let registry = ThreadSafeGeneratorRegistry::new();
197        
198        registry.register("test".to_string(), || Value::String("test".to_string()));
199        
200        assert_eq!(registry.count(), 1);
201        assert!(registry.get("test").is_some());
202        
203        let generator = registry.get("test").unwrap();
204        let result = generator();
205        assert_eq!(result, Value::String("test".to_string()));
206        
207        assert!(registry.remove("test"));
208        assert_eq!(registry.count(), 0);
209    }
210
211    #[test]
212    fn test_concurrent_access() {
213        let counter = Arc::new(ThreadSafeCounter::new(0));
214        let mut handles = vec![];
215
216        for _ in 0..10 {
217            let counter = Arc::clone(&counter);
218            let handle = thread::spawn(move || {
219                for _ in 0..100 {
220                    counter.increment();
221                }
222            });
223            handles.push(handle);
224        }
225
226        for handle in handles {
227            handle.join().unwrap();
228        }
229
230        assert_eq!(counter.get(), 1000);
231    }
232}