dataforge/multithreading/
safety.rs1use std::sync::{Arc, RwLock, Mutex};
4use std::collections::HashMap;
5use serde_json::Value;
6use crate::error::{DataForgeError, Result};
7
8pub struct ThreadSafeCache<K, V> {
10 data: Arc<RwLock<HashMap<K, V>>>,
11 max_size: usize,
12}
13
14impl<K, V> ThreadSafeCache<K, V>
15where
16 K: std::hash::Hash + Eq + Clone,
17 V: Clone,
18{
19 pub fn new(max_size: usize) -> Self {
21 Self {
22 data: Arc::new(RwLock::new(HashMap::new())),
23 max_size,
24 }
25 }
26
27 pub fn get(&self, key: &K) -> Option<V> {
29 self.data.read().unwrap().get(key).cloned()
30 }
31
32 pub fn insert(&self, key: K, value: V) -> Result<()> {
34 let mut data = self.data.write().unwrap();
35
36 if data.len() >= self.max_size && !data.contains_key(&key) {
37 return Err(DataForgeError::validation("Cache is full"));
38 }
39
40 data.insert(key, value);
41 Ok(())
42 }
43
44 pub fn remove(&self, key: &K) -> Option<V> {
46 self.data.write().unwrap().remove(key)
47 }
48
49 pub fn clear(&self) {
51 self.data.write().unwrap().clear();
52 }
53
54 pub fn len(&self) -> usize {
56 self.data.read().unwrap().len()
57 }
58
59 pub fn is_empty(&self) -> bool {
61 self.data.read().unwrap().is_empty()
62 }
63}
64
65pub struct ThreadSafeCounter {
67 value: Arc<Mutex<usize>>,
68}
69
70impl ThreadSafeCounter {
71 pub fn new(initial: usize) -> Self {
73 Self {
74 value: Arc::new(Mutex::new(initial)),
75 }
76 }
77
78 pub fn increment(&self) -> usize {
80 let mut value = self.value.lock().unwrap();
81 *value += 1;
82 *value
83 }
84
85 pub fn decrement(&self) -> usize {
87 let mut value = self.value.lock().unwrap();
88 if *value > 0 {
89 *value -= 1;
90 }
91 *value
92 }
93
94 pub fn get(&self) -> usize {
96 *self.value.lock().unwrap()
97 }
98
99 pub fn reset(&self) {
101 *self.value.lock().unwrap() = 0;
102 }
103}
104
105pub struct ThreadSafeGeneratorRegistry {
107 generators: Arc<RwLock<HashMap<String, Arc<dyn Fn() -> Value + Send + Sync>>>>,
108}
109
110impl ThreadSafeGeneratorRegistry {
111 pub fn new() -> Self {
113 Self {
114 generators: Arc::new(RwLock::new(HashMap::new())),
115 }
116 }
117
118 pub fn register<F>(&self, name: String, generator: F)
120 where
121 F: Fn() -> Value + Send + Sync + 'static,
122 {
123 self.generators.write().unwrap().insert(name, Arc::new(generator));
124 }
125
126 pub fn get(&self, name: &str) -> Option<Arc<dyn Fn() -> Value + Send + Sync>> {
128 self.generators.read().unwrap().get(name).cloned()
129 }
130
131 pub fn remove(&self, name: &str) -> bool {
133 self.generators.write().unwrap().remove(name).is_some()
134 }
135
136 pub fn list_generators(&self) -> Vec<String> {
138 self.generators.read().unwrap().keys().cloned().collect()
139 }
140
141 pub fn count(&self) -> usize {
143 self.generators.read().unwrap().len()
144 }
145}
146
147impl Default for ThreadSafeGeneratorRegistry {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl Clone for ThreadSafeGeneratorRegistry {
154 fn clone(&self) -> Self {
155 Self {
156 generators: Arc::clone(&self.generators),
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use std::thread;
165
166
167 #[test]
168 fn test_thread_safe_cache() {
169 let cache = ThreadSafeCache::new(10);
170
171 assert!(cache.insert("key1".to_string(), "value1".to_string()).is_ok());
172 assert_eq!(cache.get(&"key1".to_string()), Some("value1".to_string()));
173 assert_eq!(cache.len(), 1);
174
175 cache.clear();
176 assert!(cache.is_empty());
177 }
178
179 #[test]
180 fn test_thread_safe_counter() {
181 let counter = ThreadSafeCounter::new(0);
182
183 assert_eq!(counter.increment(), 1);
184 assert_eq!(counter.increment(), 2);
185 assert_eq!(counter.get(), 2);
186
187 assert_eq!(counter.decrement(), 1);
188 assert_eq!(counter.get(), 1);
189
190 counter.reset();
191 assert_eq!(counter.get(), 0);
192 }
193
194 #[test]
195 fn test_thread_safe_generator_registry() {
196 let registry = ThreadSafeGeneratorRegistry::new();
197
198 registry.register("test".to_string(), || Value::String("test".to_string()));
199
200 assert_eq!(registry.count(), 1);
201 assert!(registry.get("test").is_some());
202
203 let generator = registry.get("test").unwrap();
204 let result = generator();
205 assert_eq!(result, Value::String("test".to_string()));
206
207 assert!(registry.remove("test"));
208 assert_eq!(registry.count(), 0);
209 }
210
211 #[test]
212 fn test_concurrent_access() {
213 let counter = Arc::new(ThreadSafeCounter::new(0));
214 let mut handles = vec![];
215
216 for _ in 0..10 {
217 let counter = Arc::clone(&counter);
218 let handle = thread::spawn(move || {
219 for _ in 0..100 {
220 counter.increment();
221 }
222 });
223 handles.push(handle);
224 }
225
226 for handle in handles {
227 handle.join().unwrap();
228 }
229
230 assert_eq!(counter.get(), 1000);
231 }
232}