redis_oxide/script.rs
1//! Lua scripting support for Redis
2//!
3//! This module provides functionality for executing Lua scripts on Redis servers
4//! using EVAL and EVALSHA commands. Scripts are automatically cached and managed
5//! for optimal performance.
6//!
7//! # Examples
8//!
9//! ## Basic Script Execution
10//!
11//! ```no_run
12//! use redis_oxide::{Client, ConnectionConfig};
13//!
14//! # #[tokio::main]
15//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//! let config = ConnectionConfig::new("redis://localhost:6379");
17//! let client = Client::connect(config).await?;
18//!
19//! // Execute a simple Lua script
20//! let script = "return redis.call('GET', KEYS[1])";
21//! let result: Option<String> = client.eval(script, vec!["mykey".to_string()], vec![]).await?;
22//! println!("Result: {:?}", result);
23//! # Ok(())
24//! # }
25//! ```
26//!
27//! ## Script with Arguments
28//!
29//! ```no_run
30//! use redis_oxide::{Client, ConnectionConfig};
31//!
32//! # #[tokio::main]
33//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//! let config = ConnectionConfig::new("redis://localhost:6379");
35//! let client = Client::connect(config).await?;
36//!
37//! // Script that increments a counter by a given amount
38//! let script = r#"
39//! local current = redis.call('GET', KEYS[1]) or 0
40//! local increment = tonumber(ARGV[1])
41//! local new_value = tonumber(current) + increment
42//! redis.call('SET', KEYS[1], new_value)
43//! return new_value
44//! "#;
45//!
46//! let result: i64 = client.eval(
47//! script,
48//! vec!["counter".to_string()],
49//! vec!["5".to_string()]
50//! ).await?;
51//! println!("New counter value: {}", result);
52//! # Ok(())
53//! # }
54//! ```
55//!
56//! ## Using Script Manager for Caching
57//!
58//! ```no_run
59//! use redis_oxide::{Client, ConnectionConfig, Script};
60//!
61//! # #[tokio::main]
62//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
63//! let config = ConnectionConfig::new("redis://localhost:6379");
64//! let client = Client::connect(config).await?;
65//!
66//! // Create a reusable script
67//! let script = Script::new(r#"
68//! local key = KEYS[1]
69//! local value = ARGV[1]
70//! redis.call('SET', key, value)
71//! return redis.call('GET', key)
72//! "#);
73//!
74//! // Execute the script (automatically uses EVALSHA if cached)
75//! let result: String = script.execute(
76//! &client,
77//! vec!["mykey".to_string()],
78//! vec!["myvalue".to_string()]
79//! ).await?;
80//! println!("Result: {}", result);
81//! # Ok(())
82//! # }
83//! ```
84
85use crate::core::{
86 error::{RedisError, RedisResult},
87 value::RespValue,
88};
89use sha1::{Digest, Sha1};
90use std::collections::HashMap;
91use std::convert::TryFrom;
92use std::sync::Arc;
93use tokio::sync::RwLock;
94
95/// A Lua script that can be executed on Redis
96#[derive(Debug, Clone)]
97pub struct Script {
98 /// The Lua script source code
99 source: String,
100 /// SHA1 hash of the script (for EVALSHA)
101 sha: String,
102}
103
104impl Script {
105 /// Create a new script from Lua source code
106 ///
107 /// The script is automatically hashed for use with EVALSHA.
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// use redis_oxide::Script;
113 ///
114 /// let script = Script::new("return redis.call('GET', KEYS[1])");
115 /// println!("Script SHA: {}", script.sha());
116 /// ```
117 pub fn new(source: impl Into<String>) -> Self {
118 let source = source.into();
119 let sha = calculate_sha1(&source);
120
121 Self { source, sha }
122 }
123
124 /// Get the SHA1 hash of the script
125 #[must_use]
126 pub fn sha(&self) -> &str {
127 &self.sha
128 }
129
130 /// Get the source code of the script
131 #[must_use]
132 pub fn source(&self) -> &str {
133 &self.source
134 }
135
136 /// Execute the script on the given client
137 ///
138 /// This method will first try to use EVALSHA (if the script is cached on the server),
139 /// and fall back to EVAL if the script is not cached.
140 ///
141 /// # Arguments
142 ///
143 /// * `client` - The Redis client to execute the script on
144 /// * `keys` - List of Redis keys that the script will access (KEYS array in Lua)
145 /// * `args` - List of arguments to pass to the script (ARGV array in Lua)
146 ///
147 /// # Examples
148 ///
149 /// ```no_run
150 /// # use redis_oxide::{Client, ConnectionConfig, Script};
151 /// # #[tokio::main]
152 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
153 /// # let config = ConnectionConfig::new("redis://localhost:6379");
154 /// # let client = Client::connect(config).await?;
155 /// let script = Script::new("return KEYS[1] .. ':' .. ARGV[1]");
156 ///
157 /// let result: String = script.execute(
158 /// &client,
159 /// vec!["user".to_string()],
160 /// vec!["123".to_string()]
161 /// ).await?;
162 ///
163 /// assert_eq!(result, "user:123");
164 /// # Ok(())
165 /// # }
166 /// ```
167 pub async fn execute<T>(
168 &self,
169 client: &crate::Client,
170 keys: Vec<String>,
171 args: Vec<String>,
172 ) -> RedisResult<T>
173 where
174 T: TryFrom<RespValue>,
175 T::Error: Into<RedisError>,
176 {
177 // First try EVALSHA
178 match client.evalsha(&self.sha, keys.clone(), args.clone()).await {
179 Ok(result) => Ok(result),
180 Err(RedisError::Protocol(msg)) if msg.contains("NOSCRIPT") => {
181 // Script not cached, use EVAL
182 client.eval(&self.source, keys, args).await
183 }
184 Err(e) => Err(e),
185 }
186 }
187
188 /// Load the script into Redis cache
189 ///
190 /// This sends the script to Redis using SCRIPT LOAD, which caches it
191 /// for future EVALSHA calls.
192 ///
193 /// # Examples
194 ///
195 /// ```no_run
196 /// # use redis_oxide::{Client, ConnectionConfig, Script};
197 /// # #[tokio::main]
198 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
199 /// # let config = ConnectionConfig::new("redis://localhost:6379");
200 /// # let client = Client::connect(config).await?;
201 /// let script = Script::new("return 'Hello, World!'");
202 ///
203 /// // Preload the script
204 /// let sha = script.load(&client).await?;
205 /// println!("Script loaded with SHA: {}", sha);
206 /// # Ok(())
207 /// # }
208 /// ```
209 pub async fn load(&self, client: &crate::Client) -> RedisResult<String> {
210 client.script_load(&self.source).await
211 }
212}
213
214/// Script manager for caching and managing multiple scripts
215#[derive(Debug)]
216pub struct ScriptManager {
217 scripts: Arc<RwLock<HashMap<String, Script>>>,
218}
219
220impl ScriptManager {
221 /// Create a new script manager
222 #[must_use]
223 pub fn new() -> Self {
224 Self {
225 scripts: Arc::new(RwLock::new(HashMap::new())),
226 }
227 }
228
229 /// Register a script with the manager
230 ///
231 /// # Examples
232 ///
233 /// ```no_run
234 /// use redis_oxide::{Script, ScriptManager};
235 ///
236 /// # #[tokio::main]
237 /// # async fn main() {
238 /// let mut manager = ScriptManager::new();
239 /// let script = Script::new("return 'Hello'");
240 ///
241 /// manager.register("greeting", script).await;
242 /// # }
243 /// ```
244 pub async fn register(&self, name: impl Into<String>, script: Script) {
245 let mut scripts = self.scripts.write().await;
246 scripts.insert(name.into(), script);
247 }
248
249 /// Get a script by name
250 ///
251 /// # Examples
252 ///
253 /// ```
254 /// # use redis_oxide::{Script, ScriptManager};
255 /// # #[tokio::main]
256 /// # async fn main() {
257 /// let manager = ScriptManager::new();
258 /// let script = Script::new("return 'Hello'");
259 /// manager.register("greeting", script).await;
260 ///
261 /// if let Some(script) = manager.get("greeting").await {
262 /// println!("Found script with SHA: {}", script.sha());
263 /// }
264 /// # }
265 /// ```
266 pub async fn get(&self, name: &str) -> Option<Script> {
267 let scripts = self.scripts.read().await;
268 scripts.get(name).cloned()
269 }
270
271 /// Execute a script by name
272 ///
273 /// # Examples
274 ///
275 /// ```no_run
276 /// # use redis_oxide::{Client, ConnectionConfig, Script, ScriptManager};
277 /// # #[tokio::main]
278 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
279 /// # let config = ConnectionConfig::new("redis://localhost:6379");
280 /// # let client = Client::connect(config).await?;
281 /// let manager = ScriptManager::new();
282 /// let script = Script::new("return KEYS[1]");
283 /// manager.register("get_key", script).await;
284 ///
285 /// let result: String = manager.execute(
286 /// "get_key",
287 /// &client,
288 /// vec!["mykey".to_string()],
289 /// vec![]
290 /// ).await?;
291 /// # Ok(())
292 /// # }
293 /// ```
294 pub async fn execute<T>(
295 &self,
296 name: &str,
297 client: &crate::Client,
298 keys: Vec<String>,
299 args: Vec<String>,
300 ) -> RedisResult<T>
301 where
302 T: TryFrom<RespValue>,
303 T::Error: Into<RedisError>,
304 {
305 let script = self
306 .get(name)
307 .await
308 .ok_or_else(|| RedisError::Protocol(format!("Script '{}' not found", name)))?;
309
310 script.execute(client, keys, args).await
311 }
312
313 /// Load all registered scripts into Redis cache
314 ///
315 /// # Examples
316 ///
317 /// ```no_run
318 /// # use redis_oxide::{Client, ConnectionConfig, Script, ScriptManager};
319 /// # #[tokio::main]
320 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
321 /// # let config = ConnectionConfig::new("redis://localhost:6379");
322 /// # let client = Client::connect(config).await?;
323 /// let manager = ScriptManager::new();
324 ///
325 /// // Register some scripts
326 /// manager.register("script1", Script::new("return 1")).await;
327 /// manager.register("script2", Script::new("return 2")).await;
328 ///
329 /// // Load all scripts at once
330 /// let results = manager.load_all(&client).await?;
331 /// println!("Loaded {} scripts", results.len());
332 /// # Ok(())
333 /// # }
334 /// ```
335 pub async fn load_all(&self, client: &crate::Client) -> RedisResult<HashMap<String, String>> {
336 let scripts = self.scripts.read().await;
337 let mut results = HashMap::new();
338
339 for (name, script) in scripts.iter() {
340 let sha = script.load(client).await?;
341 results.insert(name.clone(), sha);
342 }
343
344 Ok(results)
345 }
346
347 /// Get the number of registered scripts
348 #[must_use]
349 pub async fn len(&self) -> usize {
350 let scripts = self.scripts.read().await;
351 scripts.len()
352 }
353
354 /// Check if the manager has any scripts
355 #[must_use]
356 pub async fn is_empty(&self) -> bool {
357 let scripts = self.scripts.read().await;
358 scripts.is_empty()
359 }
360
361 /// List all registered script names
362 pub async fn list_scripts(&self) -> Vec<String> {
363 let scripts = self.scripts.read().await;
364 scripts.keys().cloned().collect()
365 }
366
367 /// Remove a script from the manager
368 pub async fn remove(&self, name: &str) -> Option<Script> {
369 let mut scripts = self.scripts.write().await;
370 scripts.remove(name)
371 }
372
373 /// Clear all scripts from the manager
374 pub async fn clear(&self) {
375 let mut scripts = self.scripts.write().await;
376 scripts.clear();
377 }
378}
379
380impl Default for ScriptManager {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386/// Calculate SHA1 hash of a string
387fn calculate_sha1(input: &str) -> String {
388 let mut hasher = Sha1::new();
389 hasher.update(input.as_bytes());
390 let result = hasher.finalize();
391 hex::encode(result)
392}
393
394/// Common Lua script patterns
395pub mod patterns {
396 use super::Script;
397
398 /// Atomic increment with expiration
399 ///
400 /// # Arguments
401 /// - KEYS[1]: The key to increment
402 /// - ARGV\[1\]: Increment amount
403 /// - ARGV\[2\]: Expiration time in seconds
404 pub fn atomic_increment_with_expiration() -> Script {
405 Script::new(
406 r"
407 local key = KEYS[1]
408 local increment = tonumber(ARGV[1])
409 local expiration = tonumber(ARGV[2])
410
411 local current = redis.call('GET', key)
412 local new_value
413
414 if current == false then
415 new_value = increment
416 else
417 new_value = tonumber(current) + increment
418 end
419
420 redis.call('SET', key, new_value)
421 redis.call('EXPIRE', key, expiration)
422
423 return new_value
424 ",
425 )
426 }
427
428 /// Conditional set (SET if value matches)
429 ///
430 /// # Arguments
431 /// - KEYS\[1\]: The key to set
432 /// - ARGV\[1\]: Expected current value
433 /// - ARGV\[2\]: New value to set
434 pub fn conditional_set() -> Script {
435 Script::new(
436 r"
437 local key = KEYS[1]
438 local expected = ARGV[1]
439 local new_value = ARGV[2]
440
441 local current = redis.call('GET', key)
442
443 if current == expected then
444 redis.call('SET', key, new_value)
445 return 1
446 else
447 return 0
448 end
449 ",
450 )
451 }
452
453 /// Rate limiting with sliding window
454 ///
455 /// # Arguments
456 /// - KEYS\[1\]: The rate limit key
457 /// - ARGV\[1\]: Window size in seconds
458 /// - ARGV\[2\]: Maximum requests per window
459 pub fn sliding_window_rate_limit() -> Script {
460 Script::new(
461 r#"
462 local key = KEYS[1]
463 local window = tonumber(ARGV[1])
464 local limit = tonumber(ARGV[2])
465 local now = redis.call('TIME')[1]
466
467 -- Remove old entries
468 redis.call('ZREMRANGEBYSCORE', key, 0, now - window)
469
470 -- Count current entries
471 local current = redis.call('ZCARD', key)
472
473 if current < limit then
474 -- Add current request
475 redis.call('ZADD', key, now, now)
476 redis.call('EXPIRE', key, window)
477 return { 1, limit - current - 1 }
478 else
479 return { 0, 0 }
480 end
481 "#,
482 )
483 }
484
485 /// Distributed lock with expiration
486 ///
487 /// # Arguments
488 /// - KEYS\[1\]: The lock key
489 /// - ARGV\[1\]: Lock identifier (unique per client)
490 /// - ARGV\[2\]: Lock expiration in seconds
491 pub fn distributed_lock() -> Script {
492 Script::new(
493 r#"
494 local key = KEYS[1]
495 local identifier = ARGV[1]
496 local expiration = tonumber(ARGV[2])
497
498 if redis.call('SET', key, identifier, 'NX', 'EX', expiration) then
499 return 1
500 else
501 return 0
502 end
503 "#,
504 )
505 }
506
507 /// Release distributed lock
508 ///
509 /// # Arguments
510 /// - KEYS\[1\]: The lock key
511 /// - ARGV\[1\]: Lock identifier (must match)
512 pub fn release_lock() -> Script {
513 Script::new(
514 r#"
515 local key = KEYS[1]
516 local identifier = ARGV[1]
517
518 if redis.call('GET', key) == identifier then
519 return redis.call('DEL', key)
520 else
521 return 0
522 end
523 "#,
524 )
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_script_creation() {
534 let script = Script::new("return 'hello'");
535 assert_eq!(script.source(), "return 'hello'");
536 assert!(!script.sha().is_empty());
537 assert_eq!(script.sha().len(), 40); // SHA1 is 40 hex characters
538 }
539
540 #[test]
541 fn test_sha1_calculation() {
542 let sha = calculate_sha1("hello world");
543 assert_eq!(sha, "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed");
544 }
545
546 #[test]
547 fn test_script_sha_consistency() {
548 let script1 = Script::new("return 1");
549 let script2 = Script::new("return 1");
550 assert_eq!(script1.sha(), script2.sha());
551 }
552
553 #[test]
554 fn test_script_sha_uniqueness() {
555 let script1 = Script::new("return 1");
556 let script2 = Script::new("return 2");
557 assert_ne!(script1.sha(), script2.sha());
558 }
559
560 #[tokio::test]
561 async fn test_script_manager_creation() {
562 let manager = ScriptManager::new();
563 assert!(manager.is_empty().await);
564 assert_eq!(manager.len().await, 0);
565 }
566
567 #[tokio::test]
568 async fn test_script_manager_register_and_get() {
569 let manager = ScriptManager::new();
570 let script = Script::new("return 'test'");
571 let sha = script.sha().to_string();
572
573 manager.register("test_script", script).await;
574
575 assert!(!manager.is_empty().await);
576 assert_eq!(manager.len().await, 1);
577
578 let retrieved = manager.get("test_script").await.unwrap();
579 assert_eq!(retrieved.sha(), sha);
580 assert_eq!(retrieved.source(), "return 'test'");
581 }
582
583 #[tokio::test]
584 async fn test_script_manager_remove() {
585 let manager = ScriptManager::new();
586 let script = Script::new("return 'test'");
587
588 manager.register("test_script", script).await;
589 assert_eq!(manager.len().await, 1);
590
591 let removed = manager.remove("test_script").await;
592 assert!(removed.is_some());
593 assert_eq!(manager.len().await, 0);
594
595 let not_found = manager.remove("nonexistent").await;
596 assert!(not_found.is_none());
597 }
598
599 #[tokio::test]
600 async fn test_script_manager_clear() {
601 let manager = ScriptManager::new();
602
603 manager.register("script1", Script::new("return 1")).await;
604 manager.register("script2", Script::new("return 2")).await;
605 assert_eq!(manager.len().await, 2);
606
607 manager.clear().await;
608 assert_eq!(manager.len().await, 0);
609 assert!(manager.is_empty().await);
610 }
611
612 #[tokio::test]
613 async fn test_script_manager_list_scripts() {
614 let manager = ScriptManager::new();
615
616 manager
617 .register("script_a", Script::new("return 'a'"))
618 .await;
619 manager
620 .register("script_b", Script::new("return 'b'"))
621 .await;
622
623 let mut scripts = manager.list_scripts().await;
624 scripts.sort();
625
626 assert_eq!(scripts, vec!["script_a", "script_b"]);
627 }
628
629 #[test]
630 fn test_pattern_scripts() {
631 // Test that pattern scripts can be created without panicking
632 let _increment = patterns::atomic_increment_with_expiration();
633 let _conditional = patterns::conditional_set();
634 let _rate_limit = patterns::sliding_window_rate_limit();
635 let _lock = patterns::distributed_lock();
636 let _unlock = patterns::release_lock();
637 }
638}