use crate::{DbError, RedisClient, RedisValue, Result};
use redis::FromRedisValue;
pub struct RedisTransaction {
client: RedisClient,
pipe: redis::Pipeline,
watched_keys: Vec<String>,
}
impl RedisTransaction {
pub fn new(client: RedisClient) -> Self {
let mut pipe = redis::pipe();
pipe.atomic();
Self {
client,
pipe,
watched_keys: Vec::new(),
}
}
pub fn watch(&mut self, keys: &[String]) -> &mut Self {
self.watched_keys.extend_from_slice(keys);
self
}
pub fn set(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.pipe.set(key.into(), value.into());
self
}
pub fn get(&mut self, key: impl Into<String>) -> &mut Self {
self.pipe.get(key.into());
self
}
pub fn del(&mut self, keys: &[String]) -> &mut Self {
self.pipe.del(keys);
self
}
pub fn incr(&mut self, key: impl Into<String>) -> &mut Self {
self.pipe.incr(key.into(), 1);
self
}
pub fn decrby(&mut self, key: impl Into<String>, decrement: i64) -> &mut Self {
self.pipe.decr(key.into(), decrement);
self
}
pub fn hset(
&mut self,
key: impl Into<String>,
field: impl Into<String>,
value: impl Into<String>,
) -> &mut Self {
self.pipe.hset(key.into(), field.into(), value.into());
self
}
pub fn hget(&mut self, key: impl Into<String>, field: impl Into<String>) -> &mut Self {
self.pipe.hget(key.into(), field.into());
self
}
pub fn lpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
let key_str = key.into();
for value in values {
self.pipe.lpush(&key_str, value);
}
self
}
pub fn rpush(&mut self, key: impl Into<String>, values: &[String]) -> &mut Self {
let key_str = key.into();
for value in values {
self.pipe.rpush(&key_str, value);
}
self
}
pub fn sadd(&mut self, key: impl Into<String>, members: &[String]) -> &mut Self {
let key_str = key.into();
for member in members {
self.pipe.sadd(&key_str, member);
}
self
}
pub fn zadd(&mut self, key: impl Into<String>, members: &[(f64, String)]) -> &mut Self {
let key_str = key.into();
for (score, member) in members {
self.pipe.zadd(&key_str, member, *score);
}
self
}
pub fn cmd(&mut self, cmd: redis::Cmd) -> &mut Self {
self.pipe.add_command(cmd);
self
}
pub async fn exec<T: FromRedisValue>(self) -> Result<T> {
let mut conn = self
.client
.pool()
.get()
.await
.map_err(|e| DbError::RedisPoolError(format!("获取连接失败: {}", e)))?;
const MAX_RETRIES: usize = 100;
let mut retries = 0;
loop {
if !self.watched_keys.is_empty() {
let mut watch_cmd = redis::cmd("WATCH");
for key in &self.watched_keys {
watch_cmd.arg(key);
}
watch_cmd
.query_async::<()>(&mut *conn)
.await
.map_err(|e| DbError::RedisCommandError(format!("WATCH 命令失败: {}", e)))?;
}
match self.pipe.query_async::<T>(&mut *conn).await {
Ok(result) => {
return Ok(result);
}
Err(e) => {
let err_msg = e.to_string();
if (err_msg.contains("EXECABORT") || err_msg.contains("nil"))
&& !self.watched_keys.is_empty()
{
retries += 1;
if retries >= MAX_RETRIES {
return Err(DbError::RedisCommandError(format!(
"事务执行失败,已重试 {} 次: {}",
MAX_RETRIES, e
)));
}
continue;
} else {
return Err(DbError::RedisCommandError(format!("事务执行失败: {}", e)));
}
}
}
}
}
pub async fn execute(self) -> Result<Vec<RedisValue>> {
let results: Vec<redis::Value> = self.exec().await?;
Ok(results.into_iter().map(RedisValue::from).collect())
}
pub fn len(&self) -> usize {
self.pipe.cmd_iter().count()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_transaction_creation() {
}
}