Skip to main content

tibba_cache/
cache.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{CompressionSnafu, Error, RedisClient, RedisClientConn, RedisSnafu, SerdeJsonSnafu};
16use deadpool_redis::redis::{cmd, pipe};
17use redis::AsyncCommands;
18use serde::{Serialize, de::DeserializeOwned};
19use snafu::ResultExt;
20use std::{borrow::Cow, time::Duration};
21use tibba_util::{Algorithm, compress, decompress};
22
23const DEFAULT_ZSTD: Algorithm = Algorithm::Zstd(3);
24
25type Result<T> = std::result::Result<T, Error>;
26
27/// Redis 缓存封装,提供键值读写、分布式锁、计数器等常用缓存操作。
28pub struct RedisCache {
29    /// 缓存条目的默认过期时长
30    ttl: Duration,
31    /// 所有缓存键统一添加的前缀
32    prefix: String,
33    /// Redis 连接池
34    client: &'static RedisClient,
35}
36
37impl RedisCache {
38    #[inline]
39    pub async fn conn(&self) -> Result<RedisClientConn> {
40        self.client.conn().await
41    }
42
43    /// 创建新的 RedisCache 实例,默认 TTL 10 分钟,无前缀。
44    pub fn new(client: &'static RedisClient) -> Self {
45        Self {
46            ttl: Duration::from_secs(10 * 60),
47            prefix: String::new(),
48            client,
49        }
50    }
51
52    /// 设置缓存条目的过期时长,支持链式调用。
53    #[must_use]
54    pub fn with_ttl(mut self, ttl: Duration) -> Self {
55        self.ttl = ttl;
56        self
57    }
58
59    /// 设置所有缓存键的前缀,支持链式调用。
60    #[must_use]
61    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
62        self.prefix = prefix.into();
63        self
64    }
65
66    #[inline]
67    fn get_ttl(&self, ttl: Option<Duration>) -> u64 {
68        ttl.unwrap_or(self.ttl).as_secs()
69    }
70
71    /// 拼接前缀与键名,生成完整的缓存键。
72    /// 前缀为空时直接借用原始键,避免额外分配。
73    #[inline]
74    fn get_key<'a>(&'a self, key: &'a str) -> Cow<'a, str> {
75        if self.prefix.is_empty() {
76            Cow::Borrowed(key)
77        } else {
78            Cow::Owned(format!("{}{}", self.prefix, key))
79        }
80    }
81
82    /// 向 Redis 发送 PING 以检测连接是否正常。
83    pub async fn ping(&self) -> Result<()> {
84        let () = self
85            .conn()
86            .await?
87            .ping()
88            .await
89            .context(RedisSnafu { category: "ping" })?;
90        Ok(())
91    }
92
93    /// 从 Redis 读取原始值,类型由调用方通过泛型指定。
94    async fn get_value<T: redis::FromRedisValue>(&self, key: &str) -> Result<T> {
95        let result = self
96            .conn()
97            .await?
98            .get(key)
99            .await
100            .context(RedisSnafu { category: "get" })?;
101
102        Ok(result)
103    }
104
105    /// 向 Redis 写入原始值,并设置过期时间(秒)。
106    async fn set_value<T: redis::ToSingleRedisArg + Send + Sync>(
107        &self,
108        key: &str,
109        value: T,
110        ttl: u64,
111    ) -> Result<()> {
112        let () = self
113            .conn()
114            .await?
115            .set_ex(key, value, ttl)
116            .await
117            .context(RedisSnafu { category: "set" })?;
118        Ok(())
119    }
120
121    /// 尝试通过 SET NX 获取分布式锁。
122    /// 返回 `true` 表示加锁成功,`false` 表示锁已被持有。
123    pub async fn lock(&self, key: &str, ttl: Option<Duration>) -> Result<bool> {
124        let mut conn = self.conn().await?;
125
126        let result = cmd("SET")
127            .arg(self.get_key(key))
128            .arg(true)
129            .arg("NX")
130            .arg("EX")
131            .arg(self.get_ttl(ttl))
132            .query_async(&mut conn)
133            .await
134            .context(RedisSnafu { category: "lock" })?;
135        Ok(result)
136    }
137
138    /// 删除指定键。
139    pub async fn del(&self, key: &str) -> Result<()> {
140        let () = self
141            .conn()
142            .await?
143            .del(self.get_key(key))
144            .await
145            .context(RedisSnafu { category: "del" })?;
146
147        Ok(())
148    }
149
150    /// 原子性地将计数器累加 delta,返回累加后的值。
151    /// 键不存在时先用 SET NX 初始化为 0 再执行 INCRBY。
152    pub async fn incr(&self, key: &str, delta: i64, ttl: Option<Duration>) -> Result<i64> {
153        let mut conn = self.conn().await?;
154        let k = self.get_key(key);
155        // 这里的逻辑逻辑更加自然
156        let (count, _) = pipe()
157            .cmd("INCRBY")
158            .arg(&k)
159            .arg(delta) // 1. 先累加(不存在会自动创建,且无 TTL)
160            .cmd("EXPIRE")
161            .arg(&k)
162            .arg(self.get_ttl(ttl))
163            .arg("NX") // 2. 只有它没有 TTL 时(即刚创建时)才设 TTL
164            .query_async::<(i64, bool)>(&mut conn)
165            .await
166            .context(RedisSnafu { category: "incr" })?;
167        Ok(count)
168    }
169
170    /// 向 Redis 写入值,TTL 为 None 时使用实例默认值。
171    pub async fn set<T: redis::ToSingleRedisArg + Send + Sync>(
172        &self,
173        key: &str,
174        value: T,
175        ttl: Option<Duration>,
176    ) -> Result<()> {
177        self.set_value(&self.get_key(key), value, self.get_ttl(ttl))
178            .await
179    }
180
181    /// 从 Redis 读取值,类型由泛型参数指定。
182    pub async fn get<T: redis::FromRedisValue>(&self, key: &str) -> Result<T> {
183        self.get_value::<T>(&self.get_key(key)).await
184    }
185
186    /// 将结构体序列化为 JSON 后存入 Redis。
187    pub async fn set_struct<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
188    where
189        T: ?Sized + Serialize,
190    {
191        let value = serde_json::to_vec(&value).context(SerdeJsonSnafu)?;
192        self.set_value(&self.get_key(key), &value, self.get_ttl(ttl))
193            .await?;
194        Ok(())
195    }
196
197    /// 从 Redis 读取并反序列化为结构体,键不存在时返回 `None`。
198    pub async fn get_struct<T>(&self, key: &str) -> Result<Option<T>>
199    where
200        T: DeserializeOwned,
201    {
202        let buf: Option<Vec<u8>> = self.get_value(&self.get_key(key)).await?;
203        match buf {
204            None => Ok(None),
205            Some(b) => serde_json::from_slice(&b).context(SerdeJsonSnafu).map(Some),
206        }
207    }
208
209    /// 获取指定键的剩余过期时间(秒)。
210    /// 返回 -2 表示键不存在,-1 表示键无过期时间。
211    pub async fn ttl(&self, key: &str) -> Result<i32> {
212        let result = self
213            .conn()
214            .await?
215            .ttl(self.get_key(key))
216            .await
217            .context(RedisSnafu { category: "ttl" })?;
218
219        Ok(result)
220    }
221
222    /// 原子性地读取并删除指定键(需 Redis ≥6.2.0)。
223    pub async fn get_del<T: redis::FromRedisValue>(&self, key: &str) -> Result<T> {
224        let result = self
225            .conn()
226            .await?
227            .get_del(self.get_key(key))
228            .await
229            .context(RedisSnafu {
230                category: "get_del",
231            })?;
232
233        Ok(result)
234    }
235
236    /// 检查指定键是否存在。
237    pub async fn exists(&self, key: &str) -> Result<bool> {
238        let result = self
239            .conn()
240            .await?
241            .exists(self.get_key(key))
242            .await
243            .context(RedisSnafu { category: "exists" })?;
244        Ok(result)
245    }
246
247    /// 刷新指定键的过期时间而不修改其值。
248    /// 返回 `true` 表示刷新成功,`false` 表示键不存在。
249    pub async fn expire(&self, key: &str, ttl: Option<Duration>) -> Result<bool> {
250        let result = self
251            .conn()
252            .await?
253            .expire(self.get_key(key), self.get_ttl(ttl) as i64)
254            .await
255            .context(RedisSnafu { category: "expire" })?;
256        Ok(result)
257    }
258
259    async fn set_struct_compressed<T>(
260        &self,
261        key: &str,
262        value: &T,
263        ttl: u64,
264        algorithm: Algorithm,
265    ) -> Result<()>
266    where
267        T: ?Sized + Serialize,
268    {
269        let value = serde_json::to_vec(value).context(SerdeJsonSnafu)?;
270        let buf = compress(&value, algorithm).context(CompressionSnafu)?;
271        self.set_value(key, &buf, ttl).await
272    }
273
274    async fn get_struct_compressed<T>(&self, key: &str, algorithm: Algorithm) -> Result<Option<T>>
275    where
276        T: DeserializeOwned,
277    {
278        let value: Option<Vec<u8>> = self.get_value(&self.get_key(key)).await?;
279        match value {
280            None => Ok(None),
281            Some(compressed_buf) => {
282                let buf = decompress(&compressed_buf, algorithm).context(CompressionSnafu)?;
283                serde_json::from_slice(&buf)
284                    .context(SerdeJsonSnafu)
285                    .map(Some)
286            }
287        }
288    }
289
290    /// 将结构体序列化为 JSON 并以 LZ4 压缩后存入 Redis。
291    /// LZ4 压缩速度快,适合对延迟敏感的场景。
292    pub async fn set_struct_lz4<T>(&self, key: &str, value: &T, ttl: Option<Duration>) -> Result<()>
293    where
294        T: ?Sized + Serialize,
295    {
296        self.set_struct_compressed(&self.get_key(key), value, self.get_ttl(ttl), Algorithm::Lz4)
297            .await
298    }
299
300    /// 从 Redis 读取并以 LZ4 解压后反序列化为结构体,键不存在时返回 `None`。
301    pub async fn get_struct_lz4<T>(&self, key: &str) -> Result<Option<T>>
302    where
303        T: DeserializeOwned,
304    {
305        self.get_struct_compressed(key, Algorithm::Lz4).await
306    }
307
308    /// 将结构体序列化为 JSON 并以 Zstd 压缩后存入 Redis。
309    /// Zstd 压缩率更高,适合对存储空间敏感的场景。
310    pub async fn set_struct_zstd<T>(
311        &self,
312        key: &str,
313        value: &T,
314        ttl: Option<Duration>,
315    ) -> Result<()>
316    where
317        T: ?Sized + Serialize,
318    {
319        self.set_struct_compressed(&self.get_key(key), value, self.get_ttl(ttl), DEFAULT_ZSTD)
320            .await
321    }
322
323    /// 从 Redis 读取并以 Zstd 解压后反序列化为结构体,键不存在时返回 `None`。
324    pub async fn get_struct_zstd<T>(&self, key: &str) -> Result<Option<T>>
325    where
326        T: DeserializeOwned,
327    {
328        self.get_struct_compressed(key, DEFAULT_ZSTD).await
329    }
330}