1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap()
22 .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26 SystemTime::now()
27 .checked_add(ttl)
28 .unwrap()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_secs()
32}
33
34pub trait ISledExt {
35 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36 where
37 K: AsRef<[u8]>;
38}
39
40impl ISledExt for Db {
41 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42 where
43 K: AsRef<[u8]>,
44 {
45 let expire_at = expired_time(ttl).to_be_bytes();
46 self.insert(key, expire_at.as_slice())?;
47 Ok(true)
48 }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53 pub path: String,
54 pub cache_capacity: u64,
55 pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62 pub db: Db,
63 pub kv_tree: sled::Tree,
64 #[cfg(feature = "ttl")]
65 pub ttl_tree: sled::Tree,
66}
67
68#[cfg(feature = "ttl")]
69pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
70 let t = match interval {
71 Some(d) => d,
72 None => Duration::from_secs(3),
73 };
74 let limit = match limit {
75 Some(l) => l,
76 None => 200,
77 };
78 tokio::spawn(async move {
79 loop {
80 tokio::time::sleep(t).await;
81 loop {
82 let now = std::time::Instant::now();
83 let count = db.cleanup(limit);
84 if count > 0 {
85 log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
86 }
87 if count < limit {
88 break;
89 }
90 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
91 }
92 }
93 });
94}
95
96#[cfg(feature = "ttl")]
97pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
98where
99 F: Fn(String) + Send + Sync + 'static,
100{
101 tokio::spawn(async move {
102 for event in db.ttl_tree.watch_prefix(vec![]) {
103 match event {
104 Event::Remove { key } => {
105 let key = String::from_utf8_lossy(&key).into_owned();
106 _evt(key);
107 }
108 _ => {}
109 }
110 }
111 });
112}
113
114impl KvDb {
115 pub fn new(cfg: KvDbConfig) -> Result<Self> {
116 let c = Config::default()
117 .path(cfg.path)
118 .cache_capacity(cfg.cache_capacity)
119 .flush_every_ms(Some(cfg.flush_every_ms))
120 .mode(sled::Mode::LowSpace);
121 let db = c.open()?;
122 let kv_tree = db.open_tree(KV_TREE)?;
123 #[cfg(feature = "ttl")]
124 let ttl_tree = db.open_tree(_TTL_TREE)?;
125
126 Ok(KvDb {
127 db,
128 kv_tree,
129 #[cfg(feature = "ttl")]
130 ttl_tree,
131 })
132 }
133
134 #[cfg(feature = "ttl")]
135 fn cleanup(&self, limit: usize) -> usize {
136 let mut count = 0;
137
138 for item in self.ttl_tree.iter() {
139 if count > limit {
140 break;
141 }
142 let (key, expire_at_iv) = match item {
143 Ok(item) => item,
144 Err(e) => {
145 log::error!("cleanup err: {:?}", e);
146 break;
147 }
148 };
149
150 let expire_at = match expire_at_iv.as_ref().try_into() {
151 Ok(at) => u64::from_be_bytes(at),
152 Err(e) => {
153 log::error!("cleanup err: {:?}", e);
154 break;
155 }
156 };
157
158 if expire_at > _now() {
159 break;
160 }
161
162 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
163 kv.remove(key.clone())?;
164 exp.remove(key.clone())?;
165 Ok::<_, ConflictableTransactionError<()>>(())
166 }) {
167 log::error!("cleanup err: {:?}", e);
168 } else {
169 count += 1;
170 }
171 }
172 count
173 }
174
175 #[cfg(feature = "ttl")]
176 pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
177 where
178 K: AsRef<[u8]>,
179 {
180 let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
181 Ok(Some(at_bytes)) => at_bytes,
182 Ok(None) => return None,
183 Err(e) => {
184 log::error!("get_ttl_at err: {:?}", e);
185 return None;
186 }
187 };
188
189 let expire_at = match expire_at_iv.as_ref().try_into() {
190 Ok(at) => u64::from_be_bytes(at),
191 Err(e) => {
192 log::error!("get_ttl_at err: {:?}", e);
193 return None;
194 }
195 };
196
197 Some(expire_at)
198 }
199
200 #[cfg(feature = "ttl")]
201 pub fn is_expired<K>(&self, key: K) -> Option<bool>
202 where
203 K: AsRef<[u8]>,
204 {
205 let expire_at = self.get_ttl_at(key);
206
207 let Some(expire_at) = expire_at else {
208 return None;
209 };
210
211 if _now() > expire_at {
212 return Some(true);
213 }
214
215 Some(false)
216 }
217
218 #[cfg(feature = "ttl")]
219 pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
220 where
221 K: AsRef<[u8]>,
222 V: Serialize + Encode,
223 {
224 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
225 let expire_at = expired_time(ttl).to_be_bytes();
226
227 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
228 kv.insert(key.as_ref(), v.clone())?;
229 ttl.insert(key.as_ref(), expire_at.as_slice())?;
230 Ok::<_, ConflictableTransactionError<()>>(())
231 }) {
232 return Err(anyhow!("insert_ttl err: {:?}", e));
233 }
234 Ok(())
235 }
236
237 #[cfg(feature = "ttl")]
238 pub fn refresh_ttl<K>(&self, key: K, ttl: Duration) -> Result<()>
239 where
240 K: AsRef<[u8]>,
241 {
242 if !self.contains_key(&key) {
243 return Err(anyhow!("key is not exist"));
244 }
245 let expire_at = expired_time(ttl).to_be_bytes();
246 self.ttl_tree.insert(key, expire_at.as_slice())?;
247 Ok(())
248 }
249 pub fn insert_or_update<K, V>(&self, key: K, value: V) -> Result<()>
250 where
251 K: AsRef<[u8]>,
252 V: Serialize + Encode,
253 {
254 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
255 self.kv_tree.insert(key, v)?;
256 Ok(())
257 }
258
259 pub fn contains_key<K>(&self, key: K) -> bool
260 where
261 K: AsRef<[u8]>,
262 {
263 #[cfg(feature = "ttl")]
264 {
265 let exp_v = self.is_expired(&key);
266
267 if let Some(v) = exp_v
269 && v
270 {
271 return false;
272 }
273 }
274
275 self.kv_tree.contains_key(key).ok().unwrap_or(false)
276 }
277
278 pub fn get<K, V>(&self, key: K) -> Option<V>
279 where
280 K: AsRef<[u8]>,
281 V: DeserializeOwned + Decode<()>,
282 {
283 let val = match self.kv_tree.get(key) {
284 Ok(v) => v,
285 Err(e) => {
286 log::error!("kvdb get err: {}", e);
287 return None;
288 }
289 };
290
291 if let Some(v) = val {
292 let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
293 if let Ok((v, _)) = b {
294 return Some(v);
295 }
296 if let Err(e) = b {
297 log::error!("kvdb deserialize error: {}", e.to_string());
298 }
299 return None;
300 }
301
302 None
303 }
304
305 pub fn remove<K>(&self, key: K) -> Result<()>
306 where
307 K: AsRef<[u8]>,
308 {
309 let key_ref = key.as_ref();
310 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
311 kv.remove(key_ref)?;
312 ttl.remove(key_ref)?;
313 Ok::<_, ConflictableTransactionError<()>>(())
314 }) {
315 return Err(anyhow!("remove key err: {:?}", e));
316 }
317 Ok(())
318 }
319
320 pub fn clean(&self) -> Result<()> {
321 self.db.clear()?;
322 self.kv_tree.clear()?;
323 self.ttl_tree.clear()?;
324 Ok(())
325 }
326}