use std::time::Duration;
use redis::{FromRedisValue, ToRedisArgs, aio::MultiplexedConnection};
use sha1_smol::Sha1;
use crate::cache_redis::{RedisCacheError, RedisCacheResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RedisLuaScript {
text: String,
sha1: String,
}
impl RedisLuaScript {
pub fn new(text: impl Into<String>) -> Self {
let text = text.into();
let mut hash = Sha1::new();
hash.update(text.as_bytes());
Self {
text,
sha1: hash.digest().to_string(),
}
}
pub fn text(&self) -> &str {
&self.text
}
pub fn sha1(&self) -> &str {
&self.sha1
}
pub fn is_noscript_error(error: &RedisCacheError) -> bool {
crate::cache_redis::classify_redis_error(error).is_noscript()
}
pub async fn invoke_async<T>(
&self,
connection: &mut MultiplexedConnection,
keys: &[String],
args: &[String],
command_timeout: Duration,
) -> RedisCacheResult<T>
where
T: FromRedisValue,
{
match self
.evalsha_async(connection, keys, args, command_timeout)
.await
{
Ok(value) => Ok(value),
Err(error) if Self::is_noscript_error(&error) => {
self.load_async(connection, command_timeout).await?;
self.evalsha_async(connection, keys, args, command_timeout)
.await
}
Err(error) => Err(error),
}
}
async fn load_async(
&self,
connection: &mut MultiplexedConnection,
command_timeout: Duration,
) -> RedisCacheResult<String> {
tokio::time::timeout(command_timeout, async {
redis::cmd("SCRIPT")
.arg("LOAD")
.arg(self.text.as_bytes())
.query_async::<String>(connection)
.await
})
.await
.map_err(|_| RedisCacheError::Timeout("SCRIPT LOAD".to_string()))?
.map_err(redis_error)
}
async fn evalsha_async<T>(
&self,
connection: &mut MultiplexedConnection,
keys: &[String],
args: &[String],
command_timeout: Duration,
) -> RedisCacheResult<T>
where
T: FromRedisValue,
{
let mut cmd = redis::cmd("EVALSHA");
cmd.arg(&self.sha1).arg(keys.len());
for key in keys {
cmd.arg(key);
}
for arg in args {
cmd.arg(arg);
}
tokio::time::timeout(command_timeout, async {
cmd.query_async::<T>(connection).await
})
.await
.map_err(|_| RedisCacheError::Timeout("EVALSHA".to_string()))?
.map_err(redis_error)
}
}
fn redis_error(error: redis::RedisError) -> RedisCacheError {
RedisCacheError::Backend(error.to_string())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RedisScriptEvent {
CacheHit,
NoScript,
Loaded,
}
impl RedisScriptEvent {
pub fn outcome(self) -> &'static str {
match self {
Self::CacheHit => "cache_hit",
Self::NoScript => "noscript",
Self::Loaded => "loaded",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RedisScriptInvocation {
keys: Vec<String>,
args: Vec<String>,
}
impl RedisScriptInvocation {
pub fn new() -> Self {
Self::default()
}
pub fn key(mut self, key: impl Into<String>) -> Self {
self.keys.push(key.into());
self
}
pub fn arg<T: ToRedisArgs>(mut self, arg: T) -> Self {
let parts = arg
.to_redis_args()
.into_iter()
.map(|bytes| String::from_utf8_lossy(&bytes).into_owned());
self.args.extend(parts);
self
}
pub fn keys(&self) -> &[String] {
&self.keys
}
pub fn args(&self) -> &[String] {
&self.args
}
}