use crate::core::{
error::{RedisError, RedisResult},
value::RespValue,
};
use sha1::{Digest, Sha1};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct Script {
source: String,
sha: String,
}
impl Script {
pub fn new(source: impl Into<String>) -> Self {
let source = source.into();
let sha = calculate_sha1(&source);
Self { source, sha }
}
#[must_use]
pub fn sha(&self) -> &str {
&self.sha
}
#[must_use]
pub fn source(&self) -> &str {
&self.source
}
pub async fn execute<T>(
&self,
client: &crate::Client,
keys: Vec<String>,
args: Vec<String>,
) -> RedisResult<T>
where
T: TryFrom<RespValue>,
T::Error: Into<RedisError>,
{
match client.evalsha(&self.sha, keys.clone(), args.clone()).await {
Ok(result) => Ok(result),
Err(RedisError::Protocol(msg)) if msg.contains("NOSCRIPT") => {
client.eval(&self.source, keys, args).await
}
Err(e) => Err(e),
}
}
pub async fn load(&self, client: &crate::Client) -> RedisResult<String> {
client.script_load(&self.source).await
}
}
#[derive(Debug)]
pub struct ScriptManager {
scripts: Arc<RwLock<HashMap<String, Script>>>,
}
impl ScriptManager {
#[must_use]
pub fn new() -> Self {
Self {
scripts: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, name: impl Into<String>, script: Script) {
let mut scripts = self.scripts.write().await;
scripts.insert(name.into(), script);
}
pub async fn get(&self, name: &str) -> Option<Script> {
let scripts = self.scripts.read().await;
scripts.get(name).cloned()
}
pub async fn execute<T>(
&self,
name: &str,
client: &crate::Client,
keys: Vec<String>,
args: Vec<String>,
) -> RedisResult<T>
where
T: TryFrom<RespValue>,
T::Error: Into<RedisError>,
{
let script = self
.get(name)
.await
.ok_or_else(|| RedisError::Protocol(format!("Script '{}' not found", name)))?;
script.execute(client, keys, args).await
}
pub async fn load_all(&self, client: &crate::Client) -> RedisResult<HashMap<String, String>> {
let scripts = self.scripts.read().await;
let mut results = HashMap::new();
for (name, script) in scripts.iter() {
let sha = script.load(client).await?;
results.insert(name.clone(), sha);
}
Ok(results)
}
#[must_use]
pub async fn len(&self) -> usize {
let scripts = self.scripts.read().await;
scripts.len()
}
#[must_use]
pub async fn is_empty(&self) -> bool {
let scripts = self.scripts.read().await;
scripts.is_empty()
}
pub async fn list_scripts(&self) -> Vec<String> {
let scripts = self.scripts.read().await;
scripts.keys().cloned().collect()
}
pub async fn remove(&self, name: &str) -> Option<Script> {
let mut scripts = self.scripts.write().await;
scripts.remove(name)
}
pub async fn clear(&self) {
let mut scripts = self.scripts.write().await;
scripts.clear();
}
}
impl Default for ScriptManager {
fn default() -> Self {
Self::new()
}
}
fn calculate_sha1(input: &str) -> String {
let mut hasher = Sha1::new();
hasher.update(input.as_bytes());
let result = hasher.finalize();
hex::encode(result)
}
pub mod patterns {
use super::Script;
pub fn atomic_increment_with_expiration() -> Script {
Script::new(
r"
local key = KEYS[1]
local increment = tonumber(ARGV[1])
local expiration = tonumber(ARGV[2])
local current = redis.call('GET', key)
local new_value
if current == false then
new_value = increment
else
new_value = tonumber(current) + increment
end
redis.call('SET', key, new_value)
redis.call('EXPIRE', key, expiration)
return new_value
",
)
}
pub fn conditional_set() -> Script {
Script::new(
r"
local key = KEYS[1]
local expected = ARGV[1]
local new_value = ARGV[2]
local current = redis.call('GET', key)
if current == expected then
redis.call('SET', key, new_value)
return 1
else
return 0
end
",
)
}
pub fn sliding_window_rate_limit() -> Script {
Script::new(
r#"
local key = KEYS[1]
local window = tonumber(ARGV[1])
local limit = tonumber(ARGV[2])
local now = redis.call('TIME')[1]
-- Remove old entries
redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
-- Count current entries
local current = redis.call('ZCARD', key)
if current < limit then
-- Add current request
redis.call('ZADD', key, now, now)
redis.call('EXPIRE', key, window)
return { 1, limit - current - 1 }
else
return { 0, 0 }
end
"#,
)
}
pub fn distributed_lock() -> Script {
Script::new(
r#"
local key = KEYS[1]
local identifier = ARGV[1]
local expiration = tonumber(ARGV[2])
if redis.call('SET', key, identifier, 'NX', 'EX', expiration) then
return 1
else
return 0
end
"#,
)
}
pub fn release_lock() -> Script {
Script::new(
r#"
local key = KEYS[1]
local identifier = ARGV[1]
if redis.call('GET', key) == identifier then
return redis.call('DEL', key)
else
return 0
end
"#,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_script_creation() {
let script = Script::new("return 'hello'");
assert_eq!(script.source(), "return 'hello'");
assert!(!script.sha().is_empty());
assert_eq!(script.sha().len(), 40); }
#[test]
fn test_sha1_calculation() {
let sha = calculate_sha1("hello world");
assert_eq!(sha, "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed");
}
#[test]
fn test_script_sha_consistency() {
let script1 = Script::new("return 1");
let script2 = Script::new("return 1");
assert_eq!(script1.sha(), script2.sha());
}
#[test]
fn test_script_sha_uniqueness() {
let script1 = Script::new("return 1");
let script2 = Script::new("return 2");
assert_ne!(script1.sha(), script2.sha());
}
#[tokio::test]
async fn test_script_manager_creation() {
let manager = ScriptManager::new();
assert!(manager.is_empty().await);
assert_eq!(manager.len().await, 0);
}
#[tokio::test]
async fn test_script_manager_register_and_get() {
let manager = ScriptManager::new();
let script = Script::new("return 'test'");
let sha = script.sha().to_string();
manager.register("test_script", script).await;
assert!(!manager.is_empty().await);
assert_eq!(manager.len().await, 1);
let retrieved = manager.get("test_script").await.unwrap();
assert_eq!(retrieved.sha(), sha);
assert_eq!(retrieved.source(), "return 'test'");
}
#[tokio::test]
async fn test_script_manager_remove() {
let manager = ScriptManager::new();
let script = Script::new("return 'test'");
manager.register("test_script", script).await;
assert_eq!(manager.len().await, 1);
let removed = manager.remove("test_script").await;
assert!(removed.is_some());
assert_eq!(manager.len().await, 0);
let not_found = manager.remove("nonexistent").await;
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_script_manager_clear() {
let manager = ScriptManager::new();
manager.register("script1", Script::new("return 1")).await;
manager.register("script2", Script::new("return 2")).await;
assert_eq!(manager.len().await, 2);
manager.clear().await;
assert_eq!(manager.len().await, 0);
assert!(manager.is_empty().await);
}
#[tokio::test]
async fn test_script_manager_list_scripts() {
let manager = ScriptManager::new();
manager
.register("script_a", Script::new("return 'a'"))
.await;
manager
.register("script_b", Script::new("return 'b'"))
.await;
let mut scripts = manager.list_scripts().await;
scripts.sort();
assert_eq!(scripts, vec!["script_a", "script_b"]);
}
#[test]
fn test_pattern_scripts() {
let _increment = patterns::atomic_increment_with_expiration();
let _conditional = patterns::conditional_set();
let _rate_limit = patterns::sliding_window_rate_limit();
let _lock = patterns::distributed_lock();
let _unlock = patterns::release_lock();
}
}