Skip to main content

sled_ext/
lib.rs

1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap()
22        .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26    SystemTime::now()
27        .checked_add(ttl)
28        .unwrap()
29        .duration_since(UNIX_EPOCH)
30        .unwrap()
31        .as_secs()
32}
33
34pub trait ISledExt {
35    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36    where
37        K: AsRef<[u8]>;
38}
39
40impl ISledExt for Db {
41    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42    where
43        K: AsRef<[u8]>,
44    {
45        let expire_at = expired_time(ttl).to_be_bytes();
46        self.insert(key, expire_at.as_slice())?;
47        Ok(true)
48    }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53    pub path: String,
54    pub cache_capacity: u64,
55    pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62    pub db: Db,
63    pub kv_tree: sled::Tree,
64    #[cfg(feature = "ttl")]
65    pub ttl_tree: sled::Tree,
66}
67
68#[cfg(feature = "ttl")]
69pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
70    let t = match interval {
71        Some(d) => d,
72        None => Duration::from_secs(3),
73    };
74    let limit = match limit {
75        Some(l) => l,
76        None => 200,
77    };
78    tokio::spawn(async move {
79        loop {
80            tokio::time::sleep(t).await;
81            loop {
82                let now = std::time::Instant::now();
83                let count = db.cleanup(limit);
84                if count > 0 {
85                    log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
86                }
87                if count < limit {
88                    break;
89                }
90                tokio::time::sleep(std::time::Duration::from_millis(300)).await;
91            }
92        }
93    });
94}
95
96#[cfg(feature = "ttl")]
97pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
98where
99    F: Fn(String) + Send + Sync + 'static,
100{
101    tokio::spawn(async move {
102        for event in db.ttl_tree.watch_prefix(vec![]) {
103            match event {
104                Event::Remove { key } => {
105                    let key = String::from_utf8_lossy(&key).into_owned();
106                    _evt(key);
107                }
108                _ => {}
109            }
110        }
111    });
112}
113
114impl KvDb {
115    pub fn new(cfg: KvDbConfig) -> Result<Self> {
116        let c = Config::default()
117            .path(cfg.path)
118            .cache_capacity(cfg.cache_capacity)
119            .flush_every_ms(Some(cfg.flush_every_ms))
120            .mode(sled::Mode::LowSpace);
121        let db = c.open()?;
122        let kv_tree = db.open_tree(KV_TREE)?;
123        #[cfg(feature = "ttl")]
124        let ttl_tree = db.open_tree(_TTL_TREE)?;
125
126        Ok(KvDb {
127            db,
128            kv_tree,
129            #[cfg(feature = "ttl")]
130            ttl_tree,
131        })
132    }
133
134    #[cfg(feature = "ttl")]
135    fn cleanup(&self, limit: usize) -> usize {
136        let mut count = 0;
137
138        for item in self.ttl_tree.iter() {
139            if count > limit {
140                break;
141            }
142            let (key, expire_at_iv) = match item {
143                Ok(item) => item,
144                Err(e) => {
145                    log::error!("cleanup err: {:?}", e);
146                    break;
147                }
148            };
149
150            let expire_at = match expire_at_iv.as_ref().try_into() {
151                Ok(at) => u64::from_be_bytes(at),
152                Err(e) => {
153                    log::error!("cleanup err: {:?}", e);
154                    break;
155                }
156            };
157
158            if expire_at > _now() {
159                break;
160            }
161
162            if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
163                kv.remove(key.clone())?;
164                exp.remove(key.clone())?;
165                Ok::<_, ConflictableTransactionError<()>>(())
166            }) {
167                log::error!("cleanup err: {:?}", e);
168            } else {
169                count += 1;
170            }
171        }
172        count
173    }
174
175    #[cfg(feature = "ttl")]
176    pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
177    where
178        K: AsRef<[u8]>,
179    {
180        let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
181            Ok(Some(at_bytes)) => at_bytes,
182            Ok(None) => return None,
183            Err(e) => {
184                log::error!("get_ttl_at err: {:?}", e);
185                return None;
186            }
187        };
188
189        let expire_at = match expire_at_iv.as_ref().try_into() {
190            Ok(at) => u64::from_be_bytes(at),
191            Err(e) => {
192                log::error!("get_ttl_at err: {:?}", e);
193                return None;
194            }
195        };
196
197        Some(expire_at)
198    }
199
200    #[cfg(feature = "ttl")]
201    pub fn is_expired<K>(&self, key: K) -> Option<bool>
202    where
203        K: AsRef<[u8]>,
204    {
205        let expire_at = self.get_ttl_at(key);
206
207        let Some(expire_at) = expire_at else {
208            return None;
209        };
210
211        if _now() > expire_at {
212            return Some(true);
213        }
214
215        Some(false)
216    }
217
218    #[cfg(feature = "ttl")]
219    pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
220    where
221        K: AsRef<[u8]>,
222        V: Serialize + Encode,
223    {
224        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
225        let expire_at = expired_time(ttl).to_be_bytes();
226
227        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
228            kv.insert(key.as_ref(), v.clone())?;
229            ttl.insert(key.as_ref(), expire_at.as_slice())?;
230            Ok::<_, ConflictableTransactionError<()>>(())
231        }) {
232            return Err(anyhow!("insert_ttl err: {:?}", e));
233        }
234        Ok(())
235    }
236
237    #[cfg(feature = "ttl")]
238    pub fn refresh_ttl<K>(&self, key: K, ttl: Duration) -> Result<()>
239    where
240        K: AsRef<[u8]>,
241    {
242        if !self.contains_key(&key) {
243            return Err(anyhow!("key is not exist"));
244        }
245        let expire_at = expired_time(ttl).to_be_bytes();
246        self.ttl_tree.insert(key, expire_at.as_slice())?;
247        Ok(())
248    }
249    pub fn insert_or_update<K, V>(&self, key: K, value: V) -> Result<()>
250    where
251        K: AsRef<[u8]>,
252        V: Serialize + Encode,
253    {
254        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
255        self.kv_tree.insert(key, v)?;
256        Ok(())
257    }
258
259    pub fn contains_key<K>(&self, key: K) -> bool
260    where
261        K: AsRef<[u8]>,
262    {
263        #[cfg(feature = "ttl")]
264        {
265            let exp_v = self.is_expired(&key);
266
267            //如果ttl 存在,并已过期 则返回false
268            if let Some(v) = exp_v
269                && v
270            {
271                return false;
272            }
273        }
274
275        self.kv_tree.contains_key(key).ok().unwrap_or(false)
276    }
277
278    pub fn get<K, V>(&self, key: K) -> Option<V>
279    where
280        K: AsRef<[u8]>,
281        V: DeserializeOwned + Decode<()>,
282    {
283        let val = match self.kv_tree.get(key) {
284            Ok(v) => v,
285            Err(e) => {
286                log::error!("kvdb get err: {}", e);
287                return None;
288            }
289        };
290
291        if let Some(v) = val {
292            let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
293            if let Ok((v, _)) = b {
294                return Some(v);
295            }
296            if let Err(e) = b {
297                log::error!("kvdb deserialize error: {}", e.to_string());
298            }
299            return None;
300        }
301
302        None
303    }
304
305    pub fn remove<K>(&self, key: K) -> Result<()>
306    where
307        K: AsRef<[u8]>,
308    {
309        let key_ref = key.as_ref();
310        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
311            kv.remove(key_ref)?;
312            ttl.remove(key_ref)?;
313            Ok::<_, ConflictableTransactionError<()>>(())
314        }) {
315            return Err(anyhow!("remove key err: {:?}", e));
316        }
317        Ok(())
318    }
319
320    pub fn clean(&self) -> Result<()> {
321        self.db.clear()?;
322        self.kv_tree.clear()?;
323        self.ttl_tree.clear()?;
324        Ok(())
325    }
326}