1use anyhow::Result;
2#[cfg(feature = "ttl")]
3use anyhow::anyhow;
4pub use bincode::{Decode, Encode};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7#[cfg(feature = "ttl")]
8use sled::Event;
9#[cfg(feature = "ttl")]
10use sled::Transactional;
11#[cfg(feature = "ttl")]
12use sled::transaction::ConflictableTransactionError;
13use sled::{Config, Db};
14
15#[cfg(feature = "ttl")]
16use std::sync::Arc;
17use std::time::{Duration, SystemTime, UNIX_EPOCH};
18fn _now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap()
22 .as_secs()
23}
24
25fn expired_time(ttl: Duration) -> u64 {
26 SystemTime::now()
27 .checked_add(ttl)
28 .unwrap()
29 .duration_since(UNIX_EPOCH)
30 .unwrap()
31 .as_secs()
32}
33
34pub trait ISledExt {
35 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
36 where
37 K: AsRef<[u8]> + Sync + Send;
38}
39
40impl ISledExt for Db {
41 fn expire<K>(&self, key: K, ttl: Duration) -> Result<bool>
42 where
43 K: AsRef<[u8]> + Sync + Send,
44 {
45 let expire_at = expired_time(ttl).to_be_bytes();
46 self.insert(key, expire_at.as_slice())?;
47 Ok(true)
48 }
49}
50
51#[derive(Serialize, Deserialize)]
52pub struct KvDbConfig {
53 pub path: String,
54 pub cache_capacity: u64,
55 pub flush_every_ms: u64,
56}
57
58const KV_TREE: &[u8] = b"__kv_tree@";
59const _TTL_TREE: &[u8] = b"__tll_tree@";
60
61pub struct KvDb {
62 pub(crate) kv_tree: sled::Tree,
63 #[cfg(feature = "ttl")]
64 pub(crate) ttl_tree: sled::Tree,
65}
66
67#[cfg(feature = "ttl")]
68pub fn def_ttl_cleanup(db: Arc<KvDb>, interval: Option<Duration>, limit: Option<usize>) {
69 let t = match interval {
70 Some(d) => d,
71 None => Duration::from_secs(3),
72 };
73 let limit = match limit {
74 Some(l) => l,
75 None => 200,
76 };
77 tokio::spawn(async move {
78 loop {
79 tokio::time::sleep(t).await;
80 loop {
81 let now = std::time::Instant::now();
82 let count = db.cleanup(limit);
83 if count > 0 {
84 log::debug!("cleanup count: {}, cost time: {:?}", count, now.elapsed());
85 }
86 if count < limit {
87 break;
88 }
89 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
90 }
91 }
92 });
93}
94
95#[cfg(feature = "ttl")]
96pub fn set_expire_event<F>(db: Arc<KvDb>, _evt: F)
97where
98 F: Fn(String) + Send + Sync + 'static,
99{
100 tokio::spawn(async move {
101 for event in db.ttl_tree.watch_prefix(vec![]) {
102 match event {
103 Event::Remove { key } => {
104 let key = String::from_utf8_lossy(&key).into_owned();
105 _evt(key);
106 }
107 _ => {}
108 }
109 }
110 });
111}
112
113impl KvDb {
114 pub fn new(cfg: KvDbConfig) -> Result<Self> {
115 let c = Config::default()
116 .path(cfg.path)
117 .cache_capacity(cfg.cache_capacity)
118 .flush_every_ms(Some(cfg.flush_every_ms))
119 .mode(sled::Mode::LowSpace);
120 let db = c.open()?;
121 let kv_tree = db.open_tree(KV_TREE)?;
122 #[cfg(feature = "ttl")]
123 let ttl_tree = db.open_tree(_TTL_TREE)?;
124
125 Ok(KvDb {
127 kv_tree,
128 #[cfg(feature = "ttl")]
129 ttl_tree,
130 })
131 }
132
133 #[cfg(feature = "ttl")]
134 fn cleanup(&self, limit: usize) -> usize {
135 let mut count = 0;
136
137 for item in self.ttl_tree.iter() {
138 if count > limit {
139 break;
140 }
141 let (key, expire_at_iv) = match item {
142 Ok(item) => item,
143 Err(e) => {
144 log::error!("cleanup err: {:?}", e);
145 break;
146 }
147 };
148
149 let expire_at = match expire_at_iv.as_ref().try_into() {
150 Ok(at) => u64::from_be_bytes(at),
151 Err(e) => {
152 log::error!("cleanup err: {:?}", e);
153 break;
154 }
155 };
156
157 if expire_at > _now() {
158 break;
159 }
160
161 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, exp)| {
162 kv.remove(key.clone())?;
163 exp.remove(key.clone())?;
164 Ok::<_, ConflictableTransactionError<()>>(())
165 }) {
166 log::error!("cleanup err: {:?}", e);
167 } else {
168 count += 1;
169 }
170 }
171 count
172 }
173
174 #[cfg(feature = "ttl")]
175 pub fn get_ttl_at<K>(&self, key: K) -> Option<u64>
176 where
177 K: AsRef<[u8]> + Sync + Send,
178 {
179 let expire_at_iv = match self.ttl_tree.get(key.as_ref()) {
180 Ok(Some(at_bytes)) => at_bytes,
181 Ok(None) => return None,
182 Err(e) => {
183 log::error!("get_ttl_at err: {:?}", e);
184 return None;
185 }
186 };
187
188 let expire_at = match expire_at_iv.as_ref().try_into() {
189 Ok(at) => u64::from_be_bytes(at),
190 Err(e) => {
191 log::error!("get_ttl_at err: {:?}", e);
192 return None;
193 }
194 };
195
196 Some(expire_at)
197 }
198
199 #[cfg(feature = "ttl")]
200 pub fn is_expired<K>(&self, key: K) -> Option<bool>
201 where
202 K: AsRef<[u8]> + Sync + Send,
203 {
204 let expire_at = self.get_ttl_at(key);
205
206 let Some(expire_at) = expire_at else {
207 return None;
208 };
209
210 if _now() > expire_at {
211 return Some(true);
212 }
213
214 Some(false)
215 }
216
217 #[cfg(feature = "ttl")]
218 pub fn insert_ttl<K, V>(&self, key: K, value: V, ttl: Duration) -> Result<()>
219 where
220 K: AsRef<[u8]>,
221 V: Serialize + Encode + Sync + Send,
222 {
223 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
224 let expire_at = expired_time(ttl).to_be_bytes();
225
226 if let Err(e) = (&self.kv_tree, &self.ttl_tree).transaction(|(kv, ttl)| {
227 kv.insert(key.as_ref(), v.clone())?;
228 ttl.insert(key.as_ref(), expire_at.as_slice())?;
229 Ok::<_, ConflictableTransactionError<()>>(())
230 }) {
231 return Err(anyhow!("insert_ttl err: {:?}", e));
232 }
233 Ok(())
234 }
235
236 pub fn insert<K, V>(&self, key: K, value: V) -> Result<()>
237 where
238 K: AsRef<[u8]>,
239 V: Serialize + Encode + Sync + Send,
240 {
241 let v = bincode::encode_to_vec(value, bincode::config::standard())?;
242 self.kv_tree.insert(key, v)?;
243 Ok(())
244 }
245
246 pub fn contains_key<K>(&self, key: K) -> bool
247 where
248 K: AsRef<[u8]> + Sync + Send,
249 {
250 #[cfg(feature = "ttl")]
251 {
252 let exp_v = self.is_expired(&key);
253
254 if let Some(v) = exp_v
256 && v
257 {
258 return false;
259 }
260 }
261
262 self.kv_tree.contains_key(key).ok().unwrap_or(false)
263 }
264
265 pub fn get<K, V>(&self, key: K) -> Option<V>
266 where
267 K: AsRef<[u8]>,
268 V: DeserializeOwned + Decode<()> + Sync + Send,
269 {
270 let val = match self.kv_tree.get(key) {
271 Ok(v) => v,
272 Err(e) => {
273 log::error!("kvdb get err: {}", e);
274 return None;
275 }
276 };
277
278 if let Some(v) = val {
279 let b = bincode::decode_from_slice::<V, _>(v.as_ref(), bincode::config::standard());
280 if let Ok((v, _)) = b {
281 return Some(v);
282 }
283 if let Err(e) = b {
284 log::error!("kvdb deserialize error: {}", e.to_string());
285 }
286 return None;
287 }
288
289 None
290 }
291}