sled_ext/
lib.rs

1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap()
22        .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26    SystemTime::now()
27        .checked_add(ttl)
28        .unwrap()
29        .duration_since(UNIX_EPOCH)
30        .unwrap()
31        .as_secs()
32}
33
34pub trait ISledExt {
35    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36    where
37        K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41    fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42    where
43        K: AsRef<[u8]> + Sync + Send,
44    {
45        let expire_at = expired_time(ttl).to_be_bytes();
46        self.insert(key, expire_at.as_slice())?;
47        Ok(true)
48    }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53    pub path: String,
54    pub cache_capacity: u64,
55    pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62    pub(crate) kv_tree: sled::Tree,
63    #[cfg(feature = "ttl")]
64    pub(crate) ttl_tree: sled::Tree,
65}
66
67#[cfg(feature = "ttl")]
68pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
69    let t = match interval {
70        Some(d) => d,
71        None => Duration::from_secs(3),
72    };
73    let limit = match limit {
74        Some(l) => l,
75        None => 200,
76    };
77    tokio::spawn(async move {
78        loop {
79            tokio::time::sleep(t).await;
80            loop {
81                let now = std::time::Instant::now();
82                let count = db.cleanup(limit);
83                if count > 0 {
84                    log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
85                }
86                if count < limit {
87                    break;
88                }
89                tokio::time::sleep(std::time::Duration::from_millis(300)).await;
90            }
91        }
92    });
93}
94
95#[cfg(feature = "ttl")]
96pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
97where
98    F: Fn(String) + Send + Sync + 'static,
99{
100    tokio::spawn(async move {
101        for event in db.ttl_tree.watch_prefix(vec![]) {
102            match event {
103                Event::Remove { key } => {
104                    let key = String::from_utf8_lossy(&key).into_owned();
105                    _evt(key);
106                }
107                _ => {}
108            }
109        }
110    });
111}
112
113impl KvDb {
114    pub fn new(cfg: KvDbConfig) -> Result<Self> {
115        let c = Config::default()
116            .path(cfg.path)
117            .cache_capacity(cfg.cache_capacity)
118            .flush_every_ms(Some(cfg.flush_every_ms))
119            .mode(sled::Mode::LowSpace);
120        let db = c.open()?;
121        let kv_tree = db.open_tree(KV_TREE)?;
122        #[cfg(feature = "ttl")]
123        let ttl_tree = db.open_tree(_TTL_TREE)?;
124
125        // let db = Arc::new(db);
126        Ok(KvDb {
127            kv_tree,
128            #[cfg(feature = "ttl")]
129            ttl_tree,
130        })
131    }
132
133    #[cfg(feature = "ttl")]
134    fn cleanup(&self, limit: usize) -> usize {
135        let mut count = 0;
136
137        for item in self.ttl_tree.iter() {
138            if count > limit {
139                break;
140            }
141            let (key, expire_at_iv) = match item {
142                Ok(item) => item,
143                Err(e) => {
144                    log::error!("cleanup err: {:?}", e);
145                    break;
146                }
147            };
148
149            let expire_at = match expire_at_iv.as_ref().try_into() {
150                Ok(at) => u64::from_be_bytes(at),
151                Err(e) => {
152                    log::error!("cleanup err: {:?}", e);
153                    break;
154                }
155            };
156
157            if expire_at > _now() {
158                break;
159            }
160
161            if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
162                kv.remove(key.clone())?;
163                exp.remove(key.clone())?;
164                Ok::<_, ConflictableTransactionError<()>>(())
165            }) {
166                log::error!("cleanup err: {:?}", e);
167            } else {
168                count += 1;
169            }
170        }
171        count
172    }
173
174    #[cfg(feature = "ttl")]
175    pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
176    where
177        K: AsRef<[u8]> + Sync + Send,
178    {
179        let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
180            Ok(Some(at_bytes)) => at_bytes,
181            Ok(None) => return None,
182            Err(e) => {
183                log::error!("get_ttl_at err: {:?}", e);
184                return None;
185            }
186        };
187
188        let expire_at = match expire_at_iv.as_ref().try_into() {
189            Ok(at) => u64::from_be_bytes(at),
190            Err(e) => {
191                log::error!("get_ttl_at err: {:?}", e);
192                return None;
193            }
194        };
195
196        Some(expire_at)
197    }
198
199    #[cfg(feature = "ttl")]
200    pub fn is_expired<K>(&self, key: K) -> Option<bool>
201    where
202        K: AsRef<[u8]> + Sync + Send,
203    {
204        let expire_at = self.get_ttl_at(key);
205
206        let Some(expire_at) = expire_at else {
207            return None;
208        };
209
210        if _now() > expire_at {
211            return Some(true);
212        }
213
214        Some(false)
215    }
216
217    #[cfg(feature = "ttl")]
218    pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
219    where
220        K: AsRef<[u8]>,
221        V: Serialize + Encode + Sync + Send,
222    {
223        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
224        let expire_at = expired_time(ttl).to_be_bytes();
225
226        if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
227            kv.insert(key.as_ref(), v.clone())?;
228            ttl.insert(key.as_ref(), expire_at.as_slice())?;
229            Ok::<_, ConflictableTransactionError<()>>(())
230        }) {
231            return Err(anyhow!("insert_ttl err: {:?}", e));
232        }
233        Ok(())
234    }
235
236    pub fn insert<K, V>(&self, key: K, value: V) -> Result<()>
237    where
238        K: AsRef<[u8]>,
239        V: Serialize + Encode + Sync + Send,
240    {
241        let v = bincode::encode_to_vec(value, bincode::config::standard())?;
242        self.kv_tree.insert(key, v)?;
243        Ok(())
244    }
245
246    pub fn contains_key<K>(&self, key: K) -> bool
247    where
248        K: AsRef<[u8]> + Sync + Send,
249    {
250        #[cfg(feature = "ttl")]
251        {
252            let exp_v = self.is_expired(&key);
253
254            //如果ttl 存在,并已过期 则返回false
255            if let Some(v) = exp_v
256                && v
257            {
258                return false;
259            }
260        }
261
262        self.kv_tree.contains_key(key).ok().unwrap_or(false)
263    }
264
265    pub fn get<K, V>(&self, key: K) -> Option<V>
266    where
267        K: AsRef<[u8]>,
268        V: DeserializeOwned + Decode<()> + Sync + Send,
269    {
270        let val = match self.kv_tree.get(key) {
271            Ok(v) => v,
272            Err(e) => {
273                log::error!("kvdb get err: {}", e);
274                return None;
275            }
276        };
277
278        if let Some(v) = val {
279            let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
280            if let Ok((v, _)) = b {
281                return Some(v);
282            }
283            if let Err(e) = b {
284                log::error!("kvdb deserialize error: {}", e.to_string());
285            }
286            return None;
287        }
288
289        None
290    }
291}