1use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
7use std::time::Duration;
8
9use bytes::Bytes;
10use dashmap::DashMap;
11
12use crate::keyspace::{EvictionPolicy, TtlResult};
13use crate::memory;
14use crate::time;
15
16#[derive(Debug, Clone)]
19struct Entry {
20 value: Bytes,
21 expires_at_ms: u64,
23}
24
25impl Entry {
26 #[inline]
27 fn is_expired(&self) -> bool {
28 time::is_expired(self.expires_at_ms)
29 }
30
31 #[inline]
33 fn size(&self, key_len: usize) -> usize {
34 key_len + self.value.len() + 48
36 }
37}
38
39#[derive(Debug)]
44pub struct ConcurrentKeyspace {
45 data: DashMap<Box<str>, Entry>,
47 memory_used: AtomicUsize,
48 max_memory: Option<usize>,
49 eviction_policy: EvictionPolicy,
50 ops_count: AtomicU64,
51}
52
53impl ConcurrentKeyspace {
54 pub fn new(max_memory: Option<usize>, eviction_policy: EvictionPolicy) -> Self {
56 Self {
57 data: DashMap::new(),
58 memory_used: AtomicUsize::new(0),
59 max_memory,
60 eviction_policy,
61 ops_count: AtomicU64::new(0),
62 }
63 }
64
65 pub fn get(&self, key: &str) -> Option<Bytes> {
67 self.ops_count.fetch_add(1, Ordering::Relaxed);
68
69 let entry = self.data.get(key)?;
70
71 if entry.is_expired() {
72 let key_len = entry.key().len();
73 let size = entry.size(key_len);
74 drop(entry);
75 if self.data.remove(key).is_some() {
77 self.memory_used.fetch_sub(size, Ordering::Relaxed);
78 }
79 return None;
80 }
81
82 Some(entry.value.clone())
83 }
84
85 pub fn set(&self, key: String, value: Bytes, ttl: Option<Duration>) -> bool {
87 self.ops_count.fetch_add(1, Ordering::Relaxed);
88
89 let key: Box<str> = key.into_boxed_str();
90 let entry_size = key.len() + value.len() + 48;
91 let expires_at_ms = time::expiry_from_duration(ttl);
92
93 if let Some(max) = self.max_memory {
95 let limit = memory::effective_limit(max);
96 let current = self.memory_used.load(Ordering::Relaxed);
97 if current + entry_size > limit {
98 if self.eviction_policy == EvictionPolicy::NoEviction {
99 return false;
100 }
101 self.evict_entries(entry_size);
103 }
104 }
105
106 let entry = Entry {
107 value,
108 expires_at_ms,
109 };
110
111 if let Some(old) = self.data.insert(key.clone(), entry) {
113 let old_size = old.size(key.len());
115 let diff = entry_size as isize - old_size as isize;
116 if diff > 0 {
117 self.memory_used.fetch_add(diff as usize, Ordering::Relaxed);
118 } else {
119 self.memory_used
120 .fetch_sub((-diff) as usize, Ordering::Relaxed);
121 }
122 } else {
123 self.memory_used.fetch_add(entry_size, Ordering::Relaxed);
124 }
125
126 true
127 }
128
129 pub fn del(&self, key: &str) -> bool {
131 self.ops_count.fetch_add(1, Ordering::Relaxed);
132
133 if let Some((k, removed)) = self.data.remove(key) {
134 self.memory_used
135 .fetch_sub(removed.size(k.len()), Ordering::Relaxed);
136 true
137 } else {
138 false
139 }
140 }
141
142 pub fn exists(&self, key: &str) -> bool {
144 self.get(key).is_some()
145 }
146
147 pub fn ttl(&self, key: &str) -> TtlResult {
149 match self.data.get(key) {
150 None => TtlResult::NotFound,
151 Some(entry) => {
152 if entry.is_expired() {
153 TtlResult::NotFound
154 } else {
155 match time::remaining_secs(entry.expires_at_ms) {
156 None => TtlResult::NoExpiry,
157 Some(secs) => TtlResult::Seconds(secs),
158 }
159 }
160 }
161 }
162 }
163
164 pub fn expire(&self, key: &str, seconds: u64) -> bool {
166 self.ops_count.fetch_add(1, Ordering::Relaxed);
167
168 if let Some(mut entry) = self.data.get_mut(key) {
169 if entry.is_expired() {
170 return false;
171 }
172 entry.expires_at_ms = time::now_ms() + seconds * 1000;
173 true
174 } else {
175 false
176 }
177 }
178
179 pub fn len(&self) -> usize {
181 self.data.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.data.is_empty()
187 }
188
189 pub fn memory_used(&self) -> usize {
191 self.memory_used.load(Ordering::Relaxed)
192 }
193
194 pub fn ops_count(&self) -> u64 {
196 self.ops_count.load(Ordering::Relaxed)
197 }
198
199 pub fn clear(&self) {
201 self.data.clear();
202 self.memory_used.store(0, Ordering::Relaxed);
203 }
204
205 fn evict_entries(&self, needed: usize) {
207 let mut freed = 0usize;
208 let mut keys_to_remove = Vec::new();
209
210 for entry in self.data.iter() {
212 if freed >= needed {
213 break;
214 }
215 let key_len = entry.key().len();
216 keys_to_remove.push(entry.key().clone());
217 freed += entry.value().size(key_len);
218 }
219
220 for key in keys_to_remove {
222 if let Some((k, removed)) = self.data.remove(&key) {
223 self.memory_used
224 .fetch_sub(removed.size(k.len()), Ordering::Relaxed);
225 }
226 }
227 }
228}
229
230impl Default for ConcurrentKeyspace {
231 fn default() -> Self {
232 Self::new(None, EvictionPolicy::NoEviction)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn set_and_get() {
242 let ks = ConcurrentKeyspace::default();
243 assert!(ks.set("key".into(), Bytes::from("value"), None));
244 assert_eq!(ks.get("key"), Some(Bytes::from("value")));
245 }
246
247 #[test]
248 fn get_missing() {
249 let ks = ConcurrentKeyspace::default();
250 assert_eq!(ks.get("missing"), None);
251 }
252
253 #[test]
254 fn del_existing() {
255 let ks = ConcurrentKeyspace::default();
256 ks.set("key".into(), Bytes::from("value"), None);
257 assert!(ks.del("key"));
258 assert_eq!(ks.get("key"), None);
259 }
260
261 #[test]
262 fn del_missing() {
263 let ks = ConcurrentKeyspace::default();
264 assert!(!ks.del("missing"));
265 }
266
267 #[test]
268 fn exists_check() {
269 let ks = ConcurrentKeyspace::default();
270 ks.set("key".into(), Bytes::from("value"), None);
271 assert!(ks.exists("key"));
272 assert!(!ks.exists("missing"));
273 }
274
275 #[test]
276 fn ttl_expires() {
277 let ks = ConcurrentKeyspace::default();
278 ks.set(
279 "key".into(),
280 Bytes::from("value"),
281 Some(Duration::from_millis(10)),
282 );
283 assert!(matches!(ks.ttl("key"), TtlResult::Seconds(_)));
284 std::thread::sleep(Duration::from_millis(20));
285 assert_eq!(ks.get("key"), None);
286 }
287
288 #[test]
289 fn concurrent_access() {
290 use std::sync::Arc;
291 use std::thread;
292
293 let ks = Arc::new(ConcurrentKeyspace::default());
294 let mut handles = vec![];
295
296 for i in 0..8 {
298 let ks = Arc::clone(&ks);
299 handles.push(thread::spawn(move || {
300 for j in 0..1000 {
301 let key = format!("key-{}-{}", i, j);
302 ks.set(key, Bytes::from("value"), None);
303 }
304 }));
305 }
306
307 for h in handles {
308 h.join().unwrap();
309 }
310
311 assert_eq!(ks.len(), 8000);
312 }
313}